【题解】LOJ-2116-BZOJ-4012-「HNOI2015」开店

题目

https://loj.ac/problem/2116

题意

询问一个点到所有点权在 $[l, r]$ 之间的点的距离和。(强制在线)

题解

方法1

  • $\displaystyle \sum_{val_v \in [l, r]} lca(u, v)= \displaystyle \sum_{val_v \in [l, r]} dep_v+dep_u-2dep_{lca(u,v)}$
  • 前两项可以用前缀和相减 $O(1)$ 得到。
  • 如果不考虑 $val_v$ 的限制,计算 $u$ 与所有点的 $lca$ 到根的距离和,这部分可以用树链剖分 + 线段树完成。
  • 每个点为从自己到根的所有边贡献 1,最后查询 $u$ 到根所有边权与贡献数之积。
  • 由于有区间限制,使用函数式线段树,按 $val$ 从小到大计算贡献,保留每个 $val$ 对应的根,最后对两棵历史版本的线段树上的查询答案相减即可。
  • $val$ 需要离散化,空间玄学(理论上要开 $2n\log^2n$,但开不下)。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 15E4 + 100;
int v[maxn];
struct E {
int to, d;
};
vector<E> G[maxn];
int n, Q;
LL ans, A;
int sz[maxn], son[maxn], fa[maxn], idx[maxn], top[maxn], dis[maxn], clk, len[maxn];

void predfs(int u, int d = 0) {
dis[u] = d; sz[u] = 1;
int& maxs = son[u] = -1;
for (E& e: G[u]) {
int v = e.to;
if (v == fa[u]) continue;
fa[v] = u;
predfs(v, d + e.d);
sz[u] += sz[v];
if (maxs == -1 || sz[v] > sz[maxs]) maxs = v;
}
}

void dfs(int u, int tp) {
top[u] = tp;
idx[u] = ++clk;
if (~son[u]) dfs(son[u], tp);
for (E& e: G[u]) {
int v = e.to;
if (v != son[u] && v != fa[u]) dfs(v, v);
if (v != fa[u]) len[idx[v]] = e.d;
}
}

template<typename T>
void mod(int u, T&& fn) {
while (top[u] != 1) {
fn(idx[top[u]], idx[u]);
u = fa[top[u]];
}
if (u != 1) fn(2, idx[u]);
}
LL dsum[maxn], cnt[maxn];
struct P {
int u, v;
};
vector<P> a;

namespace tree {
#define mid ((l + r) >> 1)
#define lson ql, qr, l, mid
#define rson ql, qr, mid + 1, r
struct P {
LL add, sum;
int ls, rs;
} tr[maxn * 45 * 2];
int sz = 1;
int N(LL add, int l, int r, int ls, int rs) {
tr[sz] = {add, tr[ls].sum + tr[rs].sum + add * (len[r] - len[l - 1]), ls, rs};
return sz++;
}
int update(int o, int ql, int qr, int l, int r, LL add) {
if (ql > r || l > qr) return o;
const P& t = tr[o];
if (ql <= l && r <= qr) return N(add + t.add, l, r, t.ls, t.rs);
return N(t.add, l, r, update(t.ls, lson, add), update(t.rs, rson, add));
}
LL query(int o, int ql, int qr, int l, int r, LL add = 0) {
if (ql > r || l > qr) return 0;
const P& t = tr[o];
if (ql <= l && r <= qr) return add * (len[r] - len[l - 1]) + t.sum;
return query(t.ls, lson, add + t.add) + query(t.rs, rson, add + t.add);
}
}

int rt[maxn];
int main() {
cin >> n >> Q >> A;
vector<int> vv;
FOR (i, 0, n) {
scanf("%d", &v[i]);
vv.push_back(v[i]);
}
sort(vv.begin(), vv.end());
vv.erase(unique(vv.begin(), vv.end()), vv.end());
FOR (i, 0, n) a.push_back({i + 1, int(lower_bound(vv.begin(), vv.end(), v[i]) - vv.begin() + 1)});
FOR (_, 1, n) {
int u, v, d; scanf("%d%d%d", &u, &v, &d);
G[u].push_back({v, d});
G[v].push_back({u, d});
}
predfs(1); dfs(1, 1);
FOR (i, 1, n + 1) len[i] += len[i - 1];
for (P& x: a) dsum[x.v] += dis[x.u], cnt[x.v]++;
FOR (i, 1, vv.size() + 1) dsum[i] += dsum[i - 1], cnt[i] += cnt[i - 1];
sort(a.begin(), a.end(), [](const P& a, const P& b)->bool { return a.v < b.v; });
int pt = 0, o = 0;
FOR (val, 1, (int)vv.size() + 1) {
while (pt < n && a[pt].v == val) mod(a[pt++].u, [&](int ql, int qr) {
o = tree::update(o, ql, qr, 1, n, 1);
});
rt[val] = o;
}
while (Q--) {
int u; LL a, b; scanf("%d%lld%lld", &u, &a, &b);
a = (a + ans) % A; b = (b + ans) % A;
LL l = min(a, b), r = max(a, b);
l = lower_bound(vv.begin(), vv.end(), l) - vv.begin() + 1;
r = upper_bound(vv.begin(), vv.end(), r) - vv.begin();
if (l > r) { cout << (ans = 0) << endl; continue; }
ans = dsum[r] - dsum[l - 1] + 1LL * (cnt[r] - cnt[l - 1]) * dis[u];
mod(u, [&](int ql, int qr) {
ans -= (tree::query(rt[r], ql, qr, 1, n) -
tree::query(rt[l - 1], ql, qr, 1, n)) * 2;
});
cout << ans << endl;
}
}

