文章目录
- 题目
- 题解
- code1(NTT)
- code2(EGF+卷积)
题目
大中锋的学院要组织学生参观博物馆,要求学生们在博物馆中排成一队进行参观。他的同学可以分为四类:一部分最喜欢唱、一部分最喜欢跳、一部分最喜欢rap,还有一部分最喜欢篮球。如果队列中k,k+1,k+2,k+3k,k + 1,k + 2,k + 3k,k+1,k+2,k+3位置上的同学依次,最喜欢唱、最喜欢跳、最喜欢rap、最喜欢篮球,那么他们就会聚在一起讨论蔡徐坤。大中锋不希望这种事情发生,因为这会使得队伍显得很乱。大中锋想知道有多少种排队的方法,不会有学生聚在一起讨论蔡徐坤。两个学生队伍被认为是不同的,当且仅当两个队伍中至少有一个位置上的学生的喜好不同。由于合法的队伍可能会有很多种,种类数对998244353取模。
输入格式
输入数据只有一行。每行5个整数,第一个整数n,代表大中锋的学院要组织多少人去参观博物馆。接下来四个整数a、b、c、d,分别代表学生中最喜欢唱的人数、最喜欢跳的人数、最喜欢rap的人数和最喜欢篮球的人数。保证a+b+c+d≥na+b+c+d \ge na+b+c+d≥n。
输出格式
每组数据输出一个整数,代表你可以安排出多少种不同的学生队伍,使得队伍中没有学生聚在一起讨论蔡徐坤。结果对998244353998244353998244353取模。
输入输出样例
输入
4 4 3 2 1
输出
174
输入
996 208 221 132 442
输出
442572391
说明/提示
对于20%的数据,有n=a=b=c=d≤500n=a=b=c=d\le500n=a=b=c=d≤500
对于100%的数据,有n≤1000n \le 1000n≤1000 , a,b,c,d≤500a, b, c, d \le 500a,b,c,d≤500
题解
考虑枚举有iii堆人在讨论cxk小姐姐,那么被cxk迷倒的人就有4i4i4i个,且每一堆人是挨在一起的,
所以我们可以把这4i4i4i个人缩成iii个群体,每一个群体可以展开成为444人
那么现在就只用考虑剩下的n−3in-3in−3i人,放置的方案数就是Cn−3iiC_{n-3i}^iCn−3ii
简单用整体法证明一下:
在n−3in-3in−3i人中,要选iii个拿来讨论,n−4in-4in−4i个不讨论,方案数是对应的Cn−3ii=Cn−3in−4iC_{n-3i}^i=C_{n-3i}^{n-4i}Cn−3ii=Cn−3in−4i
也可以理解为在n−3in-3in−3i个空格里面找iii个起点开始讨论小姐姐
接着枚举好了iii及它们可能出现的地方Cn−3iiC_{n-3i}^iCn−3ii,我们就要去考虑剩下n−4in-4in−4i人的排列,可以乱来
(n−4i)!(n-4i)!(n−4i)!
但是attentionattentionattention,如果看了我的上一篇指数型生成函数专练,就会与这里有一点相通
题目说的方案数不同的要求是至少有一个位置上的学生的喜好不同,而不是学生本人不同
所以我们要除掉喜好篮球,跳舞,唱歌,rap本身内部的乱拍,因为从爱好上来看是看不出来排列不同的
设a,b,c,da,b,c,da,b,c,d分别代表学生中最喜欢唱的人数、最喜欢跳的人数、最喜欢rap的人数和最喜欢篮球的总人数,我们考虑的是剩下的不讨论cxk姐姐的学生乱排,所以要剪掉外层所枚举的去参与讨论的人数
(n−4i)!(a−i)!(n−i)!(c−i)!(d−i)!\frac{(n-4i)!}{(a-i)!(n-i)!(c-i)!(d-i)!}(a−i)!(n−i)!(c−i)!(d−i)!(n−4i)!
最后写出每种喜好的生成函数,以喜欢篮球的人为例:
∑i=0cxii!\sum_{i=0}^{c}\frac{x^i}{i!}i=0∑ci!xi
把四种爱好卷起来,卷出最后乘积的第n−4in-4in−4i项就是我们需要除掉的东西,乘上(n−4i)!(n-4i)!(n−4i)!就是真正的乱排数
因为带取模,所以是用NTTNTTNTT跑,除的话就要变成乘逆元,所以我们可以在卷的时候就直接卷逆元
但是我们又发现虽然假设的是iii堆人在讨论,但是统计答案的时候却把≥i\ge i≥i堆人讨论的情况都统计了
而且当统计i=1i=1i=1时,会算两遍至少两堆人讨论的方案,三遍至少三堆人讨论的方案…在统计i=2i=2i=2时,会算三遍至少三堆人讨论的方案…
当统计iii的方案时,会多算CjiC_j^iCji次至少jjj堆人讨论的方案,所以我们可以用容斥来解决
总结一下答案应该是
∑i=0min(−1)i∗Cn−3ii∗(n−4i)!(a−i)!(n−i)!(c−i)!(d−i)!\sum_{i=0}^{min}(-1)^i*C_{n-3i}^i*\frac{(n-4i)!}{(a-i)!(n-i)!(c-i)!(d-i)!}i=0∑min(−1)i∗Cn−3ii∗(a−i)!(n−i)!(c−i)!(d−i)!(n−4i)!
code1(NTT)
#include <cstdio>
#include <iostream>
using namespace std;
#define mod 998244353
#define LL long long
#define MAXN 10005
int n, anum, bnum, cnum, dnum;
LL pi, ni, result;
LL a[MAXN], b[MAXN], c[MAXN], d[MAXN], rev[MAXN], fac[MAXN], Invfac[MAXN];LL qkpow ( LL x, LL y ) {LL ans = 1;while ( y ) {if ( y & 1 )ans = ans * x % mod;x = x * x % mod;y >>= 1;}return ans;
}LL C ( int n, int m ) {return fac[n] * Invfac[m] % mod * Invfac[n - m] % mod;
}void NTT ( LL *c, LL limit, LL f ) {for ( LL i = 0;i < limit;i ++ )if ( i < rev[i] )swap ( c[i], c[rev[i]] );for ( LL i = 1;i < limit;i <<= 1 ) {LL omega = qkpow ( f == 1 ? pi : ni, ( mod - 1 ) / ( i << 1 ) );for ( LL j = 0;j < limit;j += ( i << 1 ) ) {LL w = 1;for ( LL k = 0;k < i;k ++, w = w * omega % mod ) {LL x = c[k + j], y = w * c[i + j + k] % mod;c[k + j] = ( x + y ) % mod;c[k + j + i] = ( x - y + mod ) % mod;}}}LL inv = qkpow ( limit, mod - 2 );if ( f == -1 )for ( LL i = 0;i < limit;i ++ )c[i] = c[i] * inv % mod;
}void init () {Invfac[0] = fac[0] = 1;for ( int i = 1;i <= n;i ++ )fac[i] = fac[i - 1] * i % mod;Invfac[n] = qkpow ( fac[n], mod - 2 );for ( int i = n - 1;i;i -- )Invfac[i] = Invfac[i + 1] * ( i + 1 ) % mod;
}LL solve ( int n, int A, int B, int C, int D ) {LL len = 1, l = 0;while ( len < ( ( A + B + C + D ) << 1 ) ) {len <<= 1;l ++;}for ( LL i = 0;i < len;i ++ )rev[i] = ( rev[i >> 1] >> 1 ) | ( ( i & 1 ) << ( l - 1 ) );for ( int i = 0;i < len;i ++ )a[i] = ( i <= A ? Invfac[i] : 0 );for ( int i = 0;i < len;i ++ )b[i] = ( i <= B ? Invfac[i] : 0 );for ( int i = 0;i < len;i ++ )c[i] = ( i <= C ? Invfac[i] : 0 );for ( int i = 0;i < len;i ++ )d[i] = ( i <= D ? Invfac[i] : 0 );NTT ( a, len, 1 );NTT ( b, len, 1 );NTT ( c, len, 1 );NTT ( d, len, 1 );for ( int i = 0;i < len;i ++ )a[i] = a[i] * b[i] % mod * c[i] % mod * d[i] % mod;NTT ( a, len, -1 );return a[n] * fac[n] % mod;
}int main() {pi = 3;ni = mod / pi + 1;scanf ( "%d %d %d %d %d", &n, &anum, &bnum, &cnum, &dnum );init();int k = min ( min ( min ( min ( anum, bnum ), cnum ), dnum ), n / 4 );anum -= k;bnum -= k;cnum -= k;dnum -= k;result = 0;for ( k;k >= 0;k -- ) {LL tmp = C ( n - 3 * k, k ) % mod * solve ( n - 4 * k, anum, bnum, cnum, dnum ) % mod;anum ++;bnum ++;cnum ++;dnum ++;result = ( result + ( ( k & 1 ) ? mod - tmp : tmp ) ) % mod;}printf ( "%lld\n", result );return 0;
}
code2(EGF+卷积)
#include <cstdio>
#include <iostream>
using namespace std;
#define mod 998244353
#define LL long long
#define MAXN 1005
int n, a, b, c, d;
LL result;
LL foldAB[MAXN], foldCD[MAXN], fac[MAXN], Invfac[MAXN];LL qkpow ( LL x, LL y ) {LL ans = 1;while ( y ) {if ( y & 1 )ans = ans * x % mod;x = x * x % mod;y >>= 1;}return ans;
}LL C ( int n, int m ) {return fac[n] * Invfac[m] % mod * Invfac[n - m] % mod;
}void Fac () {Invfac[0] = fac[0] = 1;for ( int i = 1;i <= n;i ++ )fac[i] = fac[i - 1] * i % mod;Invfac[n] = qkpow ( fac[n], mod - 2 );for ( int i = n - 1;i;i -- )Invfac[i] = Invfac[i + 1] * ( i + 1 ) % mod;
}int main() {scanf ( "%d %d %d %d %d", &n, &a, &b, &c, &d );Fac();int k = min ( min ( min ( min ( a, b ), c ), d), n / 4 );a -= k;b -= k;c -= k;d -= k;for ( int i = 0;i <= a;i ++ )for ( int j = 0;j <= b;j ++ )foldAB[i + j] = ( foldAB[i + j] + Invfac[i] * Invfac[j] % mod ) % mod;for ( int i = 0;i <= c;i ++ )for ( int j = 0;j <= d;j ++ )foldCD[i + j] = ( foldCD[i + j] + Invfac[i] * Invfac[j] % mod ) % mod;result = 0;for ( k;k >= 0;k -- ) {LL ans = 0;int tp = n - 4 * k;for ( int i = 0;i <= tp;i ++ )ans = ( ans + foldAB[i] * foldCD[tp - i] % mod ) % mod;ans = ans * fac[tp] % mod * C ( n - k * 3, k ) % mod;result = ( result + ( ( k & 1 ) ? mod - ans : ans ) ) % mod;a ++;b ++;c ++;d ++;for ( int i = 0;i <= a;i ++ )foldAB[i + b] = ( foldAB[i + b] + Invfac[i] * Invfac[b] % mod ) % mod;for ( int i = 0;i <= b;i ++ )foldAB[i + a] = ( foldAB[i + a] + Invfac[i] * Invfac[a] % mod ) % mod;foldAB[a + b] = ( foldAB[a + b] - Invfac[a] * Invfac[b] % mod + mod ) % mod;for ( int i = 0;i <= c;i ++ )foldCD[i + d] = ( foldCD[i + d] + Invfac[i] * Invfac[d] % mod ) % mod;for ( int i = 0;i <= d;i ++ )foldCD[i + c] = ( foldCD[i + c] + Invfac[i] * Invfac[c] % mod ) % mod;foldCD[c + d] = ( foldCD[c + d] - Invfac[c] * Invfac[d] % mod ) % mod;}printf ( "%lld\n", result );return 0;
}
代码或者思路有任何问题欢迎评论