题目
代码
下面的两个代码的区别在于modify的分类,modify最简单的分类方式是存在性分类,另一种类似某些query采用的三段式分类,详细见代码
存在性
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5+10;
int a[N];
struct node
{int l, r;ll sum;ll add;
}tr[4*N];
void pushup(int u)
{tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int i)
{node &u = tr[i], &l = tr[i << 1], &r = tr[i << 1 | 1];if(u.add){l.add += u.add; r.add += u.add;l.sum += (l.r - l.l + 1) * u.add; r.sum += (r.r - r.l + 1) * u.add;u.add = 0;}
}
void build(int u, int l, int r)
{if(l == r) tr[u] = {l, r, a[l], 0};else{tr[u] = {l, r};int mid = l + r >> 1;build(u << 1, l, mid), build(u << 1 | 1, mid+1 , r);pushup(u);}
}
void modify(int u, int l, int r, int v)
{if(l <= tr[u].l && tr[u].r <= r){tr[u].add += v;tr[u].sum += (ll)(tr[u].r - tr[u].l + 1) * v;return;}pushdown(u);int mid = tr[u].l + tr[u].r >> 1;if(l <= mid) modify(u << 1, l, r, v);if(r > mid) modify(u << 1 | 1, l, r, v);pushup(u);
}
ll query(int u, int l, int r)
{if(l <= tr[u].l && tr[u].r <= r) return tr[u].sum;pushdown(u);int mid = tr[u].l + tr[u].r >> 1;if(r <= mid) return query(u << 1, l, r);else if(l > mid) return query(u << 1 | 1, l, r);else{ll ans = 0;ans += query(u << 1, l, r);ans += query(u << 1 | 1, l, r);return ans;}
}
int main()
{int n, m;cin >> n >> m;for(int i = 1; i <= n; i++) cin >> a[i];build(1, 1, n);while (m -- ){char op; int l, r;cin >> op >> l >> r;if(op == 'C'){int d;cin >> d;modify(1, l, r, d);}else{cout << query(1, l, r) << '\n';}}return 0;
}
三段式
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5+10;
int a[N];
struct node
{int l, r;ll sum;ll add;
}tr[4*N];
void pushup(int u)
{tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int i)
{node &u = tr[i], &l = tr[i << 1], &r = tr[i << 1 | 1];if(u.add){l.add += u.add; r.add += u.add;l.sum += (l.r - l.l + 1) * u.add; r.sum += (r.r - r.l + 1) * u.add;u.add = 0;}
}
void build(int u, int l, int r)
{if(l == r) tr[u] = {l, r, a[l], 0};else{tr[u] = {l, r};int mid = l + r >> 1;build(u << 1, l, mid), build(u << 1 | 1, mid+1 , r);pushup(u);}
}
void modify(int u, int l, int r, int v)
{if(l <= tr[u].l && tr[u].r <= r){tr[u].add += v;tr[u].sum += (ll)(tr[u].r - tr[u].l + 1) * v;return;}pushdown(u);int mid = tr[u].l + tr[u].r >> 1;if(r <= mid) modify(u << 1, l, r, v);else if(l > mid) modify(u << 1 | 1, l, r, v);else{modify(u << 1, l, r, v);modify(u << 1 | 1, l, r, v);}pushup(u);
}
ll query(int u, int l, int r)
{if(l <= tr[u].l && tr[u].r <= r) return tr[u].sum;pushdown(u);int mid = tr[u].l + tr[u].r >> 1;if(r <= mid) return query(u << 1, l, r);else if(l > mid) return query(u << 1 | 1, l, r);else{ll ans = 0;ans += query(u << 1, l, r);ans += query(u << 1 | 1, l, r);return ans;}
}
int main()
{int n, m;cin >> n >> m;for(int i = 1; i <= n; i++) cin >> a[i];build(1, 1, n);while (m -- ){char op; int l, r;cin >> op >> l >> r;if(op == 'C'){int d;cin >> d;modify(1, l, r, d);}else{cout << query(1, l, r) << '\n';}}return 0;
}