【题解】nowcoder-203F-Palindrome

题目

https://www.nowcoder.com/acm/contest/203/F

题意

求字符串 \(s\) 中选取两个非空子串连接成回文串的方案数。

题解

  • 这题很类似
  • 假设选取的串是 \(p+s\)\(s'\),其中 \(p\) 是回文串,\(s'\)\(s\) 的反串且 \(s\) 非空,\(p\) 可以为空。答案就是原串的方案数 + 反串的方案数减去一些什么。(一些什么就是选取的串是 \(s\)\(s'\),这种情况下被算了两次)
  • 枚举 \(p\)\(s\) 的交界处 \(i\),用回文自动机求出以 \(i\) 结尾的回文串的个数 \(pcnt\),用广义后缀自动机求出 \(i+1\) 开始的后缀和反串的所有后缀的 LCP 之和 \(scnt\),那么交界 \(i\) 的贡献就是 \(pcnt \times scnt\)。具体来说就是 \(i+1\) 开始的后缀和反串所有后缀的结点的 LCA 的能表示的最长串的和。
  • 考虑求一条链和若干条链的 LCA 之和,这是一个经典问题,把那若干条链到根的路径上都贡献次数 1(其实就是获得每个结点的 \(EndPos\) 的大小),那么只需要求询问的链到根的和。这部分在 \(O(n)\) 预处理后可以 \(O(1)\) 查询。
  • 总复杂度 \(O(n)\)
  • 注意答案会爆 LL。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(x...) do { cout << "\033[32;1m" << #x << " -> "; err(x); } while (0)
#else
#define dbg(...)
#endif
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) { for (auto v: a) cout << v << ' '; err(x...); }
template<typename T, typename... A>
void err(T a, A... x) { cout << a << ' '; err(x...); }
// -----------------------------------------------------------------------------
const int N = 8E5 + 100;
char s[N];
int n;

template<typename T>
void o(T p) {
static int stk[70], tp;
if (p == 0) { putchar('0'); return; }
if (p < 0) { p = -p; putchar('-'); }
while (p) stk[++tp] = p % 10, p /= 10;
while (tp) putchar(stk[tp--] + '0');
}

namespace pam {
int t[N][26], fa[N], len[N], rs[N], cnt[N], num[N];
int sz, n, last;
int _new(int l) {
memset(t[sz], 0, sizeof t[0]);
len[sz] = l; cnt[sz] = num[sz] = 0;
return sz++;
}
void init() {
rs[n = sz = 0] = -1;
last = _new(0);
fa[last] = _new(-1);
}
int get_fa(int x) {
while (rs[n - 1 - len[x]] != rs[n]) x = fa[x];
return x;
}
void ins(int ch) {
rs[++n] = ch;
int p = get_fa(last);
if (!t[p][ch]) {
int np = _new(len[p] + 2);
num[np] = num[fa[np] = t[get_fa(fa[p])][ch]] + 1;
t[p][ch] = np;
}
++cnt[last = t[p][ch]];
}
}

namespace sam {
int t[N][26], len[N] = {-1}, fa[N], sz = 2, last = 1;
LL cnt[N][2];
void init() {
memset(t, 0, (sz + 10) * sizeof t[0]);
memset(cnt, 0, (sz + 10) * sizeof cnt[0]);
sz = 2;
last = 1;
}
void reset() { last = 1; }
void ins(int ch, int id) {
int p = last, np = 0, nq = 0, q = -1;
if (!t[p][ch]) {
np = sz++;
len[np] = len[p] + 1;
for (; p && !t[p][ch]; p = fa[p]) t[p][ch] = np;
}
if (!p) fa[np] = 1;
else {
q = t[p][ch];
if (len[p] + 1 == len[q]) fa[np] = q;
else {
nq = sz++; len[nq] = len[p] + 1;
memcpy(t[nq], t[q], sizeof t[0]);
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
for (; t[p][ch] == q; p = fa[p]) t[p][ch] = nq;
}
}
last = np ? np : nq ? nq : q;
cnt[last][id] = 1;
}
int c[N] = {1}, a[N];
LL sum[N];
void rsort() {
FOR (i, 1, sz) c[i] = 0;
FOR (i, 1, sz) c[len[i]]++;
FOR (i, 1, sz) c[i] += c[i - 1];
FOR (i, 1, sz) a[--c[len[i]]] = i;
FORD (i, sz - 1, 0) {
int u = a[i];
cnt[fa[u]][0] += cnt[u][0];
}
FOR (i, 2, sz) {
int u = a[i];
sum[u] = sum[fa[u]] + cnt[u][0] * 1LL * (len[u] - len[fa[u]]);
}
}
}

int pos[N];
__int128 Z;
__int128 go() {
sam::init(); pam::init();
FOR (i, 0, n) sam::ins(s[i] - 'a', 0);
sam::reset();
FORD (i, n - 1, -1) {
sam::ins(s[i] - 'a', 1);
pos[i] = sam::last;
}
sam::rsort();
__int128 ans = 0;
FOR (i, -1, n) {
LL pcnt = 1;
if (i != -1) {
pam::ins(s[i] - 'a');
pcnt += pam::num[pam::last];
}
LL scnt = 0;
if (i != n - 1) scnt += sam::sum[pos[i + 1]];
ans += (__int128)pcnt * scnt;
}
Z = 0; FOR (i, 0, n) Z += sam::sum[pos[i]];
return ans;
}

int main() {
#ifdef zerol
freopen("inin", "r", stdin);
#endif
scanf("%d%s", &n, s);
__int128 ans = go();
reverse(s, s + n);
o(ans + go() - Z); puts("");
}