BZOJ4589. Hard Nim
题意:
Claris和NanoApe在玩石子游戏,他们有n堆石子,规则如下:
- Claris和NanoApe两个人轮流拿石子,Claris先拿。
- 每次只能从一堆中取若干个,可将一堆全取走,但不可不取,拿到最后1颗石子的人获胜。
不同的初始局面,决定了最终的获胜者,有些局面下先拿的Claris会赢,其余的局面Claris会负。
Claris很好奇,如果这n堆石子满足每堆石子的初始数量是不超过m的质数,而且他们都会按照最优策略玩游戏,那么NanoApe能获胜的局面有多少种。
由于答案可能很大,你只需要给出答案对10^9+7取模的值。
题解:
首先要知道nim先手获胜条件是所有堆的数量异或为0
那么问题就抽象为:n个数,每个数取值范围是[2,m]中的质数,可以取重,问一共有多少种方案,使得这n个数异或为0
对于每一个2~m的质数p,我们都可以取,用数组b来存就是b[p]=1,质数都被标记为1
CkC_kCk表示异或和为k的方案数有多少种,b[i]=1说明第i位是质数,否则不是
有:Ck=∑i1⊕i2⊕.....⊕in=kbi1×bi2×....×binC_k=\sum_{i_1⊕i_2⊕.....⊕i_n=k}b_{i_1}×b_{i_2}×....×b_{i_n}Ck=∑i1⊕i2⊕.....⊕in=kbi1×bi2×....×bin
因为我们要求异或和为0,所以k等于0
式子就是:
求n个序列,其中i1⊕i2⊕.....⊕in=0i_1⊕i_2⊕.....⊕i_n=0i1⊕i2⊕.....⊕in=0
C0=∑i1⊕i2⊕.....⊕in=0bi1×bi2×....×binC_0=\sum_{i_1⊕i_2⊕.....⊕i_n=0}b_{i_1}×b_{i_2}×....×b_{i_n}C0=∑i1⊕i2⊕.....⊕in=0bi1×bi2×....×bin
很明显这就是FWT能解决的问题
但是本题中n很大,n<=1e9,如果直接乘一定会超时,我们仔细观察,这个式子就是n个b序列,而所有b序列都是一样的,相当于其实是n个b相乘,也就是bnb^nbn,所以可以通过快速相乘,结合FWT模板理解一下就懂了
代码:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <map>
#include <queue>
using namespace std;
typedef long long ll;
typedef int itn;
typedef pair<int, int>PII;
const int N = 5e5 + 7, mod = 1e9 + 7;
const double PI = acos(-1.0);int n, m, limit, t;
int a[N], b[N];
bool vis[N];
int primes[N], cnt;int qpow(int a, int b)
{int res = 1;while(b) {if(b & 1) res = 1ll * res * a % mod;a = 1ll * a * a % mod;b >>= 1;}return res;
}int inv2 = qpow(2, mod - 2);void init(int n)
{for(int i = 2; i <= n; ++ i) {if(vis[i] == 0) primes[ ++ cnt] = i;for(int j = 1; j <= cnt && i * primes[j] <= n; ++ j) {vis[i * primes[j]] = true;if(i % primes[j] == 0) break;}}
}void XOR(int *a, int n, int type = 1)
{for(int o = 2; o <= n; o <<= 1) {for(int i = 0, k = o >> 1; i < n; i += o) {for(int j = 0; j < k; ++j) {int X = a[i + j];int Y = a[i + j + k];a[i + j] = (1ll * X + Y) % mod;a[i + j + k] = ((1ll * X - Y) % mod + mod) % mod;if(type == -1) {a[i + j] = (1ll * a[i + j] * inv2) % mod;a[i + j + k] = (1ll * a[i + j + k] * inv2) % mod;}}}}
}void solve()
{memset(b, 0, sizeof b);memset(a, 0, sizeof a);for(int i = 2; i <= m; ++ i)if(vis[i] == 0) b[i] = 1;limit = 1;while(limit <= m) limit <<= 1;XOR(b, limit);
// for(int i=0;i<=limit;i++){
// cout<<"a[i]="<<a[i]<<endl;
// }
// for(int i=0;i<=limit;i++){
// cout<<"b[i]="<<b[i]<<endl;
// }for(int i = 0; i <= limit; ++ i) {b[i] = 1ll * qpow(b[i], n) % mod;}XOR(b, limit, -1);printf("%d\n", b[0]);return ;
}int main()
{
// freopen("data.in", "r", stdin);init(N - 7);while(scanf("%d%d", &n, &m) != EOF) {solve();}
}