今天写了一道题目,需要采用线段树合并+树上差分来解决
题目链接:P1600 [NOIP2016 提高组] 天天爱跑步 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
其实当时已经想到要用这两种方法,但苦于一直找不到转移方程,最后看了答案才领悟到一点推导公式的思路
我一开始的想法是对于起点s和终点t而言,对于x到其lca的路径中的节点,有人出现的时间结点就是dep[x]一直往上递推过去,然后我只需要最后合并找出符合w[i]的时间点就行了
但细想一下这个方法根本就不能实现,首先是lca到t的路径的结点无法处理,其次是差分数组加和的时候每次区间下标都要+1,而且也不知道x出发的人到哪个结点就消失,所以这种顺推的递推公式并不成立
那正确的思路是什么呢?
对于x到y路径
我们先分析x到lca的路径上的结点i,对于这个节点i,如果x出发的人能够被他观测到,那么应该满足关系式:dep[x]-dep[i]=w[i],这样一来,我们就可以得到:dep[i]+w[i]=dep[x],也就是说,对于路径上的每个这样的结点i,我们只要计数dep[x]的个数,就可以知道他能够观测到多少个人了
我们再来分析lca到y的路径上的结点j,对于这个结点j,如果x出发的人能够被他观测到,那么应该满足关系式:dep[x]-dep[lca]+dep[j]-dep[lca]=w[j],这样一来,我们就可以得到:dep[j]-w[i]=2*dep[lca]-dep[x],也就是说,对于路径上的每个这样的结点j,我们只需要计数2*dep[lca]-dep[x]的个数,就可以知道他可以观测到多少个人了
那么我们需要从哪里开始修改呢?
显然这些dep[x],2*dep[lca]-dep[x]所代表的人的个数只在x到y这条路径上生效,也就是说,我们只需要修改x结点上dep[x]的个数,然后在由下往上合并差分数组的时候,每个在路径(x,lca)的结点都可以享受到这份dep[x]的贡献。同理,我们只需要修改y结点上2*dep[lca]-dep[x]的个数,那么每个在路径(lca,y)的节点都可以享受到这份2*dep[lca]-dep[x]的贡献
而对于他们的lca来说,它既享受到了dep[x]的贡献,又享受到了2*dep[lca]-dep[x]的贡献,哪份是它不应该拥有的呢?其实这两者在这个节点这里是等效的,所以我们随机减免一个就可以了,我们这里选择给dep[x]的权值减1。那么对于lca的父亲而言,它就多享受到了一个2*dep[lca]-dep[x]的贡献,因此我们给其2*dep[lca]-dep[x]的权值-1,从而把这个差分修改的影响限制在了路径(x,y)中
这里我们选择用线段树来维护这个权值区间
实现代码如下:
数组版本:
//树上差分来处理树上路径的信息
//每个点建一棵权值线段树
//结点深度为区间下标
//sum维护结点深度出现的次数#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 3E5 + 10, M = N << 1;
#define mid (l+r>>1)
//链式前向星
int to[M], nxt[M], h[N], idx;
//左儿子,右儿子,sum差分权值和
int ls[N * 100], rs[N * 100], sum[N * 100];
//fa倍增父亲
//dep节点深度
int fa[N][22], dep[N];
//n棵线段树
int root[N], tot;
int lg[N];
//w:观察员
//ans:第i个节点的答案
int n, m, w[N], ans[N];
//加边
void add(int a,int b){to[++idx] = b;nxt[idx] = h[a];h[a] = idx;
}
//lg[i]==log2+1,i向下取整
void init(){for (int i = 1; i <= n;i++){lg[i] = lg[i - 1] + (1<<lg[i-1] == i);}
}
//快读
void read(int &x){x = 0;char c = getchar();while(!isdigit(c)){c = getchar();}while(isdigit(c)){x = (x << 3 + x << 1) + c - '0';c = getchar();}
}//树增
void dfs(int x,int f){dep[x] = dep[f] + 1;fa[x][0] = f;//距离x的递增父亲的距离不会超过x的深度for (int i = 1; i < lg[dep[x]];i++){//距离x的距离为2^i的父亲等于距离x的距离为2^(i-1)的父亲的距离它为距离2^(i-1)的父亲fa[x][i] = fa[fa[x][i - 1]][i - 1];}//递归每一个子结点for (int i = h[x], y; y = to[i];i=nxt[i]){if(y!=f)dfs(y, x);}
}int lca(int x,int y){//默认x为深度大的结点if(dep[x]<dep[y]){swap(x, y);}//往上爬的距离不会超过x与y的深度差for (int i = lg[dep[x] - dep[y]] - 1; ~i;i--){//如果跳到父亲的深度还是比y大,可以放心跳if(dep[fa[x][i]]>=dep[y]){x = fa[x][i];}}if(x==y){return y;}for (int i = lg[dep[x]] - 1; ~i;i--){if(fa[x][i]!=fa[y][i]){x = fa[x][i];y = fa[y][i];}}return fa[x][0];
}//动态开点
//单点修改
void change(int &u,int l,int r,int p,int k){if(!u)u = ++tot;if(l==r){sum[u] += k;return;}if(p<=mid){change(ls[u], l, mid, p, k);}else{change(rs[u], mid + 1, r, p, k);}
}//把x线段树合并到y上
int merge(int x,int y,int l,int r){if(!x||!y){return x + y;}if(l==r){sum[x] += sum[y];return x;}ls[x] = merge(ls[x], ls[y], l, mid);rs[x] = merge(rs[x], rs[y], mid + 1, r);return x;
}//单点查询
int query(int u,int l,int r,int p){if(l==r){return sum[u];}if(p<=mid){return query(ls[u], l, mid, p);}else{return query(rs[u], mid + 1, r, p);}
}//递归合并线段树
void dfs2(int x){for (int i = h[x], y; y = to[i];i=nxt[i]){if(y==fa[x][0])continue;dfs2(y);//由于整体向右平移了值域n,所以线段树的区间要开到2*nroot[x]=merge(root[x],root[y],1,n<<1);}//如果这里有观察员//而且if(w[x]&&n+dep[x]+w[x]<=(n<<1)){ans[x] += query(root[x], 1, n << 1, n + dep[x] + w[x]);}ans[x] += query(root[x], 1, n << 1, n + dep[x] - w[x]);
}int main(){scanf("%d%d", &n, &m);for (int i = 1, x, y; i < n;i++){scanf("%d%d", &x, &y);add(x, y);add(y, x);}init();for (int i = 1; i <= n;i++){scanf("%d", w + i);}dfs(1, 0);for (int i = 1, x, y,l; i <= m;i++){scanf("%d%d", &x, &y);l = lca(x, y);change(root[x], 1, n << 1, n + dep[x], 1);change(root[y], 1, n << 1, n + 2 * dep[l] - dep[x], 1);change(root[l], 1, n << 1, n + dep[x], -1);change(root[fa[l][0]], 1, n << 1, n + 2 * dep[l] - dep[x],-1);}dfs2(1);for (int i = 1; i <= n;i++){printf("%d ",ans[i]);}return 0;
}
指针版本:
#include <iostream>
#include <cstdio>
#include <cstring>
#define mid (l+r>>1)
using namespace std;
struct tree {struct tree* l = nullptr, * r = nullptr;int sum=0;
};
const int N = 3E5 + 10, M = N << 1;
int to[M], nxt[M], h[N], tot;
void add(int a, int b) {to[++tot] = b;nxt[tot] = h[a];h[a] = tot;
}
tree* tr[N];
int son[N], siz[N], top[N], dep[N], fa[N];
int ans[N], w[N];
int n, m;
//fa,dep,siz,son
void dfs1(int u) {dep[u] = dep[fa[u]] + 1;siz[u] = 1;for (int i = h[u], v; v = to[i]; i = nxt[i]) {if (v == fa[u])continue;fa[v] = u;dfs1(v);siz[u] += siz[v];if (siz[v] > siz[son[u]])son[u] = v;}
}//top
void dfs2(int u, int tp) {top[u] = tp;top[son[u]] = tp;if(son[u])dfs2(son[u], tp);for (int i = h[u], v; v = to[i]; i = nxt[i]) {if (v == fa[u] || v == son[u])continue;dfs2(v, v);}
}//公共祖先
int lca(int x, int y) {//不在一条链上的时候while (top[x] != top[y]) {if (dep[top[x]] < dep[top[y]]) {swap(x, y);}x = fa[top[x]];}return dep[x] < dep[y] ? x : y;
}//点修
void change(tree** u, int l, int r, int p, int k) {if (!*u) {*u = new tree;}if (l == r) {(*u)->sum += k;return;}if (p <= mid) {change(&((*u)->l), l, mid, p, k);}else {change(&((*u)->r), mid + 1, r, p, k);}//反正不用区间查找,所以不Pushup也没关系了
}//点查
int query(tree* u, int l, int r, int p) {if (!u) {return 0;}if (l == r) {return u->sum;}if (p <= mid) {return query(u->l, l, mid, p);}else {return query(u->r, mid + 1, r, p);}
}tree* merge(tree* x, tree* y) {if (!x || !y) {return x ? x : y;}x->sum += y->sum;x->l = merge(x->l, y->l);x->r = merge(x->r, y->r);return x;
}void dfs3(int u) {for (int i = h[u], v; v = to[i]; i = nxt[i]) {if (v == fa[u])continue;dfs3(v);tr[u] = merge(tr[u], tr[v]);}if (w[u]&&dep[u] + w[u] <= n) {ans[u] += query(tr[u], -n, n, dep[u] + w[u]);}ans[u] += query(tr[u], -n, n, dep[u] - w[u]);
}int main() {scanf("%d%d", &n, &m);for (int i = 1, x, y; i < n; i++) {scanf("%d%d", &x, &y);add(x, y);add(y, x);}for (int i = 1; i <= n; i++) {scanf("%d", w + i);}dfs1(1);dfs2(1, 1);for (int i = 1, s, t, l; i <= m; i++) {scanf("%d%d", &s, &t);l = lca(s, t);change(&tr[s], -n, n, dep[s], 1);change(&tr[t], -n, n, 2 * dep[l] - dep[s], 1);change(&tr[l], -n, n, dep[s], -1);change(&tr[fa[l]], -n, n, 2 * dep[l] - dep[s], -1);}dfs3(1);for (int i = 1; i <= n; i++) {printf("%d ", ans[i]);}return 0;
}