题意:给一棵 nnn 个点的树,每个点有个字符,另给一个长度为 mmm 的特征串,求树上 n2n^2n2 条有向路径在特征串中出现的次数之和。
n,m≤5×104n,m\leq 5\times 10^4n,m≤5×104
看到母串先建 SAM (bushi
树上路径统计问题,考虑点分治。
对当前分治中心,我们希望求出所有经过分治中心的路径的贡献。因为是计数问题,同一个子树的可以容斥掉,所以这不会成为思维瓶颈。
考虑求出特征串上每个位置作为分治中心的匹配点的方案数,我们只需要求出这个位置结束的后缀以及开始的前缀总共可以匹配多少个以分治中心为终点或起点的路径。两个是同理的,考虑前面一个。
我们然后从分治中心开始 dfs,记录从当前点到分治中心的路径构成的串在 SAM 上的位置。如果这个串不是其等价类中最长的,那么判断一下加入一个字符后是否是等价类中长度多 111 的串,不是就返回。
如果是等价类中最长的,前面加一个字符后可(yi)能(ding)会对应多个状态。在 SAM 上预处理出一个 nxt(p,c)nxt(p,c)nxt(p,c),表示 ppp 号结点代表的最长串在前面加上字符 ccc 后到达的状态,没有为 000。每访问一个位置就让该结点的计数器 +1+1+1。
然后我们要求某个位置结尾的后缀被访问次数之和,也就是到根的路径和,做一个树上前缀和即可。
这样对一个大小为 sizsizsiz 的连通块复杂度是 O(siz+m)\Omicron(siz+m)O(siz+m),在 sizsizsiz 很小时显得浪费。
考虑根号分治,当 siz<nsiz<\sqrt nsiz<n 时直接 O(siz2)\Omicron(siz^2)O(siz2) 暴力。这样点分治只会递归一半的层数,暴力的次数也只有 O(n)\Omicron(\sqrt n)O(n)。总复杂度 O((n+m)n)\Omicron((n+m)\sqrt n)O((n+m)n)
注意容斥的时候也要根号分治,否则会被菊花图卡。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <cmath>
#define MAXN 100005
using namespace std;
typedef long long ll;
int n,m,B;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
struct SAM
{char s[MAXN];int ch[MAXN][26],nxt[MAXN][26],fa[MAXN],len[MAXN],pos[MAXN],rt[MAXN],sum[MAXN],siz[MAXN],tot,las;void insert(int c,int k){int cur=++tot,p=las;len[cur]=len[p]+1;for (;p&&!ch[p][c];p=fa[p]) ch[p][c]=cur;if (!p) fa[cur]=1;else{int q=ch[p][c];if (len[q]==len[p]+1) fa[cur]=q;else{int _q=++tot;len[_q]=len[p]+1;fa[_q]=fa[q],fa[q]=fa[cur]=_q;memcpy(ch[_q],ch[q],sizeof(ch[q]));for (;ch[p][c]==q;p=fa[p]) ch[p][c]=_q; } } ++siz[pos[k]=las=cur];}int a[MAXN],c[MAXN];void build(){tot=las=1;for (int i=1;i<=m;i++) insert(s[i]-'a',i);for (int i=1;i<=tot;i++) ++c[len[i]];for (int i=1;i<=tot;i++) c[i]+=c[i-1];for (int i=1;i<=tot;i++) a[c[len[i]]--]=i;for (int i=1;i<=m;i++) rt[pos[i]]=i;for (int i=tot;i>=1;i--) rt[fa[a[i]]]=rt[a[i]],siz[fa[a[i]]]+=siz[a[i]];for (int i=2;i<=tot;i++) nxt[fa[i]][s[rt[i]-len[fa[i]]]-'a']=i;}int getnxt(int p,int l,int c){if (l<len[p]) return c==s[rt[p]-l]-'a'? (++sum[p],p):0;return ++sum[p=nxt[p][c]],p;}inline void clear(){for (int i=1;i<=tot;i++) sum[i]=0;}inline void getsum(int* ans){for (int i=2;i<=tot;i++) sum[a[i]]+=sum[fa[a[i]]];for (int i=1;i<=m;i++) ans[i]=sum[pos[i]];}
}S,T;
ll ans;
vector<int> e[MAXN];
char s[MAXN];
int vis[MAXN],fa[MAXN];
int siz[MAXN],maxp[MAXN],rt;
void findrt(int u,int f,int sum)
{siz[u]=1,maxp[u]=0;for (int i=0;i<(int)e[u].size();i++)if (!vis[e[u][i]]&&e[u][i]!=f){findrt(e[u][i],u,sum);siz[u]+=siz[e[u][i]];maxp[u]=max(maxp[u],siz[e[u][i]]);}maxp[u]=max(maxp[u],sum-siz[u]);if (maxp[u]<maxp[rt]) rt=u;
}
void dfs(SAM& S,int u,int f,int l,int p)
{siz[u]=1,p=S.getnxt(p,l,s[u]-'a');for (int i=0;i<(int)e[u].size();i++)if (!vis[e[u][i]]&&e[u][i]!=f)dfs(S,e[u][i],u,l+1,p),siz[u]+=siz[e[u][i]];
}
vector<int> lis;
void dfs(int u,int f)
{lis.push_back(u),fa[u]=f;for (int i=0;i<(int)e[u].size();i++)if (!vis[e[u][i]]&&e[u][i]!=f)dfs(e[u][i],u);
}
void dfs(int u,int f,int p,int v)
{if (!(p=S.ch[p][s[u]-'a'])) return;ans+=v*S.siz[p];for (int i=0;i<(int)e[u].size();i++)if (!vis[e[u][i]]&&e[u][i]!=f)dfs(e[u][i],u,p,v);
}
int a[MAXN],b[MAXN];
void calc()
{S.clear(),T.clear();dfs(S,rt,0,0,1),dfs(T,rt,0,0,1);S.getsum(a),T.getsum(b);for (int i=1;i<=m;i++) ans+=(ll)a[i]*b[m-i+1];for (int i=0;i<(int)e[rt].size();i++)if (!vis[e[rt][i]]){if (siz[e[rt][i]]<B){lis.clear(),dfs(e[rt][i],rt);for (int j=0;j<(int)lis.size();j++){int v=lis[j],p=1;while (fa[v]) p=S.ch[p][s[v]-'a'],v=fa[v];p=S.ch[p][s[rt]-'a'];dfs(e[rt][i],0,p,-1);}}else{S.clear(),T.clear();dfs(S,e[rt][i],0,1,S.ch[1][s[rt]-'a']),dfs(T,e[rt][i],0,1,T.ch[1][s[rt]-'a']);S.getsum(a),T.getsum(b);for (int i=1;i<=m;i++) ans-=(ll)a[i]*b[m-i+1];}}
}
void calcV()
{lis.clear(),dfs(rt,0);for (int i=0;i<(int)lis.size();i++) dfs(lis[i],0,1,1);
}
void solve()
{int u=rt;vis[u]=1;calc();for (int i=0;i<(int)e[u].size();i++)if (!vis[e[u][i]]){rt=0,findrt(e[u][i],0,siz[e[u][i]]);siz[e[u][i]]<B? calcV():solve();}
}
int main()
{maxp[0]=0x7fffffff;B=sqrt(n=read()),m=read();for (int i=1;i<n;i++){int u,v;u=read(),v=read();e[u].push_back(v),e[v].push_back(u);}scanf("%s",s+1);scanf("%s",S.s+1);for (int i=1;i<=m;i++) T.s[m-i+1]=S.s[i];S.build(),T.build();findrt(1,0,n);solve();cout<<ans<<'\n';return 0;
}