牛客练习赛55E树
题意:
你有一颗大小为n 的树,点从 1 到 n 标号。
设dis(x,y)表示 x 到 y 的距离。
求∑i=1n∑j=1ndis2(i,j)\sum_{i=1}^{n}\sum_{j=1}^{n}dis^2(i,j)∑i=1n∑j=1ndis2(i,j)对998244353取模的结果
题解:
我们以1为根,设dep[i]表示第i个点的深度
dis(x,y)=dep[x]+dep[y]−2dep[lca(x,y)]dis(x,y)=dep[x]+dep[y]-2dep[lca(x,y)]dis(x,y)=dep[x]+dep[y]−2dep[lca(x,y)]
所以
将dis(i,j)展开
dis2(i,j)=dep2[x]+dep2[y]+2∗dep[x]∗dep[y]+4∗dep2[lca(x,y)]−4∗(dep[x]+dep[y])∗dep[lca(x,y)]dis^2(i,j)=dep^2[x]+dep^2[y]+2*dep[x]*dep[y]+4*dep^2[lca(x,y)]-4*(dep[x]+dep[y])*dep[lca(x,y)]dis2(i,j)=dep2[x]+dep2[y]+2∗dep[x]∗dep[y]+4∗dep2[lca(x,y)]−4∗(dep[x]+dep[y])∗dep[lca(x,y)]
对于前两项dep2[x]+dep2[y]dep^2[x]+dep^2[y]dep2[x]+dep2[y]:因为x和y都是枚举1~n,所以就是求2dep[i]∗dep[i],1<=i<=n2dep[i]*dep[i],1<=i<=n2dep[i]∗dep[i],1<=i<=n
对于第三项:2∗dep[x]∗dep[y]2*dep[x]*dep[y]2∗dep[x]∗dep[y]:先求出maxx=∑i=1ndep[i]maxx=\sum_{i=1}^{n}dep[i]maxx=∑i=1ndep[i],然后用∑i=1nmaxx∗dep[i]∗2\sum_{i=1}^{n}maxx*dep[i]*2∑i=1nmaxx∗dep[i]∗2
对于后两部分,我们需要计算lca(x,y)=i的(x,y)这样的数对个数,以及dep[x]+dep[y]之和
先解答第一个:数对个数为:size2[i]−∑j∈son[i]size2[j]size^2[i]-\sum_{j∈son[i]}size^2[j]size2[i]−∑j∈son[i]size2[j].相当于整个子树内的任意两个点组合,这样会重复,儿子j子树内的点会重复,要减去
第二个:
设sumisum_{i}sumi表示i这个子树的dep[x]之和
那么dep[x]+dep[y]之和为:2∑j∈son[i]sumj∗(size[i]−size[j])+2dep[i]∗size[i]2\sum_{j∈son[i]}sum_{j}*(size[i]-size[j])+2dep[i]*size[i]2∑j∈son[i]sumj∗(size[i]−size[j])+2dep[i]∗size[i]。以y为根的子树内的点与会除y为根的子树外的点(都在以x的根的子树内)相互组队,同时所有点都可以与点x组队
详细可以看代码
代码:
敲错一个变量名,调了半小时,老演员了
#include <bits/stdc++.h>
#include <unordered_map>
#define debug(a, b) printf("%s = %d\n", a, b);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
clock_t startTime, endTime;
//Fe~Jozky
const ll INF_ll= 1e18;
const int INF_int= 0x3f3f3f3f;
void read(){};
template <typename _Tp, typename... _Tps> void read(_Tp& x, _Tps&... Ar)
{x= 0;char c= getchar();bool flag= 0;while (c < '0' || c > '9')flag|= (c == '-'), c= getchar();while (c >= '0' && c <= '9')x= (x << 3) + (x << 1) + (c ^ 48), c= getchar();if (flag)x= -x;read(Ar...);
}
template <typename T> inline void write(T x)
{if (x < 0) {x= ~(x - 1);putchar('-');}if (x > 9)write(x / 10);putchar(x % 10 + '0');
}
void rd_test()
{
#ifdef ONLINE_JUDGE
#elsestartTime = clock ();freopen("data.in", "r", stdin);
#endif
}
void Time_test()
{
#ifdef ONLINE_JUDGE
#elseendTime= clock();printf("\nRun Time:%lfs\n", (double)(endTime - startTime) / CLOCKS_PER_SEC);
#endif
}
const int maxn=2e6+9;
vector<int>vec[maxn];
ll siz[maxn];
ll sum=0;
ll dep[maxn];
ll num[maxn];
const int mod=998244353;
void dfs(int u,int fa){siz[u]=1;dep[u]=dep[fa]+1;num[u]=dep[u];for(auto v:vec[u]){if(v==fa)continue;dfs(v,u);num[u]=(num[u]+num[v])%mod;siz[u]+=siz[v];}
} void solve(int u,int fa){ll sum1=siz[u]*siz[u]%mod;//lca(x,y)=i的(x,y)这样的数对个数ll sum2=2*dep[u]*siz[u]%mod;//fx+fy for(auto v:vec[u]){if(v==fa)continue;solve(v,u);sum1=(sum1-siz[v]*siz[v]+mod)%mod;sum2=(sum2+2*num[v]*(siz[u]-siz[v]))%mod;}sum=(sum+4*sum1%mod*dep[u]%mod*dep[u])%mod;sum=((sum+mod-4*sum2%mod*dep[u]%mod)%mod+mod)%mod;
}signed main()
{rd_test();int n;read(n);for(int i=1;i<n;i++){int u,v;read(u,v);vec[u].push_back(v);vec[v].push_back(u);}dfs(1,0);ll maxx=0; for(int i=1;i<=n;i++){maxx=(maxx+dep[i])%mod;//第三部分 sum=(sum+(dep[i]*dep[i])%mod*2ll*n)%mod;//第一,二部分 }for(int i=1;i<=n;i++){sum=(sum+(maxx*dep[i])%mod*2ll)%mod;// 第三部分 }
// cout<<(sum%mod+mod)%mod<<endl;solve(1,0);cout<<(sum%mod+mod)%mod<<endl;//Time_test();
}