方法2

  • 动态点分治,对每个重心维护一个按照点权排序 vector,从而可以二分查询区间和。
  • 代码比方法1短(似乎也更加好写),时间复杂度一致,空间复杂度少一个 $\log$ 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 15E4 + 100, INF = 1E9;
struct E {
int to, d;
};
vector<E> G[maxn];
int n, Q, w[maxn];
LL A, ans;

bool vis[maxn];
int sz[maxn];
int get_sz(int u, int fa) {
int& s = sz[u] = 1;
for (E& e: G[u]) {
int v = e.to;
if (vis[v] || v == fa) continue;
s += get_sz(v, u);
}
return s;
}

void get_rt(int u, int fa, int s, int& m, int& rt) {
int t = s - sz[u];
for (E& e: G[u]) {
int v = e.to;
if (vis[v] || v == fa) continue;
get_rt(v, u, s, m, rt);
t = max(t, sz[v]);
}
if (t < m) { m = t; rt = u; }
}

int dep[maxn], md[maxn];
void get_dep(int u, int fa, int d) {
dep[u] = d; md[u] = 0;
for (E& e: G[u]) {
int v = e.to;
if (vis[v] || v == fa) continue;
get_dep(v, u, d + e.d);
md[u] = max(md[u], md[v] + 1);
}
}

struct P {
int w;
LL s;
};
using VP = vector<P>;
struct R {
VP *rt, *rt2;
int dep;
};
VP pool[maxn << 1], *pit = pool;
vector<R> tr[maxn];

void go(int u, int fa, VP* rt, VP* rt2) {
tr[u].push_back({rt, rt2, dep[u]});
for (E& e: G[u]) {
int v = e.to;
if (v == fa || vis[v]) continue;
go(v, u, rt, rt2);
}
}

void dfs(int u) {
int tmp = INF; get_rt(u, -1, get_sz(u, -1), tmp, u);
vis[u] = true;
get_dep(u, -1, 0);
VP* rt = pit++; tr[u].push_back({rt, nullptr, 0});
for (E& e: G[u]) {
int v = e.to;
if (vis[v]) continue;
go(v, u, rt, pit++);
dfs(v);
}
}

bool cmp(const P& a, const P& b) { return a.w < b.w; }

LL query(VP& p, int d, int l, int r) {
l = lower_bound(p.begin(), p.end(), P{l, -1}, cmp) - p.begin();
r = upper_bound(p.begin(), p.end(), P{r, -1}, cmp) - p.begin() - 1;
return p[r].s - p[l - 1].s + 1LL * (r - l + 1) * d;
}

int main() {
cin >> n >> Q >> A;
FOR (i, 1, n + 1) scanf("%d", &w[i]);
FOR (_, 1, n) {
int u, v, d; scanf("%d%d%d", &u, &v, &d);
G[u].push_back({v, d}); G[v].push_back({u, d});
}
dfs(1);
FOR (i, 1, n + 1)
for (R& x: tr[i]) {
x.rt->push_back({w[i], x.dep});
if (x.rt2) x.rt2->push_back({w[i], x.dep});
}
FOR (it, pool, pit) {
it->push_back({-INF, 0});
sort(it->begin(), it->end(), cmp);
FOR (i, 1, it->size())
(*it)[i].s += (*it)[i - 1].s;
}
while (Q--) {
int u; LL a, b; scanf("%d%lld%lld", &u, &a, &b);
a = (a + ans) % A; b = (b + ans) % A;
int l = min(a, b), r = max(a, b);
ans = 0;
for (R& x: tr[u]) {
ans += query(*(x.rt), x.dep, l, r);
if (x.rt2) ans -= query(*(x.rt2), x.dep, l, r);
}
printf("%lld\n", ans);
}
}