P3899 [湖南集训]谈笑风生
给定一颗以111号节点为根的树,如果a≠ba \neq ba=b,且aaa是bbb的祖先,则aaa比bbb更厉害,如果a≠ba \neq ba=b,且dis(a,b)≤xdis(a, b) \leq xdis(a,b)≤x,xxx为给定的一个数,则a,ba, ba,b紧邻。
现有mmm次询问,每次询问给定p,kp, kp,k,为存在多少个三元组(p,b,c)(p, b, c)(p,b,c)满足一下条件:
- p,bp, bp,b都比ccc厉害。
- p,bp, bp,b彼此紧邻对于给定的常数kkk。
因为p,bp, bp,b都比ccc更厉害,则p,bp, bp,b都是ccc的祖先,分情况讨论:
-
ppp是bbb的祖先
只要满足bbb在ppp的子树中且,ccc在bbb的子树中即可。
-
bbb是ppp的祖先
bbb一定在1−>p1-> p1−>p的路径上,且ccc在ppp的子树中。
线段树合并搞一搞,在线,离线都可,复杂度O(nlogn)O(n \log n)O(nlogn)。
#include <bits/stdc++.h>using namespace std;const int N = 3e5 + 10;int head[N], to[N << 1], nex[N << 1], cnt = 1;int root[N], ls[N << 5], rs[N << 5], num;int dep[N], sz[N], n, m;long long sum[N << 5];long long ans[N];vector<pair<int, int>> query[N];void add(int x, int y) {to[cnt] = y;nex[cnt] = head[x];head[x] = cnt++;
}int merge(int x, int y, int l, int r) {if (!x || !y) {return x | y;}if (l == r) {sum[x] += sum[y];return x;}int mid = l + r >> 1;ls[x] = merge(ls[x], ls[y], l, mid);rs[x] = merge(rs[x], rs[y], mid + 1, r);sum[x] = sum[ls[x]] + sum[rs[x]];return x;
}void update(int &rt, int l, int r, int x, int v) {if (!rt) {rt = ++num;}sum[rt] += v;if (l == r) {return ;}int mid = l + r >> 1;if (x <= mid) {update(ls[rt], l, mid, x, v);}else {update(rs[rt], mid + 1, r, x, v);}
}long long ask(int rt, int l, int r, int L, int R) {if (l >= L && r <= R) {return sum[rt];}long long ans = 0;int mid = l + r >> 1;if (L <= mid) {ans += ask(ls[rt], l, mid, L, R);}if (R > mid) {ans += ask(rs[rt], mid + 1, r, L, R);}return ans;
}void dfs(int rt, int fa) {dep[rt] = dep[fa] + 1, sz[rt] = 1;for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa) {continue;}dfs(to[i], rt);sz[rt] += sz[to[i]];root[rt] = merge(root[rt], root[to[i]], 1, n);}for (auto it : query[rt]) {int id = it.first, k = it.second;ans[id] = ask(root[rt], 1, n, dep[rt], min(dep[rt] + k, n));ans[id] += 1ll * (dep[rt] - max(1, dep[rt] - k)) * (sz[rt] - 1);}update(root[rt], 1, n, dep[rt], sz[rt] - 1);
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);scanf("%d %d", &n, &m);for (int i = 1, x, y; i < n; i++) {scanf("%d %d", &x, &y);add(x, y);add(y, x);}for (int i = 1, x, k; i <= m; i++) {scanf("%d %d", &x, &k);query[x].push_back({i, k});}dfs(1, 0);for (int i = 1; i <= m; i++) {printf("%lld\n", ans[i]);}return 0;
}