正题
luogu 7518
题目大意
给你一棵树,一条路径的价值为:路径上点权以1开始依次递增1的子序列,有q次询问,每次询问一条路径的价值
解题思路
n,m值比较大,对于每次询问只有O(log2n)O(log^2n)O(log2n)的时间
考虑树链剖分,将询问分成logloglog段
先预处理出uji,j,dji,juj_{i,j},dj_{i,j}uji,j,dji,j,分别是当前枚举到i点,点权为wiw_iwi,在该重链上向上/下跳2j2^j2j个权值跳到的点
那么找到一个起始点后,就可以对这条重链进行查询了
对于起始点,可以对每个权值的点建一个set,因为一条重链上的点dfs序是连在一起的,所以在set上找dfs序大于等于或小于等于该重链起始节点的点即可
代码
#include<set>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define N 200021
using namespace std;
int n, m, x, y, q, cc, ww, tot;
int c[N], w[N], v[N], p[N], fa[N], sz[N], bv[N], hs[N];
int uj[N][20], dj[N][20], dep[N], top[N], head[N];
set<int>gm[N];
struct rec
{int to, next;
}a[N<<1];
int read()
{char x=getchar();int d=1,l=0;while (x<'0'||x>'9') {if (x=='-') d=-1;x=getchar();}while (x>='0'&&x<='9') l=(l<<3)+(l<<1)+x-48,x=getchar();return l*d;
}
void add(int x, int y)
{a[++tot].to = y;a[tot].next = head[x];head[x] = tot;return;
}
void dfs1(int x)
{sz[x] = 1;for (int i = head[x]; i; i = a[i].next)if (a[i].to != fa[x]){fa[a[i].to] = x;dep[a[i].to] = dep[x] + 1;dfs1(a[i].to);sz[x] += sz[a[i].to];if (sz[a[i].to] > sz[hs[x]]) hs[x] = a[i].to;}return;
}
void dfs2(int x, int anc)
{v[x] = ++ww;bv[ww] = x;top[x] = anc;if (top[p[w[x] + 1]] == top[x]) uj[x][0] = p[w[x] + 1];p[w[x]] = x; for (int i = 1; i <= cc; ++i)uj[x][i] = uj[uj[x][i - 1]][i - 1];//求ujif (hs[x]) dfs2(hs[x], anc);for (int i = head[x]; i; i = a[i].next)if (a[i].to != fa[x] && a[i].to != hs[x])dfs2(a[i].to, a[i].to);return;
}
void dfs3(int x)
{for (int i = head[x]; i; i = a[i].next)if (a[i].to != fa[x] && a[i].to != hs[x])dfs3(a[i].to);if (hs[x]) dfs3(hs[x]);if (top[p[w[x] + 1]] == top[x]) dj[x][0] = p[w[x] + 1];//求djp[w[x]] = x;for (int i = 1; i <= cc; ++i)dj[x][i] = dj[dj[x][i - 1]][i - 1];return;
}
int lca(int x, int y)
{while(top[x] != top[y]){if (dep[top[x]] < dep[top[y]]) swap(x, y);x = fa[top[x]];}if (dep[x] > dep[y]) swap(x, y);return x;
}
int fu(int x, int y, int now)
{int g = bv[*--gm[now + 1].upper_bound(v[x])];//找一个小于等于该点的if (v[g] <= v[x] && w[g] == now + 1 && top[x] == top[g] && dep[g] >= dep[y] && g){now++;for (int i = cc; i >= 0; --i)if (top[x] == top[uj[g][i]] && dep[uj[g][i]] >= dep[y] && uj[g][i])g = uj[g][i], now += 1<<i;}if (dep[fa[top[x]]] >= dep[y] && top[x] != 1) return fu(fa[top[x]], y, now);else return now;
}
int fd(int x, int y, int now)
{if (dep[fa[top[x]]] >= dep[y] && top[x] != 1) now = fd(fa[top[x]], y, now);//递归处理int gg = top[x], g;if (dep[gg] < dep[y]) gg = y;g = bv[*gm[now + 1].lower_bound(v[gg])];if (v[g] >= v[gg] && w[g] == now + 1 && top[x] == top[g] && dep[g] <= dep[x] && g){now++;for (int i = cc; i >= 0; --i)if (top[x] == top[dj[g][i]] && dep[dj[g][i]] <= dep[x] && dj[g][i])g = dj[g][i], now += 1<<i;}return now;
}
int main()
{scanf("%d%d%d", &n, &cc, &m);cc = log2(cc);for (int i = 1; i <= m; ++i){scanf("%d", &x);c[x] = i;}for (int i = 1; i <= n; ++i){scanf("%d", &x);w[i] = c[x]; }for (int i = 1; i < n; ++i){scanf("%d%d", &x, &y);add(x, y);add(y, x);}dep[1] = fa[1] = 1;dfs1(1);dfs2(1, 1);for (int i = 1; i <= n; ++i)gm[w[i]].insert(v[i]);memset(p, 0, sizeof(p));dfs3(1);scanf("%d", &q);while(q--){scanf("%d%d", &x, &y);int z = lca(x, y), g;g = fu(x, z, 0);g = fd(y, z, g);printf("%d\n", g);}return 0;
}