P3564 [POI2014]BAR-Salad Bar
给定一个长度为nnn的数组,里面元素只有111跟−1-1−1,问选出一个长度为lenlenlen的区间使得,这个区间的前缀和时刻大于零,后缀和时刻大于零,输出最大长度lenlenlen,
考虑枚举lll端点,我们可以二分出最大的rrr,满足pre_sumpre\_sumpre_sum时刻大于等于零,设为[l,r][l, r][l,r],
考虑枚举RRR端点,我们可以二分出最小的LLL,满足suc_sumsuc\_sumsuc_sum时刻大于等于零,设为[L,R][L, R][L,R],
则答案一定是在所有上述点对中的l,Rl, Rl,R中的一个,且有l≤R≤rl \leq R \leq rl≤R≤r,L≤l≤RL \leq l \leq RL≤l≤R,
假设我们已经把上述满足要求的两种点对都算出来了,考虑新开一个线段树,
我们把第二种点对[L,R][L, R][L,R],放进线段树上维护,在RRR点记录符合要求的最小的LLL,
考虑枚举[l,r][l, r][l,r]点对,在区间[l,r][l, r][l,r]中寻找一个最大的RRR,使得于RRR相对应的LLL,满足L≤lL \leq lL≤l,这个时候我们的答案就是R−l+1R - l + 1R−l+1的最大值了。
上述操作都利用STSTST表,然后二分一下即可,整体复杂度O(nlogn)O (n \log n)O(nlogn)。
#include <bits/stdc++.h>using namespace std;const int N = 1e6 + 10, logn = 20;int a[N], sum[N], b[N], Log[N], f[N][logn + 1], n;char str[N];vector<pair<int, int>> vt, v;void init() {Log[1] = 0, Log[2] = 1;for (int i = 3; i < N; i++) {Log[i] = Log[i / 2] + 1;}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);init();scanf("%d %s", &n, str + 1);for (int i = 1; i <= n; i++) {a[i] = str[i] == 'p' ? 1 : -1;}for (int i = 1; i <= n; i++) {sum[i] = a[i] + sum[i - 1], f[i][0] = sum[i];}for (int j = 1; j <= logn; j++) {for (int i = 1; i + (1 << j) - 1 <= n; i++) {f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);}}for (int i = 1; i <= n; i++) {if (a[i] == -1) {continue;}int l = i, r = n;while (l < r) {int mid = l + r + 1 >> 1, s = Log[mid - i + 1];if (min(f[i][s], f[mid - (1 << s) + 1][s]) >= sum[i - 1]) {l = mid;}else {r = mid - 1;}}// printf("%d %d\n", i, l);vt.push_back({i, l});}for (int i = 1; i <= n; i++) {sum[i] = a[n - i + 1] + sum[i - 1], f[i][0] = sum[i];}for (int j = 1; j <= logn; j++) {for (int i = 1; i + (1 << j) - 1 <= n; i++) {f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);}}memset(b, 0x3f, sizeof b);for (int i = 1; i <= n; i++) {if (a[n - i + 1] == -1) {continue;}int l = i, r = n;while (l < r) {int mid = l + r + 1 >> 1, s = Log[mid - i + 1];if (min(f[i][s], f[mid - (1 << s) + 1][s]) >= sum[i - 1]) {l = mid;}else {r = mid - 1;}}b[n - i + 1] = n - l + 1;// printf("%d %d\n", n - l + 1, n - i + 1);}for (int i = 1; i <= n; i++) {f[i][0] = b[i];}for (int j = 1; j <= logn; j++) {for (int i = 1; i + (1 << j) - 1 <= n; i++) {f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);}}int ans = 0;for (auto it : vt) {int L = it.first, R = it.second;int l = it.first, r = it.second;while (L < R) {// [mid + 1, r]int mid = L + R >> 1, s = Log[r - mid];if (min(f[mid + 1][s], f[r - (1 << s) + 1][s]) <= l) {L = mid + 1;}else {R = mid;}}ans = max(ans, L - l + 1);}printf("%d\n", ans);return 0;
}