正题
jzoj 7181
题目大意
给你由若干铁路组成的图(一个铁路上有若干点),问你从1到n在最短路径的前提下,乘坐的每一条铁路所花费时间的平方和的最大值
解题思路
先用dij跑出最短路图(即长度等于最短路的所有路径)
然后在最短路图上DP,设fif_ifi为到第i个点满足最短路的最大答案,转移方程
fi=max(fj+(disi−disj)2)f_i = max\ (f_j+(dis_i-dis_j)^2)fi=max (fj+(disi−disj)2)
对于该方程,可以用斜率优化
当计算到i时,枚举所在铁路,然后斜率优化(因为取max,且dis单调递增,所以用单调栈维护上凸壳)
代码
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define mp make_pair
#define fs first
#define sn second
#define sz st[g].size()
#define l1 st[g][sz - 1]
#define l2 st[g][sz - 2]
#define N 1000100
using namespace std;
int n, m, x, y, nn, tot, cnt, p[N], q[N], h[N];
ll z, f[N], X[N], Y[N], dis[N];
vector<pair<int, int> >w[N];
vector<pair<int, ll> >a[N];
vector<int>st[N<<1], wst[N];
priority_queue<pair<ll, int> >d;
struct rec
{int to, next;ll l;
}e[N];
void add(int x, int y, ll z)
{e[++tot].to = y;e[tot].l = z;e[tot].next = h[x];h[x] = tot;
}
void dij()
{memset(dis, 0x7f, sizeof(dis));dis[1] = 0;d.push(mp(0, 1));while(!d.empty()){int u = d.top().sn;d.pop();if (p[u]) continue;p[u] = 1;for (int i = h[u]; i; i = e[i].next){int v = e[i].to;if (dis[u] + e[i].l < dis[v]){dis[v] = dis[u] + e[i].l;d.push(mp(-dis[v], v));}}}return;
}
bool cmp(int x, int y)
{return dis[x] < dis[y];
}
void solve()
{for (int i = 1; i <= n; ++i)q[i] = i;//按dis枚举sort(q + 1, q + 1 + n, cmp);for (int i = 1; i <= n; ++i){int u = q[i];for (int j = 0; j < w[u].size(); ++j){x = w[u][j].fs;y = w[u][j].sn;if (y && dis[a[x][y - 1].fs] + a[x][y - 1].sn == dis[u]) wst[x][y] = wst[x][y - 1];//一个铁路可能在不同的最短路中,先判断在哪一段中else wst[x][y] = ++cnt;int g = wst[x][y];while(sz >= 2 && Y[l1] - Y[l2] < 2ll * dis[u] * (X[l1] - X[l2]))//斜率优化st[g].pop_back();if (sz){int v = st[g][sz - 1];f[u] = max(f[u], f[v] + (dis[u] - dis[v]) * (dis[u] - dis[v]));}}Y[u] = f[u] + dis[u] * dis[u];X[u] = dis[u];for (int j = 0; j < w[u].size(); ++j){x = w[u][j].fs;y = w[u][j].sn;int g = wst[x][y];while (sz >= 2 && (Y[l1] - Y[l2]) * (X[u] - X[l1]) < (Y[u] - Y[l1]) * (X[l1] - X[l2]))st[g].pop_back();st[g].push_back(u);}}
}
int main()
{scanf("%d%d", &n, &m);for (int i = 1; i <= m; ++i){scanf("%d", &nn);scanf("%d", &x);for (int j = 1; j <= nn; ++j){scanf("%lld%d", &z, &y);add(x, y, z);w[x].push_back(mp(i,j - 1));//在哪一段铁路的那个点a[i].push_back(mp(x,z));//插入到第i段铁路中,到下一个点的距离为zwst[i].push_back(0);x = y;}w[x].push_back(mp(i,nn));a[i].push_back(mp(x,0));wst[i].push_back(0);}dij();solve();printf("%lld %lld", dis[n], f[n]);return 0;
}