【题解】HackerRank - Counting On a Tree

题目

https://www.hackerrank.com/challenges/counting-on-a-tree

题意

给一棵树,树上的每一个点都有一个数字,询问对于给定的两条树上路径,有多少对不同的且分别属于两条路径的点满足对应数字相等。

题解

  • 题目要求的是不同的点,先不考虑这个条件,最后减去路径交的长度即可。
  • 树上的任意路径可以分解成两条链(后文称之为规范链),使得每条链中的一个点是另一个点的祖先,这样便于计数。
  • 设路径 x 和 y 分别分解成 x1, x2 和 y1, y2,最后的答案就是把 (x1, y1), (x2, y1), (x1, y2), (x2, y2) 的答案加起来
  • 对于每一种颜色,按是否大于点数量的平方根,分为两种情况
    • 如果颜色数量大于 SQRT
      • 对于每一种这样的颜色,跑一边 dfs,计算出每个点到根的路径上有多少这种颜色的点(前缀和)。
      • 然后对于每一条规范链,可以做差得到链上这种颜色的点有几个。
      • 乘起来就是两条规范链的同色点对数了。
    • 如果颜色数量不大于 SQRT,设这样的颜色属于 集合c。
      • 首先预处理时记录 dfs 序
      • 在 dfs 过程中维护一个树状数组,记录的是树上的每一个点到根的路径上 颜色属于集合x 的点的个数,集合 x 为当前点到根的路径上出现过得所有属于 集合c 的颜色(可重复)。
      • 访问一个结点时,把所有相同颜色的点对应的 dfs序 区间 +1 即可。
      • 注意需要对询问预处理,分解询问分发到每个结点
      • 注意由于 dfs 时树状数组状态要回溯,所以要回滚操作
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args<< " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 100;
const int SQRT = 100;
struct L {
int bt, tp;
};
struct Q {
L l;
int sgn, idx;
};
vector<Q> qq[maxn];
L query[maxn][4];
int n, clk = 1;
int in[maxn], out[maxn], dep[maxn], pa[maxn][20], a[maxn], cnt[maxn];
LL ans[maxn];
vector<int> G[maxn];
vector<int> C[maxn];

struct BIT {
int c[maxn];
void init() { memset(c, 0, sizeof c); }
inline int lowbit(int x) { return x & -x; }
void add(int x, int v) {
for (int i = x; i <= n; i += lowbit(i))
c[i] += v;
}
int sum(int x) {
int ret = 0;
for (int i = x; i > 0; i -= lowbit(i))
ret += c[i];
return ret;
}
void add(int l, int r, int v) {
add(l, v); add(r + 1, -v);
}
} bit;


void dfs(int u, int fa, int d) {
in[u] = clk++;
pa[u][0] = fa;
dep[u] = d;
for (int v: G[u])
if (v != fa)
dfs(v, u, d + 1);
out[u] = clk - 1;
}

void lca_init() {
FOR (x, 1, 20)
FOR (i, 1, n + 1)
pa[i][x] = pa[pa[i][x - 1]][x - 1];
}

inline int pp(int x) {
if (x == 1) return 0;
return pa[x][0];
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
FORD (i, 19, -1) {
int uu = pa[u][i];
if (dep[uu] >= dep[v]) u = uu;
}
if (u == v) return u;
FORD (i, 19, -1) {
int uu = pa[u][i], vv = pa[v][i];
if (uu != vv) { u = uu; v = vv; }
}
return pp(u);
}

int intersection(int x, int y, int xx, int yy) {
int t[4] = {lca(x, xx), lca(x, yy), lca(y, xx), lca(y, yy)};
sort(t, t + 4);
int r = lca(x, y), rr = lca(xx, yy);
if (dep[t[0]] < min(dep[r], dep[rr]) || dep[t[2]] < max(dep[r], dep[rr]))
return 0;
int tt = lca(t[2], t[3]);
int ret = 1 + dep[t[2]] + dep[t[3]] - dep[tt] * 2;
return ret;
}

