1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include <bits/stdc++.h> using namespace std; using LL = long long; #define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i) #define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i) #ifdef zerol #define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0) #else #define dbg(...) #endif void err() { cout << "\033[39;0m" << endl; } template<template<typename...> class T, typename t, typename... Args> void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); } template<typename T, typename... Args> void err(T a, Args... args) { cout << a << ' '; err(args...); } // ----------------------------------------------------------------------------- const int N = 1E5 + 100, B = 330; vector<int> G[N]; int c[N], clk, in[N], out[N], rin[N]; struct P { int l, r, idx; }; P query[N]; int ans[N], cnt[N];
void dfs(int u, int fa) { in[u] = ++clk; rin[clk] = u; for (int& v: G[u]) if (v != fa) dfs(v, u); out[u] = clk; }
int Ans; void mv(int p, int d) { dbg(p, d); int cc = c[rin[p]]; if (d == 1) Ans += cnt[cc]++ == 0; if (d == -1) Ans -= --cnt[cc] == 0; }
int main() { int n, Qn, rt; cin >> n >> Qn >> rt; FOR (_, 1, n) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } vector<int> tt; FOR (i, 1, n + 1) { scanf("%d", &c[i]); tt.push_back(c[i]); } sort(tt.begin(), tt.end()); tt.erase(unique(tt.begin(), tt.end()), tt.end()); FOR (i, 1, n + 1) c[i] = lower_bound(tt.begin(), tt.end(), c[i]) - tt.begin(); dfs(rt, -1); FOR (i, 0, Qn) { int u; scanf("%d", &u); query[i] = {in[u], out[u] + 1, i}; } sort(query, query + Qn, [](const P& a, const P& b){ int ta = a.l / B, tb = b.l / B; if (ta != tb) return ta < tb; return a.r < b.r; }); int l = 1, r = 1; FOR (i, 0, Qn) { const P& q = query[i]; dbg(q.l, q.r); while (l > q.l) mv(--l, 1); while (r < q.r) mv(r++, 1); while (l < q.l) mv(l++, -1); while (r > q.r) mv(--r, -1); ans[q.idx] = Ans; } FOR (i, 0, Qn) printf("%d\n", ans[i]); }
|