题意:给三棵基于同一点集的带边权的树,边权非负,求两点间三棵树上距离之和的最大值。
n≤105n\leq 10^5n≤105
一句话题解:在第一棵树上做边分治,丢到第二棵树上建虚树,在虚树上根据第三棵树的直径dp。
首先,这个问题难搞的地方只在于需要统计 lca。所以我们做的一切工作都是为了搞掉 lca 。
在求最大值的问题上有以下搞法:
- 树分治
- 枚举 lca 统计贡献。或者说就是 dp。
把这两个合起来就可以做 暴力写挂 了。但这个问题还多了一棵树,还要再用一个方法。
哲学分析,如果没有特殊性质,剩下这个方法肯定不会太弱于点分治,然后你就可以弃疗了。仔细观察,这题唯一有的特殊条件就只有藏在数据范围里的边权非负了。
边权非负的时候直径是可以合并的,所以第三个方法就是利用直径。
整理一下,在第一棵树上点分治,通过到分治中心的距离之和代替距离,搞掉第一棵树的 lca。在第二棵树的虚树上 dp,只在 lca 处更新答案,搞掉第二棵树的 lca。在 dp 时记录虚树的子树中两种颜色的点集在第三棵树上的直径,枚举端点合并,直接处理第三棵树上的距离。
设分治中心边的边权为 www,第一棵树上的点到分治中心的距离为 disdisdis,第二棵树上的点的深度为 depdepdep,第三棵树上两点距离为 distdistdist,我们相当于求这个东西的最大值
w+disa+disb+depa+depb−2deplca(a,b)+dist(a,b)w+dis_a+dis_b+dep_a+dep_b-2dep_{\operatorname{lca}(a,b)}+dist(a,b)w+disa+disb+depa+depb−2deplca(a,b)+dist(a,b)
我们假装在第三棵树上每个点挂两个叶子,边权为 disa+depadis_a+dep_adisa+depa,所以直径可合并的结论在端点有权值时也是成立的。
复杂度是 O(nlogn)O(n\log n)O(nlogn)。
人生最长代码,不过很多复制粘贴,不算难写。
要注意虚树上的虚点不属于任何颜色。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <algorithm>
#define MAXN 200005
#define MAXM 400005
using namespace std;
typedef long long ll;
const int INF=0x7fffffff;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
inline ll readll()
{ll ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
struct edge{int u,v;ll w;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt=1;
inline void addnode(int u,int v,ll w)
{e[++cnt]=(edge){u,v,w};nxt[cnt]=head[u];head[u]=cnt;
}
vector<edge> E[MAXN];
int vis[MAXN],n,tot;
void dfs(int u)
{vis[u]=1;if ((int)E[u].size()<=3){for (int i=0;i<(int)E[u].size();i++){int v=E[u][i].v;ll w=E[u][i].w;if (!vis[v])dfs(v),addnode(u,v,w),addnode(v,u,w); }return;}int cur[2]={++tot,++tot},pos=0;addnode(u,cur[0],0),addnode(cur[0],u,0);addnode(u,cur[1],0),addnode(cur[1],u,0);for (int i=0;i<(int)E[u].size();i++)if (!vis[E[u][i].v])E[cur[pos^=1]].push_back(E[u][i]);dfs(cur[0]),dfs(cur[1]);
}
int rt,mi,siz[MAXN];
void findrt(int u,int f,int sum)
{siz[u]=1;for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f){findrt(e[i].v,u,sum);if (max(siz[e[i].v],sum-siz[e[i].v])<mi) mi=max(siz[e[i].v],sum-siz[e[i].v]),rt=i;siz[u]+=siz[e[i].v];}
}
int LOG[MAXM];
namespace FT
{edge e[MAXM];int head[MAXN],nxt[MAXM],cnt;inline void addnode(int u,int v,ll w){e[++cnt]=(edge){u,v,w};nxt[cnt]=head[u];head[u]=cnt;}ll dis[MAXN],val[MAXN];int dfn[MAXN],lis[MAXM],tim;inline bool cmp(const int& x,const int& y){return dfn[x]<dfn[y];}void dfs(int u,int f){lis[dfn[u]=++tim]=u;for (int i=head[u];i;i=nxt[i])if (e[i].v!=f){dis[e[i].v]=dis[u]+e[i].w;dfs(e[i].v,u);lis[++tim]=u;}}int st[MAXM][20];inline int lca(int x,int y){x=dfn[x],y=dfn[y];if (x>y) swap(x,y);int t=LOG[y-x+1];return min(st[x][t],st[y-(1<<t)+1][t],cmp);}inline ll dist(int x,int y){return x&&y? dis[x]+dis[y]+val[x]+val[y]-2*dis[lca(x,y)]:-1;}void input(){for (int i=1;i<n;i++){int u,v;ll w;u=read(),v=read(),w=readll();addnode(u,v,w),addnode(v,u,w);}dfs(1,0);for (int i=1;i<=tim;i++) st[i][0]=lis[i];for (int j=1;j<20;j++)for (int i=1;i+(1<<(j-1))<=tim;i++)st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1],cmp);}
}
namespace VT
{edge e[MAXM];int head[MAXN],nxt[MAXM],cnt;inline void addnode(int u,int v,ll w){e[++cnt]=(edge){u,v,w};nxt[cnt]=head[u];head[u]=cnt;}ll dis[MAXN];int dfn[MAXN],lis[MAXM],tim;inline bool cmp(const int& x,const int& y){return dfn[x]<dfn[y];}void dfs(int u,int f){lis[dfn[u]=++tim]=u;for (int i=head[u];i;i=nxt[i])if (e[i].v!=f){dis[e[i].v]=dis[u]+e[i].w;dfs(e[i].v,u);lis[++tim]=u;}}int st[MAXM][20],col[MAXN];inline int lca(int x,int y){x=dfn[x],y=dfn[y];if (x>y) swap(x,y);int t=LOG[y-x+1];return min(st[x][t],st[y-(1<<t)+1][t],cmp);}void input(){memset(col,-1,sizeof(col));for (int i=1;i<n;i++){int u,v;ll w;u=read(),v=read(),w=readll();addnode(u,v,w),addnode(v,u,w);}dfs(1,0);for (int i=1;i<=tim;i++) st[i][0]=lis[i];for (int j=1;j<20;j++)for (int i=1;i+(1<<(j-1))<=tim;i++)st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1],cmp);}vector<int> p,son[MAXN];struct path{int x,y;inline path(const int& x=0,const int& y=0):x(x),y(y){}};inline ll calc(const path& a){return FT::dist(a.x,a.y);}inline bool operator <(const path& a,const path& b){return calc(a)<calc(b);}int stk[MAXN],tp;path mx[MAXN][2];ll ans;void dfs(int u){mx[u][0]=mx[u][1]=path(0,0);if (~col[u]) mx[u][col[u]]=path(u,u);for (int i=0;i<(int)son[u].size();i++){int v=son[u][i];dfs(v);path t;t=max(t,path(mx[u][0].x,mx[v][1].x));t=max(t,path(mx[u][0].x,mx[v][1].y));t=max(t,path(mx[u][0].y,mx[v][1].x));t=max(t,path(mx[u][0].y,mx[v][1].y));t=max(t,path(mx[u][1].x,mx[v][0].x));t=max(t,path(mx[u][1].x,mx[v][0].y));t=max(t,path(mx[u][1].y,mx[v][0].x));t=max(t,path(mx[u][1].y,mx[v][0].y));ans=max(ans,calc(t)-2*dis[u]);t=max(mx[u][0],mx[v][0]);t=max(t,path(mx[u][0].x,mx[v][0].x));t=max(t,path(mx[u][0].x,mx[v][0].y));t=max(t,path(mx[u][0].y,mx[v][0].x));t=max(t,path(mx[u][0].y,mx[v][0].y));mx[u][0]=t;t=max(mx[u][1],mx[v][1]);t=max(t,path(mx[u][1].x,mx[v][1].x));t=max(t,path(mx[u][1].x,mx[v][1].y));t=max(t,path(mx[u][1].y,mx[v][1].x));t=max(t,path(mx[u][1].y,mx[v][1].y));mx[u][1]=t;}}ll solve(){sort(p.begin(),p.end(),cmp);int s=p.size();for (int i=0;i<s-1;i++) p.push_back(lca(p[i],p[i+1]));sort(p.begin(),p.end(),cmp);p.erase(unique(p.begin(),p.end()),p.end());tp=0;for (int i=0;i<(int)p.size();i++){while (tp&&lca(stk[tp],p[i])!=stk[tp]) --tp;if (tp) son[stk[tp]].push_back(p[i]);stk[++tp]=p[i];}ans=0;dfs(stk[1]);for (int i=0;i<(int)p.size();i++) son[p[i]].clear(),col[p[i]]=-1;p.clear();return ans;}
}
void dfs(int u,int f,int c,ll d)
{if (u<=n) FT::val[u]=d+VT::dis[u],VT::p.push_back(u),VT::col[u]=c;for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f)dfs(e[i].v,u,c,d+e[i].w);
}
ll ans;
void solve(int sum)
{if (mi==INF) return;vis[rt>>1]=1;dfs(e[rt].v,0,0,0);dfs(e[rt].u,0,1,0); ans=max(ans,VT::solve()+e[rt].w);int cur=rt,sz=siz[e[rt].v];mi=INF,findrt(e[cur].v,0,sz),solve(sz);mi=INF,findrt(e[cur].u,0,sum-sz),solve(sum-sz);
}
int main()
{LOG[0]=-1;for (int i=1;i<MAXM;i++) LOG[i]=LOG[i>>1]+1;tot=n=read();for (int i=1;i<n;i++){int u,v;ll w;u=read(),v=read(),w=readll();E[u].push_back((edge){u,v,w}),E[v].push_back((edge){v,u,w});}dfs(1);memset(vis,0,sizeof(vis));VT::input();FT::input();mi=INF,findrt(1,0,tot),solve(tot);cout<<ans;return 0;
}