【题解】nowcoder-小H和圣诞树

题目

https://www.nowcoder.com/acm/contest/79/F

题意

树上的每一个结点都有一种颜色,每次询问某两种颜色之间两两之间的距离和。

题解

方法1

\(K=\sqrt[3] n\) 设查询的两种颜色 c1, c2 对应的结点个数 a, b,路径长度和为 \(\displaystyle \sum_{u\in c_1,v\in c_2} dep[u]+dep[v]-2*dep[lca(u, v)]\)

  • \(a \times b \leqslant K^2\) ,那么暴力。
  • 否则结点数超过 \(K\) 的颜色数不超过 \(K^2\) 种,每种用 \(O(n)\) 的时间预处理后能够做到 \(O(1)\) 查询相关询问。方法就是计算路径的贡献,由于这道题卡常,所以把递归的 dfs 改成按拓扑序(或欧拉序)遍历才过。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 100;
int w[maxn];
struct E {
int to, d;
};
vector<E> G[maxn];
vector<int> c[maxn], GG[maxn];
struct Q {
int c;
LL* ans;
};
vector<Q> qc[maxn];
LL ans[maxn];
int n, tp[maxn], pa[maxn];

int dep[maxn << 1], fd[maxn], vs[maxn << 1], clk, tf_dep[maxn];
LL dd[maxn], cds[maxn];
int id[maxn]; // first show up pair<int, int> f[maxn][18];
void dfs(int u, int fa, int fake_d = 0, LL d = 0) {
dd[u] = d; cds[w[u]] += d;
vs[clk] = u; id[u] = clk;
tf_dep[u] = fake_d;
dep[clk++] = fake_d;
for (E& e: G[u]) {
int v = e.to;
if (v == fa) continue;
GG[u].push_back(v); pa[v] = u;
fd[v] = e.d;
dfs(v, u, fake_d + 1, d + e.d);
vs[clk] = u; dep[clk++] = fake_d;
}
}

namespace tree {

pair<int, int> f[maxn << 1][20];
inline int highbit(int x) { return 31 - __builtin_clz(x); }
// arr has to be 0 based
void rmq_init(int *arr, int length) {
for (int x = 0; x <= highbit(length); ++x)
for (int i = 0; i <= length - (1 << x); ++i) {
if (!x) f[i][x] = {arr[i], i};
else f[i][x] = min(f[i][x - 1], f[i + (1 << (x - 1))][x - 1]);
}
}
pair<int, int> query(int x, int y) {
int p = highbit(y - x + 1);
return min(f[x][p], f[y - (1 << p) + 1][p]);
}

inline int lca(int u, int v) {
return vs[query(min(id[u], id[v]), max(id[u], id[v])).second];
}

int c[maxn] = {1};
void rsort() {
FOR (i, 1, n + 1) ++c[tf_dep[i]];
FOR (i, 1, n + 1) c[i] += c[i - 1];
FOR (i, 1, n + 1) tp[--c[tf_dep[i]]] = i;
}
}

LL sum[maxn], csum[maxn];
int ccnt[maxn];
void calc(int cc, vector<Q>& qr) {
memset(csum, 0, sizeof csum); memset(ccnt, 0, sizeof ccnt);
FORD (i, n, 0) {
int u = tp[i];
ccnt[u] = w[u] == cc;
for (int& v: GG[u]) ccnt[u] += ccnt[v];
sum[u] = 1LL * ccnt[u] * fd[u];
}
FOR (i, 1, n + 1) {
int u = tp[i];
sum[u] += sum[pa[u]];
csum[w[u]] += sum[u];
}
for (Q& q: qr) *q.ans = -2 * csum[q.c] + cds[q.c] * c[cc].size() + cds[cc] * c[q.c].size();
}

bool d2[maxn];
int main() {
cin >> n;
FOR (i, 1, n + 1) {
scanf("%d", &w[i]);
c[w[i]].push_back(i);
}
FOR (_, 1, n) {
int u, v, d; scanf("%d%d%d", &u, &v, &d);
G[u].push_back({v, d}); G[v].push_back({u, d});
}
dfs(1, 0); tree::rmq_init(dep, clk); tree::rsort();
int Qn; cin >> Qn;
FOR (i, 0, Qn) {
int x, y; scanf("%d%d", &x, &y);
if (c[x].size() > c[y].size()) swap(x, y);
if (c[x].size() * c[y].size() > 3000) {
qc[y].push_back({x, &ans[i]});
} else {
LL s = 0;
for (int& u: c[x])
for (int& v: c[y])
s += dd[u] + dd[v] - 2 * dd[tree::lca(u, v)];
ans[i] = s;
}
d2[i] = x == y;
}
FOR (i, 1, n + 1) if (!qc[i].empty()) calc(i, qc[i]);
FOR (i, 0, Qn) printf("%lld\n", ans[i] / (d2[i] + 1));
}

