题意:给一棵 nnn 个点的树,每条边需要染成黑白两种颜色中的一种。给出 mmm 个条件,每个条件给出 u,vu,vu,v,其中 uuu 是 vvv 的祖先,要求 uuu 到 vvv 的链上至少一条黑边。求方案数 模 998244353998244353998244353。
n,m≤5×105n,m\leq 5\times 10^5n,m≤5×105
这个dp想了一上午
对于树上的一个点,考虑子树内有关的所有限制,唯一不好处理的是超出子树的部分,而这部分只需要考虑超出最短的。
定义: dp(u,k)dp(u,k)dp(u,k) 表示有多少种确定 uuu 子树内的边的颜色的方案,使得所有下端点在 uuu 子树内并且尚未满足的条件 的上端点的深度最大值恰好为 kkk。如果所有上述条件都满足, k=0k=0k=0。
人话翻译:
考虑 uuu 子树内能影响到的条件,分为下列两种:
- 上端点在子树内(显然下端点就在子树内了)。如果这种条件没有满足,就永远不可能满足了,这个时候上面的定义表现为 k≥depuk\geq dep_uk≥depu,后面可以看到这部分状态是无用的。
- 上端点是 uuu 的严格祖先,下端点在 uuu 子树内,且 uuu 到下端点这段没有黑边。此时就需要上端点到 uuu 有黑边。如果这样的条件的上端点的最大深度为 kkk,那么所有条件成立当且仅当 uuu 深度为 kkk 的祖先到 uuu 有一条黑边,处理方式后述。
进行一次 dfs,每个点 uuu 先假设它没有儿子,即让 dp(u,x)=1dp(u,x)=1dp(u,x)=1,其中 xxx 为所有下端点为 uuu 的条件的上端点的最大深度。
然后依次突然加入每个儿子,设儿子为 vvv,得到新的 dp 数组为 dp′dp'dp′
考虑连接儿子的这条边是黑边还是白边。
如果是黑边,对于 uuu 来说,从 vvv 子树内来的条件就全部满足了(当然要原来有机会满足),但 uuu 原来不满足的还是不满足。即
dp(u,k)∑i=0depudp(v,i)dp(u,k)\sum_{i=0}^{dep_u}dp(v,i)dp(u,k)i=0∑depudp(v,i)
如果是白边,那么要同时满足两边的深度限制,即
∑max(i,j)=kdp(u,i)dp(v,j)\sum_{\max(i,j)=k}dp(u,i)dp(v,j)max(i,j)=k∑dp(u,i)dp(v,j)
整理一下
dp′(u,k)=dp(u,k)∑i=0depudp(v,i)+dp(u,k)∑i=0kdp(v,i)+dp(v,k)∑i=0k−1dp(u,i)dp'(u,k)=dp(u,k)\sum_{i=0}^{dep_u}dp(v,i)+dp(u,k)\sum_{i=0}^kdp(v,i)+dp(v,k)\sum_{i=0}^{k-1}dp(u,i)dp′(u,k)=dp(u,k)i=0∑depudp(v,i)+dp(u,k)i=0∑kdp(v,i)+dp(v,k)i=0∑k−1dp(u,i)
长这样子的式子都可以考虑线段树合并。
∑i=0depudp(v,i)\sum_{i=0}^{dep_u}dp(v,i)∑i=0depudp(v,i) 是个常数,先算出来。
合并的时候顺便维护左边遍历过的结点的和,如果一边的结点为空,用维护的和给另一边的结点打乘法标记。递归到叶结点了再处理求和符号的边界情况。
注意维护的这个和是合并前的,要先维护再打标记。可以在递归的时候传引用。
复杂度 O(nlogn+m)O(n\log n+m)O(nlogn+m)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#define MAXN 500005
using namespace std;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
typedef long long ll;
const int MOD=998244353;
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
int n;
int ch[MAXN<<5][2],sum[MAXN<<5],mul[MAXN<<5],cnt;
inline void update(int x){sum[x]=add(sum[ch[x][0]],sum[ch[x][1]]);}
inline void pushlzy(int x,ll v){sum[x]=sum[x]*v%MOD,mul[x]=mul[x]*v%MOD;}
inline void pushdown(int x)
{if (mul[x]!=1){if (ch[x][0]) pushlzy(ch[x][0],mul[x]);if (ch[x][1]) pushlzy(ch[x][1],mul[x]);mul[x]=1;}
}
void modify(int& x,int l,int r,int k)
{if (!x) mul[x=++cnt]=1;if (l==r) return (void)(++sum[x]);int mid=(l+r)>>1;if (k<=mid) modify(ch[x][0],l,mid,k);else modify(ch[x][1],mid+1,r,k);update(x);
}
int query(int x,int l,int r,int ql,int qr)
{if (!x) return 0;if (ql<=l&&r<=qr) return sum[x];if (qr<l||r<ql) return 0;pushdown(x);int mid=(l+r)>>1;return add(query(ch[x][0],l,mid,ql,qr),query(ch[x][1],mid+1,r,ql,qr));
}
int merge(int x,int y,int l,int r,int& xsum,int& ysum)
{if (!x&&!y) return 0;if (!x) return ysum=add(ysum,sum[y]),pushlzy(y,xsum),y;if (!y) return xsum=add(xsum,sum[x]),pushlzy(x,ysum),x;if (l==r){ysum=add(ysum,sum[y]);int t=sum[x];sum[x]=((ll)sum[x]*ysum+(ll)xsum*sum[y])%MOD;xsum=add(xsum,t);return x; } pushdown(x),pushdown(y);int mid=(l+r)>>1;ch[x][0]=merge(ch[x][0],ch[y][0],l,mid,xsum,ysum);ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r,xsum,ysum);update(x);return x;
}
vector<int> e[MAXN],lis[MAXN];
int dep[MAXN],rt[MAXN];
void dfs(int u)
{int mx=0;for (int i=0;i<(int)lis[u].size();i++) mx=max(mx,dep[lis[u][i]]);modify(rt[u],0,n,mx);for (int i=0;i<(int)e[u].size();i++)if (!dep[e[u][i]]){dep[e[u][i]]=dep[u]+1;dfs(e[u][i]);int xsum=0,ysum=query(rt[e[u][i]],0,n,0,dep[u]);rt[u]=merge(rt[u],rt[e[u][i]],0,n,xsum,ysum);}
}
int main()
{n=read();for (int i=1;i<n;i++){int u,v;u=read(),v=read();e[u].push_back(v),e[v].push_back(u);}int m=read();while (m--){int u,v;u=read(),v=read();lis[v].push_back(u);}dfs(dep[1]=1);printf("%d\n",query(rt[1],0,n,0,0));return 0;
}