这几天学了一个树链剖分,觉得还不是很难,这里我试着讲一讲吧。
首先,我认为树链剖分是把在树上一个节点一个节点的走改为按照某种规则跳,从而降低了时间复杂度。
那这是什么规则呢?
首先我们得知道什么是重链,知道什么是重链就得先知道什么是重儿子,重儿子就是子树较大的儿子。然后对于一个点,我们总是往他的重儿子走,这样就构成了重链,那么剩下的就是轻链。
放张图直观些
然后我们同样可以对树进行dfs,只不过重儿子优先,这样我们也得到了一个dfs序,于是我们把树上问题成功转化成了线性问题。接着就可以用线段树等数据结构维护了。
那就拿这道板子体为例:https://www.luogu.org/problemnew/show/P3384
首先是两遍dfs,第一遍dfs维护子树大小size[],节点深度dep[],重儿子son[],以及一个节点的父亲节点(因为如果跳到了一条链的顶端,就要再自己走到他的父亲节点)。
1 void dfs1(int now) 2 { 3 vis[now] = 1; 4 size[now] = 1; 5 for(int i = 0; i < (int)v[now].size(); ++i) 6 { 7 if(!vis[v[now][i]]) 8 { 9 dep[v[now][i]] = dep[now] + 1; 10 fa[v[now][i]] = now; 11 dfs1(v[now][i]); 12 size[now] += size[v[now][i]]; 13 if(!son[now] || size[son[now]] < size[v[now][i]]) son[now] = v[now][i]; 14 //如果没有重儿子,或者当前子树大小大于重儿子的子树大小,就更新重儿子 15 } 16 } 17 }
第二遍dfs是维护dfs序dfsx[],每一条链的顶端是哪一个节点。但我们还要在维护一个pos[],因为当我们将树转化成线性后,用线段树建树的时候需要添加节点,而这个节点的编号是dfs序的编号,所以需要再用一个数组记录dfs序的编号所对应的树上节点编号。
1 int cnt = 0, dfsx[maxn], pos[maxn], top[maxn]; 2 void dfs2(int now) 3 { 4 //dfsx[]因为智慧更新一次,所以可以当做vis[]用 5 dfsx[now] = ++cnt; pos[cnt] = now; 6 if(son[now]) 7 { 8 top[son[now]] = top[now]; 9 dfs2(son[now]); //优先走重儿子,保证一条链在dfs序上的编号是连续的 10 } 11 for(int i = 0; i < (int)v[now].size(); ++i) 12 { 13 if(!dfsx[v[now][i]] && son[now] != v[now][i]) //再走不是重儿子的节点 14 { 15 top[v[now][i]] = v[now][i]; //轻儿子所在的链只有他自己一个节点,所以顶端节点就是他自己 16 dfs2(v[now][i]); 17 } 18 } 19 }
这两个预处理完事后就可以看看题了。
第一个询问,将树从x到y结点最短路径上所有节点的值都加上z。
首先我们要将x,y移到同一条链上,具体操作就是如果其中一个点所在链的顶端的深度更低,就将他跳到链的顶端,并更新他到顶端节点的区间。
移到同一条链上后,就更新这两个点的区间就行了
1 void pathUpdate(int x, int y, int z) 2 { 3 while(top[x] != top[y]) //先把这俩搞到一条链上 4 { 5 if(dep[top[x]] < dep[top[y]]) swap(x, y); //默认让x跳 6 update(dfsx[top[x]], dfsx[x], 1, z); 7 x = fa[top[x]]; 8 } 9 if(dfsx[x] > dfsx[y]) swap(x, y); 10 update(dfsx[x], dfsx[y], 1, z); 11 }
操作2: 求树从x到y结点最短路径上所有节点的值之和
和修改一样,先把两点移到同一条链上,然后计算跳的点在该链上的贡献
1 ll pathQuery(int x, int y) 2 { 3 ll ret = 0; 4 while(top[x] != top[y]) 5 { 6 if(dep[top[x]] < dep[top[y]]) swap(x, y); 7 ret += query(dfsx[top[x]], dfsx[x], 1); ret %= mod; 8 x = fa[top[x]]; 9 } 10 if(dfsx[x] > dfsx[y]) swap(x, y); 11 ret += query(dfsx[x], dfsx[y], 1); ret %= mod; 12 return ret; 13 }
操作3: 将以x为根节点的子树内所有节点值都加上z
值得一提的是,尽管我们在维护dfs序时是重链优先遍历,但仍满足一个节点以及他的子树在dfs序上是一段长为子树大小的连续区间,自己画一画就明白了
这里和查询放一块
1 void sbtUpdate(int x, int z) 2 { 3 update(dfsx[x], dfsx[x] + size[x] - 1, 1, z); 4 } 5 ll sbtQuery(int x) 6 { 7 return query(dfsx[x], dfsx[x] + size[x] - 1, 1); 8 }
这样板子就写完了,是不是很简单?
然后最重要的一件事是别忘了取模,而且每一个运算后都要取,否则你就可能70分代码debug一小时……
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 #include<cmath> 5 #include<algorithm> 6 #include<vector> 7 #include<cctype> 8 using namespace std; 9 #define enter printf("\n") 10 #define space printf(" ") 11 typedef long long ll; 12 const int INF = 0x3f3f3f3f; 13 const int maxn = 1e5 + 5; 14 inline ll read() 15 { 16 ll ans = 0; 17 char ch = getchar(), last = ' '; 18 while(!isdigit(ch)) {last = ch; ch = getchar();} 19 while(isdigit(ch)) 20 { 21 ans = ans * 10 + ch - '0'; ch = getchar(); 22 } 23 if(last == '-') ans = -ans; 24 return ans; 25 } 26 inline void write(ll x) 27 { 28 if(x < 0) x = -x, putchar('-'); 29 if(x >= 10) write(x / 10); 30 putchar('0' + x % 10); 31 } 32 33 int n, m, s, mod; 34 int a[maxn]; 35 vector<int> v[maxn]; 36 37 bool vis[maxn]; 38 int fa[maxn], son[maxn], size[maxn], dep[maxn]; 39 void dfs1(int now) 40 { 41 vis[now] = 1; 42 size[now] = 1; 43 for(int i = 0; i < (int)v[now].size(); ++i) 44 { 45 if(!vis[v[now][i]]) 46 { 47 dep[v[now][i]] = dep[now] + 1; 48 fa[v[now][i]] = now; 49 dfs1(v[now][i]); 50 size[now] += size[v[now][i]]; 51 if(!son[now] || size[son[now]] < size[v[now][i]]) son[now] = v[now][i]; 52 //如果没有重儿子,或者当前子树大小大于重儿子的子树大小,就更新重儿子 53 } 54 } 55 } 56 //第二遍dfs是维护dfs序dfsx[],每一条链的顶端是哪一个节点 57 58 int cnt = 0, dfsx[maxn], pos[maxn], top[maxn]; 59 void dfs2(int now) 60 { 61 //dfsx[]因为智慧更新一次,所以可以当做vis[]用 62 dfsx[now] = ++cnt; pos[cnt] = now; 63 if(son[now]) 64 { 65 top[son[now]] = top[now]; 66 dfs2(son[now]); //优先走重儿子,保证一条链在dfs序上的编号是连续的 67 } 68 for(int i = 0; i < (int)v[now].size(); ++i) 69 { 70 if(!dfsx[v[now][i]] && son[now] != v[now][i]) //再走不是重儿子的节点 71 { 72 top[v[now][i]] = v[now][i]; //轻儿子所在的链只有他自己一个节点,所以顶端节点就是他自己 73 dfs2(v[now][i]); 74 } 75 } 76 } 77 78 int l[maxn << 2], r[maxn << 2]; 79 ll sum[maxn << 2], lazy[maxn << 2]; 80 void build(int L, int R, int now) 81 { 82 l[now] = L; r[now] = R; 83 if(L == R) {sum[now] = a[pos[L]]; return;} 84 int mid = (L + R) >> 1; 85 build(L, mid, now << 1); 86 build(mid + 1, R, now << 1 | 1); 87 sum[now] = (sum[now << 1] + sum[now << 1 | 1]) % mod; 88 } 89 void pushdown(int now) 90 { 91 if(lazy[now]) 92 { 93 lazy[now << 1] += lazy[now]; lazy[now << 1] %= mod; 94 lazy[now << 1 | 1] += lazy[now]; lazy[now << 1 | 1] %= mod; 95 sum[now << 1] += (ll)(r[now << 1] - l[now << 1] + 1) * lazy[now]; sum[now << 1] %= mod; 96 sum[now << 1 | 1] += (ll)(r[now << 1 | 1] - l[now << 1 | 1] + 1) * lazy[now]; sum[now << 1 | 1] %= mod; 97 lazy[now] = 0; 98 } 99 } 100 void update(int L, int R, int now, int d) 101 { 102 if(L == l[now] && R == r[now]) 103 { 104 sum[now] += (ll)(r[now] - l[now] + 1) * d; sum[now] %= mod; 105 lazy[now] += d; lazy[now] %= mod; 106 return; 107 } 108 pushdown(now); 109 int mid = (l[now] + r[now]) >> 1; 110 if(R <= mid) update(L, R, now << 1, d); 111 else if(L > mid) update(L, R, now << 1 | 1, d); 112 else {update(L, mid, now << 1, d); update(mid + 1, R, now << 1 | 1, d);} 113 sum[now] = sum[now << 1] + sum[now << 1 | 1]; 114 } 115 ll query(int L, int R, int now) 116 { 117 if(L == l[now] && R == r[now]) return sum[now]; 118 pushdown(now); 119 int mid = (l[now] + r[now]) >> 1; 120 if(R <= mid) return query(L, R, now << 1); 121 else if(L > mid) return query(L, R, now << 1 | 1); 122 else return (query(L, mid, now << 1) + query(mid + 1, R, now << 1 | 1)) % mod; 123 } 124 125 void pathUpdate(int x, int y, int z) 126 { 127 while(top[x] != top[y]) //先把这俩搞到一条链上 128 { 129 if(dep[top[x]] < dep[top[y]]) swap(x, y); //默认让x跳 130 update(dfsx[top[x]], dfsx[x], 1, z); 131 x = fa[top[x]]; 132 } 133 if(dfsx[x] > dfsx[y]) swap(x, y); 134 update(dfsx[x], dfsx[y], 1, z); 135 } 136 ll pathQuery(int x, int y) 137 { 138 ll ret = 0; 139 while(top[x] != top[y]) 140 { 141 if(dep[top[x]] < dep[top[y]]) swap(x, y); 142 ret += query(dfsx[top[x]], dfsx[x], 1); ret %= mod; 143 x = fa[top[x]]; 144 } 145 if(dfsx[x] > dfsx[y]) swap(x, y); 146 ret += query(dfsx[x], dfsx[y], 1); ret %= mod; 147 return ret; 148 } 149 150 void sbtUpdate(int x, int z) 151 { 152 update(dfsx[x], dfsx[x] + size[x] - 1, 1, z); 153 } 154 ll sbtQuery(int x) 155 { 156 return query(dfsx[x], dfsx[x] + size[x] - 1, 1); 157 } 158 159 int main() 160 { 161 n = read(); m = read(); s = read(); mod = read(); 162 for(int i = 1; i <= n; ++i) a[i] = read(); 163 for(int i = 1 ; i < n; ++i) 164 { 165 int a = read(), b = read(); 166 v[a].push_back(b); v[b].push_back(a); 167 } 168 dfs1(s); 169 top[s] = s; dfs2(s); 170 build(1, n, 1); 171 for(int i = 1; i <= m ; ++i) 172 { 173 int d = read(); 174 if(d == 1) 175 { 176 int x = read(), y = read(), z = read(); 177 pathUpdate(x, y, z); 178 } 179 else if(d == 2) 180 { 181 int x = read(), y = read(); 182 write(pathQuery(x, y)); enter; 183 } 184 else if(d == 3) 185 { 186 int x = read(), z = read(); 187 sbtUpdate(x, z); 188 } 189 else 190 { 191 int x = read(); 192 write(sbtQuery(x)); enter; 193 } 194 } 195 return 0; 196 }