【题解】HackerRank-How Many Substrings

题目

https://www.hackerrank.com/challenges/how-many-substrings/problem

题意

询问一个串的某个子串中本质不同子串的个数。

题解

  • 写这道题之前要对 SAM 和 LCT 有一定的理解。
  • 往 SAM 里一个个插入字符,如果当前插入到第 r 个字符,回答所有 l, r 的询问。所以需要维护对于所有左端点询问的答案。
  • 考虑新加入第 i 个字符后对答案产生的影响。在后缀树上,会建一个 EndPos={i} 的接受结点,然后可能会在后缀树上的一条边上插入一个点。最后从接受结点到根的这一条路径上的所有点的 EndPos 集合会增加 i。
  • 对于每个结点,维护一个 last,表示 EndPos 集合中的最大元素。那么这个结点能表示的串可以看做 [last-R+1, last-L+1]~last 的共 R-L+1 个串(其实就是看做尽可能靠右的串),其中 L, R 指的是这个结点表示的串的长度范围(也就是 L=len[u], R=len[fa[u]]+1)。last 变为 last' 后需要对 last-d+1~last'-d+1 \((L\le d\le R)\) 共 R-L+1 个区间加 1,这部分可以用数据结构维护。
  • 考虑如何对 last 进行修改,每次修改是在一条链上,修改完后整条链的 last 都一样了。考虑用 LCT 维护,LCT 的每条实链上的 last 都一样,所以修改时可以一起处理(相当于把链上的若干个点的 L 和 R 合并了),而这个修改 last 的过程正是 access。
  • 实现细节:
    • 区间修改差分,单点求和可以用树状数组。
    • 后缀树的根节点(长度为 0 的结点)可以不加入动态树。
    • 在后缀树的边上插入结点时要继承 last,而且是标记下传后继承。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int N = 2E5 + 1000;
namespace bit {
LL c[N], cc[N];
inline int lowbit(int x) { return x & -x; }
void add(int x, LL v) {
for (int i = x; i < N; i += lowbit(i)) {
c[i] += v; cc[i] += x * v;
}
}
void add(int l, int r, LL v) {
assert(0 <= r && l <= r && r < N - 1);
add(l, v); add(r + 1, -v);
}
LL sum(int x) {
LL ret = 0;
for (int i = x; i > 0; i -= lowbit(i))
ret += (x + 1) * c[i] - cc[i];
return ret;
}
LL sum(int l, int r) { return sum(r) - sum(l - 1); }
}

namespace lct_sam {
extern struct P *const null;
const int M = N;
struct P {
P *fa, *ls, *rs;
int last;

bool has_fa() { return fa->ls == this || fa->rs == this; }
bool d() { return fa->ls == this; }
P*& c(bool x) { return x ? ls : rs; }
P* up() { return this; }
void down() {
if (ls != null) ls->last = last;
if (rs != null) rs->last = last;
}
void all_down() { if (has_fa()) fa->all_down(); down(); }
} *const null = new P{0, 0, 0, 0}, pool[M], *pit = pool;
P* G[N];
int t[M][26], len[M] = {-1}, fa[M], sz = 2, last = 1;

void rot(P* o) {
bool dd = o->d();
P *f = o->fa, *t = o->c(!dd);
if (f->has_fa()) f->fa->c(f->d()) = o; o->fa = f->fa;
if (t != null) t->fa = f; f->c(dd) = t;
o->c(!dd) = f->up(); f->fa = o;
}
void splay(P* o) {
o->all_down();
while (o->has_fa()) {
if (o->fa->has_fa())
rot(o->d() ^ o->fa->d() ? o : o->fa);
rot(o);
}
o->up();
}
void access(int last, P* u, P* v = null) {
if (u == null) { v->last = last; return; }
splay(u);
P *t = u;
while (t->ls != null) t = t->ls;
int L = len[fa[t - pool]] + 1, R = len[u - pool];

if (u->last) bit::add(u->last - R + 2, u->last - L + 2, 1);
else bit::add(1, 1, R - L + 1);
bit::add(last - R + 2, last - L + 2, -1);

u->rs = v;
access(last, u->up()->fa, u);
}
void insert(P* u, P* v, P* t) {
if (v != null) { splay(v); v->rs = null; }
splay(u);
u->fa = t; t->fa = v;
}

void ins(int ch, int pp) {
int p = last, np = last = sz++;
len[np] = len[p] + 1;
for (; p && !t[p][ch]; p = fa[p]) t[p][ch] = np;
if (!p) fa[np] = 1;
else {
int q = t[p][ch];
if (len[p] + 1 == len[q]) { fa[np] = q; G[np]->fa = G[q]; }
else {
int nq = sz++; len[nq] = len[p] + 1;
memcpy(t[nq], t[q], sizeof t[0]);
insert(G[q], G[fa[q]], G[nq]);
G[nq]->last = G[q]->last;
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
G[np]->fa = G[nq];
for (; t[p][ch] == q; p = fa[p]) t[p][ch] = nq;
}
}
access(pp + 1, G[np]);
}

void init() {
++pit;
FOR (i, 1, N) {
G[i] = pit++;
G[i]->ls = G[i]->rs = G[i]->fa = null;
}
G[1] = null;
}
}


char s[N];
struct P { int l, idx; };
vector<P> query[N];
LL ans[N];
int main() {
lct_sam::init();
int n, Qn; cin >> n >> Qn;
scanf("%s", s);
FOR (i, 0, Qn) {
int l, r; scanf("%d%d", &l, &r);
query[r].push_back({l, i});
}
FOR (i, 0, n) {
lct_sam::ins(s[i] - 'a', i);
for (P& q: query[i])
ans[q.idx] = bit::sum(q.l + 1);
}
FOR (i, 0, Qn)
printf("%lld\n", ans[i]);
}