4747
思路:
线段树
先求出mex(1,1), mex(1, 2) , mex(1,3),...,mex(1,n)(单调上升),先将这些mex放进线段树里求和
然后再求出next[i]表示下一次出现a[i] 的位置
然后从前往后不停的删数,对于一个数a[i],我们删掉他的影响是:l为mex大于a[i]的位置,r 为next[i],l 到 r-1 之间的 mex都变为 a[i]
然后这个线段树只需要维护区间最大值(方便查找第一个大于a[i]的位置)和区间和就可以了
代码:
#include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pii pair<int, int> #define mem(a, b) memset(a, b, sizeof(a)) #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout); //headconst int N = 2e5 + 5; int a[N], nxt[N], mx[N<<2], lazy[N<<2], mex[N]; LL sum[N<<2]; map<int, int>mp; void push_up(int rt) {sum[rt] = sum[rt<<1] + sum[rt<<1|1];mx[rt] = max(mx[rt<<1], mx[rt<<1|1]); } void push_down(int rt, int len) {sum[rt<<1] = 1LL * lazy[rt] * (len - (len >> 1));mx[rt<<1] = lazy[rt];lazy[rt<<1] = lazy[rt];sum[rt<<1|1] = 1LL * lazy[rt] * (len >> 1);mx[rt<<1|1] = lazy[rt];lazy[rt<<1|1] = lazy[rt];lazy[rt] = 0; } void build(int rt, int l, int r) {if(l == r) {mx[rt] = sum[rt] = mex[l];return ;}int m = (l+r) >> 1;build(ls);build(rs);push_up(rt); } void update(int x, int L, int R, int rt, int l, int r) {if(L <= l && r <= R) {mx[rt] = x;sum[rt] = 1LL * (r-l+1) * x;lazy[rt] = x;return ;}if(lazy[rt]) push_down(rt, r-l+1);int m = (l+r) >> 1;if(L <= m) update(x, L, R, ls);if(R > m) update(x, L, R, rs);push_up(rt); } LL query(int L, int R, int rt, int l, int r) {if(L <= l && r <= R) return sum[rt];if(lazy[rt]) push_down(rt, r-l+1);int m = (l+r) >> 1;LL ans = 0;if(L <= m) ans += query(L, R, ls);if(R > m) ans += query(L, R, rs);push_up(rt);return ans; } int Find(int x, int rt, int l, int r) {if(l == r) return l;int m = (l+r) >> 1;if(lazy[rt]) push_down(rt, r-l+1);if(mx[rt<<1] > x) return Find(x, ls);else return Find(x, rs); } int main() {int n;while(~scanf("%d", &n) && n) {mem(lazy, 0);build(1, 1, n);for (int i = 1; i <= n; i++) scanf("%d", &a[i]);mp.clear();int tmp = 0;for (int i = 1; i <= n; i++) {mp[a[i]]++;while(mp.find(tmp) != mp.end()) tmp++;mex[i] = tmp;}build(1, 1, n);mp.clear();for (int i = n; i >= 1; i--) {if(mp.find(a[i]) == mp.end()) nxt[i] = n+1;else nxt[i] = mp[a[i]];mp[a[i]] = i;}LL ans = 0;for (int i = 1; i <= n; i++) {ans += query(1, n, 1, 1, n);if(mx[1] <= a[i]) continue;int l = Find(a[i], 1, 1, n);int r = nxt[i];if(l < r) update(a[i], l, r-1, 1, 1, n);}printf("%lld\n", ans);}return 0; }