题意:一棵 nnn 个点的树,每个点有两个权值 ai,bia_i,b_iai,bi,有黑白两种颜色。mmm 次询问,每次给定一个 kkk,求一条端点异色的路径,使得 k∑ai+∑bik\sum a_i+\sum b_ik∑ai+∑bi 最大化。
n≤2×105n\leq 2\times 10^5n≤2×105
就差把“请在边分治的时候维护闵可夫斯基和”写题面上了……
直观来看是把 (ai,bi)(a_i,b_i)(ai,bi) 看成直线,但是并不好维护。不过半平面交和凸包是对偶的,并且凸包合并有闵可夫斯基和这个东西,所以可以把每个结点直接看成点维护凸包,询问的时候二分就可以了。
分治的时候对两边黑白分别求出凸壳,然后交叉合并,把点丢到答案集合里最后再做一个凸包就可以了。
然后就是三度化后处理答案的问题。把新建的虚点的参数从父亲那里复制,然后如果路径的 lca 是虚点强行给它加上去。可以在 dfs 的时候记一个当前走的方向,然后如果是往上走的可以进行一次换方向,并在这里统计当前点是否需要有贡献。
复杂度 O(nlogn)\Omicron(n\log n)O(nlogn)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <algorithm>
#define MAXN 200005
#define MAXM 400005
using namespace std;
const int INF=0x7fffffff;
inline int read()
{int ans=0,f=1;char c=getchar();while (!isdigit(c)) (c=='-')&&(f=-1),c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return f*ans;
}
typedef long long ll;
struct point{int x,y;inline point(const int& x=0,const int& y=0):x(x),y(y){}};
inline point operator +(const point& a,const point& b){return point(a.x+b.x,a.y+b.y);}
inline point operator -(const point& a,const point& b){return point(a.x-b.x,a.y-b.y);}
inline ll operator *(const point& a,const point& b){return (ll)a.x*b.y-(ll)a.y*b.x;}
inline bool operator <(const point& a,const point& b){return a.x<b.x||(a.x==b.x&&a.y>b.y);}
typedef vector<point> hull;
#define s(t) ((int)(t).size())
void make_hull(hull& A)
{hull t;sort(A.begin(),A.end());for (int i=0;i<(int)A.size();i++){if (i&&A[i].x==A[i-1].x) continue;while (s(t)>1&&(t[s(t)-1]-t[s(t)-2])*(A[i]-t[s(t)-1])>0) t.pop_back();t.push_back(A[i]);}A=t;
}
void merge(const hull& A,const hull& B,hull& C)
{if (A.empty()||B.empty()) return;int i,j;hull ans;for (i=j=0;i<s(A)-1&&j<s(B)-1;){if ((A[i+1]-A[i])*(B[j+1]-B[j])>0) ans.push_back(B[j+1]-B[j]),++j;else ans.push_back(A[i+1]-A[i]),++i;}while (i<s(A)-1) ans.push_back(A[i+1]-A[i]),++i;while (j<s(B)-1) ans.push_back(B[j+1]-B[j]),++j;point las=A[0]+B[0];C.push_back(las);for (int i=0;i<s(ans);i++) C.push_back(las=las+ans[i]);
}
vector<int> E[MAXN];
struct edge{int u,v;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt=1;
inline void addnode(int u,int v)
{e[++cnt]=(edge){u,v};nxt[cnt]=head[u];head[u]=cnt;
}
point val[MAXN];
int vis[MAXN],dep[MAXN],type[MAXN],n,tot;
void dfs(int u)
{vis[u]=1;if ((int)E[u].size()<=3){for (int i=0;i<(int)E[u].size();i++)if (!vis[E[u][i]])dep[E[u][i]]=dep[u]+1,dfs(E[u][i]),addnode(u,E[u][i]),addnode(E[u][i],u); return; }int cur[2]={++tot,++tot},pos=0;addnode(u,cur[0]),addnode(cur[0],u);addnode(u,cur[1]),addnode(cur[1],u);val[cur[0]]=val[cur[1]]=val[u];type[cur[0]]=type[cur[1]]=type[u];dep[cur[0]]=dep[cur[1]]=dep[u]+1;for (int i=0;i<(int)E[u].size();i++)if (!vis[E[u][i]])E[cur[pos^=1]].push_back(E[u][i]);dfs(cur[0]),dfs(cur[1]);
}
int tp;
int rt,mn,siz[MAXN];
void findrt(int u,int f,int sum)
{siz[u]=1;for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f){findrt(e[i].v,u,sum);if (max(siz[e[i].v],sum-siz[e[i].v])<mn) mn=max(siz[e[i].v],sum-siz[e[i].v]),rt=i;siz[u]+=siz[e[i].v];}
}
hull A[2],B[2],lis;
void dfs(int u,int f,point cur,hull* A,bool up)
{if (u<=n) A[type[u]].push_back(cur=cur+val[u]);for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f&&(up^(dep[e[i].v]>dep[u])))dfs(e[i].v,u,cur,A,up);if (up){if (u>n) A[type[u]].push_back(cur=cur+val[u]);for (int i=head[u];i;i=nxt[i])if (!vis[i>>1]&&e[i].v!=f&&dep[e[i].v]>dep[u])dfs(e[i].v,u,cur,A,0);}
}
void calc()
{A[0].clear(),A[1].clear();B[0].clear(),B[1].clear();int u=e[rt].u,v=e[rt].v;if (dep[u]>dep[v]) swap(u,v);dfs(v,0,point(0,0),A,0);dfs(u,0,point(0,0),B,1);make_hull(A[0]),make_hull(A[1]);make_hull(B[0]),make_hull(B[1]);merge(A[0],B[1],lis);merge(A[1],B[0],lis);
}
void solve(int sum)
{if (mn==INF) return;vis[rt>>1]=1;calc();int cur=rt,sz=siz[e[rt].v];mn=INF,findrt(e[cur].v,0,sz),solve(sz);mn=INF,findrt(e[cur].u,0,sum-sz),solve(sum-sz);
}
int main()
{tot=n=read();int m=read();for (int i=1;i<=n;i++) val[i].x=read();for (int i=1;i<=n;i++) val[i].y=read();for (int i=1;i<=n;i++) type[i]=read();for (int i=1;i<n;i++){int u,v;u=read(),v=read();E[u].push_back(v),E[v].push_back(u);}dep[1]=1,dfs(1);memset(vis,0,sizeof(vis));mn=INF,findrt(1,0,tot),solve(tot);make_hull(lis);
// for (int i=0;i<(int)lis.size();i++) printf("%d %d\n",lis[i].x,lis[i].y);lis.push_back(point(lis.back().x,-INF));while (m--){int k=read();int l=0,r=s(lis)-2,mid;while (l<r){mid=(l+r)>>1;if (point(1,-k)*(lis[mid+1]-lis[mid])>0) l=mid+1;else r=mid;}printf("%lld\n",(ll)k*lis[l].x+lis[l].y);}return 0;
}