「题解」CF1827C Palindrome Partition

AtomAlpaca

Table of contents

题意

简要题意:称一个字符串是好的,当且仅当这个是偶回文串,或由多个偶回文串拼接得到。给定字符串 ss,求 ss 有多少好的子串。

|s|5×105|s| \le 5\times10^5

题解

fif_i 为结尾位置为 ii 的好串数量,gig_i 为结尾位置为 ii 的极短偶回文串的长度,如果不存在 gig_i00

考虑到当且仅当向一个好串后方接上一个偶回文串,这个串依然是好串;且极短偶回文串本身也是一个好串,我们可以写出转移方程:

fi={0,ifgi=0figi+1,otherwise\begin{aligned} f_i = \left\{ \begin{aligned} & 0, &if g_i=0 \\ &f_{i-g_i} + 1, &otherwise \end{aligned} \right. \end{aligned}

因为每次向后面加上的都是极短的偶回文串,这种 dp 方式一定是不重不漏的。答案即为 i=1|s|fi\sum_{i=1}^{|s|}{f_i}

现在问题转化为令快速求 gg。考虑每次求 gg 其实就是找字符串的最短的一段偶后缀,这和回文自动机的构造过程很相似。我们考虑对自动机上每个节点维护一个变量 hh,表示这个节点代表的回文串的短偶回文串长度,那么每次插入完一个字符,这个位置的 gg 就等于 hlasth_{last}

考虑如何维护 hh。只有两种情况对 hh 有贡献,一是这个节点有一段真后缀是偶回文串,此时 hu=hfailuh_u=h_{fail_u};二是自己本身是一个偶回文串,此时 hu=lenuh_u=len_u

整理一下:

h_u={hfailu,hfailu0lenu,ffailu=0andlenumod2=00,otherwise\begin{aligned} h\_u = \left\{ \begin{aligned} & h_{fail_u}, &h_{fail_u} \ne 0 \\ & len_u, & f_{fail_u} = 0 \ and \ len_u \operatorname {mod} 2=0 \\ &0, &otherwise \end{aligned} \right. \end{aligned}

在插入字符时维护即可。由于回文自动机复杂度是线性的,整体复杂度为 O(|s|)O(|s|)

代码

个人的回文自动机是按照 SAM 的写法改的,和主流写法不太一样,可能看起来有点奇怪。

#include <bits/stdc++.h>

typedef long long ll;

const int MAX = 5e5 + 5;
ll T, n, lst, tot, ans;
ll h[MAX], g[MAX], f[MAX];
char s[MAX];

struct E { ll l, f, c[27]; } t[MAX];

void init()
{
  lst = tot = 2;
  t[1].f = t[2].f = 1;
  t[1].l = -1; t[2].l = 0;
}

void clear()
{
  ans = 0;
  for (int i = 1; i <= tot; ++i)
  {
    t[i].f = t[i].l = 0; h[i] = 0;
    for (int j = 0; j <= 26; ++j) { t[i].c[j] = 0; }
  }
  for (int i = 1; i <= n; ++i) { f[i] = g[i] = 0; }
  init();
}

void add(int k, int c)
{
  int p = lst;
  while (s[k - t[p].l - 1] - 'a' != c) { p = t[p].f; }
  if (t[p].c[c]) { lst = t[p].c[c]; g[k] = h[lst]; return ; }
  int np = ++tot, q = t[p].f;
  while (s[k - t[q].l - 1] - 'a' != c) { q = t[q].f; }
  if (t[q].c[c]) { t[np].f = t[q].c[c]; } else { t[np].f = 2; }
  t[p].c[c] = np; t[np].l = t[p].l + 2;
  h[np] = h[t[np].f]; if (!h[np] and !(t[np].l & 1)) { h[np] = t[np].l; }
  lst = t[p].c[c]; g[k] = h[lst];
}

void solve()
{
  clear();
  scanf("%lld%s", &n, s + 1);
  for (int i = 1; i <= n; ++i) { add(i, s[i] - 'a'); }
  for (int i = 1; i <= n; ++i) { if (g[i]) { f[i] = 1 + f[i - g[i]]; } ans += f[i]; }
  printf("%lld\n", ans);
}

int main()
{
  scanf("%lld", &T); while (T--) { solve(); }
}