多项式开根
给定多项式g(x)g(x)g(x),求f(x)f(x)f(x),满足f2(x)=g(x)f ^ 2(x) = g(x)f2(x)=g(x)。
假设我们已经得到了g(x)g(x)g(x),膜x⌈n2⌉x ^{\lceil \frac{n}{2} \rceil}x⌈2n⌉下的根f0(x)f_0 (x)f0(x),要求膜xnx ^ nxn下的根f(x)f(x)f(x)
有f02(x)≡g(x)(modx⌈n2⌉)f_0 ^2(x) \equiv g(x) \pmod {x ^{\lceil \frac{n}{2} \rceil}}f02(x)≡g(x)(modx⌈2n⌉)
移项再开方有(f02(x)−g(x))2≡0(modxn)\left(f_0 ^2(x) - g(x) \right) ^ 2 \equiv 0 \pmod {x ^ n}(f02(x)−g(x))2≡0(modxn)
则,(f02(x)+g(x))2≡4f02(x)g(x)(modxn)\left( f_0 ^ 2(x) + g(x) \right) ^ 2 \equiv 4 f_0 ^ 2(x) g(x) \pmod {x ^ n}(f02(x)+g(x))2≡4f02(x)g(x)(modxn)
g(x)≡(f02(x)+g(x)2f0(x))2(modxn)g(x) \equiv \left(\frac{f_0 ^ 2(x) + g(x)}{2f_0 (x)} \right) ^ 2 \pmod {x ^ n}g(x)≡(2f0(x)f02(x)+g(x))2(modxn)
所以f(x)≡f02(x)+g(x)2f0(x)(modxn)f(x) \equiv \frac{f_0 ^ 2(x) + g(x)}{2f_0(x)} \pmod {x ^ n}f(x)≡2f0(x)f02(x)+g(x)(modxn)。
所以有f(x)≡2−1f0(x)+2−1f0−1(x)g(x)(modxn)f(x) \equiv 2 ^{-1} f_0 (x) + 2 ^{-1} f_0 ^{-1}(x) g(x) \pmod {x ^ n}f(x)≡2−1f0(x)+2−1f0−1(x)g(x)(modxn),
对于g(0)=1g(0) = 1g(0)=1的特殊情况
#include <bits/stdc++.h>using namespace std;typedef long long ll;const int N = 5e6 + 10, mod = 998244353, inv2 = mod + 1 >> 1;int a[N], b[N], c[N], d[N], r[N];int quick_pow(int a, int n) {int ans = 1;while (n) {if (n & 1) {ans = 1ll * ans * a % mod;}a = 1ll * a * a % mod;n >>= 1;}return ans;
}void get_r(int lim) {for (int i = 0; i < lim; i++) {r[i] = (i & 1) * (lim >> 1) + (r[i >> 1] >> 1);}
}void NTT(int *f, int lim, int rev) {for (int i = 0; i < lim; i++) {if (i < r[i]) {swap(f[i], f[r[i]]);}}for (int mid = 1; mid < lim; mid <<= 1) {int wn = quick_pow(3, (mod - 1) / (mid << 1));for (int len = mid << 1, cur = 0; cur < lim; cur += len) {int w = 1;for (int k = 0; k < mid; k++, w = 1ll * w * wn % mod) {int x = f[cur + k], y = 1ll * w * f[cur + mid + k] % mod;f[cur + k] = (x + y) % mod, f[cur + mid + k] = (x - y + mod) % mod;}}}if (rev == -1) {int inv = quick_pow(lim, mod - 2);reverse(f + 1, f + lim);for (int i = 0; i < lim; i++) {f[i] = 1ll * f[i] * inv % mod;}}
}void polyinv1(int *a, int *b, int n) {if (n == 1) {b[0] = quick_pow(a[0], mod - 2);return ;}polyinv1(a, b, n + 1 >> 1);int lim = 1;while (lim < 2 * n) {lim <<= 1;}get_r(lim);for (int i = 0; i < n; i++) {c[i] = a[i];}for (int i = n; i < lim; i++) {c[i] = 0;}NTT(b, lim, 1);NTT(c, lim, 1);for (int i = 0; i < lim; i++) {int cur = (2 - 1ll * c[i] * b[i] % mod + mod) % mod;b[i] = 1ll * b[i] * cur % mod;}NTT(b, lim, -1);for (int i = n; i < lim; i++) {b[i] = 0;}
}void polysqrt(int *a, int *b, int n) {if (n == 1) {b[0] = 1;return ;}polysqrt(a, b, n + 1 >> 1);int lim = 1;while (lim < 2 * n) {lim <<= 1;}for (int i = 0; i < lim; i++) {d[i] = 0;}polyinv1(b, d, n);for (int i = 0; i < n; i++) {c[i] = a[i];}for (int i = n; i < lim; i++) {c[i] = 0;}get_r(lim);NTT(b, lim, 1);NTT(c, lim, 1);NTT(d, lim, 1);for (int i = 0; i < lim; i++) {b[i] = (1ll * inv2 * b[i] % mod + 1ll * inv2 * d[i] % mod * c[i] % mod) % mod; }NTT(b, lim, -1);for (int i = n; i < lim; i++) {b[i] = 0;}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);int n;scanf("%d", &n);for (int i = 0; i < n; i++) {scanf("%d", &a[i]);}polysqrt(a, b, n);for (int i = 0; i < n; i++) {printf("%d%c", b[i], i + 1 == n ? '\n' : ' ');}return 0;
}
二次剩余解一般情况
#include <bits/stdc++.h>using namespace std;typedef long long ll;const int N = 5e6 + 10, mod = 998244353, inv2 = mod + 1 >> 1;int a[N], b[N], c[N], d[N], r[N];namespace Quadratic_residue {struct Complex {int r, i;Complex(int _r = 0, int _i = 0) : r(_r), i(_i) {}};int I2;Complex operator * (const Complex &a, Complex &b) {return Complex((1ll * a.r * b.r % mod + 1ll * a.i * b.i % mod * I2 % mod) % mod, (1ll * a.r * b.i % mod + 1ll * a.i * b.r % mod) % mod);}Complex quick_pow(Complex a, int n) {Complex ans = Complex(1, 0);while (n) {if (n & 1) {ans = ans * a;}a = a * a;n >>= 1;}return ans;}int get_residue(int n) {mt19937 e(233);if (n == 0) {return 0;}if(quick_pow(n, (mod - 1) >> 1).r == mod - 1) {return -1;}uniform_int_distribution<int> r(0, mod - 1);int a = r(e);while(quick_pow((1ll * a * a % mod - n + mod) % mod, (mod - 1) >> 1).r == 1) {a = r(e);}I2 = (1ll * a * a % mod - n + mod) % mod;int x = quick_pow(Complex(a, 1), (mod + 1) >> 1).r, y = mod - x;if(x > y) swap(x, y);return x;}
}int quick_pow(int a, int n) {int ans = 1;while (n) {if (n & 1) {ans = 1ll * ans * a % mod;}a = 1ll * a * a % mod;n >>= 1;}return ans;
}void get_r(int lim) {for (int i = 0; i < lim; i++) {r[i] = (i & 1) * (lim >> 1) + (r[i >> 1] >> 1);}
}void NTT(int *f, int lim, int rev) {for (int i = 0; i < lim; i++) {if (i < r[i]) {swap(f[i], f[r[i]]);}}for (int mid = 1; mid < lim; mid <<= 1) {int wn = quick_pow(3, (mod - 1) / (mid << 1));for (int len = mid << 1, cur = 0; cur < lim; cur += len) {int w = 1;for (int k = 0; k < mid; k++, w = 1ll * w * wn % mod) {int x = f[cur + k], y = 1ll * w * f[cur + mid + k] % mod;f[cur + k] = (x + y) % mod, f[cur + mid + k] = (x - y + mod) % mod;}}}if (rev == -1) {int inv = quick_pow(lim, mod - 2);reverse(f + 1, f + lim);for (int i = 0; i < lim; i++) {f[i] = 1ll * f[i] * inv % mod;}}
}void polyinv1(int *a, int *b, int n) {if (n == 1) {b[0] = quick_pow(a[0], mod - 2);return ;}polyinv1(a, b, n + 1 >> 1);int lim = 1;while (lim < 2 * n) {lim <<= 1;}get_r(lim);for (int i = 0; i < n; i++) {c[i] = a[i];}for (int i = n; i < lim; i++) {c[i] = 0;}NTT(b, lim, 1);NTT(c, lim, 1);for (int i = 0; i < lim; i++) {int cur = (2 - 1ll * c[i] * b[i] % mod + mod) % mod;b[i] = 1ll * b[i] * cur % mod;}NTT(b, lim, -1);for (int i = n; i < lim; i++) {b[i] = 0;}
}void polysqrt(int *a, int *b, int n) {if (n == 1) {b[0] = Quadratic_residue::get_residue(a[0]);return ;}polysqrt(a, b, n + 1 >> 1);int lim = 1;while (lim < 2 * n) {lim <<= 1;}for (int i = 0; i < lim; i++) {d[i] = 0;}polyinv1(b, d, n);for (int i = 0; i < n; i++) {c[i] = a[i];}for (int i = n; i < lim; i++) {c[i] = 0;}get_r(lim);NTT(b, lim, 1);NTT(c, lim, 1);NTT(d, lim, 1);for (int i = 0; i < lim; i++) {b[i] = (1ll * inv2 * b[i] % mod + 1ll * inv2 * d[i] % mod * c[i] % mod) % mod; }NTT(b, lim, -1);for (int i = n; i < lim; i++) {b[i] = 0;}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);int n;scanf("%d", &n);for (int i = 0; i < n; i++) {scanf("%d", &a[i]);}polysqrt(a, b, n);for (int i = 0; i < n; i++) {printf("%d%c", b[i], i + 1 == n ? '\n' : ' ');}return 0;
}