文章目录
- 树状数组
- lowbit
- 线段树与树状数组
- 单点修改
- 区间查询
- 区间修改
- 区间求和
- 二维树状数组
- 离线树状数组
- 例题
- POJ:stars
- MooFest
- [SDOI2009]HH的项链
- Turing Tree
- Counting Sequences
- Zip-line
树状数组
用于快速高效的计算与前缀和相关的信息
lowbit
int lowbit( int i ) { return i & -i; }
lowbit\rm lowbitlowbit:返回xxx二进制最低位为111的位置的值
e.g.
40=101000
,lowbit(40)=8
线段树与树状数组
因为涉及lowbit\rm lowbitlowbit,所以树状数组的下标一定从111开始,而不是000
线段树用mid=(l+r)>>1
进行log\loglog的优化
树状数组的通过±lowbit(i)±\rm lowbit(i)±lowbit(i)进行二进制位的进/退111
时间复杂度同样都是O(nlogn)O(n\log n)O(nlogn)
但一般来说树状数组的空间都是O(N)O(N)O(N),不会像线段树有N<<2N<<2N<<2的大空间
线段树因为其结构原因有更多的应用:优化建图,线段树分治.........
但是树状数组就比较死板了,就是跟静态区间/单点挂钩
那么应用广点,消耗的代价(更大空间)多点也是可以理解的了
很多情况下树状数组和线段树没有什么区别,可以互换
单点修改
void add( int i, int val ) {for( ;i <= n;i += lowbit( i ) )t[i] += val;
}
区间查询
int query( int i ) {//求的是[1,i]的前缀和int ans = 0;for( ;i;i -= lowbit( i ) ) ans += t[i];return ans;
}
int query( int l, int r ) {return query( r ) - query( l - 1 );
}
一般的写法都是维护前缀和,所以修改是+lowbit+\rm lowbit+lowbit,查询是−lowbit-\rm lowbit−lowbit
但有些时候题目反而是跟后缀挂钩,这个时候有两种选择
-
强制后缀转前缀
每次传入iii的时候,暴力变成i=n−i+1i=n-i+1i=n−i+1,然后进行
add
query
的操作 -
直接反转使用树状数组
修改直接−lowbit-\rm lowbit−lowbit,查询+lowbit+\rm lowbit+lowbit
只要维护了意义上的相对即可
区间修改
原序列a1,a2,...,ana_1,a_2,...,a_na1,a2,...,an,定义差分数组ci=ai−ai−1c_i=a_i-a_{i-1}ci=ai−ai−1, 则ai=∑j=1icja_i=\sum_{j=1}^ic_jai=∑j=1icj
那么修改区间[l,r][l,r][l,r],加上www,相当于在clc_lcl加www,在cr+1c_{r+1}cr+1减www
这就将区间修改转化成了两次的单点修改
达到[l,r][l,r][l,r]区间的数加www的效果
区间求和
∑i=1nai=∑i=1n∑j=1icj=∑i=1n(n−i+1)ci=(c1)+(c1+c2)+...+(c1+c2+...+cn)\sum_{i=1}^na_i=\sum_{i=1}^n\sum_{j=1}^ic_j=\sum_{i=1}^n(n-i+1)c_i \\=(c_1)+(c_1+c_2)+...+(c_1+c_2+...+c_n) i=1∑nai=i=1∑nj=1∑icj=i=1∑n(n−i+1)ci=(c1)+(c1+c2)+...+(c1+c2+...+cn)
=(n+1)∗(c1+c2+...+cn)−(c1⏞1+c2+c2⏞2+...+cn+cn⏞n)=(n+1)*(c_1+c_2+...+c_n)-(\overbrace{c_1}^{1}+\overbrace{c_2+c_2}^{2}+...\overbrace{+c_n+c_n}^{n}) =(n+1)∗(c1+c2+...+cn)−(c11+c2+c22+...+cn+cnn)
=(n+1)∗∑i=1nci−∑i=1nci∗i=(n+1)*\sum_{i=1}^nc_i-\sum_{i=1}^nc_i*i =(n+1)∗i=1∑nci−i=1∑nci∗i
所以只需要用两个树状数组,分别维护cic_ici和ci∗ic_i*ici∗i即可
LOJ:区间修改区间查询
#include <cstdio>
#define int long long
#define maxn 1000005
int n, Q;
int a[maxn], t1[maxn], t2[maxn];int lowbit( int x ) { return x & -x; }void modify( int x, int val ) {for( int i = x;i <= n;i += lowbit( i ) )t1[i] += val, t2[i] += val * x;
}int query( int x ) {int ans = 0;for( int i = x;i;i -= lowbit( i ) )ans += ( x + 1 ) * t1[i] - t2[i];return ans;
}signed main() {scanf( "%lld %lld", &n, &Q );for( int i = 1;i <= n;i ++ ) {scanf( "%lld", &a[i] );modify( i, a[i] - a[i - 1] );}int opt, l, r, x;while( Q -- ) {scanf( "%lld %lld %lld", &opt, &l, &r );if( opt & 1 ) {scanf( "%lld", &x );modify( l, x ), modify( r + 1, -x );}else printf( "%lld\n", query( r ) - query( l - 1 ) );}return 0;
}
二维树状数组
既然树状数组是前缀和的工具,那么二维树状数组就相当于与二维差分
树状数组嵌树状数组的感觉,查询就是用二维差分计算围成的面积
void modify( int x, int y, int val ) {for( int i = x;i <= n;i += lowbit( i ) )for( int j = y;j <= m;j += lowbit( j ) )t[i][j] += val;
}
int query( int x, int y ) {int ans = 0;for( int i = x;i;i -= lowbit( i ) )for( int j = y;j;j -= lowbit( j ) )ans += t[i][j];return ans;
}modify( x1, y1, k );
query( x2, y2 ) - query( x2, y1 - 1 ) - query( x1 - 1, y2 ) + query( x1 - 1, y1 - 1 );
离线树状数组
离线树状数组求区间不同数的个数/值和
都是将询问按照l,r\rm l,rl,r排序,然后记录iii的上一个/下一个位置
将指针拨到询问的端点处,删去上一个位置/加入下一个位置
从而做到111的个数差,满足不同数只记录一次的要求,自然树状数组就能维护
例题的最后两题就是如此,看代码比较清晰能够理解
例题
POJ:stars
POJ2352
题目已经保证了yyy递增,那么树状数组维护xxx,每次查询比xxx小的星星有多少个即可
#include <cstdio>
#include <iostream>
using namespace std;
#define maxn 32005
int n, N;
int ans[maxn], t[maxn], x[maxn], y[maxn];int lowbit( int i ) { return i & -i; }void modify( int i ) {for( ;i <= N;i += lowbit( i ) ) t[i] ++;
}int query( int i ) {int ret = 0;for( ;i;i -= lowbit( i ) ) ret += t[i];return ret;
}int main() {scanf( "%d", &n );for( int i = 1;i <= n;i ++ ) {scanf( "%d %d", &x[i], &y[i] );x[i] ++, y[i] ++, N = max( N, x[i] );//x,y值域包含0 树状数组不能从0开始 所以整体+1}for( int i = 1;i <= n;i ++ )ans[query( x[i] )] ++, modify( x[i] );for( int i = 0;i < n;i ++ )printf( "%d\n", ans[i] );return 0;
}
MooFest
POJ1990
按vvv排序,维护两个树状数组,一个是小于当前位置的位置和,一个是大于当前位置的位置和,这样就避免了距离带的绝对值
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
#define int long long
#define maxn 20005
struct node {int val, pos;node(){}node( int Val, int Pos ) {val = Val, pos = Pos;}
}cow[maxn];
struct Node {int cnt, sumd;Node(){}Node( int Cnt, int Sumd ) {cnt = Cnt, sumd = Sumd;}
}t1[maxn], t2[maxn];
int n, N;int lowbit( int i ) { return i & -i; }void modify1( int i, int val ) {for( ;i <= N;i += lowbit( i ) )t1[i].cnt ++, t1[i].sumd += val;
}void modify2( int i, int val ) {for( ;i;i -= lowbit( i ) )t2[i].cnt ++, t2[i].sumd += val;
}Node query1( int i ) {Node ans( 0, 0 );for( ;i;i -= lowbit( i ) )ans.cnt += t1[i].cnt, ans.sumd += t1[i].sumd;return ans;
}Node query2( int i ) {Node ans( 0, 0 );for( ;i <= N;i += lowbit( i ) )ans.cnt += t2[i].cnt, ans.sumd += t2[i].sumd;return ans;
}bool cmp( node x, node y ) { return x.val < y.val; } signed main() {scanf( "%lld", &n );for( int i = 1;i <= n;i ++ ) {scanf( "%lld %lld", &cow[i].val, &cow[i].pos );N = max( cow[i].pos, N );}N ++;sort( cow + 1, cow + n + 1, cmp );int ans = 0;for( int i = 1;i <= n;i ++ ) {Node t = query1( cow[i].pos );ans += ( cow[i].pos * t.cnt - t.sumd ) * cow[i].val;t = query2( cow[i].pos );ans += ( t.sumd - t.cnt * cow[i].pos ) * cow[i].val;modify1( cow[i].pos, cow[i].pos );modify2( cow[i].pos, cow[i].pos );}printf( "%lld\n", ans );return 0;
}
[SDOI2009]HH的项链
Luogu1972
离线树状数组求区间不同数个数
将询问按lll排序,对于每个位置iii记录下一个与该位置值相等的位置,每一次到iii就把下一次的位置加进去
询问区间左端点以前的自然都要加,这样区间查询相减,就知道下一次的位置在不在区间内,就恰好为111
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 1000005
struct node {int l, r, id;
}q[maxn];
int n, m;
int a[maxn], t[maxn], lst[maxn], nxt[maxn], ans[maxn];
bool vis[maxn];void read( int &x ) {x = 0; char s = getchar();while( s < '0' or s > '9' ) s = getchar();while( '0' <= s and s <= '9' ) {x = ( x << 1 ) + ( x << 3 ) + ( s ^ 48 );s = getchar();}
}int lowbit( int i ) { return i & -i; }void add( int i ) {for( ;i < maxn;i += lowbit( i ) ) t[i] ++;
}int query( int i ) {int ret = 0;for( ;i;i -= lowbit( i ) ) ret += t[i];return ret;
}int main() {read( n );for( int i = 1;i <= n;i ++ ) read( a[i] );read( m );for( int i = 1;i <= m;i ++ ) read( q[i].l ), read( q[i].r ), q[i].id = i;sort( q + 1, q + m + 1, []( node x, node y ) { return x.l < y.l; } );for( int i = 1;i <= n;i ++ )if( ! vis[a[i]] ) add( i ), vis[a[i]] = 1;for( int i = n;i;i -- ) {if( ! lst[a[i]] ) nxt[i] = maxn;else nxt[i] = lst[a[i]];lst[a[i]] = i;}int pos = 1;for( int i = 1;i <= m;i ++ ) {while( pos < q[i].l ) add( nxt[pos] ), pos ++;ans[q[i].id] = query( q[i].r ) - query( q[i].l - 1 );}for( int i = 1;i <= m;i ++ )printf( "%d\n", ans[i] );return 0;
}
Turing Tree
HDU3333
离线树状数组求区间不同数的和
与不同数的个数一致的思路,这里值域比较大,就记录下标
按照rrr排序也可
#include <map>
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 30005
#define maxm 100005
#define int long long
struct node {int l, r, id;
}q[maxm];
map < int, int > lst;
int T, n, m;
int a[maxn], t[maxn], ans[maxm];int lowbit( int i ) { return i & -i; }void add( int i, int val ) {for( ;i <= n;i += lowbit( i ) ) t[i] += val;
}int query( int i ) {int ret = 0;for( ;i;i -= lowbit( i ) ) ret += t[i];return ret;
}signed main() {scanf( "%lld", &T );while( T -- ) {scanf( "%lld", &n );for( int i = 1;i <= n;i ++ )scanf( "%lld", &a[i] ), t[i] = 0;scanf( "%lld", &m );for( int i = 1;i <= m;i ++ )scanf( "%lld %lld", &q[i].l, &q[i].r ), q[i].id = i;sort( q + 1, q + m + 1, []( node x, node y ) { return x.r < y.r; } );lst.clear();int j = 1;for( int i = 1;i <= m;i ++ ) {for( ;j <= q[i].r;j ++ ) {if( lst[a[j]] ) add( lst[a[j]], -a[j] );add( j, a[j] ), lst[a[j]] = j;}ans[q[i].id] = query( q[i].r ) - query( q[i].l - 1 );}for( int i = 1;i <= m;i ++ )printf( "%lld\n", ans[i] );}return 0;
}
Counting Sequences
HDU3450
很简单的dpdpdp转移都能想到
dpi:dp_i:dpi: 最后一位选iii的完美子序列个数,cnti:icnt_i:icnti:i前面与iii距离不超过ddd的个数
则dpi=∑j,∣xi−xj∣≤ddpj+cntidp_i=\sum_{j,|x_i-x_j|\le d}dp_j+cnt_idpi=∑j,∣xi−xj∣≤ddpj+cnti
因为规定个数至少要是222,但jjj成为第一个和后面的iii组成新111个子序列是不会被计算在dpjdp_jdpj里面的,所以单独开一个cntcntcnt记录
用值域树状数组维护,查询query(x+d)−query(x−d−1)\rm query(x+d)-query(x-d-1)query(x+d)−query(x−d−1)
但是这道题的难度就在于值域过大,树状数组内存根本开不下
把nnn个位置离散化,查询就直接找离散化数组中最大的不超过该值的
用upper_bound
查就可以了
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
#define maxn 100005
#define mod 9901
int n, d, m;
int t1[maxn], t2[maxn], x[maxn], pos[maxn];int lowbit( int i ) { return i & -i; }void add( int i, int val ) {for( ;i <= m;i += lowbit( i ) )t1[i] ++, t2[i] += val;
}pair < int, int > query( int i ) {int ans1 = 0, ans2 = 0;for( ;i > 0;i -= lowbit( i ) ) ans1 += t1[i], ans2 += t2[i];return make_pair( ans1, ans2 );
}int find( int p ) {return upper_bound( pos + 1, pos + m + 1, p ) - pos - 1;
}int main() {while( ~ scanf( "%d %d", &n, &d ) ) {for( int i = 1;i <= n;i ++ )scanf( "%d", &x[i] ), pos[i] = x[i], t1[i] = t2[i] = 0;sort( pos + 1, pos + n + 1 );m = unique( pos + 1, pos + n + 1 ) - pos - 1;int ans = 0;for( int i = 1;i <= n;i ++ ) {pair < int, int > r = query( find( x[i] + d ) );pair < int, int > l = query( find( x[i] - d - 1 ) );pair < int, int > t = make_pair( r.first - l.first, r.second - l.second );t.first = ( t.first + mod ) % mod;t.second = ( t.second + mod ) % mod;ans = ( ans + t.first + t.second ) % mod;x[i] = lower_bound( pos + 1, pos + m + 1, x[i] ) - pos;add( x[i], ( t.first + t.second ) % mod );}printf( "%d\n", ans );}return 0;
}
Zip-line
CF650D
预处理以iii开始/结尾的最长上升子序列li/ril_i/r_ili/ri
查询时,考虑答案有三种变化
-
往iii前找比新值xxx小的hjh_jhj的最大值rjr_jrj,往iii后找比xxx大的hjh_jhj的最大值ljl_jlj
lj+rj’+1l_j+r_j’+1lj+rj’+1 把三段拼起来,如果比答案大,就输出
-
iii一定在LIS\rm LISLIS中,答案就−1-1−1
-
iii可以不在LIS\rm LISLIS中,答案不变
e.g.
1 2 2 3
,2
是可以不在LIS\rm LISLIS中的,因为作用与另外一个等效
那现在就是求出哪些数是一定在LIS\rm LISLIS中
- 求出包含iii的最长上升子序列,如果与LIS\rm LISLIS相等,那么iii可以在LIS中
- 对于一个可以在LIS的数iii,如果iii前面≥i\ge i≥i的数可以在LIS中,iii可以不在LIS中
- 对于一个可以在LIS的数iii,如果iii后面≤i\le i≤i的数可以在LIS中,iii可以不在LIS中
对于xxx,求ai<xa_i< xai<x的最大值bib_ibi,离散化树状数组解决
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 800005
struct node { int pos, val, id, l, r; }q[maxn];
int n, m, N, ans;
int t[maxn], x[maxn], h[maxn], St[maxn], Ed[maxn], ret[maxn], cnt[maxn];int lowbit( int i ) { return i & -i; }void add( int i, int val ) {for( ;i <= N;i += lowbit( i ) ) t[i] = max( t[i], val );
}int query( int i ) {int ans = 0;for( ;i;i -= lowbit( i ) ) ans = max( ans, t[i] );return ans;
}int main() {scanf( "%d %d", &n, &m );for( int i = 1;i <= n;i ++ )scanf( "%d", &h[i] ), x[++ N] = h[i];for( int i = 1;i <= m;i ++ ) {scanf( "%d %d", &q[i].pos, &q[i].val );q[i].id = i, x[++ N] = q[i].val;}sort( x + 1, x + N + 1 );N = unique( x + 1, x + N + 1 ) - x - 1;for( int i = 1;i <= n;i ++ )h[i] = lower_bound( x + 1, x + N + 1, h[i] ) - x;for( int i = 1;i <= m;i ++ )q[i].val = lower_bound( x + 1, x + N + 1, q[i].val ) - x;for( int i = 1;i <= n;i ++ )Ed[i] = query( h[i] - 1 ) + 1, add( h[i], Ed[i] );memset( t, 0, sizeof( t ) );for( int i = n;i;i -- )St[i] = query( N - h[i] ) + 1, add( N - h[i] + 1, St[i] );memset( t, 0, sizeof( t ) );for( int i = 1;i <= n;i ++ )ans = max( ans, St[i] + Ed[i] - 1 );for( int i = 1;i <= n;i ++ )if( Ed[i] + St[i] - 1 == ans )++ cnt[Ed[i]];sort( q + 1, q + m + 1, []( node x, node y ) { return x.pos < y.pos; } );int j = 1;for( int i = 1;i <= m;i ++ ) {for( ;j < q[i].pos;j ++ ) add( h[j], Ed[j] );q[i].l = query( q[i].val - 1 );}memset( t, 0, sizeof( t ) );j = n;for( int i = m;i;i -- ) {for( ;j > q[i].pos;j -- ) add( N - h[j] + 1, St[j] );q[i].r = query( N - q[i].val );}for( int i = 1;i <= m;i ++ )if( q[i].l + q[i].r + 1 > ans )ret[q[i].id] = q[i].l + q[i].r + 1;for( int i = 1;i <= m;i ++ )if( ! ret[q[i].id] ) {if( Ed[q[i].pos] + St[q[i].pos] - 1 == ans and cnt[Ed[q[i].pos]] == 1 and q[i].l + q[i].r + 1 < ans )ret[q[i].id] = ans - 1;else ret[q[i].id] = ans;}for( int i = 1;i <= m;i ++ )printf( "%d\n", ret[i] );return 0;
}