【题解】HackerRank-Similar Strings

题目

https://www.hackerrank.com/challenges/similar-strings/problem

题意

给一个字符串,每次询问这个字符串里有多少子串与给定子串同构。

题解

  • 前情提要
  • 建出同构意义下后缀数组,然后二分得到同构字符串个数。
  • 如果对后缀数组的高度数组建 ST 表的话,复杂度可以压到只有一个 $\log$。但这题数据范围较小,可以两个 $\log$,于是也可以用 Hash 做。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int N = 5E4 * 10 * 2, M = N;

int t[M][2], len[M] = {0}, fa[M], sz = 2, last = 1;
char* one[M];
void ins(int ch, char* pp) {
int p = last, np = 0, nq = 0, q = -1;
if (!t[p][ch]) {
np = sz++; one[np] = pp;
len[np] = len[p] + 1;
for (; p && !t[p][ch]; p = fa[p]) t[p][ch] = np;
}
if (!p) fa[np] = 1;
else {
q = t[p][ch];
if (len[p] + 1 == len[q]) fa[np] = q;
else {
nq = sz++; len[nq] = len[p] + 1; one[nq] = one[q];
memcpy(t[nq], t[q], sizeof t[0]);
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
for (; t[p][ch] == q; p = fa[p]) t[p][ch] = nq;
}
}
last = np ? np : nq ? nq : q;
}
int up[M], c[256] = {2}, aa[M];
vector<int> G[M];
void rsort() {
FOR (i, 1, 256) c[i] = 0;
FOR (i, 2, sz) up[i] = *(one[i] + len[fa[i]]);
FOR (i, 2, sz) c[up[i]]++;
FOR (i, 1, 256) c[i] += c[i - 1];
FOR (i, 2, sz) aa[--c[up[i]]] = i;
FOR (i, 2, sz) G[fa[aa[i]]].push_back(aa[i]);
}

int idx[N], clk, dfn_p, dfn[N * 2][22], rdfn[N], dep[N];
void dfs(int u) {
idx[u] = ++clk;
rdfn[u] = dfn_p; dfn[dfn_p++][0] = u;
for (int& v: G[u]) {
dep[v] = dep[u] + 1;
dfs(v);
dfn[dfn_p++][0] = u;
}
}
inline int highbit(int x) { return 31 - __builtin_clz(x); }
inline int dmin(int a, int b) { return dep[a] < dep[b] ? a : b; }
void rmq_init() {
FOR (x, 1, highbit(dfn_p) + 1)
FOR (i, 0, dfn_p - (1 << x) + 1)
dfn[i][x] = dmin(dfn[i][x - 1], dfn[i + (1 << (x - 1))][x - 1]);
}
int lca(int u, int v) {
u = rdfn[u]; v = rdfn[v]; if (u > v) swap(u, v);
int t = highbit(v - u + 1);
return dmin(dfn[u][t], dfn[v - (1 << t) + 1][t]);
}

char s[N];
int mp[N][10];
char a[10][N];
struct P { int u, c; };
P p[N][10];

int n;
int lcp(int a, int b) {
int ret = N;
FOR (c, 0, 10) ret = min(ret, len[lca(p[a][c].u, p[b][c].u)]);
return min(ret, min(n - a, n - b));
}

bool cmp(int x, int y) {
int l = lcp(x, y);
if (x + l >= n || y + l >= n) return x > y;
return mp[x][s[x + l] - 'a'] < mp[y][s[y + l] - 'a'];
}
int sa[N], rk[N];

int main() {
int Qn; cin >> n >> Qn;
scanf("%s", s);
FOR (c, 0, 10) {
last = 1;
FORD (i, n - 1, -1) {
a[c][i] = s[i] - 'a' == c ? 0 : 1;
ins(a[c][i], &a[c][i]);
p[i][c] = {last, c};
}
}
rsort();
dfs(1);
rmq_init();
FOR (i, 0, n) {
sort(p[i], p[i] + 10, [](const P& a, const P& b){ return idx[a.u] < idx[b.u]; });
FOR (c, 0, 10) mp[i][p[i][c].c] = c;
}
FOR (i, 0, n) sa[i] = i;
sort(sa, sa + n, cmp);
FOR (i, 0, n) rk[sa[i]] = i;
while (Qn--) {
int L, R; scanf("%d%d", &L, &R); --L; --R;
int len = R - L + 1;
int l = 0, r = rk[L], ll = -1, rr = -1;
while (l <= r) {
int m = (l + r) / 2;
if (lcp(L, sa[m]) >= len) { ll = m; r = m - 1; }
else l = m + 1;
}
l = rk[L]; r = n - 1;
while (l <= r) {
int m = (l + r) / 2;
if (lcp(L, sa[m]) >= len) { rr = m; l = m + 1; }
else r = m - 1;
}
printf("%d\n", rr - ll + 1);
}
}