正题
题目链接:https://uoj.ac/problem/33
题目大意
给出nnn个点的一棵树
定义f(x,y)=gcd(dis(x,lca),dis(y,lca))f(x,y)=gcd(\ dis(x,lca),dis(y,lca)\ )f(x,y)=gcd( dis(x,lca),dis(y,lca) )。
对于每个iii求有多少对f(x,y)=i(x<y)f(x,y)=i(x<y)f(x,y)=i(x<y)
1≤n≤1051\leq n\leq 10^51≤n≤105
解题思路
首先肯定是枚举lcalcalca节点,然后看他子树里的情况,比较麻烦的是gcdgcdgcd刚刚好是ddd,但是其实我们可以是ddd的倍数的情况,然后后面再容斥出答案。
如果,然后暴力算的话首先需要一个长链剖分,然后每次是lenloglenlen\ log\ lenlen log len的。
但是仔细想一想就会发现这个复杂度其实是假的,因为每次暴力算的话的lenlenlen是这条链上面那条链的lenlenlen。
考虑点其他做法,因为是枚举倍数,我们可以上我们的根号分治
对于dis>ndis>\sqrt ndis>n的情况,我们之间暴力枚举倍数,因为这样不会超过n\sqrt nn次
对于dis≤ndis\leq \sqrt ndis≤n的情况,我们考虑储存一些东西,设gi,jg_{i,j}gi,j表示当前的链中depdepdep模iii为jjj的点的个数,然后处理的时候我们就可以直接用这个来计算了。
这样平衡下来时间复杂度就是O(nn)O(n\sqrt n)O(nn)了
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cctype>
using namespace std;
const int N=2e5+10;
struct node{int to,next;
}a[N];
int n,T,len[N],dep[N],h[N],g[400][400];
int tot,t[N],ls[N],son[N],*f[N],*now;
long long ans[N],pre[N];
int read(){int x=0,f=1;char c=getchar();while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}while(isdigit(c)){x=(x<<1)+(x<<3)+c-'0';c=getchar();}return x*f;
}
void addl(int x,int y){a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;
}
void dfs(int x){for(int i=ls[x];i;i=a[i].next){int y=a[i].to;dep[y]=dep[x]+1;dfs(y);if(len[y]>len[son[x]])son[x]=y;}len[x]=len[son[x]]+1;return;
}
void calc(int x,int top){f[x][0]=1;for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==son[x])continue;f[y]=now;now+=len[y];calc(y,y);}if(son[x]){f[son[x]]=f[x]+1;calc(son[x],top);}for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==son[x])continue;for(int j=1;j<=len[y];j++)t[j]=f[y][j-1];for(int j=1;j<=len[y];j++)for(int k=2*j;k<=len[y];k+=j)t[j]+=t[k];for(int j=1;j<=len[y];j++)if(j>T){for(int k=j;k<len[x];k+=j)ans[j]+=1ll*f[x][k]*t[j];}else ans[j]+=1ll*g[j][dep[x]%j]*t[j];for(int j=1;j<=len[y];j++)f[x][j]+=f[y][j-1];for(int j=0;j<len[y];j++)for(int k=1;k<=T;k++)g[k][(j+dep[y])%k]+=f[y][j];}for(int i=1;i<=T;i++)g[i][dep[x]%i]++;if(x==top){for(int i=1;i<=T;i++)for(int j=0;j<len[x];j++)g[i][(j+dep[x])%i]=0;}return;
}
signed main()
{n=read();for(int i=2;i<=n;i++)addl(read(),i);dfs(1);T=sqrt(n);if(T>350)T=350;f[1]=now=h;now+=len[1];calc(1,1);for(int i=n;i>=1;i--)for(int j=2*i;j<=n;j+=i)ans[i]-=ans[j];for(int i=2;i<=n;i++)pre[dep[i]]++;for(int i=n;i>=1;i--)pre[i]+=pre[i+1];for(int i=1;i<n;i++)printf("%lld\n",ans[i]+pre[i]);return 0;
}