题意:给一棵 nnn 个点的树,求两两距离相等的三元组个数。
n≤105n\leq 10^5n≤105
显然相当于是找一个点到这三个点距离相等。子树内和子树外到当前点的距离为某个值的点的个数可以长链剖分快速得到,但统计答案非常棘手。
接下来是个鬼才想得到的 dp:
f(u,x)f(u,x)f(u,x) 为 uuu 子树内距离 uuu 距离为 xxx 的点的个数,这个比较好求。
g(u,x)g(u,x)g(u,x) 表示假想 uuu 子树外有一个距离为 xxx 的点,在子树内选两个点与这个假想点形成合法三元组的方案数。
鸟语翻译:在 uuu 子树内找两个深度相同的点,若它们到它们的 lca (不一定是 uuu)距离为 ddd ,那么这个 lca 到 uuu 的距离必须是 u−xu-xu−x,求符合条件的两个点的对数。
也就是我们把这个 ggg 当接口,乘一个子树外为 xxx 的点的个数,就可以直接得到某个范围内的合法三元组个数。
考虑加入一个子树 vvv,fff 的转移很显然
f(u,x)⟵f(u,x)+f(v,x−1)f(u,x)\longleftarrow f(u,x)+f(v,x-1)f(u,x)⟵f(u,x)+f(v,x−1)
考虑 ggg 的转移
- 两个都在 uuu 这边: g(u,x)⟵g(u,x)g(u,x)\longleftarrow g(u,x)g(u,x)⟵g(u,x)
- 两个都在 vvv 这边: g(u,x)⟵g(v,x+1)g(u,x)\longleftarrow g(v,x+1)g(u,x)⟵g(v,x+1)
- 一个 uuu 一个 vvv,此时的 lca 一定是 uuu:g(u,x)⟵f(u,x)f(v,x−1)g(u,x)\longleftarrow f(u,x)f(v,x-1)g(u,x)⟵f(u,x)f(v,x−1)
综上
g(u,x)⟵g(u,x)+g(v,x+1)+f(u,x)f(v,x−1)g(u,x)\longleftarrow g(u,x)+g(v,x+1)+f(u,x)f(v,x-1)g(u,x)⟵g(u,x)+g(v,x+1)+f(u,x)f(v,x−1)
然后在 dp 的时候顺便算答案
ans⟵ans+f(u,x)g(v,x+1)+f(v,x)g(u,x+1)ans\longleftarrow ans+f(u,x)g(v,x+1)+f(v,x)g(u,x+1)ans⟵ans+f(u,x)g(v,x+1)+f(v,x)g(u,x+1)
长链剖分优化, fff 往后继承,ggg 往前继承,轻儿子暴力转移。继承的时候 uuu 是空的,只有 f(u,0)g(sonu,1)f(u,0)g(son_u,1)f(u,0)g(sonu,1) 会影响答案,之后对答案的贡献就可以暴力算了。
注意 ggg 要开两倍在中间取指针,前面用来继承后面用来存信息。
复杂度 O(n)\Omicron(n)O(n)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#define MAXN 100005
using namespace std;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
typedef long long ll;
vector<int> e[MAXN];
ll buf[MAXN<<3],*cur=buf;
inline ll* newbuf(int x){ll* p=cur;cur+=x;return p;}
int fa[MAXN],mx[MAXN],son[MAXN];
void dfs(int u,int f)
{fa[u]=f;for (int i=0;i<(int)e[u].size();i++)if (e[u][i]!=f){dfs(e[u][i],u);if (mx[e[u][i]]>mx[son[u]]) son[u]=e[u][i];}mx[u]=mx[son[u]]+1;
}
ll *f[MAXN],*g[MAXN];
ll ans;
void dfs(int u)
{*f[u]=1;if (son[u]) f[son[u]]=f[u]+1,g[son[u]]=g[u]-1,dfs(son[u]),ans+=*g[u];;for (int i=0;i<(int)e[u].size();i++)if (e[u][i]!=fa[u]&&e[u][i]!=son[u]){int v=e[u][i];f[v]=newbuf(mx[v]),g[v]=newbuf(2*mx[v])+mx[v]-1;dfs(v);for (int j=1;j<=mx[v];j++)ans+=f[u][j-1]*g[v][j];for (int j=0;j<=mx[v]&&j<mx[u];j++)ans+=g[u][j+1]*f[v][j];for (int j=1;j<=mx[v];j++) g[u][j-1]+=g[v][j],g[u][j]+=f[u][j]*f[v][j-1];for (int j=0;j<=mx[v]&&j<mx[u];j++)f[u][j+1]+=f[v][j];}
// printf("f[%d]:",u);
// for (int i=0;i<=mx[u];i++) printf("%lld ",f[u][i]);
// puts("");
// printf("g[%d]:",u);
// for (int i=0;i<=mx[u];i++) printf("%lld ",g[u][i]);
// puts("");
}
int main()
{
// freopen("test.in","r",stdin);int n=read();for (int i=1;i<n;i++){int u,v;u=read(),v=read();e[u].push_back(v),e[v].push_back(u); } dfs(1,0),f[1]=newbuf(mx[1]),g[1]=newbuf(2*mx[1])+mx[1]-1,dfs(1);cout<<ans;return 0;
}