P5170 【模板】类欧几里得算法
Description
要求在O(lgn)O(lgn)O(lgn)的时间内求出:
∑i=0n⌊ai+bc⌋\sum_{i = 0}^n{\lfloor\frac{ai+b}{c}\rfloor}∑i=0n⌊cai+b⌋
∑i=0ni⌊ai+bc⌋\sum_{i = 0}^n{i\lfloor\frac{ai+b}{c}\rfloor }∑i=0ni⌊cai+b⌋
∑i=0n⌊ai+bc⌋2\sum_{i = 0}^n{\lfloor\frac{ai+b}{c}\rfloor^2}∑i=0n⌊cai+b⌋2
Solution
其实也就是一种优秀的求和思想。
Part one
我们先考虑最基本的第一个:
∑i=0n⌊ai+bc⌋=∑i=0n(⌊a′i+b′c⌋+pi+q)\sum_{i = 0}^n{\lfloor\frac{ai+b}{c}\rfloor} = \sum_{i = 0}^n{(\lfloor\frac{a'i+b'}{c}\rfloor} + pi + q)∑i=0n⌊cai+b⌋=∑i=0n(⌊ca′i+b′⌋+pi+q)
其中p=⌊ac⌋,q=⌊bc⌋p =\lfloor\frac{a}{c}\rfloor, q = \lfloor\frac{b}{c}\rfloorp=⌊ca⌋,q=⌊cb⌋
那么后面的∑pi+q\sum{pi+q}∑pi+q很容易计算,因此问题转化为求:
f(a,b,c,n)=∑i=0n⌊ai+bc⌋(a∈[0,c),b∈[0,c))f(a,b,c,n)=\sum_{i = 0}^n{\lfloor\frac{ai+b}{c}\rfloor}(a\in[0,c),b\in[0, c))f(a,b,c,n)=∑i=0n⌊cai+b⌋(a∈[0,c),b∈[0,c))
Part two
然后使用一个极其诡妙的思路:
f(a,b,c,n)f(a,b,c,n)f(a,b,c,n)
=∑i=0n⌊ai+bc⌋= \sum_{i = 0}^n{\lfloor\frac{ai+b}{c}\rfloor}=∑i=0n⌊cai+b⌋
=∑i=0n∑j=1m[j≤⌊ai+bc⌋](m=⌊an+bc⌋)=\sum_{i = 0}^n\sum_{j = 1}^m{[j \leq \lfloor\frac{ai+b}{c}\rfloor}]\ \ (m = \lfloor\frac{an+b}{c}\rfloor)=∑i=0n∑j=1m[j≤⌊cai+b⌋] (m=⌊can+b⌋)
=∑i=0n∑j=0m−1[jc+c≤ai+b]=\sum_{i = 0}^n\sum_{j = 0}^{m - 1}{[jc +c\leq ai+b]}=∑i=0n∑j=0m−1[jc+c≤ai+b]
=∑j=0m−1∑i=0n[jc+c≤ai+b]=\sum_{j = 0}^{m - 1}\sum_{i = 0}^n{[jc +c\leq ai+b]}=∑j=0m−1∑i=0n[jc+c≤ai+b]
=∑j=0m−1(n+1−∑i=0n[jc+c>ai+b])=\sum_{j = 0}^{m - 1}{(n + 1 - \sum_{i = 0}^n[jc +c> ai+b])}=∑j=0m−1(n+1−∑i=0n[jc+c>ai+b])
=∑j=0m−1(n−⌊jc+c−b−1a⌋)=\sum_{j = 0}^{m - 1}{(n -\lfloor\frac{jc+c-b-1}{a}\rfloor)}=∑j=0m−1(n−⌊ajc+c−b−1⌋)
=nm−∑j=0m−1⌊jc+c−b−1a⌋=nm-\sum_{j = 0}^{m - 1}\lfloor\frac{jc+c-b-1}{a}\rfloor=nm−∑j=0m−1⌊ajc+c−b−1⌋
=nm−f(c,c−b−1,a,m−1)=nm - f(c,c-b-1,a,m-1)=nm−f(c,c−b−1,a,m−1)
因此可以递归计算,直到a=0a=0a=0时,我们直接算出答案。
不难发现a,ca,ca,c相当于做了一个辗转相除的过程,因此时间复杂度为O(lgn)O(lgn)O(lgn)。
Part three
剩下两个也是同样的思路,先考虑让a′=amodc,b′=bmodca'=a\mod c,b'=b\mod ca′=amodc,b′=bmodc,然后快速计算下取整式外的贡献,下取整式内的部分通过类似上面的方法转化,推导出若干个能用类欧计算的式子相加的形式,然后递归计算即可。这里不再赘述。
Code
#include <bits/stdc++.h>using namespace std;template<typename T> inline bool upmin(T &x, T y) { return y < x ? x = y, 1 : 0; }
template<typename T> inline bool upmax(T &x, T y) { return x < y ? x = y, 1 : 0; }#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se secondtypedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int, int> PR;
typedef vector<int> VI; const lod eps = 1e-9;
const lod pi = acos(-1);
const int oo = 1 << 30;
const ll loo = 1ll << 60;
const int mods = 998244353;
const int inv2 = (mods + 1) >> 1;
const int inv6 = (mods + 1) / 6;
const int MAXN = 600005;
const int INF = 0x3f3f3f3f; //1061109567
/*--------------------------------------------------------------------*/namespace FastIO{constexpr int SIZE = (1 << 21) + 1;int num = 0, f;char ibuf[SIZE], obuf[SIZE], que[65], *iS, *iT, *oS = obuf, *oT = obuf + SIZE - 1, c;#define gc() (iS == iT ? (iT = ((iS = ibuf) + fread(ibuf, 1, SIZE, stdin)), (iS == iT ? EOF : *iS ++)) : *iS ++)inline void flush() {fwrite(obuf, 1, oS - obuf, stdout);oS = obuf;}inline void putc(char c) {*oS ++ = c;if (oS == oT) flush();}inline void getc(char &c) {for (c = gc(); !isalpha(c) && c != EOF; c = gc());}inline void reads(char *st) {char c;int n = 0;getc(st[++ n]);for (c = gc(); isalpha(c) ; c = gc()) st[++ n] = c;}template<class I>inline void read(I &x) {for (f = 1, c = gc(); c < '0' || c > '9' ; c = gc()) if (c == '-') f = -1;for (x = 0; c >= '0' && c <= '9' ; c = gc()) x = (x << 3) + (x << 1) + (c & 15);x *= f;}template<class I>inline void print(I x) {if (x < 0) putc('-'), x = -x;if (!x) putc('0');while (x) que[++ num] = x % 10 + 48, x /= 10;while (num) putc(que[num --]);}struct Flusher_{~Flusher_(){flush();}} io_Flusher_;
}
using FastIO :: read;
using FastIO :: putc;
using FastIO :: reads;
using FastIO :: print;struct Node{ int f, g, h; Node() { f = g = h = 0; }Node(int x, int y, int z):f(x), g(y), h(z){};
};
int S1(int x) { return 1ll * x * (x + 1) / 2 % mods; }
int S2(int x) { return 1ll * x * (x + 1) % mods * (x * 2 + 1) % mods * inv6 % mods; }
int upd(int x, int y) {return x + y >= mods ? x + y - mods : x + y;
}
Node solve(int a, int b, int c, int n) {Node Ans, ans;int p = 0, q = 0;if (a >= c || b >= c) {p = a / c, q = b / c;Ans.f = upd(Ans.f, 1ll * p * S1(n) % mods);Ans.f = upd(Ans.f, 1ll * q * (n + 1) % mods);Ans.g = upd(Ans.g, 1ll * p * p % mods * S2(n) % mods);Ans.g = upd(Ans.g, 1ll * q * q % mods * (n + 1) % mods);Ans.g = upd(Ans.g, 2ll * p * q % mods * S1(n) % mods);Ans.h = upd(Ans.h, 1ll * p * S2(n) % mods);Ans.h = upd(Ans.h, 1ll * q * S1(n) % mods);a -= p * c, b -= q * c;}if (!a) return Ans;int m = ((ll)a * n + b) / c;Node t = solve(c, c - b - 1, a, m - 1);ans.f = upd(ans.f, upd(1ll * m * n % mods, mods - t.f));ans.h = upd(ans.h, upd(1ll * m * S1(n) % mods, mods - 1ll * upd(t.g, t.f) * inv2 % mods));ans.g = upd(ans.g, upd(2ll * ans.h * p % mods, 2ll * ans.f * q % mods));ans.g = upd(ans.g, 1ll * m * m % mods * n % mods);ans.g = upd(ans.g, mods - 2ll * t.h % mods);ans.g = upd(ans.g, mods - t.f);return Node(upd(ans.f, Ans.f), upd(ans.g, Ans.g), upd(ans.h, Ans.h));
}
signed main() {
#ifndef ONLINE_JUDGEfreopen("a.in", "r", stdin);
#endifint Case;read(Case);while (Case --) {int n, a, b, c;read(n), read(a), read(b), read(c);Node ans = solve(a, b, c, n);print(ans.f), putc(' '), print(ans.g), putc(' '), print(ans.h), putc('\n');}return 0;
}