前言
一道看起来很毒瘤但其实还算小清新的题?
理解后感觉其实并没有那么难。
暴力分非常足,好评。
奇妙的线段树合并技巧增加了。
解析
解法1
你是怎么手玩的样例一?
大部分(比如我)都是容斥吧。
把手玩的方法搬到代码上就得到了一个从数据范围来看出题人也很想让我们写的 O(2mpoly)O(2^mpoly)O(2mpoly) 做法。
仔细想想可以树剖线段树维护没有被覆盖的边的个数,dfs的时候顺便修改,时间复杂度 O(2mlog2n)O(2^m\log ^2n)O(2mlog2n)。
期望得分 32-40 分。
不太懂为什么网上都是写的 O(2mmlog2n)O(2^mm\log ^2n)O(2mmlog2n) 啊
解法2
你不会真的去写解法1了吧。
这个题的形式一看就非常dp。
设计 fx,df_{x,d}fx,d 表示节点 xxx 的子树处理完毕,没有解决的返祖链的祖先的最大深度为 ddd 的方案数。
容易用类似树上背包的形式转移,时间复杂度 O(nmin(n,m))O(n\min(n,m))O(nmin(n,m)),期望得分 646464 分。
然而我写的破玩意做不了完全二叉树,结果就只有 565656 分了(悲)。
(前缀和优化的做法写起来更加好写,而且才能引入后面的正解)
解法3
观察dp转移,发现转移非常有规律,似乎可以维护个线段树之类的东西。
然后垃圾的我根本想不到线段树合并,只能想到重链剖分后每个重链维护个线段树,下标只存子树内有的深度限制(类似于离散化),这样所有线段树的叶子个数之和是 O(nlogn)O(n\log n)O(nlogn),重链处理完后暴力向链头父亲合并,大概是 O(nlog2n)O(n\log^2n)O(nlog2n) 的。
期望得分 88 分。
实现起来非常谔谔,所以只是胡了胡,并没有写。
如果假了的话轻喷(逃
华丽的分割线。
然后我的水平就卡在这个地方了。
接下来就是贺题解环节。
解法4
把dp转移写成前缀和形式:
dpx,i=dpx,i′∑j=0depxdps,j+dpx,i′∑j=0idps,j+dps,i∑j=0i−1dpx,jdp_{x,i}=dp'_{x,i}\sum_{j=0}^{dep_x}dp_{s,j}+dp'_{x,i}\sum_{j=0}^{i}dp_{s,j}+dp_{s,i}\sum_{j=0}^{i-1}dp_{x,j}dpx,i=dpx,i′j=0∑depxdps,j+dpx,i′j=0∑idps,j+dps,ij=0∑i−1dpx,j
也就是:
dpx,i=dpx,i′(sums,depx+sums,i)+dps,isumx,i−1dp_{x,i}=dp'_{x,i}(sum_{s,dep_x}+sum_{s,i})+dp_{s,i}sum_{x,i-1}dpx,i=dpx,i′(sums,depx+sums,i)+dps,isumx,i−1
注意到除了 sums,depxsum_{s,dep_x}sums,depx 是一个常量,其它的东西都与下标密切相关。
考虑线段树合并。
令 s1=sums,depx+sums,i,s2=sumx,i−1s1=sum_{s,dep_x}+sum_{s,i},s2=sum_{x,i-1}s1=sums,depx+sums,i,s2=sumx,i−1。
在合并时先递归左子树,在递归右子树,不断更新 s1,s2s1,s2s1,s2 即可。
我觉得你看看代码可能比听我讲更容易理解。
ll s1,s2;
int merge(int x,int y,int l,int r){if(!x&&!y) return 0;else if(!x||!y){if(x){add(s2,tr[x].sum);Mul(x,s1);return x;}else{add(s1,tr[y].sum);Mul(y,s2);return y;}}else if(l==r){int now=New();int a=tr[x].sum,b=tr[y].sum;add(s1,b);tr[now].sum=(s1*a+s2*b)%mod;;add(s2,a);return now;}pushdown(x);pushdown(y);int now=New(); tr[now].ls=merge(tr[x].ls,tr[y].ls,l,mid);tr[now].rs=merge(tr[x].rs,tr[y].rs,mid+1,r);pushup(now);return now;
}
然后这道题就做完啦。
时空复杂度 O(nlogn)O(n\log n)O(nlogn),期望得分 100 分。
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define ok debug("ok\n")
inline ll read(){ll x(0),f(1);char c=getchar();while(!isdigit(c)) {if(c=='-')f=-1;c=getchar();}while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}return x*f;
}const int N=5e5+100;
const int inf=1e9+100;
const int mod=998244353;
const bool Flag=0;
#define add(x,y) (((x)+=(y))>=mod&&((x)-=mod))int n,m;#define mid ((l+r)>>1)
struct node{int ls,rs;ll sum,mul=1;
}tr[N*40];
int tot;
vector<int>e[N];
inline int New(){tr[++tot].mul=1;return tot;
}
inline void Mul(int x,int w){if(x){tr[x].sum=tr[x].sum*w%mod;tr[x].mul=tr[x].mul*w%mod;}return;
}
inline void pushdown(int x){if(tr[x].mul!=1){Mul(tr[x].ls,tr[x].mul);Mul(tr[x].rs,tr[x].mul);tr[x].mul=1;}return;
}
inline void pushup(int k){tr[k].sum=(tr[tr[k].ls].sum+tr[tr[k].rs].sum)%mod;
}
ll ask(int k,int l,int r,int x,int y){if(!k) return 0;if(x<=l&&r<=y) return tr[k].sum;pushdown(k);ll res(0);if(x<=mid) add(res,ask(tr[k].ls,l,mid,x,y));if(y>mid) add(res,ask(tr[k].rs,mid+1,r,x,y));return res;
}
ll s1,s2;
int merge(int x,int y,int l,int r){//if(l>lim) return 0;if(!x&&!y) return 0;else if(!x||!y){if(x){add(s2,tr[x].sum);Mul(x,s1);return x;}else{add(s1,tr[y].sum);Mul(y,s2);//printf("(%d %d) s2=%lld sum=%lld\n",l,r,s2,tr[y].sum);return y;}}else if(l==r){int now=New();int a=tr[x].sum,b=tr[y].sum;add(s1,b);tr[now].sum=(s1*a+s2*b)%mod;;add(s2,a);return now;}pushdown(x);pushdown(y);int now=New(); tr[now].ls=merge(tr[x].ls,tr[y].ls,l,mid);tr[now].rs=merge(tr[x].rs,tr[y].rs,mid+1,r);pushup(now);return now;
}
void ins(int &k,int l,int r,int p){if(!k) k=New();if(l==r){tr[k].sum++;return;}pushdown(k);if(p<=mid) ins(tr[k].ls,l,mid,p);else ins(tr[k].rs,mid+1,r,p);pushup(k);
}
void print(int k,int l,int r){if(!k) return;if(l==r){printf("d=%d dp=%lld\n",l,tr[k].sum);return;}pushdown(k);print(tr[k].ls,l,mid);print(tr[k].rs,mid+1,r);return;
}
int rt[N];int dep[N],d[N];
void init(int x,int fa){dep[x]=dep[fa]+1;for(int to:e[x]){if(to==fa) continue;init(to,x);}return;
}
void dfs(int x,int fa){for(int to:e[x]){if(to==fa) continue;dfs(to,x);}ins(rt[x],0,n,d[x]);for(int to:e[x]){if(to==fa) continue;s1=ask(rt[to],0,n,0,dep[x]);s2=0;rt[x]=merge(rt[x],rt[to],0,n);//printf("----------%d->%d\n",x,to);//print(rt[x],0,n);//puts("");}return;
}signed main(){
#ifndef ONLINE_JUDGEfreopen("a.in","r",stdin);freopen("a.out","w",stdout);
#endifn=read();for(int i=1;i<n;i++){int x=read(),y=read();e[x].push_back(y);e[y].push_back(x);}init(1,0);m=read();for(int i=1;i<=m;i++){int u=read(),v=read();d[v]=max(d[v],dep[u]);}dfs(1,0);printf("%lld\n",ask(rt[1],0,n,0,0));return 0;
}