树上前缀和+LCA
暴力做法:
我们先把不删的sum维护出来,然后遍历跳过的点,假如a1,a2,a3,跳过2,那么答案就是sum-cost(a1,a2)-cost(a2,a3)+cost(a1,a3).
DFS暴力,下面是代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int k,n;
typedef pair<int,int> pii;
int a[100010];
vector<pii> edge[100010];
map<pii,ll> st;
bool dfs(int s,int u,int fa,int v,ll sum)
{if(u==v){st[{s,v}]=sum;st[{v,s}]=sum;return 1;}for(int i=0;i<edge[u].size();i++){int son=edge[u][i].first;if(fa==son) continue;if(dfs(s,son,u,v,sum+edge[u][i].second)) return 1;}return 0;
}
int main()
{cin>>n>>k;for(int i=1;i<=n-1;i++){int u,v,t;cin>>u>>v>>t;edge[u].push_back({v,t});edge[v].push_back({u,t});}for(int i=1;i<=k;i++) cin>>a[i];ll ans=0;for(int i=1;i<=k;i++){dfs(a[i],a[i],-1,a[i+1],0);ans+=st[{a[i],a[i+1]}];}for(int i=1;i<=k;i++){ll tp=ans;if(i==1) tp-=st[{a[i],a[i+1]}];if(i==k) tp-=st[{a[i-1],a[i]}];if(i>1&&i<k){tp-=st[{a[i-1],a[i]}]+st[{a[i],a[i+1]}];dfs(a[i-1],a[i-1],-1,a[i+1],0);tp+=st[{a[i-1],a[i+1]}];}cout<<tp<<" ";}
}
正确做法:
我们先预处理出各个点到根节点的距离就是树上前缀和,答案就是sum[a1]+sum[a2]-2*sum[fa],fa为a1,a2的最近公共祖先,下面是LCA用倍增实现的板子:
https://www.luogu.com.cn/problem/P3379
#include<bits/stdc++.h>
using namespace std;
int n,m,s;
vector<int> edge[500010];
int dep[500010];
int fa[500010][22];
int maxd=21;
void dfs(int x,int fath)
{if(x!=s){fa[x][0]=fath;dep[x]=dep[fath]+1;for(int i=1;(1<<i)<=n;i++) fa[x][i]=fa[fa[x][i-1]][i-1];}for(int i=0;i<edge[x].size();i++){int ck=edge[x][i];if(ck==fath) continue;dfs(ck,x);}
}
int up(int x,int d){int ret=x;for(int i=0;(1<<i)<=n;i++){if(((1<<i)&d)!=0) ret=fa[ret][i];}return ret;
}
int lca(int x,int y)
{if(dep[x]<dep[y]) swap(x,y);x=up(x,dep[x]-dep[y]);if(x==y) return x;for(int i=maxd;i>=0;i--){if(fa[x][i]!=fa[y][i]){x=fa[x][i],y=fa[y][i];}}return fa[x][0];
}
int main()
{cin>>n>>m>>s;for(int i=1;i<=n-1;i++){int x,y;cin>>x>>y;edge[x].push_back(y);edge[y].push_back(x);}dep[s]=1;dfs(s,-1);while(m--){int a1,b1;cin>>a1>>b1;cout<<lca(a1,b1)<<endl;}
}
本题的AC代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int k,n;
typedef pair<int,ll> pii;
int a[100010];
vector<pii> edge[100010];
map<pii,ll> st;
ll sum[100010];
int dep[100010];
int fa[100010][22];
int maxd=21;
void dfss(int x,int fath)
{if(x!=1){fa[x][0]=fath;dep[x]=dep[fath]+1;for(int i=1;(1<<i)<=n;i++) fa[x][i]=fa[fa[x][i-1]][i-1];}for(int i=0;i<edge[x].size();i++){int ck=edge[x][i].first;if(ck==fath) continue;dfss(ck,x);}
}
int up(int x,int d){int ret=x;for(int i=0;(1<<i)<=n;i++){if(((1<<i)&d)!=0) ret=fa[ret][i];}return ret;
}
int lca(int x,int y)
{if(dep[x]<dep[y]) swap(x,y);x=up(x,dep[x]-dep[y]);if(x==y) return x;for(int i=maxd;i>=0;i--){if(fa[x][i]!=fa[y][i]){x=fa[x][i],y=fa[y][i];}}return fa[x][0];
}
void dfs(int x,int fa)
{for(int i=0;i<edge[x].size();i++){int ck=edge[x][i].first;if(ck==fa) continue;ll num=edge[x][i].second;sum[ck]=sum[x]+num;dfs(ck,x);}
}
ll dis(int x,int y)
{ll zhi=sum[x]+sum[y];zhi-=2*sum[lca(x,y)];return zhi;
}
int main()
{cin>>n>>k;for(int i=1;i<=n-1;i++){int u,v,t;cin>>u>>v>>t;edge[u].push_back({v,t});edge[v].push_back({u,t});}for(int i=1;i<=k;i++) cin>>a[i];dep[1]=1;sum[1]=0;dfs(1,-1);dfss(1,-1);ll summ=0;for(int i=1;i<=k-1;i++) summ+=dis(a[i],a[i+1]);for(int i=1;i<=k;i++){ll ans=summ;if(i>1&&i<k){ans-=dis(a[i-1],a[i])+dis(a[i],a[i+1]);ans+=dis(a[i-1],a[i+1]);}if(i==1) ans-=dis(a[1],a[2]);if(i==k) ans-=dis(a[k-1],a[k]);cout<<ans<<" ";}}