正题
题目链接:http://noi.ac/problem/2266
题目大意
给出nnn个点的一棵树,有一些边上有中转站(边长度为222,中间有一个中转站),否则就是边长为111。
mmm次询问一个东西从xxx出发走到yyy,每隔kkk步中转站会关闭一次(kkk的倍数步走完后不能在中转站上)。求在关闭多少次以内可以到达
1≤n,m≤1051\leq n,m\leq 10^51≤n,m≤105
解题思路
发现最多只需要走2n2n2n步,然后每隔kkk步关闭一次,所以可以考虑根号分治。
先处理好总的倍增数组,后面求LCALCALCA和跳链要用。
对于k=1k=1k=1的询问,就看一下中间有没有中转站,如果有就是−1-1−1否则就是距离
对于k≤nk\leq \sqrt nk≤n的询问,我们对于每个kkk都进行一次预处理,处理每个周期每个点往上走能走到哪里。然后再处理一个倍增数组,然后询问的时候就在上面跳就好了
对于k>nk>\sqrt nk>n的询问直接每次暴力跳kkk步如果是中转站就跳k−1k-1k−1步,然后一直跳到LCALCALCA处
时间复杂度O(nnlogn)O(n\sqrt n\log n)O(nnlogn),调一下块的大小就能过了
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=2e5+10,T=17;
struct edge{int to,next;
}a[N<<1];
struct node{int x,y,k,id;
}q[N];
int n,m,Q,tot,num,ans[N],ls[N],dep[N],sd[N];
int g[N][100],f[N][T+1],h[N][T+1];
void addl(int x,int y){a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;return;
}
bool cmp(node x,node y)
{return x.k<y.k;}
void dfs(int x,int fa){g[x][0]=x;sd[x]=sd[fa]+(x>n);f[x][0]=fa;dep[x]=dep[fa]+1;for(int i=1;i<=Q;i++)g[x][i]=g[fa][i-1];for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==fa)continue;dfs(y,x);}return;
}
int LCA(int x,int y){if(dep[x]>dep[y])swap(x,y);for(int i=T;i>=0;i--)if(dep[f[y][i]]>=dep[x])y=f[y][i];if(x==y)return x;for(int i=T;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];return f[x][0];
}
void calc(int x,int fa,int k){if(g[x][k]>n)h[x][0]=g[x][k-1];else h[x][0]=g[x][k];for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==fa)continue;calc(y,x,k);}return;
}
int query(int x,int y,int k){int p=LCA(x,y),ans=0;for(int i=T;i>=0;i--){if(dep[h[x][i]]>dep[p])x=h[x][i],ans+=(1<<i);if(dep[h[y][i]]>dep[p])y=h[y][i],ans+=(1<<i);}if(x!=y){int dis=dep[x]+dep[y]-2*dep[p];if(dis>=0&&dis<=k)ans++;else if(dis>k) ans+=2;}return ans;
}
int getf(int x,int k){for(int i=0;i<=T;i++)if((k>>i)&1)x=f[x][i];return x;
}
int solve(int x,int y,int k){int p=LCA(x,y),ans=0;while(dep[x]>dep[p]){int z=getf(x,k-1),t;if(f[z][0]>n)t=z;else t=f[z][0];if(dep[t]>dep[p])x=t,ans++;else break;}while(dep[y]>dep[p]){int z=getf(y,k-1),t;if(f[z][0]>n)t=z;else t=f[z][0];if(dep[t]>dep[p])y=t,ans++;else break;}if(x!=y){int dis=dep[x]+dep[y]-2*dep[p];if(dis>=0&&dis<=k)ans++;else if(dis>k) ans+=2;}return ans;
}
int main()
{scanf("%d",&n);num=n;for(int i=1;i<n;i++){int x,y,w;scanf("%d%d%d",&x,&y,&w);if(w==1)addl(x,y),addl(y,x);else{++num;addl(x,num);addl(num,y);addl(y,num);addl(num,x);}}Q=sqrt(n);if(Q>=70)Q=70;scanf("%d",&m);for(int i=1;i<=m;i++){scanf("%d%d%d",&q[i].x,&q[i].y,&q[i].k);q[i].id=i;}sort(q+1,q+1+m,cmp);dfs(1,0);for(int j=1;j<=T;j++)for(int i=1;i<=num;i++)f[i][j]=f[f[i][j-1]][j-1];int l=1,r=1;for(;r<=m&&q[r].k<=Q;r++,l=r){while(r<m&&q[r].k==q[r+1].k)r++;if(q[r].k==1){for(int i=l;i<=r;i++){int x=q[i].x,y=q[i].y,lca=LCA(x,y);if(sd[x]+sd[y]-2*sd[lca])ans[q[i].id]=-1;else ans[q[i].id]=dep[x]+dep[y]-2*dep[lca];}continue;}calc(1,1,q[r].k);for(int j=1;j<=T;j++)for(int i=1;i<=num;i++)h[i][j]=h[h[i][j-1]][j-1];for(int i=l;i<=r;i++)ans[q[i].id]=query(q[i].x,q[i].y,q[i].k);}for(int i=r;i<=m;i++)ans[q[i].id]=solve(q[i].x,q[i].y,q[i].k);for(int i=1;i<=m;i++)printf("%d\n",ans[i]);return 0;
}