【题解】BZOJ-3522/4543 Hotel 加强版

题目

http://www.lydsy.com/JudgeOnline/problem.php?id=3522

http://www.lydsy.com/JudgeOnline/problem.php?id=4543

题意

求一棵树中有多少无序三元组使得两两之间距离相等。

题解

假算法

  • 这样的三元组肯定对应一个中心点,这个中心点到这三个点距离相同。
  • 枚举中心点,处理出每棵子树不同深度的结点个数,然后对于每个深度,借助 dp 求不同子树中各挑出一个的方案数。
  • 当然这样会超时,或者说只能过 3522。
  • 优化:由于需要从三棵子树中挑选深度一样的,所以统计结点个数时只需统计深度不超过所有子树中深度第三大的。(深度再大的话就凑不到 3 个了。)
  • 实现的时候并不能立刻知道以每个点为根时它的子树中深度第三大的是多少。所以随便挑一个点 1 为根,处理出所有点的深度,以 r 为根的第三大子树深度改为在 1 为根的树中 r 的第二大子树深度。
  • 优化以后就能通过 4543 了。这么做虽然复杂度还是 $O(n^2)$,但是恐怕只有刻意构造的数据才能把它卡掉了。

某个强力数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n = 20000
# |
# 0 - 1 - 2 - 3 - 4
e = list()
for i in range(n):
t = i * 5
e += [
(t, t + 1),
(t + 1, t + 2),
(t + 2, t + 3),
(t + 3, t + 4),
(t + 2, n * 5),
]

print(n * 5 + 1)
for u, v in e:
print(u + 1, v + 1)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args<< " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 100;
int T, n;
vector<int> G[maxn];
LL f1[maxn], f2[maxn], cnt[maxn];
int h[maxn];

void get_h(int u, int pr, int d) {
h[u] = 0;
for (int v: G[u]) {
if (v == pr) continue;
get_h(v, u, d + 1);
h[u] = max(h[u], h[v]);
}
h[u]++;
}

void dfs(int u, int pr, int max_dep, int d = 0) {
if (d > max_dep) return;
++cnt[d];
for (int v: G[u]) {
if (v == pr) continue;
dfs(v, u, max_dep, d + 1);
}
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
cin >> T;
while (T--) {
cin >> n;
FOR (i, 1, n + 1) G[i].clear();
FOR (_, 1, n) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v); G[v].push_back(u);
}
get_h(1, -1, 0);
LL ans = 0;
FOR (u, 1, n + 1) {
int M1 = 0, M2 = 0;
for (int v: G[u])
if (h[v] < h[u]) {
if (h[v] > M2) M2 = h[v];
if (M2 > M1) swap(M1, M2);
}
if (M2 == 0) continue;
int max_dep = M2 - 1;

fill(f1, f1 + max_dep + 1, 0);
fill(f2, f2 + max_dep + 1, 0);

for (int v: G[u]) {
fill(cnt, cnt + max_dep + 1, 0);
dfs(v, u, max_dep);
FOR (i, 0, max_dep + 1) {
ans += f2[i] * cnt[i];
f2[i] += f1[i] * cnt[i];
f1[i] += cnt[i];
}
}
}
cout << ans << endl;
}
}

真算法

  • 枚举符合条件的三个点的 LCA,其中两个点及其 LCA 在一棵子树中,另外一个点在另一棵子树上或者就是这三个点的 LCA 本身。
  • $f[u][d]$ 表示结点 $u$ 的子树中深度为 $d$ 的结点个数。
  • $g[u][d]$ 表示结点 $u$ 的子树中能和 $f[u][d]$ 配对的两个点的点对个数。也就是在 $u$ 的某个子树中,结点 $a$ 和 $b$ 与其 LCA 结点 $c$ 的距离为 $x$ ,那么 $g[u][d]$ 中的 $d$ 就是 $x-dist(u,c)$。
  • 状态转移看代码,和之前的假算法类似,边更新边累计答案防止重复计算。
  • 这样时间和空间都会有问题。观察转移的过程,相当于将子树的 dp 结果不断合并。因而采用启发式合并,直接采用子树中深度最大的 dp 结果,然后把其他子树的结果合并上去。
  • 每个点复杂度为 $\sum h(v)-h(u)+C$,总体复杂度为 $O(n)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 1E5 + 100;
vector<int> G[maxn];
int n, fa[maxn], son[maxn], sd[maxn];
LL ans, T;

void predfs(int u) {
sd[u] = 0;
int& maxs = son[u] = -1;
for (int v: G[u]) {
if (v == fa[u]) continue;
fa[v] = u;
predfs(v);
sd[u] = max(sd[u], sd[v]);
if (maxs == -1 || sd[v] > sd[maxs])
maxs = v;
}
sd[u] += 1;
}

void mod(deque<LL>& f, deque<LL>& g) {
f.push_front(0); if (!g.empty()) g.pop_front();
}

void dfs(int u, deque<LL>& f, deque<LL>& g) {
if (son[u] == -1) { f.push_back(1); return; }
dfs(son[u], f, g);
mod(f, g); f.resize(sd[u]); g.resize(sd[u]);
f[0]++;
ans += g[0];
deque<LL> ff, gg;
for (int v: G[u])
if (v != son[u] && v != fa[u]) {
ff.clear(); gg.clear();
dfs(v, ff, gg);
mod(ff, gg);
FOR (i, 0, ff.size()) ans += ff[i] * g[i];
FOR (i, 0, gg.size()) ans += gg[i] * f[i];
FOR (i, 0, gg.size()) g[i] += gg[i];
FOR (i, 0, ff.size()) g[i] += f[i] * ff[i];
FOR (i, 0, ff.size()) f[i] += ff[i];
}
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
cin >> T;
while (T--) {
cin >> n;
FOR (i, 1, n + 1) G[i].clear();
FOR (_, 1, n) {
int u, v; scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
predfs(1);
ans = 0;
deque<LL> t, tt;
dfs(1, t, tt);
cout << ans << endl;
}
}