正题
题目大意
给出nnn个点的一棵树,每个点有个颜色aia_iai,你每次可以选择一个颜色全部变成另一个颜色。
求最少多少次操作可以把一种颜色变成一个完整的连通块。
1≤k≤n≤2×1051\leq k\leq n\leq 2\times 10^51≤k≤n≤2×105
解题思路
考虑如果我们要把一个颜色变成一个联通块,那么首先得把它目前包含它颜色点的最小联通子图全都同化,并且同化这些颜色之后还有可能需要同化其他更多颜色。
这是一个类似于跑图的过程,我们可以考虑建边,如果颜色AAA需要颜色BBB那么A→BA\rightarrow BA→B,最后跑出来的图我们找一个点能走到的点数最少即可。
至于怎么优化这个建图的过程,我们树链剖分+线段树上维护每个节点,然后每个节点连向对应的颜色。
然后对于一种颜色我们按照点的dfsdfsdfs序排序,然后收尾相连相邻的点之间的路径拼起来恰好是它颜色的生成子图的两倍,之间用树链剖分去连边就好了。
至于建好图之后跑个tarjan就能求出答案了
时间复杂度:O(nlog2n)O(n\log^2 n)O(nlog2n)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<stack>
using namespace std;
const int N=2e5+10,M=N*5;
struct node{int to,next;
}a[N<<1];
int n,k,cnt,tot,c[N],ls[N],fa[N],siz[N];
int seg[N],id[N],top[N],son[N],dep[N];
int num,dfc,p[M],dfn[M],low[M],col[M],f[M],in[M];
bool ins[N];stack<int> s;vector<int> v[N],G[M];
void addl(int x,int y){a[++tot].to=y;a[tot].next=ls[x];ls[x]=tot;return;
}
void dfs1(int x){dep[x]=dep[fa[x]]+1;siz[x]=1;for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==fa[x])continue;fa[y]=x;dfs1(y);siz[x]+=siz[y];if(siz[y]>siz[son[x]])son[x]=y;}return;
}
void dfs2(int x){seg[++cnt]=x;id[x]=cnt;if(son[x]){top[son[x]]=top[x];dfs2(son[x]);}for(int i=ls[x];i;i=a[i].next){int y=a[i].to;if(y==fa[x]||y==son[x])continue;top[y]=y;dfs2(y);}return;
}
void Build(int x,int L,int R){p[x]=++cnt;if(L==R){G[p[x]].push_back(c[seg[L]]);return;}int mid=(L+R)>>1;Build(x*2,L,mid);Build(x*2+1,mid+1,R);G[p[x]].push_back(p[x*2]);G[p[x]].push_back(p[x*2+1]);return;
}
void Change(int x,int L,int R,int l,int r,int vp){if(L==l&&R==r){G[vp].push_back(p[x]);return;}int mid=(L+R)>>1;if(r<=mid)Change(x*2,L,mid,l,r,vp);else if(l>mid)Change(x*2+1,mid+1,R,l,r,vp);else Change(x*2,L,mid,l,mid,vp),Change(x*2+1,mid+1,R,mid+1,r,vp);return;
}
void Recovery(int x,int y,int s){
// printf("%d %d %d\n",x,y,s);while(top[x]!=top[y]){if(dep[top[x]]<dep[top[y]])swap(x,y);Change(1,1,n,id[top[x]],id[x],s);x=fa[top[x]];}if(dep[x]>dep[y])swap(x,y);Change(1,1,n,id[x],id[y],s);return;
}
void tarjan(int x){dfn[x]=low[x]=++dfc;ins[x]=1;s.push(x);for(int i=0;i<G[x].size();i++){int y=G[x][i];if(!dfn[y]){tarjan(y);low[x]=min(low[x],low[y]);}else if(ins[y])low[x]=min(low[x],dfn[y]);}if(low[x]==dfn[x]){int y;++num;do{y=s.top();s.pop();ins[y]=0;f[num]+=(y<=k);col[y]=num;}while(y!=x);}return;
}
bool cmp(int x,int y)
{return id[x]<id[y];}
int main()
{freopen("color.in","r",stdin);freopen("color.out","w",stdout);scanf("%d%d",&n,&k);for(int i=1;i<n;i++){int x,y;scanf("%d%d",&x,&y);addl(x,y);addl(y,x);}for(int i=1;i<=n;i++)scanf("%d",&c[i]),v[c[i]].push_back(i);dfs1(1);top[1]=1;dfs2(1);cnt=k;Build(1,1,n);for(int i=1;i<=k;i++){if(v[i].size()<2)continue;sort(v[i].begin(),v[i].end(),cmp);for(int j=1;j<v[i].size();j++)Recovery(v[i][j-1],v[i][j],i);}for(int i=1;i<=cnt;i++)if(!dfn[i])tarjan(i);for(int x=1;x<=cnt;x++){for(int i=0;i<G[x].size();i++){int y=G[x][i];if(col[x]==col[y])continue;in[col[x]]++;}}int ans=k;for(int i=1;i<=k;i++)if(!in[col[i]])ans=min(ans,f[col[i]]);printf("%d\n",ans-1);return 0;
}