小 L 出了一道题:
给定一棵 n n n 个点的树,定义两点之间的距离为连接两点的唯一简单路径的边的条数。求树上的点两两之间的距离之和。
小 Q 觉得这题太简单了,于是给它加强了一下:
给定一棵 n n n 个点的树,求树上的点两两之间的距离的 k k k 次方之和。
这下他们都不会做了,你能帮帮他们吗?
几个记号: T ( n ) T(n) T(n) 表示子树 n n n, f a i fa_i fai 表示 i i i 的父亲, dis ( i , j ) \operatorname{dis}(i,j) dis(i,j) 表示树上两点 i , j i,j i,j 的距离, s o n i son_i soni 表示 i i i 的儿子。
将 x k x^k xk 斯特林展开得到
x k = ∑ i = 0 x { k i } ⋅ i ! ⋅ ( x i ) x^k=\sum\limits_{i=0}^x{k\brace i}\cdot i!\cdot\binom xi xk=i=0∑x{ik}⋅i!⋅(ix)
其中 { k i } {k\brace i} {ik} 是第二类斯特林数,表示把 k k k 个数划分成 i i i 个无序非空集合的方案数。
从组合意义上可以说明式子成立: x k x^k xk 表示把 k k k 个不同的球放入 x x x 个不同盒子(最后盒子可以为空)的方案数。有 ( x i ) \binom xi (ix) 种方案选择 i i i 个盒子放球,其他盒子为空。然后 { k i } k\brace i {ik} 就是把 k k k 个球分给 i i i 个相同盒子的方案数。因为盒子是不同的,所以还要乘上 i ! i! i!。式子显然成立。
在题目中,只需枚举到 k k k 即可(因为若 k < i k<i k<i 斯特林数就为 0 0 0)
而题目是要求 ∑ i = 1 n ∑ j = i + 1 n ∑ l = 0 k { k l } ⋅ l ! ⋅ ( dis ( i , j ) l ) = ∑ l = 0 k { k l } ⋅ l ! ( ∑ i = 1 n ∑ j = i + 1 n ( dis ( i , j ) l ) ) \sum\limits_{i=1}^n\sum\limits_{j=i+1}^n\sum\limits_{l=0}^k{k\brace l}\cdot l!\cdot\binom{\operatorname{dis}(i,j)}l=\sum\limits_{l=0}^k{k\brace l}\cdot l!\left(\sum\limits_{i=1}^n\sum\limits_{j=i+1}^n\binom{\operatorname{dis}(i,j)}l\right) i=1∑nj=i+1∑nl=0∑k{lk}⋅l!⋅(ldis(i,j))=l=0∑k{lk}⋅l!(i=1∑nj=i+1∑n(ldis(i,j)))
问题转换为对于每个 l ∈ [ 1 , k ] l\in[1,k] l∈[1,k]
求 ∑ i = 1 n ∑ j = i + 1 n ( dis ( i , j ) l ) \sum\limits_{i=1}^n\sum\limits_{j=i+1}^n\binom{\operatorname{dis}(i,j)}l i=1∑nj=i+1∑n(ldis(i,j))。
设 f i , j f_{i,j} fi,j 表示 ∑ x ∈ T ( i ) ( d i s ( f a i , x ) j ) \sum\limits_{x\in T(i)}\binom{\operatorname{dis(fa_i,x)}}{j} x∈T(i)∑(jdis(fai,x)),就是 ( i 子树的所有点到 f a i 距离 j ) \binom{i 子树的所有点到 fa_i 距离}{j} (ji子树的所有点到fai距离) 之和。
考虑如何转移。由于有 ( x + 1 j ) = ( x j ) + ( x j − 1 ) \binom{x+1}{j}=\binom{x}{j}+\binom{x}{j-1} (jx+1)=(jx)+(j−1x),所以 f i , j = ∑ x ∈ s o n i ( f x , j + f x , j − 1 ) f_{i,j}=\sum\limits_{x\in son_i}(f_{x,j}+f_{x,j-1}) fi,j=x∈soni∑(fx,j+fx,j−1)。
在计算答案时,要将每个 i i i 的子树进行“合并”。由于 ( x + y j ) = ∑ i = 0 j ( x j ) ( y i − j ) \binom{x+y}{j}=\sum\limits_{i=0}^j\binom{x}{j}\binom{y}{i-j} (jx+y)=i=0∑j(jx)(i−jy),所以对于每个 j j j 答案要加上 ∑ l = 0 j f x , l f y , j − l \sum\limits_{l=0}^jf_{x,l}f_{y,j-l} l=0∑jfx,lfy,j−l。这里可以用前缀和优化。时间复杂度 O ( n k 2 ) O(nk^2) O(nk2)
O ( n k 2 ) O(nk^2) O(nk2) 实际上是过不了的,所以我们要发扬人类智慧。如果是一条链,由于儿子只有一个,显然不可能有“合并”,就不用算了;如果所有的 x ∈ T ( i ) x\in{T(i)} x∈T(i),都有 dis ( f a i , x ) < j \operatorname{dis}(fa_i,x)<j dis(fai,x)<j,就说明 f i , k f_{i,k} fi,k 为 0 0 0,可以少枚举。通过一系列操作,成功水过。
代码如下
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const int N=1e6+1,INF=1e9;
int n,k,fa[N],b[N],sz[N];
int head[N],nxt[N<<1],to[N<<1],cnt,dep[N],maxdep[N];
ll ans[101],Ans,stl[101][101],f[101],inv[101],dp[N][101];
void add(int u,int v)
{to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;
}
ll ksm(ll a,ll b)
{ll ans=1;while(b){if(b&1) ans=ans*a%mod;b>>=1;a=a*a%mod;}return ans;
}
ll C(ll x,ll y)
{ll sum=1;for(int i=1;i<=y;i++){sum=sum*(x-i+1)%mod;}return (sum*inv[y]%mod+mod)%mod;
}
void dfs(int u,int fa)
{dep[u]=dep[fa]+1;maxdep[u]=dep[u];for(int i=head[u];i;i=nxt[i]) if(to[i]!=fa) dfs(to[i],u),maxdep[u]=max(maxdep[u],maxdep[to[i]]);
}
void solve(int u,int fa)
{int sz=0;for(int i=head[u];i;i=nxt[i]){if(to[i]!=fa){++sz;solve(to[i],u);if(sz>1){for(int l=0;l<=min(k,maxdep[u]*2-2*dep[u]);l++){for(int p=0;p<=l;p++)ans[l]=(ans[l]+1ll*dp[u][p]*dp[to[i]][l-p])%mod;}}for(int l=0;l<=k;l++) ans[l]=(ans[l]+dp[to[i]][l])%mod;for(int j=0;j<=k;j++) (dp[u][j]+=dp[to[i]][j])%=mod;}}dp[u][0]++;for(int j=1;j<=k;j++){for(int i=head[u];i;i=nxt[i]){if(to[i]==fa) continue;(dp[u][j]+=dp[to[i]][j-1])%=mod;}if(j<=1) dp[u][j]++;dp[u][j]%=mod;}
}
void init()
{stl[0][0]=1;for(int i=1;i<=k;i++){for(int j=1;j<=i;j++){stl[i][j]=(stl[i-1][j-1]+stl[i-1][j]*j)%mod;}}f[0]=1;for(int i=1;i<=k;i++) f[i]=f[i-1]*i%mod;inv[k]=ksm(f[k],mod-2);for(int i=k-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
}
int read()
{int sum=0,c=getchar();while(c<48||c>57) c=getchar();while(c>=48&&c<=57) sum=sum*10+c-48,c=getchar();return sum;
}
int main()
{n=read(),k=read();init();for(int i=1,x,y;i<n;i++){x=read(),y=read();add(x,y),add(y,x);}dfs(1,0);solve(1,0);for(int i=0;i<=k;i++) Ans=(Ans+stl[k][i]*f[i]%mod*ans[i])%mod;printf("%lld",Ans);
}