P6669 [清华集训2016] 组合数问题
题意:
给你n,m,k,问有多少对(i,j)满足K∣CijK|C_{i}^{j}K∣Cij
(Cij是k的倍数C_{i}^{j}是k的倍数Cij是k的倍数)
n,m<=1e18
题解:
n和m非常大,非常非常大,很容易想到用卢卡斯来化简
Cnmmodp=Cn/pm/p∗Cn%pm%pC_{n}^{m}\bmod p=C_{n/p}^{m/p}*C_{n\%p}^{m\%p}Cnmmodp=Cn/pm/p∗Cn%pm%p
对于i%p的范围就会很小,但是i/p的范围有可能还是很大,所以将Ci/pj/pC_{i/p}^{j/p}Ci/pj/p继续用卢卡斯化简。
这样一直操作,n/p,取n%p,然后再n/p…,这不就相当于是将n转化成p进制吗?可以好好思考一下
最后式子变成:Cnmmodp=∏i=0kCnimiC_{n}^{m}\bmod p=\prod_{i=0}^{k}C_{n_{i}}^{m_{i}}Cnmmodp=∏i=0kCnimi
n=nk∗pk+nk−1∗pk−1+...+n0n=n_{k}*p^{k}+n_{k-1}*p^{k-1}+...+n_{0}n=nk∗pk+nk−1∗pk−1+...+n0
m同理
如果CnmC_{n}^{m}Cnm是k的倍数,说明CnmmodpC_{n}^{m}\bmod pCnmmodp等于0,也就是那个累乘为0,就说明存在某一项Cnimi=0C_{ni}^{mi}=0Cnimi=0
现在我们开始考虑CnimiC_{ni}^{mi}Cnimi的情况,如何求其数量,我们可以先求出所有组合数C的情况,然后减去C非0的情况,那么剩下的就是C为0的情况
所有组合数C的情况就是(m+1)∗(m+2)/2+(n−m)∗(m+1)(m+1)*(m+2)/2+(n-m)*(m+1)(m+1)∗(m+2)/2+(n−m)∗(m+1),这个很好推,写出式子就有了
那C非0的情况如何求?
用数位dp来做,设dp[i][j][k]:表示考虑第i位,j和k为0或1,j为1表示第i-1位n已经取上界(本位取值存在限制),k为1表示第i-1位m取上界。为0则表示未取上界(本位取值无限制)
当第i位n不取上届时,第i位之后的每一位就可以随便取,一定会小于(数位dp思想)
详细看代码
代码:
// Problem: P6669 [清华集训2016] 组合数问题
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P6669
// Memory Limit: 500 MB
// Time Limit: 1000 ms
// Data:2021-08-27 14:38:38
// By Jozky#include <bits/stdc++.h>
#include <unordered_map>
#define debug(a, b) printf("%s = %d\n", a, b);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
clock_t startTime, endTime;
//Fe~Jozky
const ll INF_ll= 1e18;
const int INF_int= 0x3f3f3f3f;
void read(){};
template <typename _Tp, typename... _Tps> void read(_Tp& x, _Tps&... Ar)
{x= 0;char c= getchar();bool flag= 0;while (c < '0' || c > '9')flag|= (c == '-'), c= getchar();while (c >= '0' && c <= '9')x= (x << 3) + (x << 1) + (c ^ 48), c= getchar();if (flag)x= -x;read(Ar...);
}
template <typename T> inline void write(T x)
{if (x < 0) {x= ~(x - 1);putchar('-');}if (x > 9)write(x / 10);putchar(x % 10 + '0');
}
void rd_test()
{
#ifdef LOCALstartTime= clock();freopen("in.txt", "r", stdin);
#endif
}
void Time_test()
{
#ifdef LOCALendTime= clock();printf("\nRun Time:%lfs\n", (double)(endTime - startTime) / CLOCKS_PER_SEC);
#endif
}
int t, k;
const ll mod= 1e9 + 7;
const int maxn= 2000;
int b[maxn];
int c[maxn];
int cnt1= 0;
int cnt2= 0;
int f[maxn][2][2];
ll Sum(ll a, ll b)
{ll ans= 0;while (b) {if (b & 1)ans= (ans + a) % mod;a= (a + a) % mod;b>>= 1;}return ans % mod;
}
ll poww(ll a, ll b)
{ll ans= 1ll;while (b) {if (b & 1)ans= Sum(ans, a) % mod;a= Sum(a, a) % mod;b>>= 1;}return ans % mod;
}
ll solve(int len, int nup, int mup) //数位dp
{if (!len)return 1;if (f[len][nup][mup] != -1)return f[len][nup][mup] % mod;ll ans= 0;int l, r;//如果上一位到了上界,本位取值范围是0到b[len]//如果上一位没有到上界,本位就可以随便取值,范围是0到k-1l= nup ? b[len] : k - 1;r= mup ? c[len] : k - 1;for (int i= 0; i <= l; i++) {for (int j= 0; j <= i && j <= r; j++) {ans= (ans + solve(len - 1, nup && (i == l), mup && (j == r))) % mod;}}return f[len][nup][mup]= ans % mod;
}
int main()
{//rd_test();read(t, k);while (t--) {ll n, m;scanf("%lld%lld", &n, &m);cnt1= 0;cnt2= 0;memset(f, -1, sizeof(f));memset(b, 0, sizeof(b));memset(c, 0, sizeof(c));m= min(n, m);ll sum= ((((((m + 1) % mod * ((m + 2) % mod)) % mod) * (poww(2ll, mod - 2) % mod)) % mod) + (((n - m) % mod) * ((m + 1) % mod)) % mod) % mod;//cout << "sum=" << sum << endl;ll P= mod;while (n) {b[++cnt1]= n % k;n/= k;}while (m) {c[++cnt2]= m % k;m/= k;}printf("%lld\n", (sum - solve(cnt1, 1, 1) + mod) % mod);}//Time_test();
}