void dfs_calc(int u) {
int c = a[u];
if (C[c].size() < SQRT)
for (int v: C[c])
bit.add(in[v], out[v], 1);
for (Q& q: qq[u])pingfangfenge
ans[q.idx] += q.sgn * (bit.sum(in[q.l.bt]) - bit.sum(in[q.l.tp]));
for (int v: G[u])
if (v != pp(u))
dfs_calc(v);
if (C[c].size() < SQRT)
for (int v: C[c])
bit.add(in[v], out[v], -1);
}

void dfs_cnt(int u, int c, int s) {
if (a[u] == c) ++s;
cnt[u] = s;
for (int v: G[u])
if (v != pp(u))
dfs_cnt(v, c, s);
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
int q_sz, mp_sz = 0;
map<int, int> mp;
cin >> n >> q_sz;
FOR (i, 1, n + 1) {
scanf("%d", &a[i]);
auto it = mp.find(a[i]);
if (it == mp.end()) a[i] = mp[a[i]] = mp_sz++;
else a[i] = it->second;
C[a[i]].push_back(i);
}
FOR (_, 1, n) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 1, 1);
lca_init();
FOR (i, 0, q_sz) {
int x, y, xx, yy;
scanf("%d%d%d%d", &x, &y, &xx, &yy);
ans[i] -= intersection(x, y, xx, yy);
int z = lca(x, y), zz = lca(xx, yy);
query[i][0] = {x, z};
query[i][1] = {y, pp(z)};
query[i][2] = {xx, zz};
query[i][3] = {yy, pp(zz)};
FOR (p, 0, 2)
FOR (q, 2, 4) {
qq[query[i][p].bt].push_back({query[i][q], 1, i});
qq[query[i][p].tp].push_back({query[i][q], -1, i});
}
}
FOR (i, 0, mp_sz) {
if (C[i].size() < SQRT) continue;
memset(cnt, 0, sizeof cnt);
dfs_cnt(1, i, 0);
FOR (j, 0, q_sz) {
int c1 = 0, c2 = 0;
FOR (k, 0, 2) c1 += cnt[query[j][k].bt] - cnt[query[j][k].tp];
FOR (k, 2, 4) c2 += cnt[query[j][k].bt] - cnt[query[j][k].tp];
ans[j] += c1 * c2;
}
}
dfs_calc(1);
FOR (i, 0, q_sz)
printf("%lld\n", ans[i]);
}

题解(假)

这个题解是官方的 editorial,但是我的实现会超时,但答案是对的。另外,这个算法时在线的,但难写很多,常数也大很多。总结一下,这个算法假得很,分块和不分块差不多。

  • 对于询问的路径,剖分成若干条链。
  • 把树上点值按树链剖分的下标数组进行平方分割。
  • 预处理一个 n × sqrt(n) 的数组,表示第 i 个元素在第 j 块中出现了几次。
  • 对于完整的块和连续的一段可以做到 O(1) 查询
  • 对于零散的,数组大小不超过 SQRT * log(n),进行桶排线性查询。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args<< " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 10;
const int M = 300;
const int B = 334;
int son[maxn], sz[maxn], fa[maxn], idx[maxn], dep[maxn], top[maxn];
int w[maxn], a[maxn], c[maxn][B + 2];
vector<int> G[maxn];
int clk;

void predfs(int u, int d) {
dep[u] = d;
int& maxs = son[u] = -1;
sz[u] = 1;
for (int v: G[u])
if (v != fa[u]) {
fa[v] = u;
predfs(v, d + 1);
sz[u] += sz[v];
if (maxs == -1 || son[maxs] < son[v])
maxs = v;
}
}

void dfs(int u, int tp) {
top[u] = tp;
idx[u] = ++clk;
w[clk - 1] = a[u];
if (son[u] == -1) return;
dfs(son[u], tp);
for (int v: G[u])
if (v != fa[u] && v != son[u])
dfs(v, v);
}

