正题
题目大意
nnn个点的一棵树,每个节点有一个值valvalval和一个字符串SSS。对于每个点求∑x∈decp∑y∈decp(x<y)(valxxorvaly)∗∣LCP(Sx,Sy)∣\sum_{x\in dec_p}\sum_{y\in dec_p(x<y)}(val_x\ xor\ val_y)*|LCP(S_x,S_y)|x∈decp∑y∈decp(x<y)∑(valx xor valy)∗∣LCP(Sx,Sy)∣
decxdec_xdecx表示xxx的子树。
解题思路
我们可以通过建立一棵TrieTrieTrie来查询一个字符串和一堆字符串的LCPLCPLCP和。那么我们发现这样每个子树的运输次数是该子树的字符串长度和。
所以我们可以根据字符串长度和来进行启发式合并,然后按位用TrieTrieTrie统计答案即可。
时间复杂度O(nlognlogai)O(n\log n\log a_i)O(nlognlogai)
codecodecode
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=5e5+10;
struct node{ll to,next;
}a[N*2];
ll n,tot,W,val[N],ls[N],l[N],s[N],pos[N];
ll siz[N],son[N],ans[N],out[N];
char st[N];
struct Trie{ll cnt,t[N][26],val[N];void Clear(){cnt=1;memset(t[1],0,sizeof(t[1]));return;}void Insert(ll w){ll x=1;for(ll i=1;i<=l[w];i++){ll z=s[pos[w]+i];if(!t[x][z]){t[x][z]=++cnt;val[cnt]=0;memset(t[cnt],0,sizeof(t[cnt]));}x=t[x][z];val[x]++;}return;}ll Ask(ll w){ll x=1,ans=0;for(ll i=1;i<=l[w];i++){ll z=s[pos[w]+i];if(!t[x][z])return ans;x=t[x][z];ans+=val[x];}return ans;}
}T1,T0;
void addl(ll x,ll y){a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;return;
}
void dfs(ll x,ll fa){siz[x]=l[x]+1;for(ll i=ls[x];i;i=a[i].next){ll y=a[i].to;if(y==fa)continue;dfs(y,x);siz[x]+=siz[y];if(siz[y]>siz[son[x]])son[x]=y;}return;
}
ll calcA(ll x,ll fa){ll ans=(val[x]&W)?T0.Ask(x):T1.Ask(x);for(ll i=ls[x];i;i=a[i].next){ll y=a[i].to;if(y==fa)continue;ans+=calcA(y,x);}return ans;
}
void calcI(ll x,ll fa){(val[x]&W)?T1.Insert(x):T0.Insert(x);for(ll i=ls[x];i;i=a[i].next){ll y=a[i].to;if(y==fa)continue;calcI(y,x);}return;
}
void solve(ll x,ll fa,ll top){ans[x]=0;for(ll i=ls[x];i;i=a[i].next){ll y=a[i].to;if(y==fa||y==son[x])continue;solve(y,x,y);ans[x]+=ans[y];}if(son[x])solve(son[x],x,top);ans[x]+=ans[son[x]];for(ll i=ls[x];i;i=a[i].next)if(a[i].to!=fa&&a[i].to!=son[x])ans[x]+=calcA(a[i].to,x),calcI(a[i].to,x);ans[x]+=(val[x]&W)?T0.Ask(x):T1.Ask(x);out[x]+=ans[x]*W;if(x==top)T0.Clear(),T1.Clear();else (val[x]&W)?T1.Insert(x):T0.Insert(x);return;
}
int main()
{freopen("tree.in","r",stdin);freopen("tree.out","w",stdout);scanf("%lld",&n);for(ll i=1;i<=n;i++)scanf("%lld",&val[i]);for(ll i=1;i<=n;i++){scanf("%s",st+1);l[i]=strlen(st+1);for(ll j=1;j<=l[i];j++)s[pos[i]+j]=st[j]-'a';pos[i+1]=pos[i]+l[i];}for(ll i=1;i<n;i++){ll x,y;scanf("%lld%lld",&x,&y);addl(x,y);addl(y,x);}T1.Clear();T0.Clear(); dfs(1,0);for(W=1;W<=1e5;W<<=1)solve(1,1,1);for(ll i=1;i<=n;i++)printf("%lld\n",out[i]);
}