正题
题目链接:https://www.luogu.com.cn/problem/P4831
题目大意
n∗mn*mn∗m的网格上放置2n2n2n个炮,要求互不能攻击。
数据满足n≤m≤2000n\leq m\leq 2000n≤m≤2000或n≤m≤105n\leq m\leq 10^5n≤m≤105且m−n≤10m-n\leq 10m−n≤10
解题思路
每行每列最多222个炮,所以模型可以转换为求有多少种方案满足:1∼n1\sim n1∼n的数字各两个填在mmm个无序2元组(可以有空),并且每个组中的数互不相同。
直接硬钢推式子很难做(好像可以推到生成函数那边去),考虑一下巧妙的方法。
设g(n,m)g(n,m)g(n,m)表示2n2n2n个格子填下1∼m1\sim m1∼m中的数字各两个的方案。这个的方案就是
g(n,m)=(2n)!∑i=0min{n,m−n}(mn−i)(m−n+i2i)2n−ig(n,m)=(2n)!\sum_{i=0}^{min\{n,m-n\}}\frac{\binom{m}{n-i}\binom{m-n+i}{2i}}{2^{n-i}}g(n,m)=(2n)!i=0∑min{n,m−n}2n−i(n−im)(2im−n+i)
表示mmm个数组中选出n−in-in−i对相同的来填,剩下的里面选出2i2i2i个单独的来填,然后交换导致重复的情况有2n−i2^{n-i}2n−i种,要除去。
这个式子就和m−nm-nm−n有很大的关系了。
将这个式子和答案联系起来,设f(n,m)f(n,m)f(n,m)表示答案,那么有
g(n,m)=∑i=0n2n−i(ni)Pmif(n−i,m−i)g(n,m)=\sum_{i=0}^n2^{n-i}\binom{n}{i}P_{m}^if(n-i,m-i)g(n,m)=i=0∑n2n−i(in)Pmif(n−i,m−i)
因为f(n,m)f(n,m)f(n,m)是不同无序二元组,(ni)Pmi\binom{n}{i}P_{m}^i(in)Pmi表示nnn对中选出iii个是相同的填入,剩下的都是不同的方案就是f(n−i,m−i)f(n-i,m-i)f(n−i,m−i),然后因为ggg是统计有序二元组的,所以2n2^n2n表示随意交换。
2nf(n,m)=∑i=0n(ni)Pmig(n−i,m−i)2^{n}f(n,m)=\sum_{i=0}^n\binom{n}{i}P_{m}^ig(n-i,m-i)2nf(n,m)=i=0∑n(in)Pmig(n−i,m−i)
f(n,m)=12n∑i=0n(ni)Pmig(n−i,m−i)f(n,m)=\frac{1}{2^n}\sum_{i=0}^n\binom{n}{i}P_{m}^ig(n-i,m-i)f(n,m)=2n1i=0∑n(in)Pmig(n−i,m−i)
然后直接计算就好了,时间复杂度O(n×min{n,m−n})O(n\times min\{n,m-n\})O(n×min{n,m−n})
Update:Update:Update:修改了反演前的公式错误
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=2e5+10,P=998244353,inv2=(P+1)/2;
ll n,m,fac[N],inv[N],pv2[N],ans;
ll power(ll x,ll b){ll ans=1;while(b){if(b&1)ans=ans*x%P;x=x*x%P;b>>=1;}return ans;
}
ll C(ll n,ll m)
{return fac[n]*inv[m]%P*inv[n-m]%P;}
ll A(ll n,ll m)
{return fac[n]*inv[n-m]%P;}
ll g(ll n,ll m){ll ans=0;for(ll i=0;i<=min(n,m-n);i++)(ans+=C(m,n-i)*C(m-n+i,2*i)%P*pv2[n-i]%P)%=P;return ans*fac[2*n]%P;
}
signed main()
{scanf("%lld%lld",&n,&m);inv[1]=1;for(ll i=2;i<N;i++)inv[i]=P-(P/i)*inv[P%i]%P;fac[0]=inv[0]=pv2[0]=1;for(ll i=1;i<N;i++)fac[i]=fac[i-1]*i%P,inv[i]=inv[i-1]*inv[i]%P,pv2[i]=pv2[i-1]*inv2%P;for(ll i=0,p=1;i<=n;i++,p=-p)(ans+=p*(g(n-i,m-i)*C(n,i)%P*A(m,i)%P))%=P;printf("%lld\n",(ans*pv2[n]%P+P)%P);return 0;
}