【题解】HackerRank - Coprime Paths

题目

https://www.hackerrank.com/challenges/coprime-paths/problem

题意

给一棵树,树上的每一个点都有一个至多有三个素因子的数,询问路径上互质点对个数。

题解

  • 树上莫队 + 容斥
  • 用的树上莫队的板子比较少见,其中 bug 变量的含义是如果不考虑就会翻转错一个点,好处就是不用求 LCA 并单独考虑。
  • 增加时加的是 cnt - 1,这样可以不用考虑自己与自己不互素,于是 1 的情况不用特判。
  • 容斥部分代码重用有些麻烦。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args<< " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 25E3 + 100;
const int maxp = 1E7 + 10;
int v[maxn], dep[maxn], fa[maxn], blk[maxn], in[maxn];
vector<int> G[maxn], p[maxn];
int B, ans[maxn], bug, Ans, L;
int cnt[maxp];
bool vis[maxn];

struct Q {
int u, v, idx;
bool operator < (const Q& b) const {
const Q& a = *this;
return blk[a.u] < blk[b.u] || (blk[a.u] == blk[b.u] && in[a.v] < in[b.v]);
}
};
vector<Q> query;

const int p_max = 1E5 + 100;
int prime[p_max], p_sz;
bool p_vis[p_max];
void get_prime() {
FOR (i, 2, p_max) {
if (!p_vis[i]) prime[p_sz++] = i;
FOR (j, 0, p_sz) {
if (prime[j] * i >= p_max) break;
p_vis[prime[j] * i] = 1;
if (i % prime[j] == 0) break;
}
}
}

void dfs(int u = 1, int d = 0) {
static int clk = 0, S[maxn], sz = 0, blk_cnt = 0;
dep[u] = d;
in[u] = clk++;
int btm = sz;
for (int v: G[u]) {
if (v == fa[u]) continue;
fa[v] = u;
dfs(v, d + 1);
if (sz - btm >= B) {
while (sz > btm) blk[S[--sz]] = blk_cnt;
++blk_cnt;
}
}
S[sz++] = u;
if (u == 1) while (sz) blk[S[--sz]] = blk_cnt - 1;
}

inline void up(int t, int sgn, int cnt_upd) {
if (cnt_upd) cnt[t] += cnt_upd;
else Ans += (cnt[t] - 1) * sgn;
}

void add(int k, int sgn, int cnt_upd = 0) {
FOR (i, 0, p[k].size())
up(p[k][i], sgn, cnt_upd);
FOR (i, 0, p[k].size())
FOR (j, i + 1, p[k].size())
up(p[k][i] * p[k][j], -sgn, cnt_upd);
if (p[k].size() == 3) up(p[k][0] * p[k][1] * p[k][2], sgn, cnt_upd);
}

void flip(int k) {
int sgn = vis[k] ? -1 : 1;
if (sgn == -1) add(k, -1);
add(k, 0, sgn);
if (sgn == 1) add(k, 1);
L += sgn;
vis[k] ^= 1;
}

void go(int& k) {
if (!bug) {
if (vis[k] && !vis[fa[k]]) bug = k;
if (!vis[k] && vis[fa[k]]) bug = fa[k];
}
flip(k);
k = fa[k];
}

void mv(int a, int b) {
bug = 0;
if (vis[b]) bug = b;
if (dep[a] < dep[b]) swap(a, b);
while (dep[a] > dep[b]) go(a);
while (a != b) go(a), go(b);
go(a); go(bug);
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
get_prime();
int n, q_sz;
cin >> n >> q_sz; B = int(sqrt(n));
FOR (i, 1, n + 1) {
scanf("%d", &v[i]);
int t = v[i];
FOR (j, 0, p_sz)
if (t % prime[j] == 0) {
p[i].push_back(prime[j]);
while (t % prime[j] == 0) t /= prime[j];
}
if (t != 1) p[i].push_back(t);
}
FOR (_, 1, n) {
static int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v); G[v].push_back(u);
}
dfs();
FOR (i, 1, n + 1) dbg(i, in[i], blk[i], p[i].size(), v[i], fa[i]);
FOR (i, 0, q_sz) {
static int u, v;
scanf("%d%d", &u, &v);
query.push_back({u, v, i});
}
sort(query.begin(), query.end());
int u = 1, v = 1; flip(1);
for (Q& q: query) {
mv(u, q.u); u = q.u;
mv(v, q.v); v = q.v;
dbg(q.idx, q.u, q.v, L, Ans);
ans[q.idx] = L * (L - 1) / 2 - Ans;
}
FOR (i, 0, q_sz) printf("%d\n", ans[i]);
}