【题解】2018-ICPC-南京-M-Mediocre String Problem

题目

https://codeforces.com/gym/101981/problem/M

题意

给出两个字符串 \(s\)\(t\),要求从 \(s\) 切出一段非空子串,\(t\) 中切出一段非空前缀,且 \(s\) 的比 \(t\) 的长,使得拼接起来是个回文串。

题解

  • \(s\) 中切出来的是 \(t\) 中的某个子串的反串再接上一段回文。
  • \(s\) 翻转后可以方便的求出 \(s\) 中每个位置开始的回文串个数(回文树),以及每个位置开始向前能匹配的 \(t\) 的前缀的长度(Z Algorithm),错一位相乘累加即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
#ifdef zerol
#define dbg(x...) do { cout << "\033[32;1m" << #x << " -> "; err(x); } while (0)
void err() { cout << "\033[39;0m" << endl; }
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) { for (auto v: a) cout << v << ' '; err(x...); }
template<typename T, typename... A>
void err(T a, A... x) { cout << a << ' '; err(x...); }
#else
#define dbg(...)
#endif
// -----------------------------------------------------------------------------
const int N = 2E6 + 100;

namespace pam {
int t[N][26], fa[N], len[N], rs[N], cnt[N], num[N];
int sz, n, last;
int _new(int l) {
memset(t[sz], 0, sizeof t[0]);
len[sz] = l; cnt[sz] = num[sz] = 0;
return sz++;
}
void init() {
rs[n = sz = 0] = -1;
last = _new(0);
fa[last] = _new(-1);
}
int get_fa(int x) {
while (rs[n - 1 - len[x]] != rs[n]) x = fa[x];
return x;
}
void ins(int ch) {
rs[++n] = ch;
int p = get_fa(last);
if (!t[p][ch]) {
int np = _new(len[p] + 2);
num[np] = num[fa[np] = t[get_fa(fa[p])][ch]] + 1;
t[p][ch] = np;
}
++cnt[last = t[p][ch]];
}
}

void get_z(int a[], char s[], int n) {
int l = 0, r = 0; a[0] = n;
FOR (i, 1, n) {
a[i] = i > r ? 0 : min(r - i + 1, a[i - l]);
while (i + a[i] < n && s[a[i]] == s[i + a[i]]) ++a[i];
if (i + a[i] - 1 > r) { l = i; r = i + a[i] - 1; }
}
}

char s[N], t[N];
int z[N], c[N];
int main() {
scanf("%s%s", s, t);
int n = strlen(s), m = strlen(t);
reverse(s, s + n);
t[m] = '#'; FOR (i, 0, n) t[m + 1 + i] = s[i];
get_z(z, t, n + m + 1);
pam::init();
FOR (i, 0, n) {
pam::ins(s[i] - 'a');
c[i] = pam::num[pam::last];
}
LL ans = 0;
FOR (i, 0, n - 1)
ans += c[i] * 1LL * z[m + 1 + i + 1];
cout << ans << endl;
}