正题
题目大意
nnn个点的一颗树,合法路径定义为一条路径上每个点的编号相差大于KKK。求合法路径数
解题思路
首先我们可以求不合法的路径数,这样我们就有了K∗nK*nK∗n个不合法(即不能在同一个路径上)的点对。
然后这题就和之前一题jzoj6276一样了
大概就是用矩形表示不合法的路径,之后用扫面线求矩形的面积并即可。
codecodecode
#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cctype>
using namespace std;
const int N=3e5+10;
struct node{int to,next;
}a[N*2];
struct line{int x,l,r,w;
}l[N*40];
bool operator<(line x,line y)
{return x.x<y.x;}
int n,K,tot,cnt,num;
int rfn[N],ed[N],f[N][21],dep[N];
int w[N*4],mark[N*4],ls[N];
long long ans;
__attribute__((optimize("O3"))) inline int read() {int x=0,f=1; char c=getchar();while(!isdigit(c)) {if(c=='-')f=-f;c=getchar();}while(isdigit(c)) x=(x<<1)+(x<<3)+c-48,c=getchar();return x*f;
}
void addl(int x,int y){a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;return;
}
void dfs(int x,int fa){rfn[x]=++cnt;for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==fa)continue;dep[y]=dep[x]+1;f[y][0]=x;dfs(y,x);}ed[x]=cnt;return;
}
int LCA(int x,int y){for(int i=20;i>=0;i--)if(dep[f[y][i]]>dep[x])y=f[y][i];return y;
}
void addc(int x1,int x2,int y1,int y2){if(x1>x2)swap(x1,x2);if(y1>y1)swap(y1,y2);l[++num]=(line){x1,y1,y2,1};l[++num]=(line){x2+1,y1,y2,-1};
}
void Ban(int x,int y){if(rfn[x]>rfn[y])swap(x,y);if(rfn[x]<=rfn[y]&&rfn[y]<=ed[x]){int top=LCA(x,y);if(rfn[top]!=1)addc(1,rfn[top]-1,rfn[y],ed[y]);if(ed[top]!=n)addc(rfn[y],ed[y],ed[top]+1,n);}else addc(rfn[x],ed[x],rfn[y],ed[y]);return;
}
void Change(int x,int L,int R,int l,int r,int val){if(L==l&&R==r){mark[x]+=val;if(mark[x])w[x]=r-l+1;else if(l==r)w[x]=0;else w[x]=w[x*2]+w[x*2+1];return;}int mid=(L+R)>>1;if(r<=mid)Change(x*2,L,mid,l,r,val);else if(l>mid)Change(x*2+1,mid+1,R,l,r,val);else Change(x*2,L,mid,l,mid,val),Change(x*2+1,mid+1,R,mid+1,r,val);if(mark[x])w[x]=R-L+1;else w[x]=w[x*2]+w[x*2+1];return;
}
int main()
{freopen("data.in","r",stdin);int size = 256 << 20; //250Mchar*p=(char*)malloc(size) + size;__asm__("movl %0, %%esp\n" :: "r"(p) );n=read();K=read(); for(int i=1;i<n;i++){int x=read(),y=read();addl(x,y);addl(y,x);}dfs(1,1);for(int i=1;i<=20;i++)for(int j=1;j<=n;j++)f[j][i]=f[f[j][i-1]][i-1];for(int i=1;i<=n;i++)for(int j=i+1;j<=min(i+K,n);j++)Ban(i,j);sort(l+1,l+1+num);int z=1;for(int i=1;i<=n;i++){while(z<=num&&l[z].x<=i)Change(1,1,n,l[z].l,l[z].r,l[z].w),z++;ans+=w[1];}printf("%lld",1ll*n*(n-1)/2-ans+n);
}