A and B and Lecture Rooms
题意要求我们找有多少个点iii满足dis(i,x),dis(i,y)dis(i, x), dis(i, y)dis(i,x),dis(i,y),输出点iii的数量即可。
首先特判无解的情况就是dis(x,y)dis(x, y)dis(x,y)为奇数时,接下来我们讨论有解的情况,大致分为两类。
首先我们一定可以在x−>yx->yx−>y的路径上找到一个点满足要求。
- 这个点不在lca(x,y)lca(x, y)lca(x,y)上:
如图我们要找的是(5, 6)的满足要求的点有多少个,lca(5,6)=7lca(5, 6) = 7lca(5,6)=7
显然3是其路径上的一个满足要求的点,因为5号节点是从3号节点的父亲连过来的,所以3号节点的父节点往上的节点均不满足要求。
同样的6号节点是在3号节点的某一棵子树上,所以3号节点要舍弃以4号节点为根节点的子树,
所以这种情况就变成了,sz[3]−sz[4]sz[3] - sz[4]sz[3]−sz[4],显然我们可以得到x,yx, yx,y路径上的中点记为uuu,x,yx, yx,y中深度更大的节点xxx一定在uuu的子树上,所以uuu的某个儿子vvv的子树包含xxx节点要舍弃,
所以答案就是sz[u]−sz[v]sz[u] - sz[v]sz[u]−sz[v]。
- 这个点在lca(x,y)lca(x, y)lca(x,y)上
这个情况比上面就简单了,x,yx, yx,y一定都在lcalcalca的某两个不同的儿子上,
所以找到包含xxx的儿子uuu,和包含yyy的儿子vvv,然后n−sz[u]−sz[v]n - sz[u] - sz[v]n−sz[u]−sz[v]即为答案。
最后特判一下x==yx == yx==y的情况即可。
/*Author : lifehappy
*/
#include <bits/stdc++.h>using namespace std;const int N = 1e6 + 10;int head[N], to[N], nex[N], cnt = 1;int fa[N], top[N], son[N], sz[N], dep[N], id[N], rk[N], tot;int n, m;void add(int x, int y) {to[cnt] = y;nex[cnt] = head[x];head[x] = cnt++;
}void dfs1(int rt, int f) {fa[rt] = f, dep[rt] = dep[f] + 1;sz[rt] = 1;for(int i = head[rt]; i; i = nex[i]) {if(to[i] == f) continue;dfs1(to[i], rt);sz[rt] += sz[to[i]];if(!son[rt] || sz[son[rt]] < sz[to[i]]) son[rt] = to[i];}
}void dfs2(int rt, int tp) {top[rt] = tp;rk[++tot] = rt;id[rt] = tot;if(!son[rt]) return ;dfs2(son[rt], tp);for(int i = head[rt]; i; i = nex[i]) {if(to[i] == fa[rt] || to[i] == son[rt]) continue;dfs2(to[i], to[i]);}
}int lca(int x, int y) {while(top[x] != top[y]) {if(dep[top[x]] < dep[top[y]]) swap(x, y);x = fa[top[x]];}return dep[x] < dep[y] ? x : y;
}int dis(int x, int y) {return dep[x] + dep[y] - 2 * dep[lca(x, y)];
}int get_fa(int x, int k) {while(k > id[x] - id[top[x]]) {k -= id[x] - id[top[x]] + 1;x = fa[top[x]];}return rk[id[x] - k];
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);scanf("%d", &n);for(int i = 1; i < n; i++) {int x, y;scanf("%d %d", &x, &y);add(x, y);add(y, x);}dfs1(1, 0);dfs2(1, 1);scanf("%d", &m);for(int i = 1; i <= m; i++) {int x, y;scanf("%d %d", &x, &y);if(x == y) {printf("%d\n", n);continue;}int d = dis(x, y), l = lca(x, y);if(d & 1) {puts("0");continue;}if(dep[x] < dep[y]) swap(x, y);int p = get_fa(x, d / 2);if(p == l) {int u = get_fa(x, d / 2 - 1), v = get_fa(y, d / 2 - 1);printf("%d\n", n - sz[u] - sz[v]);}else {int u = get_fa(x, d / 2 - 1);printf("%d\n", sz[p] - sz[u]);}}return 0;
}