正题
题目链接:https://www.luogu.com.cn/problem/P1117
题目大意
长度为nnn的字符串,求所有子串有多少种分割成AABBAABBAABB的方式。
解题思路
aia_iai表示以iii结尾的子串中有多少种分割成AAAAAA的方式
bib_ibi表示以iii开头的子串中有多少种分割成AAAAAA的方式
然后答案就是∑i=1n−1aibi+1\sum_{i=1}^{n-1}a_ib_{i+1}i=1∑n−1aibi+1
考虑用SASASA来计算a,ba,ba,b。
枚举长度lenlenlen,考虑所有长度为2∗len2*len2∗len的AAAAAA串,我们没隔lenlenlen格放置一个点,那么每个串必定经过了两个点,现在考虑求出相邻两个点之间的贡献
对于相邻两个点l,rl,rl,r,求出它们的LCPLCPLCP和LCSLCSLCS,分情况讨论
- LCP+LCS<len:LCP+LCS<len:LCP+LCS<len:那么我们可以发现没有任何一个串2∗len2*len2∗len的AAAAAA串同时经过这两个点,因为在lll的右边和rrr的左边,这两个串必定有一个地方不同。
- LCP+LCS≥len:LCP+LCS\geq len:LCP+LCS≥len:那么此时有串经过这两个点,且
s≥l−LCP+1,t≤r+LCPs\geq l-LCP+1,t\leq r+LCPs≥l−LCP+1,t≤r+LCP的串都满足条件
时间复杂度:O(nlogn):O(n\log n):O(nlogn)
codecodecode
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=2e5+10;
int T,n;
long long ans,a[N],b[N];
struct SA{char s[N];int m,sa[N],rk[N],height[N],lg[N];int c[N],x[N],y[N],st[N][25];void Qsort(){for(int i=1;i<=m;i++) c[i]=0;for(int i=1;i<=n;i++) c[x[i]]++;for(int i=1;i<=m;i++) c[i]+=c[i-1];for(int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;return;}void Get_SA(){m=256;for(int i=1;i<=n;i++)x[i]=s[i],y[i]=i;Qsort();for(int w=1;w<=n;w<<=1){int p=0;for(int i=n-w+1;i<=n;i++) y[++p]=i;for(int i=1;i<=n;i++)if(sa[i]>w) y[++p]=sa[i]-w;Qsort();swap(x,y);x[sa[1]]=p=1;for(int i=2;i<=n;i++)x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+w]==y[sa[i-1]+w])?p:++p;if(p==n) break;m=p;}return;}void Get_Height(){int k=0;for(int i=1;i<=n;i++)rk[sa[i]]=i;for(int i=1;i<=n;i++){if(rk[i]==1) continue;if(k)k--;int j=sa[rk[i]-1];while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k]) k++;height[rk[i]]=k;}return;}void Get_ST(){lg[0]=-1;for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1,st[i][0]=height[i];for(int j=1;(1<<j)<=n;j++)for(int i=1;i+(1<<j)-1<=n;i++)st[i][j]=min(st[i+(1<<(j-1))][j-1],st[i][j-1]);return;}void Build(){memset(rk,0,sizeof(rk));memset(st,0,sizeof(st));memset(height,0,sizeof(height));memset(sa,0,sizeof(sa));memset(x,0,sizeof(x));memset(y,0,sizeof(y));Get_SA();Get_Height();Get_ST();return;}int LCP(int l,int r){l=rk[l];r=rk[r];if(l>r) swap(l,r);l++;int z=lg[r-l+1];return min(st[l][z],st[r+1-(1<<z)][z]);}
}s1,s2;
int main()
{scanf("%d",&T);while(T--){memset(a,0,sizeof(a));memset(b,0,sizeof(b));scanf("%s",s1.s+1);n=strlen(s1.s+1);for(int i=1;i<=n;i++)s2.s[n-i+1]=s1.s[i];s1.Build();s2.Build();for(int len=1;len<=n/2;len++){for(int i=len;i<=n;i+=len){int l=i,r=i+len;int L=n-r+2,R=n-l+2;int lcp=min(len,s1.LCP(l,r));int lcs=min(len-1,s2.LCP(L,R));if(lcp+lcs>=len){b[l-lcs]++;b[l+lcp-len+1]--;a[r+lcp]--;a[r-lcs+len-1]++;}}}ans=0;for(int i=1;i<=n;i++){a[i]+=a[i-1],b[i]+=b[i-1];ans+=a[i-1]*b[i];}printf("%lld\n",ans);}return 0;
}