- Leetcode 3130. Find All Possible Stable Binary Arrays II
- 0. 序言
- 1. 算法思路
- 2. 代码实现
- 1. 第一版本
- 2. 第二版本
- 3. 第三版本
- 4. 第四版本
- 3. 算法优化
- 1. 算法实现一
- 2. 算法实现二
- 题目链接:3130. Find All Possible Stable Binary Arrays II
0. 序言
这道题和题目3129本质上就是一道题目,唯一的差别就是取值范围的差异,题目3129的范围在 [ 1 , 200 ] [1,200] [1,200],而这道题在 [ 1 , 1000 ] [1, 1000] [1,1000],因此复杂度会更高。
很不幸,我只搞定了题目3129,而这道题则是死活都没有搞定,最后是看了一下题目3129当中其他大佬们的高效率算法之后改写了一下通过了最后的测试,不过不幸的是,具体到思路上依然是没有看懂,真的是有点伤……
所以,这里的话我会现在前两部分讲一下我自己的算法思路,以及为了提升算法效率而做的优化,总体来说的话,在题目3129当中将算法效率提升了10倍左右,耗时从3727ms降至574ms,不过不幸的是依然无法通过题目3130的测试样例……
因此,如果有读者对这部分内容不感兴趣的话可以直接跳到第三部分看一下大佬们的高效算法即可。
1. 算法思路
这一题我整体的算法思路是当做一个数学上的排列组合问题来做的,显然,如果0和1的个数 n , m ≤ l i m i t n, m \leq limit n,m≤limit的话,那么显然我们可以直接给出答案 C n + m n C_{n+m}^{n} Cn+mn。
但问题就在于如果有 n , m > l i m i t n, m > limit n,m>limit的情况,此时就不能直接用数学方法解了,本来考虑如果将 l i m i t + 1 limit+1 limit+1(不妨简记 l i m i t + 1 = k limit+1=k limit+1=k)个元素进行绑定然后填充的方式,倒是可以直接计算组合数为: C n + m − k n ⋅ ( n + 1 ) C_{n+m-k}^{n} \cdot (n+1) Cn+m−kn⋅(n+1)。
不过这种情况仅限于 n < k n < k n<k且 k < m < 2 k k < m < 2k k<m<2k的情况,即要确保不可能存在两个组内的元素均不少于 k k k个,否则就会出现重复计数的情况。
最后,关于其他的情况,我们就是能用动态规划进行暴力求解了,就很繁琐。
2. 代码实现
1. 第一版本
首先,我们给出我们的第一版本的代码实现如下:
MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):FACTORIALS[i] = i * FACTORIALS[i-1] % MODdef rev(x):return pow(x, -1, MOD)def C(n, m):if m < 0:return 0return FACTORIALS[n] * rev(FACTORIALS[m]) * rev(FACTORIALS[n-m]) % MODclass Solution:def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:@lru_cache(None)def dp(n,m,k,p):# n -> zero, m -> one, k -> pre count, p -> pre elementif n + (1-p) * k > limit * (m+1) or m + p * k > limit * (n+1):return 0elif n + (1-p) * k <= limit and m + p * k <= limit:ans = C(n+m, n)elif p == 0 and n + k <= limit and m > limit and m <= 2 * limit:ans = C(n+m, n) - C(n+m-limit-1, n) * (n+1)elif p == 1 and n > limit and m + k <= limit and n <= 2 * limit:ans = C(n+m, n) - C(n+m-limit-1, m) * (m+1)else:ans = 0if k*(1-p)+1 <= limit:ans = (ans + dp(n-1, m, k*(1-p)+1, 0)) % MODif k*p+1 <= limit:ans = (ans + dp(m-1, n, k*p+1, 0)) % MODreturn ans % MODans = dp(zero, one, 0, 0)return ans
这个实现基本就是翻译了一下我们的上述实现,在题目3129上的评测结果如下:耗时3727ms,占用内存756.5MB。
2. 第二版本
然后,我们注意到这里的n, m事实上是完全等价的,因此,我们就可以取消掉p这个元素,也就是无需再记录前一个元素是什么,直接对换n,m的值即可,选用是默认从n当中进行元素选择作为开头,这样就可以进一步提升cache的利用率了。
给出第二版代码实现如下:
MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):FACTORIALS[i] = i * FACTORIALS[i-1] % MODdef rev(x):return pow(x, -1, MOD)def C(n, m):if m < 0:return 0return FACTORIALS[n] * rev(FACTORIALS[m]) * rev(FACTORIALS[n-m]) % MODclass Solution:def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:@lru_cache(None)def dp(n,m,k):if n + k > limit * (m+1) or m > limit * (n+1):return 0elif n + k <= limit and m <= limit:ans = C(n+m, n)elif n + k <= limit and m > limit and m <= 2 * limit:ans = C(n+m, n) - C(n+m-limit-1, n) * (n+1)else:ans = dp(m-1, n, 1)if k+1 <= limit:ans = (ans + dp(n-1, m, k+1)) % MODreturn ans % MODans = dp(zero, one, 0)return ans
提交代码评测得到:耗时3157ms,占用内存684.4MB。
3. 第三版本
然后,我们注意到这个k也很碍事,既然n,m的地位完全等价了,我们只需要默认每次都必须从n当中选择1到limit个元素即可,这样就可以去掉k这个参数了,可以进一步优化cache。
MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):FACTORIALS[i] = i * FACTORIALS[i-1] % MODdef rev(x):return pow(x, -1, MOD)def C(n, m):if m < 0:return 0return FACTORIALS[n] * rev(FACTORIALS[m]) * rev(FACTORIALS[n-m]) % MODclass Solution:def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:@lru_cache(None)def dp(n,m):if n == 0:return 0elif m == 0:return 1 if n <= limit else 0elif n > limit * (m+1) or m > limit * n:return 0elif n <= limit and m <= limit:ans = C(n+m-1, m)elif n <= limit and m > limit and m <= 2 * limit:ans = C(n+m-1, m) - C(n+m-limit-2, n-1) * nelse:ans = 0for i in range(1, min(limit, n) + 1):ans = (ans + dp(m, n-i))return ans % MODans = (dp(zero, one) + dp(one, zero)) % MODreturn ans
提交代码评测得到:耗时707ms,占用内存32.1MB。
4. 第四版本
最后,我们还对上述 C n m C_{n}^{m} Cnm的实现进行了一下优化,具体来说的话就是每次都算pow(n, -1 MOD)
太耗时了,因此我们也像阶乘一样预先算好保存在一个数组当中即可。
给出python代码实现如下:
MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):FACTORIALS[i] = i * FACTORIALS[i-1] % MODInv_FACTORIALS = [pow(x, -1, MOD) for x in FACTORIALS]def C(n: int, m: int):if m < 0:return 0return (FACTORIALS[n] * Inv_FACTORIALS[m] * Inv_FACTORIALS[n-m]) % MODclass Solution:def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:@lru_cache(None)def dp(n,m):if n == 0:return 0elif m == 0:return 1 if n <= limit else 0elif n > limit * (m+1) or m > limit * n:return 0elif n <= limit and m <= limit:ans = C(n+m-1, m)elif n <= limit and m > limit and m <= 2 * limit:ans = C(n+m-1, m) - C(n+m-limit-2, n-1) * nelse:ans = 0for i in range(1, min(limit, n) + 1):ans = (ans + dp(m, n-i))return ans % MODans = (dp(zero, one) + dp(one, zero)) % MODreturn ans
提交代码评测得到:耗时574ms,占用内存32MB。
3. 算法优化
不过可惜的是,即便如此,上述的算法依然无法通过题目3130的测试样例,还是会出现超时的情况。因此在下面的小节里面,我们摘录了两个大佬的两个算法实现,分别来自题目3129和题目3130的解答当中耗时最优的方法,然后稍微用我自己感觉更好理解的方式稍微翻译了一下,虽然我自己依然没有完全看明白具体的数学含义就是了。
不过插句题外话,虽然代码实现两个算法不太一样,不过从具体的思路以及参数计算来看,我觉得这俩实现很可能来自同一个大佬……
只能说,大佬牛逼……
1. 算法实现一
MOD = 10**9+7
FACTORIALS = [1 for _ in range(1001)]
for i in range(1, 1001):FACTORIALS[i] = i * FACTORIALS[i-1] % MODInv_FACTORIALS = [pow(x, -1, MOD) for x in FACTORIALS]def C(n: int, m: int):if m < 0:return 0return (FACTORIALS[n] * Inv_FACTORIALS[m] * Inv_FACTORIALS[n-m]) % MODclass Solution:def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:ans = 0N = zero + onemin_zero_group = (zero - 1) // limit + 1min_one_group = (one - 1) // limit + 1@lru_cache(None)def count(n, g, k):ans = C(n+g-1, g-1)r, flag = 1, -1while n - r * (k+1) >= 0:ans = (ans + flag * C(n - r*(k+1) + g-1, g-1) * C(g, r)) % MODr += 1flag *= -1return ansfor n in range(min_zero_group, zero+1):for m in range(n-1, n+1+1):if m < min_one_group or m > one:continueflag = 1 if n != m else 2ans = (ans + flag * count(zero-n, n, limit - 1) * count(one-m, m, limit - 1)) % MODreturn ans % MOD
提交代码评测得到:耗时72ms,占用内存17.1MB。
2. 算法实现二
MOD = 10**9+7
FACTORIALS = [1 for _ in range(1001)]
for i in range(1, 1001):FACTORIALS[i] = i * FACTORIALS[i-1] % MODInv_FACTORIALS = [pow(x, -1, MOD) for x in FACTORIALS]def C(n: int, m: int):if m < 0:return 0return (FACTORIALS[n] * Inv_FACTORIALS[m] * Inv_FACTORIALS[n-m]) % MODclass Solution:def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:ans = 0N = zero + onemin_zero_group = (zero - 1) // limit + 1min_one_group = (one - 1) // limit + 1def count(n, g, k):ans = C(n+g-1, g-1)r, flag = 1, -1while n - r * (k+1) >= 0:ans = (ans + flag * C(n - r*(k+1) + g-1, g-1) * C(g, r)) % MODr += 1flag *= -1return ansfor n in range(min_zero_group, zero+1):for m in range(n-1, n+1+1):if m < min_one_group or m > one:continueflag = 1 if n != m else 2ans = (ans + flag * count(zero-n, n, limit - 1) * count(one-m, m, limit - 1)) % MODreturn ans % MOD
提交代码评测得到:耗时94ms,占用内存16.7MB。