【题解】洛谷-3181-BZOJ-4566-[HAOI2016]找相同字符

题目

https://www.luogu.org/problemnew/show/P3181

题意

求两个字符串相同子串的对数(位置不同看做不同)。

题解

  • 代码 1:对一个串建立 SAM,用另一个串匹配。对于每一个匹配位置,要求到根的链上的 \(EndPos\) 集合大小和。
  • 代码 2, 3:建立广义 SAM,累加每个状态匹配总数。其中代码 3 没有冗余状态。
  • 如果建立广义后缀自动机的话,有的写法会产生一些无效状态,即 len[fa[u]]==len[u] ,所以不能简单地通过基数排序来获得拓扑序了,否则会出现错误的转移。

代码1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 2E5 + 100;
const int M = maxn << 1;

int t[M][26], len[M] = {-1}, fa[M], sz = 2, last = 1;
LL cnt[M], sum[M];
void ins(int ch) {
int p = last, np = last = sz++;
len[np] = len[p] + 1; cnt[np] = 1;
for (; p && !t[p][ch]; p = fa[p]) t[p][ch] = np;
if (!p) { fa[np] = 1; return; }
int q = t[p][ch];
if (len[p] + 1 == len[q]) fa[np] = q;
else {
int nq = sz++; len[nq] = len[p] + 1;
memcpy(t[nq], t[q], sizeof t[0]);
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
for (; t[p][ch] == q; p = fa[p]) t[p][ch] = nq;
}
}

int c[maxn] = {1}, a[M];
void rsort() {
FOR (i, 1, sz) c[len[i]]++;
FOR (i, 1, maxn) c[i] += c[i - 1];
FOR (i, 1, sz) a[--c[len[i]]] = i;
}

char s[maxn];
void init() {
FORD (i, sz - 1, 1) {
int u = a[i];
cnt[fa[u]] += cnt[u];
}
FOR (i, 2, sz) {
int u = a[i];
sum[u] = cnt[u] * (len[u] - len[fa[u]]) + sum[fa[u]];
}
}
LL go() {
LL ret = 0;
int u = 1, l = 0;
FOR (i, 0, strlen(s)) {
int ch = s[i] - 'a';
while (u && !t[u][ch]) { u = fa[u]; l = len[u]; }
++l; u = t[u][ch];
if (!u) u = 1;
ret += sum[fa[u]] + cnt[u] * (l - len[fa[u]]);
}
return ret;
}

int main() {
scanf("%s", s);
FOR (i, 0, strlen(s)) ins(s[i] - 'a');
rsort();
init();
scanf("%s", s);
cout << go() << endl;
}

代码2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 4E5 + 100;
const int M = maxn << 1;

int t[M][26], len[M] = {-1}, fa[M], sz = 2, last = 1;
LL cnt[M][2];
void ins(int ch, int id) {
if (t[last][ch] && len[t[last][ch]] == len[last] + 1) { last = t[last][ch]; cnt[last][id] = 1; return; }
int p = last, np = last = sz++;
len[np] = len[p] + 1; cnt[np][id] = 1; dbg(np, id);
for (; p && !t[p][ch]; p = fa[p]) t[p][ch] = np;
if (!p) { fa[np] = 1; return; }
int q = t[p][ch];
if (len[p] + 1 == len[q]) fa[np] = q;
else {
int nq = sz++; len[nq] = len[p] + 1;
memcpy(t[nq], t[q], sizeof t[0]);
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
for (; t[p][ch] == q; p = fa[p]) t[p][ch] = nq;
}
}

int c[maxn] = {1}, a[M];
void rsort() {
FOR (i, 1, sz) c[len[i]]++;
FOR (i, 1, maxn) c[i] += c[i - 1];
FOR (i, 1, sz) a[--c[len[i]]] = i;
}

char s[maxn];
LL go() {
LL ret = 0;
FORD (i, sz - 1, 1) {
int u = a[i];
dbg(u, len[u], cnt[u][0], cnt[u][1], len[u] - len[fa[u]], fa[u]);
ret += 1LL * cnt[u][0] * cnt[u][1] * (len[u] - len[fa[u]]);
FOR (j, 0, 2) cnt[fa[u]][j] += cnt[u][j];
}
return ret;
}

int main() {
FOR (i, 0, 2) {
last = 1;
scanf("%s", s);
FOR (j, 0, strlen(s)) ins(s[j] - 'a', i);
}
rsort();
cout << go() << endl;
}

代码3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(args...) do { cout << "\033[32;1m" << #args << " -> "; err(args); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... Args>
void err(T<t> a, Args... args) { for (auto x: a) cout << x << ' '; err(args...); }
template<typename T, typename... Args>
void err(T a, Args... args) { cout << a << ' '; err(args...); }
// -----------------------------------------------------------------------------
const int maxn = 4E5 + 100;
const int M = maxn << 1;

int t[M][26], len[M] = {-1}, fa[M], sz = 2, last = 1;
LL cnt[M][2];
void ins(int ch, int id) {
int p = last, np = 0, nq = 0, q;
if (!t[p][ch]) {
np = sz++;
len[np] = len[p] + 1;
for (; p && !t[p][ch]; p = fa[p]) t[p][ch] = np;
}
if (!p) fa[np] = 1;
else {
q = t[p][ch];
if (len[p] + 1 == len[q]) fa[np] = q;
else {
nq = sz++; len[nq] = len[p] + 1;
memcpy(t[nq], t[q], sizeof t[0]);
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
for (; t[p][ch] == q; p = fa[p]) t[p][ch] = nq;
}
}
last = np ? np : nq ? nq : q;
cnt[last][id] = 1;
}

int c[maxn] = {1}, a[M];
void rsort() {
FOR (i, 1, sz) c[len[i]]++;
FOR (i, 1, maxn) c[i] += c[i - 1];
FOR (i, 1, sz) a[--c[len[i]]] = i;
}

char s[maxn];
LL go() {
LL ret = 0;
FORD (i, sz - 1, 1) {
int u = a[i];
dbg(u, len[u], cnt[u][0], cnt[u][1], len[u] - len[fa[u]], fa[u]);
ret += 1LL * cnt[u][0] * cnt[u][1] * (len[u] - len[fa[u]]);
FOR (j, 0, 2) cnt[fa[u]][j] += cnt[u][j];
}
return ret;
}

int main() {
FOR (i, 0, 2) {
last = 1;
scanf("%s", s);
FOR (j, 0, strlen(s)) ins(s[j] - 'a', i);
}
rsort();
cout << go() << endl;
}