LG P3233 [HNOI2014]世界树
Solution
看完题意,显然是虚树。
建出虚树后,可以容易地求出虚树上的点会被哪一个点管辖,关键在于不在虚树上的点归属于哪个点,我们分类讨论不在虚树上的点的贡献:
我们先假设虚树上的点全是关键点,注意后文的子树都是原树的子树。
- 在虚树上点x,yx,yx,y路径上(不包含x,yx,yx,y)的点(设依次为v1,v2...vkv_1,v_2...v_kv1,v2...vk,它们不在虚树上)及其子树中的点:它们要么属于xxx,要么属于yyy,且必然存在一个midmidmid,使得v1,v2...vmid−1v_1,v_2...v_{mid-1}v1,v2...vmid−1属于xxx,vmid...vkv_{mid}...v_kvmid...vk属于yyy,而求解这个midmidmid位置的判定条件是distx,middist_{x,mid}distx,mid和distmid,ydist_{mid,y}distmid,y的大小(大小相同看编号大小),这个可以通过二分简单地得到。而对于那些viv_ivi子树中的点,一定和viv_ivi的归属相同。
- 在虚树上的点xxx的子树中不在虚树上的儿子vvv以及它的子树中的点:也就是vvv子树中没有关键点,那么一定整个子树归属于xxx,直接统计即可。
- 完全不在虚树上的点:它们一定不在虚树的根的子树内(可以理解为在虚树的上面),它们一定归属于虚树的根。
实现时,我们通过一个向上和一个向下的dpdpdp求出虚树上点的归属。
然后再对于每个点xxx,枚举其出边vvv,求出midmidmid,计算x,vx,vx,v的新增贡献。
并且记录一个gxg_xgx表示2,32,32,3类的答案,初始为子树大小,枚举出边vvv时,把xxx包含vvv的儿子的子树结点个数去掉,最后让xxx的贡献加上gxg_xgx即可。
时间复杂度O(nlgn)O(nlgn)O(nlgn)。
有一个实现过程中的小tricktricktrick是建虚树时直接把111结点放入虚树,会大大减少一些不必要的分类讨论。
Code
#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <string>
#include <cstring>
#include <ctime>
#include <cassert>
#include <string.h>
//#include <unordered_set>
//#include <unordered_map>
//#include <bits/stdc++.h>#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se secondusing namespace std;template<typename T>inline bool upmin(T &x,T y) { return y<x?x=y,1:0; }
template<typename T>inline bool upmax(T &x,T y) { return x<y?x=y,1:0; }typedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int,int> PR;
typedef vector<int> VI;const lod eps=1e-11;
const lod pi=acos(-1);
const int oo=1<<30;
const ll loo=1ll<<62;
const int mods=1e9+7;
const int MAXN=600005;
const int INF=0x3f3f3f3f;//1061109567
/*--------------------------------------------------------------------*/
inline int read()
{int f=1,x=0; char c=getchar();while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }while (c>='0'&&c<='9') { x=(x<<3)+(x<<1)+(c^48); c=getchar(); }return x*f;
}
PR mn[MAXN];
vector<int> e[MAXN],E[MAXN];
int a[MAXN],b[MAXN],f[MAXN],g[MAXN],stk[MAXN],top=0,n,m;
int dep[MAXN],sz[MAXN],Log[MAXN],dfn[MAXN],fa[MAXN][20],head[MAXN],flag[MAXN],DFN=0,edgenum;
int getlca(int x,int y)
{if (dep[x]<dep[y]) swap(x,y);for (int i=Log[dep[x]];i>=0;i--)if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];if (x==y) return x;for (int i=Log[dep[x]];i>=0;i--)if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];return fa[x][0];
}
int jump(int x,int d)
{for (int i=Log[dep[x]];i>=0;i--)if (dep[fa[x][i]]>=d) x=fa[x][i];return x;
}
void dfs(int x,int father)
{fa[x][0]=father,sz[x]=1,dep[x]=dep[father]+1,dfn[x]=++DFN;for (int i=1;i<=Log[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1];for (auto v:e[x]) if (v!=father) dfs(v,x),sz[x]+=sz[v];
}
void Init()
{dep[0]=-1,Log[1]=0;for (int i=1;i<=n;i++) Log[i]=Log[i>>1]+1;dfs(1,0);
}void add(int u,int v) { E[u].PB(v); }
void build()
{sort(a+1,a+m+1,[&](int x,int y){ return dfn[x]<dfn[y]; });stk[top=1]=1;for (int i=1+(a[1]==1);i<=m;i++){int lca=getlca(stk[top],a[i]);while (top>1&&dep[stk[top-1]]>dep[lca]) add(stk[top-1],stk[top]),top--;if (dep[stk[top]]>dep[lca]) add(lca,stk[top--]);if (!top||stk[top]!=lca) stk[++top]=lca;stk[++top]=a[i];}while (top>1) add(stk[top-1],stk[top]),top--;
}void up(int x,int father)
{mn[x]=(flag[x]?MP(0,x):MP(INF,x));for (auto v:E[x]){if (v==father) continue;up(v,x),upmin(mn[x],MP(mn[v].fi+dep[v]-dep[x],mn[v].se));}
}
void down(int x,int father)
{for (auto v:E[x])if (v!=father) upmin(mn[v],MP(mn[x].fi+dep[v]-dep[x],mn[x].se)),down(v,x);
}void tree_dp(int x,int father)
{for (auto v:E[x])if (v!=father) tree_dp(v,x);g[x]=sz[x];for (auto v:E[x]){int t=jump(v,dep[x]+1); g[x]-=sz[t];if (mn[x].se==mn[v].se) { f[mn[x].se]+=sz[t]-sz[v]; continue; }int mid=v;for (int i=Log[dep[v]];i>=0;i--){int p=fa[mid][i];if (dep[p]<=dep[x]) continue;if (MP(dep[p]-dep[x]+mn[x].fi,mn[x].se)>MP(dep[v]-dep[p]+mn[v].fi,mn[v].se)) mid=p;}f[mn[x].se]+=sz[t]-sz[mid];f[mn[v].se]+=sz[mid]-sz[v];}f[mn[x].se]+=g[x];
}void clean(int x,int father)
{for (auto v:E[x]) if (v!=father) clean(v,x);f[x]=g[x]=0,E[x].clear();
}
void clear()
{for (int i=1;i<=m;i++) flag[a[i]]=0;clean(1,0),top=0;
}signed main()
{n=read();for (int i=1,u,v;i<n;i++) u=read(),v=read(),e[u].PB(v),e[v].PB(u);Init();int Case=read();while (Case--){m=read();for (int i=1;i<=m;i++) a[i]=b[i]=read(),flag[a[i]]=1;build(),up(1,0),down(1,0),tree_dp(1,0);for (int i=1;i<=m;i++) printf("%d ",f[b[i]]); puts("");clear();}return 0;
}