解析
算法一
定义 upx,kup_{x,k}upx,k 为节点 xxx 从自己的颜色所在位置在返祖链上往后跳 2k2^k2k 个颜色到达的节点。
可以像倍增一样的求解。
这样对于一次询问 (s,t)(s,t)(s,t) 我们就能求出 (s,lca)(s,lca)(s,lca) 这一段能取到哪里了。
对于向下的情况,再处理一个 upx,k′up'_{x,k}upx,k′ 表示节点 xxx 从自己的颜色所在位置在返祖链上往前跳 2k2^k2k 个颜色到达的节点。
然后二分每一个询问的答案,从答案开始往前跳,看能否与 (s,lca)(s,lca)(s,lca) 相接即可判定是否合法。
时间复杂度 O(nlogn+mlogClogn)O(n\log n+m\log C\log n)O(nlogn+mlogClogn)。
算法二
考虑优化后一段 (lca,t)(lca,t)(lca,t) 的过程。
假设询问 iii 在 (s,lca)(s,lca)(s,lca) 过程中跳到了颜色 ccc,就在 lcalcalca 处增加一个 (i,c)(i,c)(i,c) 的元素,在 ttt 处打一个 (i)(i)(i) 标记。
考虑我们 dfsdfsdfs 过程中需要维护什么:
- 插入二元组 (i,c)(i,c)(i,c)
- 如果当前节点颜色为 ccc,收集器上的下一个颜色为 sufsufsuf,就使所有 (i,c)→(i,suf)(i,c)\to(i,suf)(i,c)→(i,suf)。
- 查询当前的 iii 元素的特征值。
- 撤销当前dfs的影响。
这个东西可以用可撤销并查集维护。
总复杂度 O(nlogn+mlogn)O(n\log n+m\log n)O(nlogn+mlogn)
代码
写的是算法二。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define ok debug("OK\n")
using namespace std;const int N=2e6+100;
const int M=50050;
const int mod=1e9+7;
const double eps=1e-9;inline ll read() {ll x(0),f(1);char c=getchar();while(!isdigit(c)) {if(c=='-')f=-1;c=getchar();}while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}return x*f;
}int n,m,C,Mx;struct node{int to,nxt;
}e[N<<1];
int fi[N],cnt;
inline void addline(int x,int y){e[++cnt]=(node){y,fi[x]};fi[x]=cnt;return;
}
int p[N],col[N];int pre[N],up[N][20],pl[N][20],dep[N],suf[N];
void dfs(int x,int f){dep[x]=dep[f]+1;pl[x][0]=f;for(int k=1;pl[x][k-1];k++) pl[x][k]=pl[pl[x][k-1]][k-1];up[x][0]=pre[suf[col[x]]];for(int k=1;up[x][k-1];k++) up[x][k]=up[up[x][k-1]][k-1];int ori=pre[col[x]];pre[col[x]]=x;for(int i=fi[x];~i;i=e[i].nxt){int to=e[i].to;if(to==f) continue;dfs(to,x);}pre[col[x]]=ori;return;
}
inline int Lca(int x,int y){if(dep[x]<dep[y]) swap(x,y);for(int k=17;k>=0;k--){if(dep[pl[x][k]]<dep[y]) continue;x=pl[x][k];}if(x==y) return x;for(int k=17;k>=0;k--){if(pl[x][k]==pl[y][k]) continue;x=pl[x][k];y=pl[y][k];}return pl[x][0];
}
struct query{int s,t,lca,id;
};
vector<query>v[N];
inline int jump(int x,int top){//return color;for(int k=17;k>=0;k--){if(dep[up[x][k]]<dep[top]) continue;x=up[x][k];}return suf[col[x]];
}
struct add{int id,c;
};
vector<add>ad[N];
vector<int>q[N];
void solve1(int x,int f){int ori=pre[col[x]];pre[col[x]]=x;for(query o:v[x]){int s=o.s,t=o.t,lca=o.lca,id=o.id;s=pre[p[1]];if(dep[s]<dep[lca]) ad[lca].push_back((add){id,p[1]});else ad[lca].push_back((add){id,jump(s,lca)});q[t].push_back(id);}for(int i=fi[x];~i;i=e[i].nxt){int to=e[i].to;if(to==f) continue;solve1(to,x);}pre[col[x]]=ori;return;
}
int mx[N],fa[N],siz[N];
struct ope{int op;//1:fa 2:siz 3:mx 4:belint id,ori;
}zhan[N<<3];
int top,nam[N],bel[N],tot;
int find(int x){return fa[x]==x?x:find(fa[x]);
}
inline int New(int val){++tot;fa[tot]=tot;siz[tot]=1;mx[tot]=val;return tot;
}
void merge(int x,int y){x=find(x);y=find(y);if(siz[x]>siz[y]) swap(x,y);zhan[++top]=(ope){1,x,fa[x]};fa[x]=y;zhan[++top]=(ope){2,y,siz[y]};siz[y]+=siz[x];zhan[++top]=(ope){3,y,mx[y]};mx[y]=max(mx[y],mx[x]);return;
}
void del(int tim){while(top!=tim){if(zhan[top].op==1) fa[zhan[top].id]=zhan[top].ori;else if(zhan[top].op==2) siz[zhan[top].id]=zhan[top].ori;else if(zhan[top].op==3) mx[zhan[top].id]=zhan[top].ori;else if(zhan[top].op==4) bel[zhan[top].id]=zhan[top].ori;top--;}return;
}
int ans[N];
int rk[N];
void solve2(int x,int f){int ori=top;for(add o:ad[x]){int id=o.id,c=o.c,now=New(0);nam[id]=now;merge(now,bel[c]);//ans[id]=rk[c];}//assert(mx[find(bel[suf[col[x]]])]==rk[suf[col[x]]]);//assert(mx[find(bel[col[x]])]==rk[col[x]]);merge(bel[col[x]],bel[suf[col[x]]]);//if(mx[find(bel[col[x]])]!=rk[col[x]]+1){// debug("%d %d\n",mx[find(bel[col[x]])],rk[col[x]]);exit(0);//}//assert(mx[find(bel[col[x]])]==rk[col[x]]+1);zhan[++top]=(ope){4,col[x],bel[col[x]]};bel[col[x]]=New(rk[col[x]]);for(int id:q[x]){int o=nam[id];ans[id]=mx[find(o)];}for(int i=fi[x];~i;i=e[i].nxt){int to=e[i].to;if(to==f) continue;solve2(to,x);}del(ori);
}
int main() {
#ifndef ONLINE_JUDGEfreopen("a.in","r",stdin);freopen("a.out","w",stdout);
#endifmemset(fi,-1,sizeof(fi));cnt=-1;n=read();Mx=read();C=read();for(int i=1;i<=C;i++) p[i]=read();for(int i=1;i<=C;i++) suf[p[i]]=p[i+1],rk[p[i]]=i;rk[0]=C+1;for(int i=0;i<=Mx;i++) bel[i]=New(rk[i]);for(int i=1;i<=n;i++) col[i]=read();for(int i=1;i<n;i++){int x=read(),y=read();addline(x,y);addline(y,x);}dfs(1,0);m=read();for(int i=1;i<=m;i++){int s=read(),t=read(),lca=Lca(s,t);v[s].push_back((query){s,t,lca,i});}solve1(1,0);solve2(1,0);for(int i=1;i<=m;i++) printf("%d\n",ans[i]-1);return 0;
}
/*
1
3 3
1000000 2000000
0 0
*/