题意:给定一棵nnn个点的二叉树,叶子的权值输入给定且互不相同,非叶子结点iii的权值有pip_ipi的概率为儿子结点权值最大值,1−pi1-p_i1−pi的概率为最小值。求根结点取每种值的概率。模998244353998244353998244353。
n≤3×105n\leq 3\times 10^5n≤3×105
这都能线段树合并……觉了
设f(u,x)f(u,x)f(u,x)为uuu点值为xxx的概率,l,rl,rl,r为它的左右儿子
容易写出
f(u,x)=px[f(l,x)∑i=1x−1f(r,i)+f(r,x)∑i=1x−1f(l,i)]+(1−px)[f(l,x)∑i=x+1mf(r,i)+f(r,x)∑i=x+1mf(l,i)]f(u,x)=p_x[f(l,x)\sum_{i=1}^{x-1}f(r,i)+f(r,x)\sum_{i=1}^{x-1}f(l,i)]+(1-p_x)[f(l,x)\sum_{i=x+1}^mf(r,i)+f(r,x)\sum_{i=x+1}^mf(l,i)]f(u,x)=px[f(l,x)i=1∑x−1f(r,i)+f(r,x)i=1∑x−1f(l,i)]+(1−px)[f(l,x)i=x+1∑mf(r,i)+f(r,x)i=x+1∑mf(l,i)]
考虑线段树合并
设当前合并的区间是[L,R][L,R][L,R],在递归的时候顺便维护两个线段树结点[1,L−1][1,L-1][1,L−1]和[R+1,m][R+1,m][R+1,m]的和,乘到f(l,x)f(l,x)f(l,x)和f(r,x)f(r,x)f(r,x)上面,维护一个乘法标记。
文字不太好讲清楚,建议直接看代码。
复杂度O(nlogn)O(n\log n)O(nlogn)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#define MAXN 300005
using namespace std;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int qpow(int a,int p)
{int ans=1;while (p){if (p&1) ans=(ll)ans*a%MOD;a=(ll)a*a%MOD;p>>=1;}return ans;
}
namespace SGT
{int ch[MAXN<<5][2],sum[MAXN<<5],mul[MAXN<<5],cnt;inline void update(int x){sum[x]=(sum[ch[x][0]]+sum[ch[x][1]])%MOD;}inline void pushmul(int x,int v){sum[x]=(ll)sum[x]*v%MOD,mul[x]=(ll)mul[x]*v%MOD;}inline void pushdown(int x){if (mul[x]!=1){pushmul(ch[x][0],mul[x]),pushmul(ch[x][1],mul[x]);mul[x]=1;}}inline int newnode(){return ++cnt,sum[cnt]=mul[cnt]=1,cnt;}void insert(int& x,int l,int r,int k){x=newnode();if (l==r) return;int mid=(l+r)>>1;if (k<=mid) insert(ch[x][0],l,mid,k);else insert(ch[x][1],mid+1,r,k);}int merge(int x,int y,int l,int r,int xmul,int ymul,int v){if (!x&&!y) return 0;if (!x) return pushmul(y,ymul),y;if (!y) return pushmul(x,xmul),x;int mid=(l+r)>>1;pushdown(x),pushdown(y);int xl=sum[ch[x][0]],xr=sum[ch[x][1]],yl=sum[ch[y][0]],yr=sum[ch[y][1]];ch[x][0]=merge(ch[x][0],ch[y][0],l,mid,(xmul+(MOD+1ll-v)*yr)%MOD,(ymul+(MOD+1ll-v)*xr)%MOD,v);ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r,(xmul+(ll)v*yl)%MOD,(ymul+(ll)v*xl)%MOD,v);return update(x),x;}void getans(int x,int l,int r,int* &ans){if (l==r) return (void)(*(ans++)=sum[x]);pushdown(x);int mid=(l+r)>>1;getans(ch[x][0],l,mid,ans),getans(ch[x][1],mid+1,r,ans);}
}
using SGT::insert;
using SGT::merge;
using SGT::getans;
int rt[MAXN],ch[MAXN][2],p[MAXN],v[MAXN],m;
void dfs(int u)
{if (!ch[u][0]) return insert(rt[u],1,m,p[u]);dfs(ch[u][0]);if (!ch[u][1]) return (void)(rt[u]=rt[ch[u][0]]);dfs(ch[u][1]);rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,m,0,0,p[u]);
}
int ans[MAXN];
int main()
{int n=read();for (int i=1;i<=n;i++){int f=read();if (!f) continue;if (!ch[f][0]) ch[f][0]=i;else ch[f][1]=i;}int t=qpow(10000,MOD-2);for (int i=1;i<=n;i++){p[i]=read();if (ch[i][0]) p[i]=(ll)p[i]*t%MOD;else v[++m]=p[i];}sort(v+1,v+m+1);for (int i=1;i<=n;i++)if (!ch[i][0])p[i]=lower_bound(v+1,v+m+1,p[i])-v;dfs(1);int* p=ans+1;getans(rt[1],1,m,p);int res=0;for (int i=1;i<=m;i++) res=(res+(ll)i*v[i]%MOD*ans[i]%MOD*ans[i])%MOD;printf("%d\n",res);return 0;
}