【题解】HackerRank - Self-Driving Bus

题目

https://www.hackerrank.com/challenges/self-driving-bus/problem

题意

给一棵树,结点编号为 1 ~ n,求有多少个编号区间使得对应结点连通。

题解

官方题解

  • \(R[x]\) 表示最大的 \(r\) 使得 \([x, x\sim r]\) 在去掉 \([1, x - 1]\) 的子图中都连通。
  • \(L[x]\) 表示最小的 \(l\) 使得 \([l\sim x,x]\) 在去掉 \([x + 1, n]\) 的子图中都连通。
  • 那么最后要求的就是有多少 \([l, r]\) 满足 \(R[l] >= r\) 以及 \(L[r] <= l\),这部分只要用树状数组统计一下就好了。
  • \(L[x]\) 以及 \(R[x]\) 用单调栈来求。以 \(R[x]\) 为例,当 \(r \rightarrow r+1\) 时,编号越大的元素越容易失败(因为删除的点的数目更多),失败的表现就是 \(r+1\)\(x\) 的路径上有已经被删除的点(编号小于 \(x\))。所以从 \(1\)\(n\) 枚举 \(r\),每次测试栈顶结点 \(x\)(编号较大),如果失败则记 \(R[x] = r - 1\) 并弹出,最后把 \(r\) 入栈。
  • 还需要做的工作就是支持查询一条路径上的最大、最小结点编号。这部分用 倍增LCA 或者 树链剖分+RMQ 搞一下。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 2E5 + 100, INF = 1E9;
int sz[maxn], son[maxn], dep[maxn], fa[maxn], top[maxn], idx[maxn], v[maxn];
int n, clk, R[maxn], L[maxn];
vector<int> G[maxn], RR[maxn];

namespace rmq {
int f[maxn][20], g[maxn][20];
inline int highbit(int x) { return 31 - __builtin_clz(x); }
void init(int *v, int n) {
FOR (i, 0, n) f[i][0] = g[i][0] = v[i];
FOR (x, 1, highbit(n) + 1)
FOR (i, 0, n - (1 << x) + 1) {
f[i][x] = min(f[i][x - 1], f[i + (1 << (x - 1))][x - 1]);
g[i][x] = max(g[i][x - 1], g[i + (1 << (x - 1))][x - 1]);
}
}
int get_min(int l, int r) {
assert(l <= r);
int t = highbit(r - l + 1);
return min(f[l][t], f[r - (1 << t) + 1][t]);
}
int get_max(int l, int r) {
assert(l <= r);
int t = highbit(r - l + 1);
return max(g[l][t], g[r - (1 << t) + 1][t]);
}
}

void predfs(int u, int d) {
dep[u] = d;
sz[u] = 1;
int& maxs = son[u] = -1;
for (int v: G[u])
if (v != fa[u]) {
fa[v] = u;
predfs(v, d + 1);
sz[u] += sz[v];
if (maxs == -1 || sz[v] > sz[maxs])
maxs = v;
}
}

void dfs(int u, int tp) {
top[u] = tp;
idx[u] = ++clk;
v[idx[u]] = u;
if (son[u] == -1) return;
dfs(son[u], tp);
for (int v: G[u])
if (v != son[u] && v != fa[u])
dfs(v, v);
}

int query(int u, int v, int q(int, int), const int& f(const int&, const int&), int init) {
int uu = top[u], vv = top[v], ret = init;
while (uu != vv) {
if (dep[uu] < dep[vv]) { swap(uu, vv); swap(u, v); }
ret = f(ret, q(idx[uu], idx[u]));
u = fa[uu];
uu = top[u];
}
if (dep[u] < dep[v]) swap(u, v);
ret = f(ret, q(idx[v], idx[u]));
return ret;
}

namespace BIT {
int c[maxn];
inline LL lowbit(LL x) { return x & -x; }
void add(LL x, LL v) {
for (; x < n + 1; x += lowbit(x))
c[x] += v;
}
LL sum(LL x) {
LL ret = 0;
for (; x > 0; x -= lowbit(x))
ret += c[x];
return ret;
}
LL sum(LL l, LL r) { return sum(r) - sum(l - 1); }
}

