传送门
文章目录
- 题意:
- 思路:
题意:
求一棵树的每对叶子节点之间距离平方的和。
思路:
这个题貌似可以容斥,但是我不会,所以我写了个淀粉质。
要知道,淀粉质的思想就是将子树内部的递归处理,当前这层处理不同子树之间的距离即可,考虑化简式子分别求贡献。
假设(ai+aj)2(a_i+a_j)^2(ai+aj)2为两点间距离平方和,ai,aja_i,a_jai,aj为叶子到当前找的伪重心的距离,把式子化简出来就是ai2+aj2+2∗ai∗aja_i^2+a_j^2+2*a_i*a_jai2+aj2+2∗ai∗aj,考虑每一块的贡献。
我们设当前遍历的子树的叶子距离为aia_iai,之前遍历过的子树的叶子距离为aja_jaj,个数为cntcntcnt个,平方和为sum1sum1sum1,和为sum2sum2sum2。算这个子树和之前遍历过的子树信息的时候,ai2a_i^2ai2的贡献是cnt∗ai2cnt*a_i^2cnt∗ai2,bi2b_i^2bi2就是sum1sum1sum1,2∗ai∗aj2*a_i*a_j2∗ai∗aj的贡献为2∗ai∗sum22*a_i*sum22∗ai∗sum2,这样我们就可以分开统计贡献,淀粉质板子套一下就好啦。
//#pragma GCC optimize(2)
#include<cstdio>
#include<iostream>
#include<string>
#include<cstring>
#include<map>
#include<cmath>
#include<cctype>
#include<vector>
#include<set>
#include<queue>
#include<algorithm>
#include<sstream>
#include<ctime>
#include<cstdlib>
#define X first
#define Y second
#define L (u<<1)
#define R (u<<1|1)
#define pb push_back
#define mk make_pair
#define Mid (tr[u].l+tr[u].r>>1)
#define Len(u) (tr[u].r-tr[u].l+1)
#define random(a,b) ((a)+rand()%((b)-(a)+1))
#define db puts("---")
using namespace std;//void rd_cre() { freopen("d://dp//data.txt","w",stdout); srand(time(NULL)); }
//void rd_ac() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//AC.txt","w",stdout); }
//void rd_wa() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//WA.txt","w",stdout); }typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;const int N=100010,M=N*2,mod=1e9+7,INF=0x3f3f3f3f;
const double eps=1e-6;int n,m;
int h[N],e[M],ne[M],w[M],idx;
LL p[N],q[N];
int d[N];
bool st[N];void add(int a,int b,int c)
{e[idx]=b,w[idx]=c,ne[idx]=h[a],h[a]=idx++;
}int get_wc(int u,int f,int tot,int &wc,int &mi)
{if(st[u]) return 0;int sum=1,mx=0;for(int i=h[u];~i;i=ne[i]){int j=e[i];if(j==f) continue;int t=get_wc(j,u,tot,wc,mi);mx=max(mx,t); sum+=t;}mx=max(mx,tot-sum);if(mx<mi) wc=u,mi=mx;return sum;
}int get_size(int u,int f)
{if(st[u]) return 0;int sum=1;for(int i=h[u];~i;i=ne[i])if(e[i]!=f)sum+=get_size(e[i],u);return sum;
}void get_dis(int u,int f,int dis,int &qt)
{if(st[u]) return ;int cnt=0;for(int i=h[u];~i;i=ne[i]){int j=e[i];if(j==f) continue;cnt++;get_dis(j,u,dis+w[i],qt);}if(d[u]==1) q[qt++]=dis;
}LL cal(int u)
{if(st[u]) return 0;LL ans=0; int tt=INF;get_wc(u,-1,get_size(u,-1),u,tt);st[u]=true;LL pt=0,pre=0,cnt=0;for(int i=h[u];~i;i=ne[i]){int j=e[i],qt=0;get_dis(j,-1,w[i],qt);for(int k=0;k<qt;k++) ans+=1ll*q[k]*q[k]*cnt,ans+=1ll*2*pt*q[k],ans+=pre;for(int k=0;k<qt;k++) pt+=q[k],pre+=1ll*q[k]*q[k],cnt++;}for(int i=h[u];~i;i=ne[i]) ans+=cal(e[i]);return ans;
}int main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);scanf("%d",&n); memset(h,-1,sizeof(h));memset(st,false,sizeof(st));for(int i=1;i<=n-1;i++){int a,b,c; scanf("%d%d%d",&a,&b,&c);add(a,b,c); add(b,a,c);d[a]++; d[b]++;}printf("%lld\n",cal(1));return 0;
}
/**/