正题
题目大意
一棵树上每个节点有不同的颜色,然后每次询问(x,y,a,b)(x,y,a,b)(x,y,a,b)表示将颜色aaa看为颜色bbb的情况下询问xxx到yyy有多少不种的颜色。
解题思路
数颜色,显然树上莫队。我们维护一个欧拉序dfndfndfn(进去时记录一次出来时记录一次),然后rfnirfn_irfni表示点iii进入时在dfndfndfn中的位置,rgnirgn_irgni表示点iii在出去时在iii的位置。
然后每次移动端点时,只有在[l,r][l,r][l,r]这个区间中出现一次的点的颜色才会被记录入cntcntcnt(让出现两次的相互抵消)。那么我们就可以让x,yx,yx,y所在它的LCALCALCA的子树中这个序列,让他们中间的相互抵消。
不过要注意,若x,yx,yx,y不是他们的LCALCALCA,那么他们的LCALCALCA就被计算了两次,就重复了,统计答案时要减去重复的。
codecodecode
#include<cstdio>
#include<algorithm>
#include<queue>
#include<cmath>
using namespace std;
const int N=51000,M=110000;
struct line{int to,next,w;
}a[M];
struct Que_node{int l,r,a,b,id,pos;
}que[M];
bool operator<(Que_node x,Que_node y)
{return x.pos==y.pos?x.r<y.r:x.pos<y.pos;}
int n,m,c[N],dfn[M],rfn[N],rgn[N],v[N],cnt,num,ans[N];
int tot,x,y,ls[N],dep[N],f[N][30],t,T;
bool b[N];
queue<int> q;
inline int read()
{int X=0,w=0; char c=0;while(c<'0'||c>'9') {w|=c=='-';c=getchar();}while(c>='0'&&c<='9') X=(X<<3)+(X<<1)+(c^48),c=getchar();return w?-X:X;
}
inline void addl(int x,int y,int w)
{a[++tot].to=y;a[tot].next=ls[x];a[tot].w=w;ls[x]=tot;
}
void dfs(int x,int fa)
{dfn[++cnt]=x;rfn[x]=cnt;for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==fa) continue;dfs(y,x);}dfn[++cnt]=x;rgn[x]=cnt;
}
inline void bfs(int s)
{q.push(s);dep[s]=1;while(!q.empty()){int x=q.front();q.pop();for (int i=ls[x];i;i=a[i].next){int y=a[i].to;if (dep[y]) continue;q.push(y);f[y][0]=x;dep[y]=dep[x]+1;}}T=(int)(log(n)/log(2))+1;for (int j=1;j<=T;j++)for (int i=1;i<=n;i++)f[i][j]=f[f[i][j-1]][j-1];
}
inline int LCA(int x,int y)
{if (dep[x]>dep[y]) swap(x,y);for (int i=T;i>=0;i--)if (dep[f[y][i]]>=dep[x]) y=f[y][i];if (x==y) return x;for (int i=T;i>=0;i--)if (f[y][i]!=f[x][i]) {x=f[x][i];y=f[y][i];}return f[x][0];
}
void rev(int x)
{if(b[x]) v[c[x]]--,num-=(v[c[x]]==0);else num+=(v[c[x]]==0),v[c[x]]++;b[x]^=1;
}
void Keep_zzy(int &l,int &r,int L,int R)
{while(l<L) rev(dfn[l]),l++;while(l>L) l--,rev(dfn[l]);while(r<R) r++,rev(dfn[r]);while(r>R) rev(dfn[r]),r--;
}
int main()
{freopen("apple.in","r",stdin);freopen("apple.out","w",stdout);n=read();m=read();for(int i=1;i<=n;i++)c[i]=read();for(int i=1;i<=n;i++){addl(x=read(),y=read(),1);addl(y,x,1);}dfs(0,0);bfs(0);t=(int)sqrt((double)cnt);for(int i=1;i<=m;i++){scanf("%d%d%d%d",&que[i].l,&que[i].r,&que[i].a,&que[i].b);if(rfn[que[i].l]>rfn[que[i].r])swap(que[i].l,que[i].r);que[i].l=rgn[que[i].l];que[i].r=rfn[que[i].r];if(que[i].l>=que[i].r) que[i].l=rfn[dfn[que[i].l]];que[i].id=i;que[i].pos=(que[i].l-1)/t+1;}sort(que+1,que+1+m);int l=1,r=0;v[0]=1; for(int i=1;i<=m;i++){int x=que[i].l,y=que[i].r;Keep_zzy(l,r,x,y);int lca=LCA(dfn[x],dfn[y]);bool flag=0;if(dfn[x]!=lca&&dfn[y]!=lca)rev(lca),flag=1;ans[que[i].id]=num;if(v[que[i].a]&&v[que[i].b]&&que[i].a!=que[i].b)ans[que[i].id]--;if(flag) rev(lca);}for(int i=1;i<=m;i++)printf("%d\n",ans[i]);
}