int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
cin >> n;
FOR (_, 1, n) {
int u, v; scanf("%d%d", &u, &v);
G[u].push_back(v); G[v].push_back(u);
}
predfs(1, 1);
dfs(1, 1);
rmq::init(v, n + 1);

stack<int> S;
FOR (i, 1, n + 1) {
while (!S.empty() && query(S.top(), i, rmq::get_min, min, INF) < S.top()) {
R[S.top()] = i - 1;
S.pop();
}
S.push(i);
}
while (!S.empty()) { R[S.top()] = n; S.pop(); }

FORD (i, n, 0) {
while (!S.empty() && query(S.top(), i, rmq::get_max, max, -INF) > S.top()) {
L[S.top()] = i + 1;
S.pop();
}
S.push(i);
}
while (!S.empty()) { L[S.top()] = 1; S.pop(); }

FOR (i, 1, n + 1) RR[R[i]].push_back(i);
LL ans = 0;
FOR (i, 1, n + 1) {
BIT::add(i, 1);
ans += BIT::sum(L[i], i);
for (int& l: RR[i]) {
BIT::add(l, -1);
}
}
cout << ans << endl;
}

另解

  • \([l, r]\) 连通等价于 \([l, r]\) 内部的边恰好有 \(r - l\) 条。
  • 枚举 \(r\),维护 \(f(l)\)\([l,r]\) 内部边的数目。对于所有 \(r\) 与编号小于 \(r\) 的相连的边 \((t, r)\in E\)\(l\in[1,t]\)\(f(l)\) 增加了 1 条,然后要检查有多少 \(l\) 满足要求。当然枚举是不行的,区间加用线段树,每个 \(f(l)\) 初始化为 \(l\) ,那么检查的就是有多少 \(f(l)=r\) 就行了。
  • 代码中线段树维护的是最大值和最大值的数量。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 2E5 + 100;
vector<int> G[maxn];
int n;

namespace TREE {
struct Q {
int add;
Q(int add = 0): add(add) {}
void operator += (Q& q) { add += q.add; }
};
struct P {
int max, cnt;
P(int max = 0, int cnt = 1): max(max), cnt(cnt) {}
void up(Q& q) { max += q.add; }
};
template<typename T>
P operator & (T&& a, T&& b) { // 用模板的引用坍缩处理 P& 和 P&& 的情况
int c = 0;
if (a.max >= b.max) c += a.cnt;
if (a.max <= b.max) c += b.cnt;
return P(max(a.max, b.max), c);
}
P p[maxn << 2];
Q q[maxn << 2];
#define ls o * 2, l, (l + r) / 2
#define rs o * 2 + 1, (l + r) / 2 + 1, r
void maintain(int o, int l, int r) {
if (l == r) p[o] = P();
else p[o] = p[o * 2] & p[o * 2 + 1];
p[o].up(q[o]);
}
void pushdown(int o, int l, int r) {
q[o * 2] += q[o]; q[o * 2 + 1] += q[o];
q[o] = Q();
maintain(ls); maintain(rs);
}
P query(int ql, int qr, int o = 1, int l = 1, int r = n) {
if (ql > r || l > qr) return P();
if (ql <= l && r <= qr) return p[o];
pushdown(o, l, r);
return query(ql, qr, ls) & query(ql, qr, rs);
}
void update(int ql, int qr, Q v, int o = 1, int l = 1, int r = n) {
if (ql > r || l > qr) return;
if (ql <= l && r <= qr) q[o] += v;
else {
pushdown(o, l, r);
update(ql, qr, v, ls); update(ql, qr, v, rs);
}
maintain(o, l, r);
}
}


int main() {
#ifdef zerol
freopen("in", "r", stdin);
#endif
cin >> n;
FOR (_, 1, n) {
int u, v; scanf("%d%d", &u, &v);
if (u < v) swap(u, v);
G[u].push_back(v);
}
LL ans = 0;
FOR (i, 1, n + 1) {
for (int x: G[i]) TREE::update(1, x, {1});
TREE::update(i, i, i);
ans += TREE::p[1].cnt;
}
cout << ans << endl;
}