概述
树状数组(Binary Indexed Tree,简称BIT),是一种数据结构,用于处理区间查询和更新问题。它是一种可以高效地在对数级别时间复杂度内进行单点更新和区间查询的数据结构。树状数组通常用于解决以下两类问题:
- 区间和查询:给定一个序列,查询序列中任意区间的和。
- 区间更新:给定一个序列,对序列中任意区间的值进行增加或减少。
问题引入
给定一个长度为n的数组,完成以下两种操作:
- 更新:将第x个数加上k;
- 查询:输出区间[x, y]内每个数的和。
我们很容易想到一种朴素做法,更新操作直接在原数组上操作,查询遍历一下即可,对应的时间复杂度分别为O(1)和O(n)。
当然,你也可能想到用前缀和数组来优化,这样的话更新操作的时间复杂度就是O(n),查询操作的复杂度为O(1)。
可以发现,两种做法中,要么查询是O(1),更新是O(n);要么更新是O(1),查询是O(n)。那么就有没有一种做法可以综合一下这两种朴素做法,然后整体时间复杂度可以降一个数量级呢?有的,对,就是树状数组。
lowbit函数
学习树状数组之前首先需要了解一下lowbit
函数。lowbit
函数的功能就是求某一个数的二进制表示中最低的一位1
所表示的数值。这个数值一定是2的幂。举个例子,x = 6
,它的二进制为110
,那么lowbit(x)
就返回2
,因为最后一个1
表示2
。再举个例子,lowbit(4) = 4
。
我们知道,负数的补码是它的反码+1。当然,还有一种快捷求法就是,从右往左数第一个1及其右边的0不动,剩下的位取反。这时候,我们如果让它和原数进行二进制与操作,就能得到最后的一个1及其后面的0。例如,6的二进制为0110,-6的补码为1010,它们两个做与运算就能消掉最后一个1前面的所有位。用代码表示如下:
int lowbit(int x)
{return x & -x;
}
树状数组的思想
首先要明确树状数组里存的是什么。假设原数组是arr
,我们需要维护一个新的树状数组c
,c
数组里的每一位存的是arr
中对应下标开始往前数lowbit(下标)
个数的和。例如,c[6]
的下标为6,并且lowbit(6) = 2
,所以c[6]
存的就是arr
中从第6项开始往前数2个数的和,即arr[5] + arr[6]
。因此,相比前缀和数组,树状数组可以说存的是区间和。
查询
明白了树状数组存的是什么,就可以用树状数组来求前缀和了。因为查询操作还是要通过两个前缀和做差来得到任意区间的和。
因为树状数组存的是区间和,我们通过不同的区间拼凑出一个完整的前缀区间就能计算前缀和。还是以6为例,6的二进制为110,可以写成100 + 10
,即4 + 2
。根据树状数组的定义,c[6]
存的是arr[5] + arr[6]
。得到第一个区间和后,减去lowbit,即6 - lowbit(6) = 6 - 2 = 4
。而c[4]
存的是arr中第1项到第4项的和,这是因为lowbit(4) = 4
。这两段拼起来正好得到第6项的前缀和。
因此,用树状数组求第x项前缀和可以用下面的代码表示:
int sum(int x, int c[])
{int res = 0;for (; x > 0; x -= lowbit(x))res += c[x];return res;
}
更新
如果理解了上述过程,我们其实能发现,树状数组求前缀和本质上就是将下标展开成二进制,根据二进制位上的1来求和,从而实现对数级别的复杂度。树状数组用图来表示就是像下面这样。
其中,1到12是树状数组的下标,上面的横条表示了这一项对应arr
数组中的区间和。我们从这张图中可以得到树状数组的如下性质:
- 下面层的下标只要补上自己的lowbit值就可以得到上面层的下标(图中的虚线指出了什么是上面层)。注意,是上面层的下标,而不是上一层的下标,这个性质就是更新操作的依据;
例如,下标6只要加上lowbit(6)
,也就是2,就能跳到自己的上面层,也就是8。之所以8是上面层,是因为它们的区间产生了重叠。加上lowbit值会让最低位的1往高位移动,其所代表的幂会指数增长,远大于加上的值。所以上面层必定包含下面层。如果不能理解记住这个性质就好。
理解了这一点,就可以明白更新操作了。如果在arr
数组上进行更新操作,很简单,只要修改第x项就可以了。但是树状数组表示的是区间和,修改了这一项会影响到很多包含这一项的区间。因此,在c
数组上,所有包含第x项的区间和都要修改。
更新了arr
的第x项,首先影响到的就是c[x]
。因为c[x]
所代表的区间和长度至少为1,即必定包含arr[x]
。然后就是上面的性质所说的上面层了。我们通过不断加上lowbit值往上层跳,不断更新c数组,就能实现对数级别复杂度的更新操作。代码如下:
void update(int x, int val, int c[], int n)
{for (; x <= n; x += lowbit(x))c[x] += val;
}
代码实现
输入格式
第一行输入两个整数n和m,分别表示数组长度和操作的次数;
第二行输入n个整数表示数组;
接下来m行,每行输入一个字符ch和两个整数x,y。ch='F'
表示查询x到y这段闭区间的和;ch='S'
表示第x个元素加上y。
输出格式
对于每个查询,输出结果。
样例输入
5 6
1 2 3 4 5
F 1 3
S 1 2
F 1 3
S 2 3
F 1 2
F 1 5
样例输出
6
8
8
20
完整代码实现如下:
#include <iostream>using namespace std;const int MAX = 1e6;
int c[MAX]; // c[i]表示从第i个元素向前数lowbit(i)个元素,这一段的和,包括c[i]int lowbit(int x)
{return x & -x;
}/*** @brief 求下标为x的前缀和** @param x* @return int*/
int sum(int x, int c[])
{int res = 0;for (; x > 0; x -= lowbit(x))res += c[x];return res;
}/*** @brief 原数组x的位置上的数加上了val,所以要维护c数组** @param x* @param val* @param c* @param n*/
void update(int x, int val, int c[], int n)
{for (; x <= n; x += lowbit(x))c[x] += val;
}int main()
{int n, m;cin >> n >> m;for (int i = 1; i <= n; i++){int x;cin >> x;update(i, x, c, n);}while (m--){int x, y;char ch;cin >> ch >> x >> y;switch (ch){case 'F':cout << sum(y, c) - sum(x - 1, c) << endl;break;case 'S':update(x, y, c, n);break;}}return 0;
}