inline int sum(int l, int r, int bl, int br) {
l--; r--; bl--; br--;
int ret = c[r][br];
if (l) ret -= c[l - 1][br];
if (bl) ret -= c[r][bl - 1];
if (l && bl) ret += c[l - 1][bl - 1];
return ret;
}


struct P {
int l, r;
};

typedef vector<P> VI;
void divide(int l, int r, VI& b, VI& t) {
int x = l / M, y = r / M;
if (x == y) t.push_back({l, r});
else {
t.push_back({l, (x + 1) * M - 1});
t.push_back({y * M, r});
if (x + 1 <= y - 1) b.push_back({x + 1, y - 1});
}
}

void query(int u, int v, VI& b, VI& t) {
b.clear(); t.clear();
int uu = top[u], vv = top[v];
while (uu != vv) {
if (dep[uu] < dep[vv]) { swap(uu, vv); swap(u, v); }
divide(idx[uu], idx[u], b, t);
u = fa[uu];
uu = top[u];
};
if (dep[u] < dep[v]) { swap(uu, vv); swap(u, v); }
divide(idx[v], idx[u], b, t);
}

int go(VI& b, VI& t, VI& bb, VI& tt) {
int ans = 0;
for (P& p: b)
for (P& q: bb) {
int l = p.l * M, r = (p.r + 1) * M - 1;
ans += sum(l, r, q.l, q.r);
}
for (P& p: b)
for (P& q: tt)
ans += sum(q.l, q.r, p.l, p.r);
for (P& p: bb)
for (P& q: t)
ans += sum(q.l, q.r, p.l, p.r);
static int cnt[maxn];
memset(cnt, 0, sizeof cnt);
for (P& p: t)
FOR (i, p.l, p.r + 1)
cnt[w[i - 1]]++;
for (P& p: tt)
FOR (i, p.l, p.r + 1)
ans += cnt[w[i - 1]];
return ans;
}

int lca(int u, int v) {
int uu = top[u], vv = top[v];
while (uu != vv) {
if (dep[uu] < dep[vv]) { swap(u, v); swap(uu, vv); }
u = fa[uu];
uu = top[u];
}
if (dep[u] < dep[v]) return u; else return v;
}

int intersection(int x, int y, int xx, int yy) {
int t[4] = {lca(x, xx), lca(x, yy), lca(y, xx), lca(y, yy)};
sort(t, t + 4);
int r = lca(x, y), rr = lca(xx, yy);
if (dep[t[0]] < min(dep[r], dep[rr]) || dep[t[2]] < max(dep[r], dep[rr]))
return 0;
int tt = lca(t[2], t[3]);
int ret = 1 + dep[t[2]] + dep[t[3]] - dep[tt] * 2;
return ret;
}

int cnt[maxn];
int main() {
int n, Q, u, v, ttt = 0;
map<int, int> mp;
cin >> n >> Q;
FOR (i, 1, n + 1) {
scanf("%d", &a[i]);
auto it = mp.find(a[i]);
if (it == mp.end()) a[i] = mp[a[i]] = ++ttt;
else a[i] = it->second;
}
FOR (i, 1, n) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
predfs(1, 1); dfs(1, 1);
FOR (b, 0, B) {
memset(cnt, 0, sizeof cnt);
FOR (i, b * M, (b + 1) * M)
if (i >= n) break;
else cnt[w[i]]++;
FOR (i, 0, n)
c[i][b] = cnt[w[i]];
}
FOR (i, 0, n)
FOR (b, 0, B) {
int& t = c[i][b];
if (i) t += c[i - 1][b];
if (b) t += c[i][b - 1];
if (i && b) t -= c[i - 1][b - 1];
}
while (Q--) {
static int x, y, xx, yy;
scanf("%d%d%d%d", &x, &y, &xx, &yy);
VI b, t, bb, tt;
query(x, y, b, t);
query(xx, yy, bb, tt);
printf("%d\n", go(b, t, bb, tt) - intersection(x, y, xx, yy));
}
}