【题解】HackerRank - Unique Colors

题目

https://www.hackerrank.com/challenges/unique-colors

题意

给定一棵树,树上每一个点都有一种颜色。定义一条路径的价值为路径上不同颜色的点的个数。对于树上的每个结点,输出从该结点出发的所有路径的价值之和。

题解

官方题解是用树上点分治,但是很遗憾题解看不懂。
神奇的是这道题和 2017 年多校第一场的某题极度相似(HDU 6035)。

多校官方题解:

单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。

  • 每种颜色分开考虑,计算每种颜色的贡献。
  • 考虑问题的反面,某种颜色没有贡献给一条路径等价于路径的两个端点在同一个没有该种颜色的联通块中。也就是说,对于每一块没有这种颜色的树上联通块,联通块中所有点的答案减去联通块的大小。
  • 难点在于对于每种颜色,复杂度不能与整棵树的大小相关,而是与该颜色结点个数相关。但是访问结点以及递归必须按顺序进行,于是需要按照 dfs序 进行递归。这就是多校题解中提到的虚树思想,当然这道题不用把树建出来。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args<< " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 100, maxc = 1E5 + 100;
int c[maxn], in[maxn], out[maxn], sz[maxn], pa[maxn];
vector<int> G[maxn], C[maxn];
int n, clk = 0;

struct BIT {
LL c[maxn];
inline int lowbit(int x) { return x & -x; }
void add(int x, int k) {
for (int i = x; i <= n; i += lowbit(i))
c[i] += k;
}
LL sum(int x) {
LL ret = 0;
for (int i = x; i > 0; i -= lowbit(i))
ret += c[i];
return ret;
}
void add(int l, int r, int v) {
add(l, v); add(r + 1, -v);
}
} bit;

void init_dfs(int u, int fa) {
in[u] = clk++;
C[c[u]].push_back(u);
pa[u] = fa;
sz[u] = 1;
for (int v: G[u]) {
if (v == fa) continue;
init_dfs(v, u);
sz[u] += sz[v];
}
out[u] = clk - 1;
}

void go(const vector<int>& V, int& k) {
vector<int> nxt;
int u = V[k];
for (int v: G[u]) {
if (v == pa[u]) continue;
int num = sz[v]; nxt.clear();
while (k + 1 < V.size()) {
int to = V[k + 1];
if (in[to] <= out[v]) {
nxt.push_back(to);
num -= sz[to];
go(V, ++k);
} else break;
}
bit.add(in[v], out[v], num);
for (int to: nxt)
bit.add(in[to], out[to], -num);
}
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
cin >> n;

FOR (i, 1, maxc) C[i].push_back(0);
G[0].push_back(1); pa[1] = 0; c[0] = 0;

FOR (i, 1, n + 1)
scanf("%d", &c[i]);
FOR (_, 1, n) {
static int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
init_dfs(0, -1);

int tmp;
LL color_cnt = 0;
FOR (k, 1, maxc) {
if (C[k].size() == 1) continue;
color_cnt++;
go(C[k], tmp = 0);
}
FOR (i, 1, n + 1)
printf("%lld\n", color_cnt * n - bit.sum(in[i]));
}