正题
题目链接:https://ac.nowcoder.com/acm/contest/1100/B
题目大意
nnn个点的一棵树,对于每个点求
- 距离该点不超过kkk的点数
- 每个点的权值是以该点为起点长度所有不超过kkk的路径覆盖该点的次数,求所有点的乘积
解题思路
对于第一问我们先考虑在子树中的,sizx,zsiz_{x,z}sizx,z表示距离该点不超过zzz的点数,有转移sizx,z=1+∑x−>ysizy,z−1siz_{x,z}=1+\sum_{x->y}siz_{y,z-1}sizx,z=1+x−>y∑sizy,z−1
然后那么我们每个点的答案我们就只需要计算xxx的祖宗节点即可,
定义xxx的第zzz代祖宗表示为fazfa_zfaz,有贡献sizfaz,k−z−sizfaz−1,k−z−1siz_{fa_z,k-z}-siz_{fa_{z-1},k-z-1}sizfaz,k−z−sizfaz−1,k−z−1
这样我们就解决了第111问,考虑第222问
依旧先考虑子树定义mulx,zmul_{x,z}mulx,z表示距离为zzz时xxx点子树中的值(注意不能计算xxx点,因为xxx的价值还不知道)
有mulx,z=∏x−>y(muly,z−1∗sizy,z−1)mul_{x,z}=\prod_{x->y} (mul_{y,z-1}*siz_{y,z-1})mulx,z=x−>y∏(muly,z−1∗sizy,z−1)
那依旧对于每一个的子树祖宗,有贡献mulfaz,k−zmulfaz−1,k−z−1\frac{mul_{fa_z,k-z}}{mul_{fa_{z-1,k-z-1}}}mulfaz−1,k−z−1mulfaz,k−z
但是该点的值还没有计算,我们发现对于该点的值就是它和所有往上的祖宗的sizfaz,k−z+sizfaz−1,k−z−1siz_{fa_z,k-z}+siz_{fa_{z-1},k-z-1}sizfaz,k−z+sizfaz−1,k−z−1之和,在上一问计算时计入即可。
时间复杂度O(nk)O(nk)O(nk)
codecodecode
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=1e5+100,XJQ=1e9+7;
struct node{ll to,next;
}a[N*2];
ll n,k,tot,ls[N],f[N],g[N],siz[N][15],mul[N][15],fa[N],num[15];
void addl(ll x,ll y)
{a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;
}
void dfs(ll x)
{for(ll j=0;j<=10;j++)siz[x][j]=mul[x][j]=1;for(ll i=ls[x];i;i=a[i].next){ll y=a[i].to;if(y==fa[x]) continue;fa[y]=x;dfs(y);for(ll j=1;j<=10;j++){siz[x][j]+=siz[y][j-1];(mul[x][j]*=mul[y][j-1]*siz[y][j-1]%XJQ)%=XJQ;}}
}
ll power(ll x,ll b)
{ll ans=1;while(b){if(b&1) ans=ans*x%XJQ;x=x*x%XJQ;b>>=1;}return ans;
}
void dfs2(ll x)
{ll now=x,z=k;f[x]=(num[k]=siz[x][k]);while(--z&&fa[now]){f[x]+=siz[fa[now]][z]-siz[now][z-1];num[z]=f[x];now=fa[now];}if(fa[now])f[x]++;now=x;z=k;g[x]=mul[x][k]*f[x]%XJQ;while(--z&&fa[now]){(g[x]*=mul[fa[now]][z]*power(mul[now][z-1]*siz[now][z-1]%XJQ,XJQ-2)%XJQ)%=XJQ;(g[x]*=f[x]-num[z+1])%=XJQ;now=fa[now];}for(ll i=ls[x];i;i=a[i].next){ll y=a[i].to;if(y==fa[x]) continue;dfs2(y);}
}
int main()
{scanf("%lld%lld",&n,&k);for(ll i=1;i<n;i++){ll x,y;scanf("%lld%lld",&x,&y);addl(x,y);addl(y,x);}dfs(1);dfs2(1);for(ll i=1;i<=n;i++)printf("%lld ",f[i]);putchar('\n');for(ll i=1;i<=n;i++)printf("%lld ",g[i]);
}