题目
521. 运输计划
算法标签: 树上倍增, l c a lca lca, 前缀和, 树上差分, 二分
思路
注意到答案是具有二分性质的, 对于某个时间 m i d mid mid假设是最优答案, 小于该时间是不可以的, 但是大于该时间是可行的, 因此可以二分答案
这样就将问题转化为, 对于给定的时间 m i d mid mid, 将树中的一条边权变为 0 0 0, 所有的运输路线耗时是否 ≤ m i d \le mid ≤mid
可以将所有运输的路线分为两类, 一种是运输时间 ≤ m i d \le mid ≤mid的, 这种路线不要需要删除边
但是还有一种路线是 > m i d > mid >mid, 对于这些路线需要找个这些路线的公共边, 将这个公共边的权值变为 0 0 0, 但是直接枚举所有的边和路线会超时, 因此需要进行优化
可以在所有路线上的边 + 1 + 1 +1, 最终结果就是公共边被加了 t t t次, t t t是大于 m i d mid mid的路线的数量, 这样就找到了这个边, 利用树上差分, 实现对每个边 + 1 +1 +1的操作
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>using namespace std;const int N = 300010, M = N << 1, K = 19;int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int fa[N][K], depth[N], d[N];
struct Path {int u, v, p, d;
} path[N];
int s[N];void add(int u, int v, int val) {ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}void dfs(int u, int pre, int dep) {depth[u] = dep;for (int i = head[u]; ~i; i = ne[i]) {int v = ed[i];if (v == pre) continue;fa[v][0] = u;for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];d[v] = d[u] + w[i];dfs(v, u, dep + 1);}
}int lca(int u, int v) {if (depth[u] < depth[v]) swap(u, v);for (int k = K - 1; k >= 0; --k) {if (depth[fa[u][k]] >= depth[v]) {u = fa[u][k];}}if (u == v) return v;for (int k = K - 1; k >= 0; --k) {if (fa[u][k] != fa[v][k]) {u = fa[u][k];v = fa[v][k];}}return fa[u][0];
}void dfs_sum(int u, int pre) {for (int i = head[u]; ~i; i = ne[i]) {int v = ed[i];if (v == pre) continue;dfs_sum(v, u);s[u] += s[v];}
}bool check(int mid) {memset(s, 0, sizeof s);int c = 0, max_d = 0;for (int i = 0; i < m; ++i) {auto [u, v, p, val] = path[i];if (val > mid) {c++;max_d = max(max_d, val);s[u]++;s[v]++;s[p] -= 2;}}if (c == 0) return true;dfs_sum(1, -1);for (int u = 2; u <= n; ++u) {if (s[u] == c && max_d - (d[u] - d[fa[u][0]]) <= mid) {return true;}}return false;
}int main() {ios::sync_with_stdio(false);cin.tie(0), cout.tie(0);memset(head, -1, sizeof head);cin >> n >> m;for (int i = 0; i < n - 1; ++i) {int u, v, w;cin >> u >> v >> w;add(u, v, w), add(v, u, w);}dfs(1, -1, 1);for (int i = 0; i < m; ++i) {int u, v;cin >> u >> v;int p = lca(u, v);int dis = d[u] + d[v] - 2 * d[p];path[i] = {u, v, p, dis};}int l = 0, r = 3e8;while (l < r) {int mid = l + r >> 1;if (check(mid)) r = mid;else l = mid + 1;}cout << l << "\n";return 0;
}
* v e c t o r vector vector存邻接表会超时
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>using namespace std;typedef pair<int, int> PII;
const int N = 300010, M = N << 1, K = 19;int n, m;
vector<PII> head[N];
int fa[N][K], depth[N], d[N];
struct Path {int u, v, p, d;
};
vector<Path> path;
int s[M];void init() {path.resize(m + 1);
}void add(int u, int v, int w) {head[u].push_back({v, w});
}void dfs(int u, int pre, int dep) {depth[u] = dep;for (auto [v, w] : head[u]) {if (v == pre) continue;fa[v][0] = u;for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];d[v] = d[u] + w;dfs(v, u, dep + 1);}
}int lca(int u, int v) {if (depth[u] < depth[v]) swap(u, v);for (int k = K - 1; k >= 0; --k) {if (depth[fa[u][k]] >= depth[v]) {u = fa[u][k];}}if (u == v) return u;for (int k = K - 1; k >= 0; --k) {if (fa[u][k] != fa[v][k]) {u = fa[u][k];v = fa[v][k];}}return fa[u][0];
}void dfs_sum(int u, int fa) {for (auto [v, w] : head[u]) {if (v == fa) continue;dfs_sum(v, u);s[u] += s[v];}
}bool check(int mid) {memset(s, 0, sizeof s);int cnt = 0, max_d = 0;for (auto [u, v, p, dis] : path) {if (dis > mid) {cnt++;s[u]++;s[v]++;s[p] -= 2;max_d = max(max_d, dis);}}if (cnt == 0) return true;dfs_sum(1, -1);for (int u = 2; u <= n; ++u) {if (s[u] == cnt && max_d - (d[u] - d[fa[u][0]]) <= mid) return true;}return false;
}int main() {ios::sync_with_stdio(false);cin.tie(0), cout.tie(0);cin >> n >> m;init();for (int i = 0; i < n - 1; ++i) {int u, v, w;cin >> u >> v >> w;add(u, v, w), add(v, u, w);}dfs(1, -1, 1);for (int i = 0; i < m; ++i) {int u, v;cin >> u >> v;int p = lca(u, v);path[i] = {u, v, p, d[u] + d[v] - 2 * d[p]};}int l = 0, r = 3e8;while (l < r) {int mid = l + r >> 1;if (check(mid)) r = mid;else l = mid + 1;}cout << l << "\n";return 0;
}