正题
题目链接:https://www.luogu.com.cn/problem/CF204E
题目大意
nnn个字符串的一个字符串集合,对于每个字符串求有多少个子串是这个字符串集合中至少kkk个字符串的子串。
解题思路
因为对于每个字符串我们需要维护的信息不同,不能累加,所以考虑使用线段树合并。
先将nnn个字符串构建出一个广义SAMSAMSAM,然后对于每个节点维护一个该线段树表示该节点属于的字符串。然后在parentsparentsparents树上从下往上合并,如果属于字符串的数量多余kkk,那么打上标记。
然后再上往下走,每个节点产生的答案就是在它parentsparentsparents树上的祖先中最近的一个打了标记的节点的lenlenlen。
时间复杂度O(nlogn)O(n\log n)O(nlogn)
好像还可以先接起来跑一遍SASASA,然后用单调队列类似于统计矩形面积一样的方法来做,也是O(nlogn)O(n\log n)O(nlogn)当然我这里写的是SAMSAMSAM
codecodecode
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int N=2e5+10;
struct node{int to,next;
}a[N<<1];
int n,k,cnt,tot,ls[N],rt[N];
int ch[N][26],fa[N],len[N];
bool mark[N];char s[N];
long long ans[N];
vector<int> q[N];
struct Seq_Tree{int w[N<<4],ls[N<<4],rs[N<<4],cnt;int Change(int x,int L,int R,int pos,int val){int y=++cnt;if(L==R){w[y]=val;return y;}int mid=(L+R)>>1;if(pos<=mid)ls[y]=Change(ls[x],L,mid,pos,val),rs[y]=rs[x];else ls[y]=ls[x],rs[y]=Change(rs[x],mid+1,R,pos,val);w[y]=w[ls[y]]+w[rs[y]];return y;}int Ask(int x,int L,int R,int pos){if(!x)return 0;if(L==R)return w[x];int mid=(L+R)>>1;if(pos<=mid)return Ask(ls[x],L,mid,pos);return Ask(rs[x],mid+1,R,pos);}int Merge(int x,int y,int L,int R){if((!x)||(!y))return x|y;if(L==R){w[x]=w[x]|w[y];return x;}int mid=(L+R)>>1;ls[x]=Merge(ls[x],ls[y],L,mid);rs[x]=Merge(rs[x],rs[y],mid+1,R);w[x]=w[ls[x]]+w[rs[x]];return x;}
}T;
void addl(int x,int y){a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;return;
}
int Insert(int c,int p){if(ch[p][c]){int q=ch[p][c];if(len[q]==len[p]+1)return q;int nq=++cnt;len[nq]=len[p]+1;memcpy(ch[nq],ch[q],sizeof(ch[nq]));fa[nq]=fa[q];fa[q]=nq;for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;return nq;}int np=++cnt;len[np]=len[p]+1;for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;if(!p)fa[np]=1;else{int q=ch[p][c];if(len[q]==len[p]+1)fa[np]=q;else{int nq=++cnt;len[nq]=len[p]+1;memcpy(ch[nq],ch[q],sizeof(ch[nq]));fa[nq]=fa[q];fa[q]=fa[np]=nq;for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;}}return np;
}
void dfs(int x){for(int i=ls[x];i;i=a[i].next){int y=a[i].to;dfs(y);rt[x]=T.Merge(rt[x],rt[y],1,n);}if(T.w[rt[x]]>=k)mark[x]=1;return;
}
void solve(int x,int res){if(mark[x])res=len[x];for(int i=ls[x];i;i=a[i].next){int y=a[i].to;solve(y,res);}for(int i=0;i<q[x].size();i++)ans[q[x][i]]+=res;return;
}
int main()
{scanf("%d%d",&n,&k);cnt=1;for(int i=1;i<=n;i++){scanf("%s",s);int l=strlen(s),last=1;for(int j=0;j<l;j++){last=Insert(s[j]-'a',last);rt[last]=T.Change(rt[last],1,n,i,1);q[last].push_back(i);}}for(int i=2;i<=cnt;i++)addl(fa[i],i);dfs(1);solve(1,0);for(int i=1;i<=n;i++)printf("%lld\n",ans[i]);return 0;
}