基本概念:
1.重儿子:假设节点u有n个子结点,其中以v子节点的为根子树的大小最大,那么v就是u的重儿子
2.轻儿子:除了重儿子以外的全部儿子都是轻儿子
3.轻边:结点u与轻儿子连接的边
4.重边:结点u与重儿子连接的边
5.轻链:均由轻儿子组成的一条链
6.重链:均由重儿子组成的一条链
预处理节点信息:
dep[u]:u节点的深度
fa[u]:u结点的父亲结点
son[u]:u结点的重儿子
siz[u]:以u节点为根的子树的大小
top[u]:u结点所在的链的顶点
首先,我们可以很简单地通过dfs获取一个结点的dep,fa和siz,从而也就获得了siz
实现代码如下:
#include <iostream>
using namespace std;
const int N = 2E5 + 10;
int dep[N], fa[N], son[N], siz[N], top[N];
int to[N << 1], nxt[N << 1], h[N], tot;
void dfs1(int u,int f){siz[u] = 1;dep[u] = dep[f] + 1;fa[u] = f;int max = 0;for (int i = h[u], v; v = to[i];i=nxt[i]){if(v==f){continue;}dfs1(v, u);siz[u] += siz[v];if(siz[v]>max){max = siz[v];son[u] = v;}}
}
接下来,我们再用一个dfs来获取top数组
处理的方式为:
重儿子的top就等于自己u节点的top
轻儿子的top就等于轻儿子本身
实现代码如下:
void dfs2(int u,int f){for (int i = h[u], v; v = to[i];i=nxt[i]){if(v==f){continue;}if(v==son[u]){top[v] = top[u];}else{top[v] = v;}dfs2(v, u);}
}
树链剖分的应用:
1.寻找最近公共祖先(lca)
对于两个结点x和y
假设他们在同一条链上,也就是top相同,那么他们的lca就是深度比较小的一方
如果他们不在同一条链上
我们知道,他们肯定可以通过若干条链走到同一条链上
如何将一个结点从一条链转移到另一条链呢?
只需要让top[u]的深度较大的一方跳到其top[u]的父亲结点上,自然就到了另一条新链了
而且可以保证他们两个的top越来越接近,直到top相同
实现代码如下:
int lca(int x,int y){while(top[x]!=top[y]){if(dep[top[x]]<dep[top[y]]){swap(x, y);}x = fa [top [x]];}return dep[x] < dep[y] ? x : y;
}
2.维护树上区间
我们轻重链剖分以后,每条链都是一个连续的区间
如果想要对路径(x,y)做区间修改和区间查询的操作
只需要对组成这条路径的若干条树链进行维护即可
具体操作为:
在跳跃到x与y在共同链之前
我们对区间dfn[top[x]]到dfn[x]进行修改,查询
维护区间的数据结构我们选择使用线段树
题目链接:P3384 【模板】重链剖分/树链剖分 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
实现代码如下(码了一个小时,找bug找了一个小时,还是对线段树不是很熟练):
#include <iostream>
using namespace std;
const int N = 1E5 + 10;
#define ll long long
#define ls (i << 1)
#define rs (i << 1 | 1)
#define mid (left + right >> 1)
int to[N << 1], nxt[N << 1], h[N], tot;
int dfn[N], siz[N], son[N], top[N], dep[N], fa[N], idx;
// dfn[u],轻重链u结点的dfs序(区间序号)
// v[u],u结点的点权
// a[i],区间下标i的点权
int a[N], v[N];
int n, m, s;
long long mod;
// 加边
void add(int a, int b)
{to[++tot] = b;nxt[tot] = h[a];h[a] = tot;
}
// 获取每个结点的根树的size,获取深度dep,获取儿子son
void dfs1(int u, int f)
{dep[u] = dep[f] + 1;siz[u] = 1;fa[u] = f;int max = 0;for (int i = h[u], v; v = to[i]; i = nxt[i]){if (v == f){continue;}dfs1(v, u);if (siz[v] > max){max = siz[v];son[u] = v;}siz[u] += siz[v];}
}
// 获取dfn序和top
void dfs2(int u, int f)
{dfn[u] = ++idx;a[idx] = v[u];if (son[u]){top[son[u]] = top[u];dfs2(son[u], u);}for (int i = h[u], v; v = to[i]; i = nxt[i]){if (v == f || v == son[u]){continue;}top[v] = v;dfs2(v, u);}
}
//线段树
struct node
{int l, r;ll sum;ll tag;
} tr[4 * N];
void pushup(int i)
{tr[i].sum = (tr[ls].sum + tr[rs].sum) % mod;
}
void pushdown(int i)
{if (tr[i].l != tr[i].r && tr[i].tag){tr[ls].sum = (tr[ls].sum + ((tr[ls].r-tr[ls].l+1)%mod*tr[i].tag)) % mod;tr[rs].sum = (tr[rs].sum + ((tr[rs].r-tr[rs].l+1)*tr[i].tag)) % mod;tr[rs].tag = (tr[i].tag+tr[rs].tag)%mod;tr[ls].tag = (tr[i].tag+tr[ls].tag)%mod;tr[i].tag = 0;}
}
// 建树
void build(int i, int left, int right)
{tr[i].l = left;tr[i].r = right;if (left == right){tr[i].sum = a[left];return;}build(ls, left, mid);build(rs, mid + 1, right);pushup(i);
}
void add(int i, ll k, int left, int right)
{if (tr[i].l >= left && tr[i].r <= right){tr[i].sum = (((tr[i].r - tr[i].l + 1) % mod * k) % mod + tr[i].sum) % mod;tr[i].tag = (k + tr[i].tag) % mod;return;}int mmid = (tr[i].l + tr[i].r >> 1);pushdown(i);if (right >= mmid + 1){add(rs, k, left, right);}if (left <= mmid){add(ls, k, left, right);}pushup(i);
}
ll search(int i, int left, int right)
{if (tr[i].l >= left && tr[i].r <= right){return tr[i].sum;}pushdown(i);ll res = 0;int mmid = (tr[i].l + tr[i].r >> 1);if (right >= mmid + 1){res += search(rs, left, right);}if (left <= mmid){res = (res + search(ls, left, right)) % mod;}return res;
}
int main()
{cin >> n >> m >> s >> mod;for (int i = 1; i <= n; i++){cin >> v[i];v[i] %= mod;}for (int i = 1, x, y; i < n; i++){cin >> x >> y;add(x, y);add(y, x);}dfs1(s, 0);top[s] = s;dfs2(s, 0);build(1, 1, n);ll z;int l, r;ll ans;for (int i = 1, opt, x, y; i <= m; i++){cin >> opt;if (opt == 1){cin >> x >> y >> z;while (top[x] != top[y]){if (dep[top[x]] < dep[top[y]]){swap(x, y);}add(1, z, dfn[top[x]], dfn[x]);x = fa[top[x]];}if (dep[x] < dep[y]){l = dfn[x];r = dfn[y];}else{l = dfn[y];r = dfn[x];}add(1, z, l, r);}else if (opt == 2){cin >> x >> y;ans = 0;while (top[x] != top[y]){if (dep[top[x]] < dep[top[y]]){swap(x, y);}ans = (ans + search(1, dfn[top[x]], dfn[x])) % mod;x = fa[top[x]];}if (dep[x] < dep[y]){l = dfn[x];r = dfn[y];}else{l = dfn[y];r = dfn[x];}ans = (ans + search(1, l, r)) % mod;cout << ans << endl;}else if (opt == 3){cin >> x >> z;add(1, z, dfn[x], dfn[x] + siz[x] - 1);}else if (opt == 4){cin >> x;cout << search(1, dfn[x], dfn[x] + siz[x] - 1) << endl;}}return 0;
}