Tree
让我们找满足一下五个条件的(x,y(x, y(x,y)点对有多少:
- x≠yx \neq yx=y
- xxx不是yyy的祖先
- yyy不是xxx的祖先
- dis(x,y)≤kdis(x, y)\leq kdis(x,y)≤k
- zzz是x,yx, yx,y的最近公共祖先,valuex+valuey=2valuezvalue_x + value_y = 2value_zvaluex+valuey=2valuez。
读题目观察到每个节点的valuevaluevalue只有[0,105][0, 10 ^ 5][0,105](如果不是的话,也可离散化处理一下吧),所以我们可以建立10510 ^ 5105棵线段树,每棵线段树里面记录的是点权为iii的节点的深度信息,
所以我们只要做一次dsuontreedsu\ on\ treedsu on tree,动态维护这颗线段树,然后按照需要查询即可,好像并不是特别难。
#include <bits/stdc++.h>using namespace std;const int N = 2e5 + 10;int head[N], to[N], nex[N], cnt = 1;int value[N], n, m;int son[N], sz[N], dep[N], l[N], r[N], rk[N], tot;int root[N], ls[N << 6], rs[N << 6], sum[N << 6], num;void add(int x, int y) {to[cnt] = y;nex[cnt] = head[x];head[x] = cnt++;
}void dfs(int rt, int fa) {dep[rt] = dep[fa] + 1, sz[rt] = 1, l[rt] = ++tot, rk[tot] = rt;for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa) {continue;}dfs(to[i], rt);sz[rt] += sz[to[i]];if (!son[rt] || sz[son[rt]] < sz[to[i]]) {son[rt] = to[i];}}r[rt] = tot;
}void push_up(int rt) {sum[rt] = sum[ls[rt]] + sum[rs[rt]];
}void update(int &rt, int l, int r, int x, int value) {if (!rt) {rt = ++num;}if (l == r) {sum[rt] += value;return ;}int mid = l + r >> 1;if (x <= mid) {update(ls[rt], l, mid, x, value);}else {update(rs[rt], mid + 1, r, x, value);}push_up(rt);
}int query(int rt, int l, int r, int L, int R) {if (!rt) {return 0;}if (l >= L && r <= R) {return sum[rt];}int mid = l + r >> 1, ans = 0;if (L <= mid) {ans += query(ls[rt], l, mid, L, R);}if (R > mid) {ans += query(rs[rt], mid + 1, r, L, R);}return ans;
}long long ans;void dfs(int rt, int fa, bool keep) {for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa || to[i] == son[rt]) {continue;}dfs(to[i], rt, 0);}if (son[rt]) {dfs(son[rt], rt, 1);}int v = 2 * value[rt], d = dep[rt];for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa || to[i] == son[rt]) {continue;}for (int j = l[to[i]]; j <= r[to[i]]; j++) {int target_v = v - value[rk[j]], last_d = m - (dep[rk[j]] - d);//目标权值,剩下的可延展的距离if (target_v < 0 || last_d <= 0) {//如果目标权值小于0或者剩下的可延展距离没有了,提前剪除不合法continue;}int l = d + 1, r = d + last_d;//深度的区间范围,然后查询即可。ans += query(root[target_v], 1, n, l, r);}for (int j = l[to[i]]; j <= r[to[i]]; j++) {update(root[value[rk[j]]], 1, n, dep[rk[j]], 1);}}update(root[value[rt]], 1, n, dep[rt], 1);if (!keep) {for (int i = l[rt]; i <= r[rt]; i++) {update(root[value[rk[i]]], 1, n, dep[rk[i]], -1);}}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);scanf("%d %d", &n, &m);for (int i = 1; i <= n; i++) {scanf("%d", &value[i]);}for (int i = 2; i <= n; i++) {int x;scanf("%d", &x);add(x, i);}dfs(1, 0);dfs(1, 0, 1);printf("%lld\n", ans * 2);return 0;
}