#6073. 「2017 山东一轮集训 Day5」距离
给定一颗有nnn个节点带边权的树,以及一个排列ppp,path(u,v)path(u, v)path(u,v)为u,vu, vu,v路径上的点集,dist(u,v)dist(u, v)dist(u,v)为u,vu, vu,v之间的最短路的长度。
有mmm次询问,每次给定u,v,ku, v, ku,v,k,要求∑i∈path(u,v)dist(pi,k)\sum\limits_{i \in path(u, v)}dist(p_i, k)i∈path(u,v)∑dist(pi,k),要求在线求解。
我们选定111号节点为根节点,定义d(i)d(i)d(i)为点iii到根节点的距离,考虑初始的∑i∈paht(u,v)dist(pi,k)\sum\limits_{i \in paht(u, v)} dist(p_i, k)i∈paht(u,v)∑dist(pi,k),有如下:
∑i∈path(u,v)dist(pi,k)=∑i∈path(u,v)(d(pi)+d(k)−2×d(lca(pi,k)))\sum\limits_{i \in path(u, v)}dist(p_i, k) = \sum_{i \in path(u, v)}\left(d(p_i) + d(k) - 2 \times d(lca(p_i, k))\right)\\ i∈path(u,v)∑dist(pi,k)=i∈path(u,v)∑(d(pi)+d(k)−2×d(lca(pi,k)))
对于d(k)d(k)d(k)的计算,由于d(k)d(k)d(k)是一个定值,则这一部分的答案即为u,vu, vu,v间点的个数$ \times d(k)$。
对于d(pi)d(p_i)d(pi)的计算,我们考虑树上差分前缀和来求解,我们定义sum[n]=∑i∈paht(1,n)d(pi)sum[n] = \sum\limits_{i \in paht(1, n)}d(p_i)sum[n]=i∈paht(1,n)∑d(pi),
则∑i∈path(u,v)d(pi)=sum[u]+sum[v]−sum[lca(u,v)]−sum[fa(lca(u,v))]\sum\limits_{i \in path(u, v)} d(p_i) = sum[u] + sum[v] - sum[lca(u, v)] - sum[fa(lca(u, v))]i∈path(u,v)∑d(pi)=sum[u]+sum[v]−sum[lca(u,v)]−sum[fa(lca(u,v))]。
最后一步,考虑最难算的∑i∈path(u,v)d(lca(pi,k))\sum\limits_{i \in path(u, v)} d(lca(p_i, k))i∈path(u,v)∑d(lca(pi,k)),可以仿照P4211 [LNOI2014]LCA这题的计算方式,
由于强制在线,所以这题必须用主席树,我们定义点iii所代表的主席树为从1−>i1->i1−>i上,也就是根节点到iii上,
点pi−>1p_i->1pi−>1所代表的信息,最后我们只需要四颗主席树即可解决问题u+v−lca(u,v)−fa(lca(u,v))u + v - lca(u, v) - fa(lca(u, v))u+v−lca(u,v)−fa(lca(u,v))。
#include <bits/stdc++.h>using namespace std;typedef long long ll;const int N = 2e5 + 10;int head[N], to[N << 1], nex[N << 1], value[N << 1], cnt = 1;int p[N], n, m, type;int fa[N], sz[N], son[N], top[N], rk[N], id[N], dep[N], w[N], tot;int root[N], ls[N * 100], rs[N * 100], num;ll dis[N], len[N], s[N], sum[N * 100], lazy[N * 100];void add(int x, int y, int w) {to[cnt] = y;nex[cnt] = head[x];value[cnt] = w;head[x] = cnt++;
}void dfs1(int rt, int f) {fa[rt] = f, dep[rt] = dep[f] + 1, sz[rt] = 1;for (int i = head[rt]; i; i = nex[i]) {if (to[i] == f) {continue;}dis[to[i]] = dis[rt] + value[i];dfs1(to[i], rt);w[to[i]] = value[i];sz[rt] += sz[to[i]];if (!son[rt] || sz[son[rt]] < sz[to[i]]) {son[rt] = to[i];}}
}void dfs2(int rt, int tp) {rk[++tot] = rt, id[rt] = tot, len[tot] = w[rt], top[rt] = tp;if (!son[rt]) {return ;}dfs2(son[rt], tp);for (int i = head[rt]; i; i = nex[i]) {if (to[i] == fa[rt] || to[i] == son[rt]) {continue;}dfs2(to[i], to[i]);}
}int lca(int u, int v) {while (top[u] != top[v]) {if (dep[top[u]] < dep[top[v]]) {swap(u, v);}u = fa[top[u]];}return dep[u] < dep[v] ? u : v;
}void update(int &rt, int pre, int l, int r, int L, int R) {rt = ++num;ls[rt] = ls[pre], rs[rt] = rs[pre], sum[rt] = sum[pre], lazy[rt] = lazy[pre];sum[rt] += len[min(r, R)] - len[max(l, L) - 1];if (l >= L && r <= R) {lazy[rt] += 1;return ;}int mid = l + r >> 1;if (L <= mid) {update(ls[rt], ls[pre], l, mid, L, R);}if (R > mid) {update(rs[rt], rs[pre], mid + 1, r, L, R);}
}ll query(int u, int v, int f, int ff, int l, int r, int L, int R) {if (l >= L && r <= R) {return sum[u] + sum[v] - sum[f] - sum[ff];}ll ans = (lazy[u] + lazy[v] - lazy[f] - lazy[ff]) * (len[min(r, R)] - len[max(l, L) - 1]);int mid = l + r >> 1;if (L <= mid) {ans += query(ls[u], ls[v], ls[f], ls[ff], l, mid, L, R);}if (R > mid) {ans += query(rs[u], rs[v], rs[f], rs[ff], mid + 1, r, L, R);}return ans;
}void dfs(int rt, int f) {s[rt] = s[f] + dis[p[rt]];int cur = p[rt];root[rt] = root[f];while (cur) {update(root[rt], root[rt], 1, n, id[top[cur]], id[cur]);cur = fa[top[cur]];}for (int i = head[rt]; i; i = nex[i]) {if (to[i] == f) {continue;}dfs(to[i], rt);}
}ll query(int u, int v, int f, int ff, int rt) {ll ans = 0;while (rt) {ans += query(root[u], root[v], root[f], root[ff], 1, n, id[top[rt]], id[rt]);rt = fa[top[rt]];}return ans;
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);scanf("%d %d %d", &type, &n, &m);for (int i = 1, u, v, w; i < n; i++) {scanf("%d %d %d", &u, &v, &w);add(u, v, w);add(v, u, w);}for (int i = 1; i <= n; i++) {scanf("%d", &p[i]);}dfs1(1, 0);dfs2(1, 1);for (int i = 1; i <= n; i++) {len[i] += len[i - 1];}dfs(1, 0);ll ans = 0;for (int i = 1, u, v, k; i <= m; i++) {scanf("%d %d %d", &u, &v, &k);u = u ^ (ans * type), v = v ^ (ans * type), k = k ^ (ans * type);int f = lca(u, v), ff = fa[f];ans = (dep[u] + dep[v] - dep[f] - dep[ff]) * dis[k];ans += s[u] + s[v] - s[f] - s[ff];ans -= 2 * query(u, v, f, ff, k);printf("%lld\n", ans);}return 0;
}