方法2

假设有算法可以在 \(O(\log n)\) 的时间内完成一次一个结点到一种颜色的距离和的查询(其实就是树上点分治),那么可以利用启发式在 \(O(n^{1.5}\log n)\) 的复杂度内完成所有询问(就是结点数少的颜色中的每一个结点去查另一种颜色,再加上记忆化)。但是很遗憾,不知道为什么就是过不了 (90%, TLE)。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 100, INF = 1E9;
struct E {
int to, d;
};
vector<E> G[maxn];
int w[maxn];
vector<int> c[maxn];

bool vis[maxn];
int sz[maxn];
int get_sz(int u, int fa) {
int& s = sz[u] = 1;
for (E& e: G[u]) {
int v = e.to;
if (vis[v] || v == fa) continue;
s += get_sz(v, u);
}
return s;
}

void get_rt(int u, int fa, int s, int& m, int& rt) {
int t = s - sz[u];
for (E& e: G[u]) {
int v = e.to;
if (vis[v] || v == fa) continue;
get_rt(v, u, s, m, rt);
t = max(t, sz[v]);
}
if (t < m) { m = t; rt = u; }
}

LL dep[maxn], md[maxn];
void get_dep(int u, int fa, LL d) {
dep[u] = d; md[u] = 0;
for (E& e: G[u]) {
int v = e.to;
if (vis[v] || v == fa) continue;
get_dep(v, u, d + e.d);
md[u] = max(md[u], md[v] + 1);
}
}

using D = int;
struct R {
D rt;
int sgn;
LL dep;
};
D pit = 1;
vector<R> tr[maxn];

void go(int u, int fa, D rt, D rt2) {
tr[u].push_back({rt, 1, dep[u]});
tr[u].push_back({rt2, -1, dep[u]});
for (E& e: G[u]) {
int v = e.to;
if (v == fa || vis[v]) continue;
go(v, u, rt, rt2);
}
}

void dfs(int u) {
int tmp = INF; get_rt(u, -1, get_sz(u, -1), tmp, u);
vis[u] = true;
get_dep(u, -1, 0);
D rt = pit++; tr[u].push_back({rt, 1, 0});
for (E& e: G[u]) {
int v = e.to;
if (vis[v]) continue;
go(v, u, rt, pit++);
dfs(v);
}
}

LL ans[maxn], *ans_p = ans;
map<pair<int, int>, LL*> mp;
struct Q {
int c1, c2;
LL* ans;
};
Q query[maxn];
vector<Q*> qc[maxn];

void solve(int c2, vector<Q*>& ask) {
static LL sum[maxn << 1];
static int cnt[maxn << 1];
for (int& u: c[c2])
for (R& x: tr[u]) {
sum[x.rt] += x.dep; ++cnt[x.rt];
}
for (Q*& q: ask) {
LL& s = *q->ans;
for (int& u: c[q->c1])
for (R& x: tr[u]) {
s += x.sgn * (sum[x.rt] + cnt[x.rt] * x.dep);
}
}
for (int& u: c[c2])
for (R& x: tr[u])
sum[x.rt] = cnt[x.rt] = 0;
}

int main() {
int n; cin >> n;
FOR (i, 1, n + 1) {
scanf("%d", &w[i]);
c[w[i]].push_back(i);
}
FOR (_, 1, n) {
int u, v, d; scanf("%d%d%d", &u, &v, &d);
G[u].push_back({v, d}); G[v].push_back({u, d});
}
dfs(1);
int Qn; cin >> Qn;
FOR (i, 0, Qn) {
int c1, c2; scanf("%d%d", &c1, &c2);
if (c[c1].size() > c[c2].size()) swap(c1, c2);
query[i] = {c1, c2, nullptr};
auto it = mp.find({c1, c2});
if (it != mp.end()) query[i].ans = it->second;
else {
mp[{c1, c2}] = query[i].ans = ans_p++;
qc[c2].push_back(&query[i]);
}
}
FOR (i, 1, n + 1) if (!qc[i].empty()) solve(i, qc[i]);
FOR (i, 0, Qn) printf("%lld\n", *query[i].ans / (query[i].c1 == query[i].c2 ? 2 : 1));
}