解析
只会扫描线树剖的三只log(悲
考虑对每个 uuu 考虑合法的 vvv 的集合,必然是一个联通块。
进一步的,观察到这个联通块就是由所有经过 uuu 的路径的端点形成的最小生成树。
我们有一个最小生成树的经典结论:最小生成树边权和等于按dfs序排列成圆后邻项距离和除以二,不难发现可以线段树维护。
把所有路径做一个树上差分,再结合线段树合并,即可进行求解了。
用欧拉序 st 表 O(1)O(1)O(1)求LCA,总复杂度 O((n+m)logn)O((n+m)\log n)O((n+m)logn)
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned ll
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define ok debug("OK\n")
inline ll read() {ll x(0),f(1);char c=getchar();while(!isdigit(c)) {if(c=='-') f=-1;c=getchar();}while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}return x*f;
}const int N=2e5+100;
const double inf=1e18;
const int mod=998244353;inline ll ksm(ll x,ll k){ll res=1;while(k){if(k&1) res=res*x%mod;x=x*x%mod;k>>=1;}return res;
}int n,m;vector<int>e[N];
int fa[N],dfn[N],pos[N],siz[N],dep[N],tim;
int q[N],in[N],ed;
int mn[N][22],mi[50],lg[N];
inline int Min(int x,int y){return dep[x]<dep[y]?x:y;}
void init(){lg[0]=-1;for(int i=1;i<=ed;i++) lg[i]=lg[i>>1]+1;mi[0]=1;for(int i=1;i<=lg[ed];i++) mi[i]=mi[i-1]<<1;for(int i=1;i<=ed;i++) mn[i][0]=q[i];for(int k=1;k<=lg[ed];k++){for(int i=1;i+mi[k]-1<=ed;i++){mn[i][k]=Min(mn[i][k-1],mn[i+mi[k-1]][k-1]);}}return;
}
inline int Lca(int x,int y){int l=in[x],r=in[y];if(l>r) swap(l,r);int k=lg[r-l+1];return Min(mn[l][k],mn[r-mi[k]+1][k]);
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[Lca(x,y)];
}
void dfs(int x,int f){fa[x]=f;dep[x]=dep[f]+1;siz[x]=1;dfn[++tim]=x;pos[x]=tim;q[++ed]=x;in[x]=ed;for(int to:e[x]){if(to==f) continue;dfs(to,x);siz[x]+=siz[to];q[++ed]=x;}return;
}
struct node{int l,r,s;
};
node operator + (const node &x,const node &y){node res;if(x.l==0) return y;if(y.l==0) return x;res.l=x.l;res.r=y.r;res.s=x.s+y.s;res.s-=dis(x.l,x.r)+dis(y.l,y.r);res.s+=dis(res.l,res.r); res.s+=dis(x.r,y.l);return res;
}
#define mid ((l+r)>>1)
struct tree{int ls,rs,num;node o;
}tr[N*30];
int tot,rt[N];
inline void pushup(int k){tr[k].o=tr[tr[k].ls].o+tr[tr[k].rs].o;
}
void upd(int &k,int l,int r,int p,int w){if(!k) k=++tot;if(l==r){tr[k].num+=w;if(tr[k].num) tr[k].o=(node){dfn[l],dfn[l],0};else tr[k].o=(node){0,0,0};return;}if(p<=mid) upd(tr[k].ls,l,mid,p,w);else upd(tr[k].rs,mid+1,r,p,w);pushup(k);
}
int merge(int x,int y,int l,int r){if(!x||!y) return x|y;int now=++tot;if(l==r){tr[now].num=tr[x].num+tr[y].num;if(tr[now].num) tr[now].o=(node){dfn[l],dfn[l],0};else tr[now].o=(node){0,0,0}; return now;}tr[now].ls=merge(tr[x].ls,tr[y].ls,l,mid);tr[now].rs=merge(tr[x].rs,tr[y].rs,mid+1,r);pushup(now);return now;
}ll ans;
struct ope{int x,y,w;
};
vector<ope>v[N];
void calc(int x,int f){for(int to:e[x]){if(to==f) continue;calc(to,x);rt[x]=merge(rt[x],rt[to],1,n);}for(ope o:v[x]){upd(rt[x],1,n,pos[o.x],o.w);upd(rt[x],1,n,pos[o.y],o.w);}ans+=tr[rt[x]].o.s/2;
}signed main(){#ifndef ONLINE_JUDGEfreopen("a.in","r",stdin);freopen("a.out","w",stdout);#endifn=read();m=read();for(int i=1;i<n;i++){int x=read(),y=read();e[x].push_back(y);e[y].push_back(x);}dfs(1,0);init();for(int i=1;i<=m;i++){int x=read(),y=read(),lca=Lca(x,y);v[x].emplace_back((ope){x,y,1});v[y].emplace_back((ope){x,y,1});v[lca].emplace_back((ope){x,y,-1});v[fa[lca]].emplace_back((ope){x,y,-1});}calc(1,0);printf("%lld\n",ans/2);return 0;
}
/*
*/