题目链接:http://poj.org/problem?id=3728
思路:题目的意思是求树上a -> b的路径上的最大收益(在最小值买入,在最大值卖出)。
我们假设路径a - > b 之间的LCA(a, b) = f, 并且另up[a]表示a - > f之间的最大收益,down[a]表示f - > a之间的最大收益,dp_max[a]表示a - > f之间的最大值,dp_min[a]表示a - > f之间的最小值,于是可以得出关系: ans[id] = max(max(up[a], down[b]), dp_max[b] - dp_min[a])。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;const int MAX_N = (50000 + 5000);
const int MAX_M = (MAX_N << 2);
const int inf = 0x3f3f3f3f;
int NE1, NE2, NE3, head1[MAX_N], head2[MAX_N], head3[MAX_N];void Init()
{NE1 = NE2 = NE3 = 0;memset(head1, -1, sizeof(head1));memset(head2, -1, sizeof(head2));memset(head3, -1, sizeof(head3));}int N, Q, ans[MAX_N], value[MAX_N], vis[MAX_N];struct Edge1 {int v, next;
} edge1[MAX_M];void Insert1(int u, int v)
{edge1[NE1].v = v;edge1[NE1].next = head1[u];head1[u] = NE1++;
}struct Edge {int v, id, next;
} edge2[MAX_M], edge3[MAX_M];void Insert2(int u, int v, int id, int flag)
{if (!flag) {edge2[NE2].v = v;edge2[NE2].id = id;edge2[NE2].next = head2[u];head2[u] = NE2++;} else {edge3[NE3].v = v;edge3[NE3].id = id;edge3[NE3].next = head3[u];head3[u] = NE3++;}
}int parent[MAX_N];
int up[MAX_N], down[MAX_N], dp_max[MAX_N], dp_min[MAX_N];int find(int x)
{if (x == parent[x]) {return x;}int fa = parent[x];parent[x] = find(parent[x]);up[x] = max(max(up[x], up[fa]), dp_max[fa] - dp_min[x]);down[x] = max(max(down[x], down[fa]), dp_max[x] - dp_min[fa]);dp_max[x] = max(dp_max[x], dp_max[fa]);dp_min[x] = min(dp_min[x], dp_min[fa]);return parent[x];
}struct Node {int u, v;
} node[MAX_N];void Tarjan(int u)
{vis[u] = 1;parent[u] = u;//Q;for (int i = head2[u]; ~i; i = edge2[i].next) {int v = edge2[i].v, id = edge2[i].id;if (!vis[v]) continue;int fa = find(v);Insert2(fa, v, id, 1);}for (int i = head1[u]; ~i; i = edge1[i].next) {int v = edge1[i].v;if (vis[v]) continue;Tarjan(v);parent[v] = u;}//edge3for (int i = head3[u]; ~i; i = edge3[i].next) {int id = edge3[i].id;find(node[id].u);find(node[id].v);ans[id] = max(max(up[node[id].u], down[node[id].v]), dp_max[node[id].v] - dp_min[node[id].u]);}
}int main()
{while (~scanf("%d", &N)) {for (int i = 1; i <= N; ++i) {scanf("%d", &value[i]);up[i] = down[i] = 0;dp_max[i] = dp_min[i] = value[i];}Init();for (int i = 1; i < N; ++i) {int u, v;scanf("%d %d", &u, &v);Insert1(u, v);Insert1(v, u);}scanf("%d", &Q);for (int i = 1; i <= Q; ++i) {scanf("%d %d", &node[i].u, &node[i].v);Insert2(node[i].u, node[i].v, i, 0);Insert2(node[i].v, node[i].u, i, 0);}memset(vis, 0, sizeof(vis));Tarjan(1);for (int i = 1; i <= Q; ++i) printf("%d\n", ans[i]);}return 0;
}