【题解】BZOJ-3572-Luogu-3233-LOJ-2206-[Hnoi2014]世界树

题目

https://www.luogu.org/problemnew/show/P3233

https://loj.ac/problem/2206

题意

每次询问一些点,求对于询问中的每一个点,求树上相对于询问中的其他点到这个点距离最近的点的数量。

题解

  • 虚树肯定是要的。
  • 难点在于树形 dp。
  • 首先求出虚树上的每一个点最近的关键点(询问中的点)。这个点可能是儿子,可能是父亲,也可能是兄弟。所以跑两遍 dfs,先向上更新,再向下更新,这样三种情况就都能考虑到了。
  • 接下来要计数。考虑每一个点,它的子树中有多少点属于离他最近的关键点。肯定是它的子树再减去若干个子树,但是减去的子树不一定以虚树中的结点为根,所以要在原树上倍增求出分界点。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int N = 3E5 + 100;
vector<int> G[N];
bool key[N];

int sz[N], fa[N], dep[N], in[N], out[N], clk, pa[N][20];
void predfs(int u, int d, int f) {
in[u] = ++clk;
dep[u] = d;
sz[u] = 1;
fa[u] = pa[u][0] = f;
FOR (i, 1, 20) pa[u][i] = pa[pa[u][i - 1]][i - 1];
for (int& v: G[u]) {
if (v == f) continue;
predfs(v, d + 1, u);
sz[u] += sz[v];
}
out[u] = clk;
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
int d = dep[u] - dep[v];
FOR (i, 0, 20) if (d & (1 << i)) u = pa[u][i];
FORD (i, 19, -1)
if (pa[u][i] != pa[v][i]) { u = pa[u][i]; v = pa[v][i]; }
return u == v ? u : pa[u][0];
}

static vector<int> V;

int tmp;
int f[N];
inline int dis(int a, int b) { return dep[a] + dep[b] - 2 * dep[lca(a, b)]; }
inline bool cmp(int o, int a, int b) {
int da = dis(a, o), db = dis(b, o);
return da < db || (da == db && a < b);
}
void go1(int& k = tmp = 0) {
int u = V[k];
f[u] = key[u] ? u : -1;
while (k + 1 < V.size()) {
int v = V[k + 1];
if (in[v] > out[u]) break;
go1(++k);
if (f[u] == -1 || cmp(u, f[v], f[u])) f[u] = f[v];
}
}
void go2(int& k = tmp = 0) {
int u = V[k];
while (k + 1 < V.size()) {
int v = V[k + 1];
if (in[v] > out[u]) break;
if (cmp(v, f[u], f[v])) f[v] = f[u];
go2(++k);
}
}
int ans[N];
void go(int& k = tmp = 0) {
int u = V[k];
ans[f[u]] += sz[u];
while (k + 1 < V.size()) {
int v = V[k + 1];
if (in[v] > out[u]) break;
go(++k);
int t = v;
FORD (i, 19, -1) {
int tt = pa[t][i];
if (dep[tt] <= dep[u]) continue;
if (cmp(tt, f[v], f[u])) t = tt;
}
ans[f[v]] += sz[t] - sz[v];
ans[f[u]] -= sz[t];
}
}

void solve(vector<int>& X) {
static auto cmp = [](int a, int b) { return in[a] < in[b]; };
V.clear();
for (int& x: X) V.push_back(x);
sort(V.begin(), V.end(), cmp);
FOR (i, 1, V.size()) V.push_back(lca(V[i], V[i - 1]));
V.push_back(1);
sort(V.begin(), V.end(), cmp);
V.erase(unique(V.begin(), V.end()), V.end());
go1(); go2(); go();
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
int n; cin >> n;
FOR (_, 1, n) {
int u, v; scanf("%d%d", &u, &v);
G[u].push_back(v); G[v].push_back(u);
}
predfs(1, 1, 1);
int Qn; cin >> Qn;
while (Qn--) {
int t; scanf("%d", &t);
static vector<int> X; X.clear();
FOR (_, 0, t) { int x; scanf("%d", &x); X.push_back(x); key[x] = true; }
solve(X);
FOR (i, 0, X.size()) {
printf("%d%c", ans[X[i]], i == _i - 1 ? '\n' : ' ');
key[X[i]] = false;
ans[X[i]] = 0;
}
}
}