NTT经典例题
CCPC-Winter-Camp-day6-A——NTT经典例题
对于上面格式,如果想求出每个i的值可以使用卷积求出,因为阶乘j和阶乘i-j相乘的值为(i+(i-j))=i
补充一个二次剩余定理
P5491 【模板】二次剩余 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
//#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<numeric>
#include<cstring>//rfind("string"),s.find(string,begin)!=s.npos,find_first _of(),find_last_of()
#include<string>//to_string(value),s.substr(int begin, int length);
#include<cstdio>
#include<cmath>
#include<vector>//res.erase(unique(res.begin(), res.end()), res.end()),resize(n)//size of vector,vector<int>().swap(at[mx])
#include<queue>//priority_queue(big) /priority_queue<int, vector<int>, greater<int>> q(small)
#include<stack>
#include<map>
#include<set>
#include<unordered_map>
#include<unordered_set>
#include<bitset>
#include<random>
#include<chrono>
//#include<ext/pb_ds/assoc_container.hpp>//gp_hash_table
//#include<ext/pb_ds/hash_policy.hpp>
//using namespace __gnu_pbds;
std::mt19937_64 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
using namespace std;
#define int long long//__int128 2^127-1(GCC)
#define PII pair<int,int>
struct num {int x;// 实部int y;// 虚部(即虚数单位√w的系数)
};int t, w, n, p;num mul(num a, num b, int p) {// 复数乘法 num res;res.x = ((a.x * b.x % p + a.y * b.y % p * w % p) % p + p) % p;// x = a.x*b.x + a.y*b.y*wres.y = ((a.x * b.y % p + a.y * b.x % p) % p + p) % p;// y = a.x*b.y + a.y*b.xreturn res;
}
int qpow_r(int a, int b, int p) {// 实数快速幂 int res = 1;while (b) {if (b & 1) res = res * a % p;a = a * a % p;b >>= 1;}return res;
}
int qpow_i(num a, int b, int p) {// 复数快速幂 num res = { 1,0 };while (b) {if (b & 1) res = mul(res, a, p);a = mul(a, a, p);b >>= 1;}return res.x % p;// 只用返回实数部分,因为虚数部分没了
}
int cipolla(int n, int p) {n %= p;if (qpow_r(n, (p - 1) / 2, p) == -1 + p) return -1;// 据欧拉准则判定是否有解 int a;while (1) {// 找出一个符合条件的aa = rand() % p;w = (((a * a) % p - n) % p + p) % p;// w = a^2 - n,虚数单位的平方if (qpow_r(w, (p - 1) / 2, p) == -1 + p) break;}num x = { a,1 };return qpow_i(x, (p + 1) / 2, p);
}
signed main() {srand(time(0));cin >> t;while (t--) {cin >> n >> p;if (!n) {printf("0\n");continue;}int ans1 = cipolla(n, p), ans2 = -ans1 + p;// 另一个解就是其相反数,ans1正数解 if (ans1 == -1) printf("Hola!\n");//无解else {if (ans1 > ans2) swap(ans1, ans2);if (ans1 == ans2) printf("%lld\n", ans1);else printf("%lld %lld\n", ans1, ans2);}}return 0;
}
NTT背包合并
NNFly (nowcoder.com)
PowerPoint 演示文稿 (nowcoder.com)
有点像数位dp,其中用到背包合并可以使用多项式解决,如果n个背包合并可以使用线段树和启发式合并类似的思想
//#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<numeric>
#include<cstring>//rfind("string"),s.find(string,begin)!=s.npos,find_first _of(),find_last_of()
#include<string>//to_string(value),s.substr(int begin, int length);
#include<cstdio>
#include<cmath>
#include<vector>//res.erase(unique(res.begin(), res.end()), res.end()),resize(n)//size of vector,vector<int>().swap(at[mx])
#include<queue>//priority_queue(big) /priority_queue<int, vector<int>, greater<int>> q(small)
#include<stack>
#include<map>
#include<set>
#include<unordered_map>
#include<unordered_set>
#include<bitset>
#include<random>
#include<chrono>
//#include<ext/pb_ds/assoc_container.hpp>//gp_hash_table
//#include<ext/pb_ds/hash_policy.hpp>
//using namespace __gnu_pbds;
std::mt19937_64 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
using namespace std;
#define int long long//__int128 2^127-1(GCC)
#define PII pair<int,int>
const int N = 3e6 + 5, mod = 998244353;
namespace ntt {const int g = 3;int a[N], b[N];int r[N], tot, bit;int invg;int qpow(int a, int b) {int res = 1;while (b) {if (b & 1) res = 1ll * res * a % mod;a = 1ll * a * a % mod;b >>= 1;}return res;}void add(int& a, int b) {a += b;if (a >= mod) a -= mod;}void NTT(int a[], int inv) {for (int i = 0; i < tot; i++)if (i < r[i])swap(a[i], a[r[i]]);for (int mid = 1; mid < tot; mid <<= 1) {int g1 = qpow(inv == 1 ? g : invg, (mod - 1) / (mid << 1));for (int i = 0; i < tot; i += mid << 1) {for (int j = 0, gk = 1; j < mid; j++, gk = 1ll * gk * g1 % mod) {int x = a[i + j], y = 1ll * gk * a[i + j + mid] % mod;a[i + j] = (x + y) % mod, a[i + j + mid] = (x - y + mod) % mod;}}}if (inv == -1) {int invtot = qpow(tot, mod - 2);for (int i = 0; i < tot; i++) {a[i] = 1ll * a[i] * invtot % mod;}}}struct Poly {vector<int> coef;int deg;int& operator[](int x) {return coef[x];}Poly(int deg = -1) : deg(deg) {coef = vector<int>(deg + 1, 0);}void norm(int deg) {this->deg = deg;coef.resize(deg + 1);}};void init(int len) {bit = tot = 0;while ((1ll << bit) <= len) bit++;tot = 1ll << bit;for (int i = 0; i < tot; i++) a[i] = b[i] = 0;for (int i = 1; i < tot; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));}Poly operator*(const Poly& f, const Poly& g) {Poly res(f.deg + g.deg);if (f.deg <= 8 || g.deg <= 8) {for (int i = 0; i <= f.deg; i++)for (int j = 0; j <= g.deg; j++)add(res[i + j], 1ll * f.coef[i] * g.coef[j] % mod);return res;}init(res.deg);copy(f.coef.begin(), f.coef.end(), a);copy(g.coef.begin(), g.coef.end(), b);NTT(a, 1), NTT(b, 1);for (int i = 0; i < tot; i++) a[i] = 1ll * a[i] * b[i] % mod;NTT(a, -1);copy(a, a + res.deg + 1, res.coef.begin());return res;}int __ = []{invg = qpow(g, mod - 2);return 0;}();
}
using namespace ntt;
signed main()
{ios_base::sync_with_stdio(0); cin.tie(0), cout.tie(0);int n, k;int m;cin >> n >> m >> k;vector<int>a(n + 1);vector<Poly>v;int sum = 0;for (int i = 1; i <= n; i++) {cin >> a[i];sum += a[i];Poly f(a[i]);f[0] = f[a[i]] = 1;v.emplace_back(f);}auto solve = [&](auto self, int l, int r)->Poly {if (l == r) return v[l];int mid = l + r >> 1;return self(self, l, mid) * self(self, mid + 1, r);};Poly f = solve(solve, 0, v.size() - 1);//assert(f.deg == sum);vector<vector<int>>ban(60, vector<int>());//array<vector<int>, 60>ban;while (k--){int b, c;cin >> b >> c;ban[c].push_back(b);}vector<Poly>dp(2);dp[0].norm(0);dp[0][0] = 1;for (int i = 0; i < 60; i++) {Poly g = f;sort(ban[i].begin(), ban[i].end());ban[i].erase(unique(ban[i].begin(), ban[i].end()), ban[i].end());for (auto x : ban[i]) {for (int j = a[x]; j <= sum; j++) {g[j] -= g[j - a[x]];if (g[j] < 0)g[j] += mod;}}vector<Poly>f(2), ndp(2);f[0] = dp[0] * g;f[1] = dp[1] * g;for (auto t : { 0,1 }) ndp[t].norm(f[t].deg / 2);for (auto t : { 0,1 }) {for (int j = 0; j <= f[t].deg; j++) {if (j % 2 == (m >> i & 1)) {add(ndp[t][j / 2], f[t][j]);}else if (j % 2 > (m >> i & 1)) {add(ndp[1][j / 2], f[t][j]);}else {add(ndp[0][j / 2], f[t][j]);}}}dp = ndp;}cout << dp[0][0] << "\n";
}