莓良心
- problem
- solution
- code
problem
莓在执行任务时,收集到了 nnn 份岩浆能源,其中第 iii 份的能量值是 wiw_iwi ,她
决定将它们分成恰好 kkk 组带回基地,每一组都要有至少 111 份能源。
每一组能源会对运输设备产生负荷值,若该组有 xxx 份能源,这 xxx 份能源能
量值之和为 yyy , 则产生的负荷值为 x×yx × yx×y 。
每种分组方案产生的负荷是每一组能源产生的负荷值总和,莓想知道所有可
能的分组方案产生的负荷之和对 998244353
取模的结果 。
k≤n≤1e6k\le n\le 1e6k≤n≤1e6。
solution
将 nnn 个不同物品分成 kkk 组,球有标号,盒子无标号,显然是第二类斯特林数。
考场上倒是把公式都推出来了,但是没反应过来是第二类斯特林。。。但这不重要,主要是没想到优化计算方式。
暴力求第二类斯特林数:fi,j=fi−1,j∗j+fi−1,j−1f_{i,j}=f_{i-1,j}*j+f_{i-1,j-1}fi,j=fi−1,j∗j+fi−1,j−1。
容斥求第二类斯特林数:∑i=0k(−1)i(ki)(k−i)nk!\frac{\sum_{i=0}^k(-1)^i\binom{k}{i}(k-i)^n}{k!}k!∑i=0k(−1)i(ik)(k−i)n。
考虑将每个数单独拆开计算贡献,暴力枚举其所在组的个数,然后乘以剩下数分成 k−1k-1k−1 组的方案数,时间复杂度是 O(n2)O(n^2)O(n2) 的。
我就卡在这里优化不了了。
考虑若 u,vu,vu,v 分在一组,则对答案有 wu+wvw_u+w_vwu+wv 的贡献,即 ans=∑u=1nwu⋅{nk}+∑u≠v(wu+wv)⋅{n−1k}ans=\sum_{u=1}^nw_u·\left\{\begin{matrix}n\\k\end{matrix}\right\}+\sum_{u\neq v}(w_u+w_v)·\left\{\begin{matrix}n-1\\k\end{matrix}\right\}ans=∑u=1nwu⋅{nk}+∑u=v(wu+wv)⋅{n−1k}
⇒ans=∑wi⋅({nk}+(n−1){n−1k})\Rightarrow ans=\sum_{w_i}·\Bigg(\left\{\begin{matrix}n\\k\end{matrix}\right\}+(n-1)\left\{\begin{matrix}n-1\\k\end{matrix}\right\}\Bigg)⇒ans=∑wi⋅({nk}+(n−1){n−1k})
线性筛预处理 (k−i)n(k-i)^n(k−i)n,时间复杂度为 O(n)O(n)O(n)。
code
#include <cstdio>
#include <iostream>
using namespace std;
#define Pair pair < int, int >
#define int long long
#define mod 998244353
#define maxn 1000005
int n, k, cnt;
int fac[maxn], inv[maxn], Pow1[maxn], Pow2[maxn], prime[maxn];
bool vis[maxn];int qkpow( int x, int y ) {int ans = 1;while( y ) {if( y & 1 ) ans = ans * x % mod;x = x * x % mod;y >>= 1;}return ans;
}void init() {fac[0] = inv[0] = 1;for( int i = 1;i <= n;i ++ ) fac[i] = fac[i - 1] * i % mod;inv[n] = qkpow( fac[n], mod - 2 );for( int i = n - 1;i;i -- ) inv[i] = inv[i + 1] * ( i + 1 ) % mod;
}void sieve() {Pow1[1] = Pow2[1] = 1;for( int i = 2;i <= n;i ++ ) {if( ! vis[i] ) {prime[++ cnt] = i;Pow1[i] = qkpow( i, n );Pow2[i] = qkpow( i, n - 1 );}for( int j = 1;j <= cnt and i * prime[j] <= n;j ++ ) {vis[i * prime[j]] = 1;Pow1[i * prime[j]] = Pow1[i] * Pow1[prime[j]] % mod;Pow2[i * prime[j]] = Pow2[i] * Pow2[prime[j]] % mod;if( i % prime[j] == 0 ) break;}}
}int C( int n, int m ) { return fac[n] * inv[m] % mod * inv[n - m] % mod; }Pair calc() {int ans1 = 0, ans2 = 0;for( int i = 0;i <= k;i ++ ) {int t = ( i & 1 ) ? -1 : 1;ans1 = ( ans1 + t * C( k, i ) * Pow1[k - i] % mod ) % mod;ans2 = ( ans2 + t * C( k, i ) * Pow2[k - i] % mod ) % mod;}ans1 = ans1 * inv[k] % mod;ans2 = ans2 * inv[k] % mod;return { ( ans1 + mod ) % mod, ( ans2 + mod ) % mod };
}signed main() {freopen( "ichigo.in", "r", stdin );freopen( "ichigo.out", "w", stdout );scanf( "%lld %lld", &n, &k );int sum = 0;for( int i = 1, w;i <= n;i ++ ) scanf( "%lld", &w ), sum = ( sum + w ) % mod;init();sieve();Pair ans = calc();printf( "%lld\n", ( ans.first + ( n - 1 ) * ans.second ) % mod * sum % mod );return 0;
}