题解:
写起来还稍微有点麻烦。
dfs序+线段树可以维护子树的整体修改和查询。 因此,这道题我们要往子树上靠。
我们首先从1号点进行dfs遍历,顺便求出点的dfs序和深度,然后我们采用倍增的思想,可以预处理出每个点的祖先是谁。然后可以在O(log(n))O(log(n))的时间复杂度内求出任意两点的lca(u,v)lca(u,v)。
而现在整个树的根是可以改变的,因此,我们需要一个结论,也就是说当树的根节点被改变为root时候,u和v的新的lca,也就是newlca(u,v)=lca(u,v)xorlca(u,root)xorlca(v,root)newlca(u,v)=lca(u,v)xorlca(u,root)xorlca(v,root) (这个可以自己画画图看一下)。
找到newlca以后还不行,根据newlca与root的关系不一样,还需要进一步讨论。
1. 当newlca=rootnewlca=root的时候,要操作的子树就是整颗树。
2. 当lca(newlca,root)!=newlcalca(newlca,root)!=newlca 那么要操作的子树就是以1为根节点时候的newlca的子树。
3. 当lca(newlca,root)==newlcalca(newlca,root)==newlca的时候,那么要操作的就是整颗树减去以(root到newlca链上深度为dep[newlca]-1的)点的子树。
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define int long long
#define pr(x) cout<<#x<<":"<<x<<endl
const int maxn = 1e5+7;
int n,q;
struct edge{int u,v,nxt;
}es[maxn<<1];
int head[maxn];
int tot = 0;
void addedge(int u,int v){es[tot].u = u,es[tot].v = v,es[tot].nxt = head[u];head[u] = tot++;
}
int fa[maxn][22],dep[maxn];
int idx = 0,IN[maxn],OUT[maxn];
int segtree[maxn<<3],addmark[maxn<<3];
void init(){memset(fa,0,sizeof(fa));memset(head,-1,sizeof(head));memset(IN,0,sizeof(IN));memset(OUT,0,sizeof(OUT));memset(dep,0,sizeof(dep));tot = idx = 0;
}
void dfs(int u,int myfa,int dp){dep[u] = dp;IN[u] = ++idx;for(int e = head[u];e != -1;e = es[e].nxt){int v = es[e].v;if(v == myfa) continue;fa[v][0] = u;dfs(v,u,dp+1);}OUT[u] = ++idx;
}
void pushdown(int rt,int lft,int rgt){if(addmark[rt]){int mid = (lft + rgt)/2;addmark[2*rt] += addmark[rt];addmark[2*rt+1] += addmark[rt];segtree[2*rt] += addmark[rt]*(mid-lft+1);segtree[2*rt+1] += addmark[rt]*(rgt-mid);addmark[rt] = 0;}
}
void pushup(int rt){segtree[rt] = segtree[rt*2] + segtree[rt*2+1];
}
int val[maxn];
/*
void build(int rt,int L,int R){if(R == L) {segtree[rt] = val[L];}else{build(L,mid);build(mid+1,R);}
}*/
void ins(int rt,int lft,int rgt,int L,int R,int adv){if(rgt < L || lft > R) return ;if(L <= lft && R >= rgt) {segtree[rt] += (rgt-lft+1)*adv;addmark[rt] += adv;return ;}int mid = (lft + rgt) / 2;pushdown(rt,lft,rgt);ins(rt*2,lft,mid,L,R,adv);ins(rt*2+1,mid+1,rgt,L,R,adv);pushup(rt);
}
int ask(int rt,int lft,int rgt,int L,int R){if(rgt < L || lft > R) return 0;if(L <= lft && R >= rgt) return segtree[rt];pushdown(rt,lft,rgt);int mid = (lft + rgt)/2;return ask(rt*2,lft,mid,L,R) + ask(rt*2+1,mid+1,rgt,L,R);
}
void makelca(){for(int i = 1;i < 20;++i){for(int u = 1;u <= n;++u){fa[u][i] = fa[fa[u][i-1]][i-1];}}
}
int lca(int u,int v){if(dep[u] < dep[v]) swap(u,v);int dpc = dep[u] - dep[v];if(dpc){int t = 0;while(dpc){if(dpc & 1)u = fa[u][t];t++;dpc >>= 1;}}if(u == v) return u;for(int i = 19;u != v && i >= 0;--i ){if(fa[u][i] != fa[v][i]) {u = fa[u][i];v = fa[v][i];}}return fa[u][0];
}int root = 1;
main(){init();scanf("%lld%lld",&n,&q);for(int i = 1;i <= n;++i){scanf("%lld",&val[i]);}for(int i = 0;i < n-1;++i){int u,v;scanf("%lld%lld",&u,&v);addedge(u,v);addedge(v,u);}dfs(1,-1,0);makelca();for(int i = 1;i <= n;++i){ins(1,1,2*n,IN[i],IN[i],val[i]);ins(1,1,2*n,OUT[i],OUT[i],val[i]);}//计算lcafor(int i = 0;i < q;++i){int op ;scanf("%lld",&op);if(op == 1){scanf("%lld",&root);}else if(op == 2){int u,v,x;scanf("%lld%lld%lld",&u,&v,&x);int rt = lca(u,v)^lca(u,root)^lca(root,v);if(rt == root) ins(1,1,2*n,1,2*n,x);else if(lca(rt,root) != rt) ins(1,1,2*n,IN[rt],OUT[rt],x);else{int dpc = dep[root]-dep[rt]-1;int t = 0;int tmp = root;while(dpc){if(dpc&1)tmp = fa[tmp][t];t++;dpc >>= 1;}//cout<<tmp<<' '<<IN[tmp]<<' '<<OUT[tmp]<<endl;ins(1,1,2*n,1,2*n,x);ins(1,1,2*n,IN[tmp],OUT[tmp],-x);}}else if(op == 3){int v;scanf("%lld",&v);if(v == root){printf("%lld\n",ask(1,1,2*n,1,2*n)/2);}else if(lca(v,root) != v){printf("%lld\n",ask(1,1,2*n,IN[v],OUT[v])/2);}else{int tmp = root;int dpc = dep[root]-1-dep[v];int t = 0;while(dpc){if(dpc&1)tmp = fa[tmp][t];++t;dpc >>= 1;}int ans = (ask(1,1,2*n,1,2*n)-ask(1,1,2*n,IN[tmp],OUT[tmp]))/2;printf("%lld\n",ans);}}}//return 0;
}