P2495 [SDOI2011]消耗战
代码
有的虚树建立好像把一些点没建,他们不用判断是否是关键点;
il void push(int x)
{if(t == 1) {s[++ t] = x;return;}int l = lca(x, s[t]); if(l == s[t]) return; //这句话我没看懂,因该就是这,脑子好乱;while(t > 1 && dfn[s[t - 1]] >= dfn[l]) v[s[t - 1]].push_back(s[t]), --t;if(s[t] != l) v[l].push_back(s[t]), s[t] = l;s[++ t] = x;
}
摘自 洛谷 Nemlit P2495 [SDOI2011]消耗战 题解
这建虚树好精简;
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <deque>
#include <stack> #define mid (l+r>>1)
#define lowbit(x) (x&-x)
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
const int N = 2.5e5 + 4;int h[N], ne[N<<1], e[N<<1], w[N<<1], idx;
void add(int a, int b, int c){e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;return ;}int dfn[N], cv, fa[N][20], deep[N];
LL mi[N];
void dfs(int u, int f)
{dfn[u] = cv++; deep[u] = deep[f]+1;for(int i = 1;i < 20;i ++) fa[u][i] = fa[fa[u][i-1]][i-1];for(int i = h[u];~i ;i = ne[i]){if(e[i] == f)continue;fa[e[i]][0] = u; mi[e[i]] = min(mi[u], (LL)w[i]);dfs(e[i], u);}
}
int lca(int u, int v)
{if(deep[u]<deep[v])swap(u, v);int x = deep[u] - deep[v];for(int i = 0;i < 20;i ++)if(x>>i&1)u = fa[u][i];if(u == v)return v;for(int i = 19;i >= 0;i --)if(fa[u][i] != fa[v][i])u = fa[u][i], v = fa[v][i];return fa[u][0];
}int a[N];
bool dis[N];
bool cmp(int a,int b){return dfn[a]<dfn[b];}LL dfs(int u)
{if(h[u] == -1)return mi[u]; LL ans = 0;for(int i = h[u]; ~i;i = ne[i])if(dis[e[i]])ans += mi[e[i]], dfs(e[i]);else ans += dfs(e[i]);h[u] = -1;return min(mi[u], ans);
}
stack<int> st;
LL build(int n)
{st.push(1); idx = 0;int tmp;for(int i = 1;i <= n;i ++){int lc = lca(a[i], st.top());while(lc != st.top()){tmp = st.top(); st.pop();if(dfn[st.top()] < dfn[lc])st.push(lc);add(st.top(), tmp, 0);}st.push(a[i]);dis[a[i]] = 1;}while(st.size() > 1){tmp = st.top(); st.pop();add(st.top(), tmp, 0);}st.pop();return dfs(1);
}
int main()
{int n, m;scanf("%d", &n);for(int i = 1;i <= n;i ++) h[i] = -1;memset(mi, 0x7f, sizeof mi);for(int i = 1;i < n;i ++){int a, b, c;scanf("%d%d%d", &a, &b, &c);add(a, b, c);add(b, a, c);}dfs(1, 0);for(int i = 1;i <= n;i ++)h[i] = -1;scanf("%d", &m);while(m--){int k;scanf("%d", &k);for(int i = 1;i <= k;i ++) scanf("%d", a+i);sort(a+1, a+k+1, cmp);printf("%lld\n", build(k));for(int i = 1;i <= k;i ++) dis[a[i]] = 0;} return 0;
}