题意:给两棵基于同一点集的带边权树,记 lca(x,y),depth(x)\operatorname{lca}(x,y),\operatorname{depth}(x)lca(x,y),depth(x) 为第一棵树上的 lca、到根的边长度之和,lca′(x,y),depth′(x)\operatorname{lca}'(x,y),\operatorname{depth}'(x)lca′(x,y),depth′(x) 为第二棵树的,最大化
depth(x)+depth(y)−(depth(lca(x,y))+depth′(lca′(x,y)))\operatorname{depth}(x)+\operatorname{depth}(y)-(\operatorname{depth(\operatorname{lca}(x,y))}+\operatorname{depth}'(\operatorname{lca}'(x,y)))depth(x)+depth(y)−(depth(lca(x,y))+depth′(lca′(x,y)))
n≤366666n\leq 366666n≤366666
这个式子非常诡异,先推一下发现等于这个
12(dist(x,y)+depth(x)+depth(y)−depth′(lca′(x,y)))\frac 12(\operatorname{dist}(x,y)+\operatorname{depth}(x)+\operatorname{depth}(y)-\operatorname{depth}'(\operatorname{lca}'(x,y)))21(dist(x,y)+depth(x)+depth(y)−depth′(lca′(x,y)))
左边是个距离,而右边只有一个二元函数,考虑对第一棵树分治
我们用点分治或边分治可以把 dist(x,y)\operatorname{dist}(x,y)dist(x,y) 拆成两项分别只与 xxx 和 yyy 有关的东西,就可以和 depth\operatorname{depth}depth 合并。现在的问题时怎么处理右边的东西。
不管是点分治还是边分治,每次计算时都有两个点集 S,TS,TS,T,要统计所有 x∈S,y∈Tx\in S,y\in Tx∈S,y∈T 的贡献。
考虑虚树。在第二棵树上用之前的代价标记 S,TS,TS,T 中的点,然后建出虚树,维护子树内两种集合中的权值最大值,在 lca\operatorname{lca}lca 处统计贡献。
这样复杂度是 O(SlogS)\Omicron(S\log S)O(SlogS),其中 SSS 为两个集合的点集大小。所以只能边分治。
用欧拉序做 O(nlogn)−O(1)\Omicron(n\log n)-\Omicron(1)O(nlogn)−O(1) lca\operatorname{lca}lca,总复杂度可以做到严格 O(nlogn)\Omicron(n\log n)O(nlogn)
码量虽大但没什么细节,还是比较好写的。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <algorithm>
#define MAXN 1000005
#define MAXM 2000005
using namespace std;
typedef long long ll;
const ll INF=1e18;
inline int read()
{int ans=0,f=1;char c=getchar();while (!isdigit(c)) (c=='-')&&(f=-1),c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return f*ans;
}
struct edge{int u,v,w;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt=1;
inline void addnode(int u,int v,int w)
{e[++cnt]=(edge){u,v,w};nxt[cnt]=head[u];head[u]=cnt;
}
vector<edge> E[MAXN];
ll dis[MAXN];
int vis[MAXM],n,tot;
void dfs(int u)
{vis[u]=1;if ((int)E[u].size()<=3){for (int i=0;i<(int)E[u].size();i++){int v=E[u][i].v,w=E[u][i].w;if (vis[v]) continue;dfs(v),addnode(u,v,w),addnode(v,u,w);}return; }int cur[2]={++tot,++tot},pos=0;addnode(u,cur[0],0),addnode(cur[0],u,0);addnode(u,cur[1],0),addnode(cur[1],u,0);for (int i=0;i<(int)E[u].size();i++){int v=E[u][i].v,w=E[u][i].w;if (vis[v]) continue;E[cur[pos]].push_back((edge){cur[pos],v,w}),pos^=1;}dfs(cur[0]),dfs(cur[1]);
}
void dfs(int u,int f)
{for (int i=head[u];i;i=nxt[i])if (e[i].v!=f)dis[e[i].v]=dis[u]+e[i].w,dfs(e[i].v,u);
}
int rt,siz[MAXN];
ll mn;
void findrt(int u,int f,int sum)
{siz[u]=1;for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f){findrt(e[i].v,u,sum);if (max(siz[e[i].v],sum-siz[e[i].v])<mn)mn=max(siz[e[i].v],sum-siz[e[i].v]),rt=i;siz[u]+=siz[e[i].v];}
}
namespace VT
{edge e[MAXM];int head[MAXN],nxt[MAXM],cnt;inline void addnode(int u,int v,int w){e[++cnt]=(edge){u,v,w};nxt[cnt]=head[u];head[u]=cnt;}int dfn[MAXN],lis[MAXM],LOG[MAXM],st[MAXM][21],tim;ll dis[MAXN];inline bool cmp(const int& x,const int& y){return dfn[x]<dfn[y];}void dfs(int u,int f){lis[dfn[u]=++tim]=u;for (int i=head[u];i;i=nxt[i])if (e[i].v!=f){ dis[e[i].v]=dis[u]+e[i].w;dfs(e[i].v,u);lis[++tim]=u;}}inline int lca(int x,int y){x=dfn[x],y=dfn[y];if (x>y) swap(x,y);int t=LOG[y-x+1];return min(st[x][t],st[y-(1<<t)+1][t],cmp);}void input(){LOG[0]=-1;for (int i=1;i<MAXM;i++) LOG[i]=LOG[i>>1]+1;for (int i=1;i<n;i++){int u,v,w;u=read(),v=read(),w=read();addnode(u,v,w),addnode(v,u,w);} dfs(1,0);for (int i=1;i<=tim;i++) st[i][0]=lis[i];for (int j=1;j<21;j++)for (int i=1;i+(1<<(j-1))<=tim;i++)st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1],cmp);}vector<int> s,son[MAXN];ll val[MAXN],x[MAXN],y[MAXN],ans;int type[MAXN];inline void insert(int u,ll v,int t){val[u]=v,type[u]=t,s.push_back(u);}void dfs(int u){x[u]=y[u]=-INF;if (type[u]==1) x[u]=val[u];if (type[u]==2) y[u]=val[u];for (int i=0;i<(int)son[u].size();i++){dfs(son[u][i]);ans=max(ans,max(x[u]+y[son[u][i]],x[son[u][i]]+y[u])-2*dis[u]);x[u]=max(x[u],x[son[u][i]]),y[u]=max(y[u],y[son[u][i]]);}}int stk[MAXN],tp;ll solve(){sort(s.begin(),s.end(),cmp);int siz=s.size();for (int i=0;i<siz-1;i++) s.push_back(lca(s[i],s[i+1]));sort(s.begin(),s.end(),cmp);s.erase(unique(s.begin(),s.end()),s.end());tp=0;for (int i=0;i<(int)s.size();i++){while (tp&&lca(stk[tp],s[i])!=stk[tp]) --tp;if (tp) son[stk[tp]].push_back(s[i]);stk[++tp]=s[i];}ans=-INF;dfs(stk[1]);for (int i=0;i<(int)s.size();i++) son[s[i]].clear();s.clear();return ans;}
}
void dfs(int u,int f,ll d,int type)
{if (u<=n) VT::insert(u,dis[u]+d,type);for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f)dfs(e[i].v,u,d+e[i].w,type);
}
ll ans=-INF;
void solve(int sum)
{if (mn==INF) return;vis[rt>>1]=1;dfs(e[rt].v,0,0,1);dfs(e[rt].u,0,0,2);ans=max(ans,VT::solve()+e[rt].w);int sz=siz[e[rt].v],cur=rt;mn=INF,findrt(e[cur].v,0,sz),solve(sz);mn=INF,findrt(e[cur].u,0,sum-sz),solve(sum-sz);
}
int main()
{tot=n=read();for (int i=1;i<n;i++){int u,v,w;u=read(),v=read(),w=read();E[u].push_back((edge){u,v,w}),E[v].push_back((edge){v,u,w});}VT::input();dfs(1);dfs(1,0);memset(vis,0,sizeof(vis));mn=INF,findrt(1,0,tot),solve(tot);for (int i=1;i<=n;i++) ans=max(ans,2*(dis[i]-VT::dis[i]));cerr<<ans<<'\n';cout<<(ans>>1);return 0;
}