P1600 [NOIP2016 提高组] 天天爱跑步
给定一颗有nnn个点的树,有mmm个人在树上移动,第iii个人从sis_isi点,移动到tit_iti点,且他们按照最短路移动,每秒移动一条边的距离,
点iii在wiw_iwi时刻有一个观察员,我们需要对每个点统计,在wiw_iwi时刻有多少个人恰好到达这个点。
如果第vvv个人,在wuw_uwu时刻恰好出现在点uuu,则一定有dis(sv,u)=wudis(s_v, u) = w_udis(sv,u)=wu,且uuu在sv,tvs_v, t_vsv,tv的路径上,
满足dis(s,u)+dis(u,t)=dis(s,t)dis(s, u) + dis(u, t) = dis(s, t)dis(s,u)+dis(u,t)=dis(s,t),假设lca(s,t)=zlca(s, t) = zlca(s,t)=z分两种情况讨论,以111号节点为根,d(i)d(i)d(i)表示第iii号节点的深度,
-
uuu在s−>zs->zs−>z的路径上,则有d(s)−d(u)+d(u)+d(t)−2×d(z)=d(s)+d(z)−2×d(z)d(s) - d(u) + d(u) + d(t) - 2 \times d(z) = d(s) + d(z) - 2 \times d(z)d(s)−d(u)+d(u)+d(t)−2×d(z)=d(s)+d(z)−2×d(z),且d(s)−d(u)=wud(s) - d(u) = w_ud(s)−d(u)=wu。
-
uuu在t−>zt->zt−>z的路径上,则有d(s)+d(u)−2×d(z)+d(t)−d(z)=d(s)+d(z)−2×d(z)d(s) + d(u) - 2 \times d(z) + d(t) - d(z) = d(s) + d(z) - 2 \times d(z)d(s)+d(u)−2×d(z)+d(t)−d(z)=d(s)+d(z)−2×d(z),且d(s)+d(u)−2×d(z)=wud(s) + d(u) - 2 \times d(z) = w_ud(s)+d(u)−2×d(z)=wu。
前项都是符合要求的,所以看后面的两项d(s)=d(u)+wud(s) = d(u) + w_ud(s)=d(u)+wu,d(u)−w=2×d(z)−d(s)d(u) - w = 2 \times d(z) - d(s)d(u)−w=2×d(z)−d(s)。
考虑树上差分,在sss点插入d(s)d(s)d(s),在ttt点插入2×d(z)−d(s)2 \times d(z) - d(s)2×d(z)−d(s),
在lcalcalca处减去d(s)d(s)d(s)的值,在fa(lca)fa(lca)fa(lca)处减去2×d(z)−d(s)2 \times d(z) - d(s)2×d(z)−d(s)的值,二者可互换顺序,之后只要线段树合并,再单点查询值即可。
#include <bits/stdc++.h>using namespace std;const int N = 3e5 + 10, maxn = 300000;int head[N], to[N << 1], nex[N << 1], cnt = 1;int dep[N], fa[N], son[N], sz[N], top[N];int w[N], ans[N], n, m;int root[N], ls[N << 5], rs[N << 5], sum[N << 5], num;vector<pair<int, int>> a[N];void add(int x, int y) {to[cnt] = y;nex[cnt] = head[x];head[x] = cnt++;
}void dfs1(int rt, int f) {fa[rt] = f, dep[rt] = dep[f] + 1, sz[rt] = 1;for (int i = head[rt]; i; i = nex[i]) {if (to[i] == f) {continue;}dfs1(to[i], rt);sz[rt] += sz[to[i]];if (!son[rt] || sz[to[i]] > sz[son[rt]]) {son[rt] = to[i];}}
}void dfs2(int rt, int tp) {top[rt] = tp;if (!son[rt]) {return ;}dfs2(son[rt], tp);for (int i = head[rt]; i; i = nex[i]) {if (to[i] == son[rt] || to[i] == fa[rt]) {continue;}dfs2(to[i], to[i]);}
}int lca(int u, int v) {while (top[u] != top[v]) {if (dep[top[u]] < dep[top[v]]) {swap(u, v);}u = fa[top[u]];}return dep[u] < dep[v] ? u : v;
}void update(int &rt, int l, int r, int x, int v) {if (!rt) {rt = ++num;}sum[rt] += v;if (l == r) {return ;}int mid = l + r >> 1;if (x <= mid) {update(ls[rt], l, mid, x, v);}else {update(rs[rt], mid + 1, r, x, v);}
}int query(int rt, int l, int r, int x) {if (l == r) {return sum[rt];}int mid = l + r >> 1;if (x <= mid) {return query(ls[rt], l, mid, x);}else {return query(rs[rt], mid + 1, r, x);}
}int merge(int x, int y, int l, int r) {if (!x || !y) {return x | y;}if (l == r) {sum[x] += sum[y];return x;}int mid = l + r >> 1;ls[x] = merge(ls[x], ls[y], l, mid);rs[x] = merge(rs[x], rs[y], mid + 1, r);sum[x] = sum[ls[x]] + sum[rs[x]];return x;
}void dfs(int rt, int fa) {for (auto it : a[rt]) {update(root[rt], -maxn, maxn, it.first, it.second);}for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa) {continue;}dfs(to[i], rt);root[rt] = merge(root[rt], root[to[i]], -maxn, maxn);}ans[rt] = query(root[rt], -maxn, maxn, dep[rt] + w[rt]);if (w[rt]) {ans[rt] += query(root[rt], -maxn, maxn, dep[rt] - w[rt]);}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);scanf("%d %d", &n, &m);for (int i = 1, x, y; i < n; i++) {scanf("%d %d", &x, &y);add(x, y);add(y, x);}dfs1(1, 0);dfs2(1, 1);for (int i = 1; i <= n; i++) {scanf("%d", &w[i]);}for (int i = 1, s, t; i <= m; i++) {scanf("%d %d", &s, &t);int f = lca(s, t), ff = fa[f];a[s].push_back({dep[s], 1});a[t].push_back({2 * dep[f] - dep[s], 1});a[f].push_back({dep[s], -1});a[ff].push_back({2 * dep[f] - dep[s], -1});}dfs(1, 0);for (int i = 1; i <= n; i++) {printf("%d%c", ans[i], i == n ? '\n' : ' ');}return 0;
}