F - Colorful Tree
给定一棵树,边有边权,且每条边有一个颜色,有mmm次操作,
每次给定x,y,u,vx, y, u, vx,y,u,v,如果把颜色为xxx的边,边权修改为yyy,求u,vu, vu,v两点的距离,考虑
设111号节点为根节点,
设d[i]d[i]d[i]为111到iii的距离,dis(u,v)=d[u]+d[v]−2×d[lca]dis(u, v) = d[u] + d[v] - 2 \times d[lca]dis(u,v)=d[u]+d[v]−2×d[lca],
设num[i][x]num[i][x]num[i][x]为111到iii,颜色为xxx的数量,sum[i][x]sum[i][x]sum[i][x]为111到iii,颜色为xxx的边权和,
则d′[u]=d[u]−sum[u][x]+num[u][x]×yd'[u] = d[u] - sum[u][x] + num[u][x] \times yd′[u]=d[u]−sum[u][x]+num[u][x]×y,dis′(u,v)=d′[u]+d′[v]−2×d′[lca]dis'(u, v) = d'[u] + d'[v] -2 \times d'[lca]dis′(u,v)=d′[u]+d′[v]−2×d′[lca],离线处理一下即可。
#include <bits/stdc++.h>using namespace std;const int N = 1e5 + 10;int head[N], to[N << 1], nex[N << 1], col[N << 1], val[N << 1], cnt = 1;int son[N], sz[N], fa[N], id[N], rk[N], top[N], dep[N], tot;int n, m, num[N];long long ans[N], sum[N], dis[N];struct Res {int x, y, add, id;
};vector<Res> a[N];void add(int x, int y, int c, int d) {to[cnt] = y;nex[cnt] = head[x];col[cnt] = c;val[cnt] = d;head[x] = cnt++;
}void dfs1(int rt, int f) {fa[rt] = f, dep[rt] = dep[f] + 1, sz[rt] = 1;for (int i = head[rt]; i; i = nex[i]) {if (to[i] == f) {continue;}dis[to[i]] = dis[rt] + val[i];dfs1(to[i], rt);sz[rt] += sz[to[i]];if (!son[rt] || sz[son[rt]] < sz[to[i]]) {son[rt] = to[i];}}
}void dfs2(int rt, int tp) {top[rt] = tp, rk[++tot] = rt, id[rt] = tot;if (!son[rt]) {return ;}dfs2(son[rt], tp);for (int i = head[rt]; i; i = nex[i]) {if (to[i] == son[rt] || to[i] == fa[rt]) {continue;}dfs2(to[i], to[i]);}
}int lca(int x, int y) {while (top[x] != top[y]) {if (dep[top[x]] < dep[top[y]]) {swap(x, y);}x = fa[top[x]];}return dep[x] < dep[y] ? x : y;
}void dfs(int rt, int fa) {for (auto &it : a[rt]) {ans[it.id] += 1ll * it.add * (dis[rt] - sum[it.x] + 1ll * num[it.x] * it.y);}for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa) {continue;}num[col[i]]++;sum[col[i]] += val[i];dfs(to[i], rt);num[col[i]]--;sum[col[i]] -= val[i];}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);scanf("%d %d", &n, &m);for (int i = 1, x, y, c, d; i < n; i++) {scanf("%d %d %d %d", &x, &y, &c, &d);add(x, y, c, d);add(y, x, c, d);}dfs1(1, 0);dfs2(1, 1);for (int i = 1, x, y, u, v; i <= m; i++) {scanf("%d %d %d %d", &x, &y, &u, &v);int f = lca(u, v);a[u].push_back({x, y, 1, i});a[v].push_back({x, y, 1, i});a[f].push_back({x, y, -2, i});}dfs(1, 0);for (int i = 1; i <= m; i++) {printf("%lld\n", ans[i]);}return 0;
}