problem
给定 nnn 个城市,n−1n-1n−1 条道路,形成一棵树。每座城市上的人口为 wiw_iwi。
现要修建若干个中心城镇,满足任意两个中心城镇之间的距离严格大于 kkk。
最大化中心城镇的总人口。
n,k≤106,wi≤109n,k\le 10^6,w_i\le 10^9n,k≤106,wi≤109。
solution
这种限制树上关键点彼此之间距离是比较经典的题目了,通常都会考虑 uuu 子树内离 uuu 最近的关键点的距离为多少,设计状态转移方程。
有非常套路的树背包,设 fu,i:uf_{u,i}:ufu,i:u 子树内离 uuu 最近的关键点深度为 iii(此深度是以 111 为根意义下,在整棵大树中的深度)的最多人口数。
有 fu,0=wuf_{u,0}=w_ufu,0=wu。考虑逐个加入 uuu 的子树 vvv。
-
(i−dep[u])<<1>k(i-dep[u])<<1>k(i−dep[u])<<1>k
gu,i←fu,i+fv,ig_{u,i}\leftarrow f_{u,i}+f_{v,i}gu,i←fu,i+fv,i
-
(i−dep[u])<<1≤k(i-dep[u])<<1\le k(i−dep[u])<<1≤k
gu,i=max{gu,i+1,fu,i+fv,k+2dep[u]−i+1,fu,l+2dep[u]−i+1+fv,i}g_{u,i}=\max\Big\{g_{u,i+1},f_{u,i}+f_{v,k+2dep[u]-i+1},f_{u,l+2dep[u]-i+1}+f_{v,i}\Big\}gu,i=max{gu,i+1,fu,i+fv,k+2dep[u]−i+1,fu,l+2dep[u]−i+1+fv,i}
-
gu,dep[u]=max{gu,dep[u]+1,fu,0+fv,depu+k+1}g_{u,dep[u]}=\max\Big\{g_{u,dep[u]+1},f_{u,0}+f_{v,dep_u+k+1}\Big\}gu,dep[u]=max{gu,dep[u]+1,fu,0+fv,depu+k+1}
最后更新回去 f←gf\leftarrow gf←g。
所求即为 f1,0f_{1,0}f1,0。
事实上,有用的只有 i∈[dep[u],dep[u]+k]i\in\big[dep[u],dep[u]+k\big]i∈[dep[u],dep[u]+k],这是 O(nk)O(nk)O(nk) 的。
事实上,有用的只有 i∈[dep[u],dep[u]+lenu]i\in\big[dep[u],dep[u]+len_u\big]i∈[dep[u],dep[u]+lenu],其中 lenu:ulen_u:ulenu:u 子树的高度(链长度),长链剖分优化,时间复杂度就只有 O(n)O(n)O(n)。
fu,i:uf_{u,i}:ufu,i:u 子树内离 uuu 最近的关键点,二者的相对距离为 iii 的最多人口数。
excuse me???
动态数组我真的会谢,卷爷 vector\text{vector}vector 都能跑过去,这是什么人啊!
code(vector—MLE)
#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define int long long
vector < int > G[maxn], f[maxn];
int n, k, ans;
int w[maxn], g[maxn], len[maxn], son[maxn];void dfs1( int u, int fa ) {for( int v : G[u] ) {if( v == fa ) continue;else dfs1( v, u );if( len[son[u]] < len[v] ) son[u] = v;}len[u] = len[son[u]] + 1;
}void dfs2( int u, int fa ) {f[u].resize( len[u] + 1 );f[u][0] = w[u];if( son[u] ) {dfs2( son[u], u );for( int i = 1;i < len[u];i ++ ) f[u][i] = f[son[u]][i - 1];if( k < len[u] ) f[u][0] += f[son[u]][k - 1];f[u][0] = max( f[u][0], f[son[u]][0] );}ans = max( ans, f[u][0] );for( int v : G[u] ) {if( v == fa or v == son[u] ) continue;else dfs2( v, u );for( int i = 0;i <= len[v];i ++ ) g[i] = f[u][i];for( int i = 0;i <= k and i <= len[v];i ++ ) {if( i > ( k >> 1 ) ) {if( i ) g[i] += f[v][i - 1];}else {if( 0 <= k - i - 1 and k - i - 1 < len[v] )g[i] = max( g[i], f[u][i] + f[v][k - i - 1] );if( 0 <= i - 1 and k - i < len[u] )g[i] = max( g[i], f[u][k - i] + f[v][i - 1] );}if( i ) g[i] = max( g[i], f[v][i - 1] );}for( int i = len[v];~ i;i -- ) {f[u][i] = g[i];if( i + 1 < len[u] ) f[u][i] = max( f[u][i], f[u][i + 1] );ans = max( ans, f[u][i] );}}
}signed main() {scanf( "%lld %lld", &n, &k ); k ++;for( int i = 1;i <= n;i ++ ) scanf( "%lld", &w[i] );for( int i = 1, u, v;i < n;i ++ ) {scanf( "%lld %lld", &u, &v );G[u].push_back( v );G[v].push_back( u );}dfs1( 1, 0 );dfs2( 1, 0 );printf( "%lld\n", ans );return 0;
}
code(动态数组—AC)
#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define int long long
vector < int > G[maxn];
int *f[maxn], *ip;
int pos[maxn << 2];
int n, k, ans;
int w[maxn], g[maxn], len[maxn], son[maxn];void dfs1( int u, int fa ) {for( int v : G[u] ) {if( v == fa ) continue;else dfs1( v, u );if( len[son[u]] < len[v] ) son[u] = v;}len[u] = len[son[u]] + 1;
}void dfs2( int u, int fa ) {// f[u].resize( len[u] + 1 );f[u][0] = w[u];if( son[u] ) {f[son[u]] = f[u] + 1;dfs2( son[u], u );// for( int i = 1;i < len[u];i ++ ) f[u][i] = f[son[u]][i - 1];if( k < len[u] ) f[u][0] += f[son[u]][k - 1];f[u][0] = max( f[u][0], f[son[u]][0] );}ans = max( ans, f[u][0] );for( int v : G[u] ) {if( v == fa or v == son[u] ) continue;else f[v] = ip, ip += len[v], dfs2( v, u );for( int i = 0;i <= len[v];i ++ ) g[i] = f[u][i];for( int i = 0;i <= k and i <= len[v];i ++ ) {if( i > ( k >> 1 ) ) {if( i ) g[i] += f[v][i - 1];}else {if( 0 <= k - i - 1 and k - i - 1 < len[v] )g[i] = max( g[i], f[u][i] + f[v][k - i - 1] );if( 0 <= i - 1 and k - i < len[u] )g[i] = max( g[i], f[u][k - i] + f[v][i - 1] );}if( i ) g[i] = max( g[i], f[v][i - 1] );}for( int i = len[v];~ i;i -- ) {f[u][i] = g[i];if( i + 1 < len[u] ) f[u][i] = max( f[u][i], f[u][i + 1] );ans = max( ans, f[u][i] );}}
}signed main() {scanf( "%lld %lld", &n, &k ); k ++;for( int i = 1;i <= n;i ++ ) scanf( "%lld", &w[i] );for( int i = 1, u, v;i < n;i ++ ) {scanf( "%lld %lld", &u, &v );G[u].push_back( v );G[v].push_back( u );}dfs1( 1, 0 );ip = pos;f[1] = ip;ip += len[1];dfs2( 1, 0 );printf( "%lld\n", ans );return 0;
}