最近花了两天时间学了下树状数组,然后感觉堪比玄学,故写篇文章总结一下。
2024.08.07 UPD:添加了树状数组的树形结构解释,与树状数组上二分。


树状数组(binary indexed tree),也叫BIT,又以其发明人名字被称为Fenwick树。树状数组利用二进制下标来维护区间和,能够做到以logn\log{n}的复杂度进行单点修改和区间求和,同时配合差分数组等结构,还能做到区间修改、单点查询和区间修改、区间查询等功能,相比与线段树来说更为简单和优雅。

基本知识

树状数组最为重要的一个设定就是lowbitlowbit——二进制最低位的1及其后面的0所组成的数字,相比与前缀和,树状数组的每一位维护的是(ilowbit(i),i]\left(i-lowbit(i),i\right]的区间和,比如6,二进制为110,那么这一位维护的就是原数组6到4100的区间和(不包含4)。

那么要如何计算一个数的lowbitlowbit呢?如果用一个1作为check来判断每一位是否是1,那么代码大概长这样:

1
2
3
4
for(int i = 0; i <= __lg(x); i++) {
if(x & (1ll << i))
return 1ll << i;
}

很显然,不但运行速度慢,而且也不够简洁优雅,不符合树状数组的特性不是,所以我们需要一种更为简单的方法。因此,就有了下面这个公式\[lowbit(x)=x&-x\]为什么会这样呢?我们知道,在计算机中,负数采用补码形式存储,而补码就是取反加1,也就是将最低位1之后的所有位全部取反。所以只需要用这个简单的公式就能轻松地得到一个数的lowbitlowbit

基本操作

最简单的树状数组支持两种操作:单点修改区间查询。两种操作都能够以logn\log{n}级别的复杂度和5行以内的代码实现,十分方便。

首先来看修改操作。如果是普通的数组,那么我们只需要直接修改对应位置的数字就行了。但树状数组的每一位维护的是区间和,那么很显然我们在修改时应该把所有包含该位置的区间都更新一遍。那么哪些区间会包含需要修改的位置呢?假设需要修改的位置为i,那么一个树状数组所维护的区间(jlowbit(j),j]\left(j-lowbit(j),j\right]想要包含i就需要保证jlowbit(j)<ij-lowbit(j)<i并且j>=ij>=i,要直接找出这样的区间似乎有点抽象,我们先来举个简单的例子。

假设现在要修改的的位置i1001 0110(150),让我们从1001 0111开始枚举,显然这个数不满足jlowbit(j)<ij-lowbit(j)<i,排除。继续往后找,下一个数是1001 1000,减去lowbit1001 0000,显然满足要求,那么我们需要修改的第一个元素就找到了——1001 1000。继续往下找,显然1001 1000无论后面的3个0那个变成1都不满足条件(1001 1000本身就大于i,后面加1后再减去lowbit一定大于i)。那么再继续往后找,也就是1010 0000,显然满足,同时我们也可以发现后面无论如何添1都不满足……以此类推,假设数组长度是1111 1111,那么我们需要修改的元素下标如下

1
2
3
4
1001 0110 //自己显然可以,上面忘记写了……
1001 1000
1010 0000
1100 0000

发现规律了吗,整个下标变化的过程就是一个进位的过程,每次加的都是lowbit,直到到达数组边界。(应该有更为形象的理解方式,但我想不到……看别人写的也不能理解,就贴个链接吧……算法学习笔记(2) : 树状数组)
回到正题,既然知道了怎么找区间,那么代码就好写了:

1
2
3
4
5
void update(int x, int i) { //i为位置,x为改变的量
for(; i <= nmax; i += lowbit(i)) {
tree[i] += x;
}
}

相比于修改,求和就要简单很多。只需要利用树状数组维护的是区间和这一特性,把要求和的区间下标以二进制展开就行了。比如我们要求1001 0110的前缀和,那么我们只需要把它拆成以下几个区间:

1
2
3
4
1001 0110 --- 1001 0100
1001 0100 --- 1001 0000
1001 0000 --- 1000 0000
1000 0000 --- 0000 0000

每次拆区间就是减去lowbit,同时要注意到,由于树状数组维护的区间是左开右闭的,所以任何区间都无法包含0,也就是说,树状数组的下标必须从1开始当然要是不嫌麻烦可以每次求和都单独把0加上求和代码如下:

1
2
3
4
5
6
7
int sum(int i) {
int res = 0;
for(; i > 0; i -= lowbit(i)) {
res += tree[i];
}
return res;
}

如果想要求区间[l,r][l,r]的和只需要用sum(r)sum(l1)sum(r)-sum(l-1)就行了。

那么,至此树状数组的基本操作就介绍完了。

UPD:树状数组的树形结构

在经过了一年不眠不休的思考后(,我终于理解了树状数组的树到底是什么样的,对于各种操作的理解在结合了树的结构之后就会变得非常自然。

树状数组之所以被称为线段树青春版,是因为其结构确实跟线段树差不多,回想一下线段树的结构,每次都是将区间左右均分为左儿子和右儿子,树状数组其实也是类似的,只不过多了一些限制条件而已,具体来说,每段区间的长度都是2的次幂,比如1,2,4,8等,这是由树状数组的性质决定的,因为每个节点维护的是(xlowbit(x),x](x-lowbit(x),x]区间,其长度也就是lowbit(x)lowbit(x),所以一个树状数组的结构大概长这样
fenwick1
当然如果真长这样的话树状数组的空间就是O(nlog(n))O(n\log(n)),而不是O(n)O(n)了,跟线段树也就没区别了,因此,树状数组做了一个优化,对于每个节点的左右儿子,只存储左儿子,优化后的树状数组长这样
fenwick2
再标上节点对应的下标后,是不是感觉一下就熟悉了起来。这样还有一个好处,那就是任一节点的下标xx加上lowbit(x)lowbit(x)——也就是所存储的区间长度,或者说那不存在的右边的兄弟的长度,就正好是父节点的下标,这就是为什么在更新时是不断加上lowbit(x)lowbit(x),因为这样就可以不断的找到父节点,从而更新整棵树,这有点类似与zkw线段树的自底向上的更新方式。

例题及扩展功能

单点修改,区间查询

首先是洛谷上的两道模板题:
P3374
这就是最简单的树状数组,单点修改,区间查询,多次询问,直接上模板就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
#define lowbit(x) ((x) & -x)
using namespace std;
int tree[500010];
inline void update(int x, int i, int n)
{
for (; i <= n; i += lowbit(i))
tree[i] += x;
}
inline int sum(int r)
{
int res = 0;
for (; r > 0; r -= lowbit(r))
res += tree[r];
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++) {
int a;
cin >> a;
update(a, i, n);
}
while (m--) {
int op, x, y;
cin >> op >> x >> y;
if (op == 1)
update(y, x, n);
else
cout << sum(y) - sum(x - 1) << '\n';
}
}

(突然发现忘记说怎么初始化树状数组了,只需要在输入的时候update每一位就行了,可以理解为初始有一个所有项都为0的数组,每次输入都相当于是单点修改)


区间修改,单点查询

另一道模板题是P3368

这题就涉及到了树状数组的另一个用法,区间修改,单点查询。这乍一看跟树状数组的基本操作不能说没有关系,只能说是南辕北辙。那么这个时候就需要另一种数据结构了——差分数组。

在差分数组里,对区间的修改只需要修改区间两头的数字,而单点的查询则是求前缀和,这样就跟树状数组的基本操作对应起来了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
#define lowbit(x) ((x) & -x)
using namespace std;
typedef long long ll;
int tree[500010], a[500010];
inline void update(int x, int i, int n)
{
for (; i <= n; i += lowbit(i))
tree[i] += x;
}
inline int sum(int r)
{
int res = 0;
for (; r > 0; r -= lowbit(r))
res += tree[r];
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= n; i++)
update(a[i] - a[i - 1], i, n);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r, k;
cin >> l >> r >> k;
update(k, l, n);
update(-k, r + 1, n);
} else {
int x;
cin >> x;
cout << sum(x) << '\n';
}
}
}

区间修改,区间查询

既然已经有了单点修改、区间查询和区间修改、单点查询,那么有没有区间修改、区间查询呢?答案是有的。

首先,既然要区间修改,那我们显然还是要用差分数组,但是要如何做到区间查询呢?假如我们要求ii的前缀和,那么就需要计算n=1isum(n)\sum_{n=1}^i{sum(n)}的值。如果一个一个的计算太浪费时间,需要寻找一种更为高效的计算方式。

不难发现,用差分数组求ii前缀和的过程中,第一个元素用到了ii次(每一个sum(n)sum(n)都包含第一个元素),第二个元素用到了i1i-1次……以此类推,也就是说上面的公式可以转变为下面这个公式\[ \sum_{n=1}^i{tree[n]\times(i-n+1)} \]稍微变形一下:\[ (i+1)\times \sum_{n=1}^i{tree[n]}-\sum_{n=1}^i{tree[n]\times n} \]假如我们还维护了另一个树状数组treei[n]=tree[n]×ntreei[n]=tree[n]\times n那么这个公式就可以变成这样\[ (i+1)\times\sum_{n=i}^i{tree[n]}-\sum_{n=1}^i{treei[n]} \]也就是两个树状数组的前缀和,这样就把原本O(n×logn)O(n\times\log{n})的逐项求和转变为了O(logn)O(\log{n})的树状数组求前缀和,大大提高了速度。基本代码如下:(找不到例题……)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
const int N = 1e6;
int tree[N], treei[N];
void update(int x, int i) {
for(; i <= N; i += lowbit(i)) {
tree[i] += x;
}
}
int sum(int tree[], int i) {
int res = 0;
for(; i > 0; i -= lowbit(i)) {
res += tree[i];
}
return res;
}
void updatei(int x, int i) {//更新treei
for(; i <= N; i += lowbit(i)) {
treei[i] += i * x; //其实就是把x变成了x*i……
}
}
int query(int l, int r) { //求区间和
int ans = (r + 1) * sum(tree, r) - sum(treei, r);
ans -= (l - 1 + 1) * sum(tree, l - 1) - sum(treei, l - 1);
return ans;
}

区间最值

除了上述的各式各样的区间修改查询以外,树状数组还有一个跟区间有关的功能——维护区间最值。
假如我们需要维护区间的最大值,那么正常的想法或许会是每次更新的时候用最大值更新,也就是这样:

1
2
3
4
5
void update(int x, int i) {
for(; i <= n; i += lowbit(i)) {
tree[i] = max(tree[i], x);
}
}

但这样会带来一个问题,如果一个本来是区间最大值的数被替换成了较小的数,那么对应的树状数组显然无法改变。

如果想要把原数组的变化实时的反映到树状数组中,最为简单暴力的方法就是每次原数组改变的时候都重新建一遍树状数组,也就是把树状数组清零然后重新update一遍,这个方法虽然暴力,但是有效。我们可以在此基础上做一些优化。

首先,显然我们并不需要每一个区间都清零,只需要把包含修改位置的区间清零并重新update废话,同时,由于修改不影响不包含修改位置的区间的正确性,这让我们自然可以想到是否能够将要修改的区间划成几个不包含修改位置的子区间,然后用子区间的值去更新,这样就不需要遍历然后用每一个值去更新。

那么,总结一下优化的方案:

  1. 只修改包含修改位置的区间;
  2. 用小区间的值去更新大区间的值;

方案一其实就是树状数组正常的更新,重点在于方案二。如何将其划分为子区间呢?假设现在要修改的位置i1001 0000,那么第一个要修改的区间为(10000000,10010000](1000 0000,1001 0000]。我们模拟一遍清零并重新update的过程,首先先用修改过的值更新一遍:

1
tree[i] = a[i];

那么接下来就是把区间(10000000,10001111](1000 0000,1000 1111]拆成小区间,结果如下:

1
2
3
4
1000 1111 --- 1000 1110
1000 1110 --- 1000 1100
1000 1100 --- 1000 1000
1000 1000 --- 1000 0000

也就是说,需要用的几个值为1000 1111,1000 1110,1000 1100,1000 1000。规律应该很明显了,分别是原数20-2^021-2^1\cdots2k-2^k,一直到lowbit(i)lowbit(i)(不包含lowbit)。

所以,在维护区间最值的过程中,我们只需要在原本的基础上,用一个数组存储原数组,在每次修改时先修改原数组,再用原数组的值去更新树状数组,最后用一个内循环用小区间的值去更新。由于内循环每次都是乘2,也就是二进制左移一位,复杂度为logn\log{n},故总体复杂度为O((logn)2)O((\log{n})^2)

1
2
3
4
5
6
7
8
9
void update(int x, int i) {
a[i] = x;
for(; i <= n; i += lowbit(i)) {
tree[i] = a[i];
for(int j = 1; j < lowbit(i); j <<= 1) {
tree[i] = max(tree[i], tree[i - j]);
}
}
}

说完了修改,接下来就是查询了。查询当然也不能像普通的树状数组那样查询,如果要查中间的某一段区间[x,y][x,y],也不能方便的直接sum(y)sum(x1)sum(y)-sum(x-1)。我们依然可以采用上面的思路,把需要查询的区间拆成几个小区间。比如要查询[10010000,10010110][1001 0000,1001 0110]的最值,那么第一个区间就是(10010100,10010110](1001 0100,1001 0110]。但这样直接拆又会带来一个问题好烦,如果区间改成[10010101,10010110][1001 0101,1001 0110],那么1001 0110所维护的区间就超过了查询的区间,也就无法保证结果的正确性。

也就是说,在这里我们要稳一手,不能直接拆。如何保证每次拆区间时不越界呢?最简单的方法当然是一次只拆走一个1,但这样就和遍历没有区别了。我们希望,在能拆区间时就拆成区间,只有在迫不得已时才拆走一个1,也就是在ylowbit(y)<xy-lowbit(y)<x的时候就拆走一个1,否则就按正常的区间拆分方式去拆,代码如下:

1
2
3
4
5
6
7
8
9
10
11
int query(int x, int y) {
int ans = 0;
while(x <= y) { //为什么是小于等于呢?别忘了树状数组的区间是左开右闭,所以最后要单独用a[x]的值更新一遍。
ans = max(ans, a[y]);
y--;
for(; y - lowbit(y) >= x; y -= lowbit(y)) {
ans = max(ans, tree[y]);
}
}
return ans;
}

例题P1440(其实这题的标准做法应该是单调队列,但我找不到别的题了……树状数组亲测会t三个点,但不会wa,可以用来检验代码是否有误)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <bits/stdc++.h>
#define lowbit(x) ((x) & -x)
using namespace std;
typedef long long ll;
int n;
int a[2000010], tree[2000010];
void update(int x, int i)
{
for (; i <= n; i += lowbit(i)) {
tree[i] = a[i];
for (int j = 1; j < lowbit(i); j <<= 1) {
tree[i] = min(tree[i], tree[i - j]); // 用子区间数据更新
}
}
}
int query(int x, int y)
{
int ans = INT_MAX;
while (y >= x) {
ans = min(a[y], ans);
y--;
for (; y - lowbit(y) >= x; y -= lowbit(y))
ans = min(tree[y], ans);
} // 把区间x,y划分为小区间
return ans;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int m;
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i];
update(a[i], i);
}
cout << 0 << '\n';
for (int i = 2; i <= n; i++) {
cout << query(max(1, i - m), i - 1) << '\n';
}
}

求逆序对

最后一个用法(指我会的,其他更高级的用法还没学……太菜了😭)是统计逆序对数量,相比于上面几种用法而言,统计逆序对不需要对树状数组的原有操作做修改,而是利用树状数组快速求前缀和的特点,对原本的问题进行了巧妙地转换。

统计逆序对,一种思路是每次在数组后添加新成员时统计前面比它小的数有几个。我们可以用一个数组统计每一个数出现的次数,那么只需要在每一次添加新成员时统计前缀和,就可以算出总共有多少逆序对。同时,由于此方法不关心数据的具体大小而只关心数据之间的大小关系,所以可以采用离散化来减小内存开支。

例题P1908

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
#define lowbit(x) ((x) & -x)
using namespace std;
typedef long long ll;
int a[500010], b[500010], tree[500010], n, m;
void update(int i)
{
for (; i < m; i += lowbit(i))
tree[i]++;
}
ll sum(int i)
{
ll res = 0;
for (; i > 0; i -= lowbit(i))
res += tree[i];
return res;
}
int bs(int x, int l, int r)
{
while (l <= r) {
int mid = (l + r) >> 1;
if (x > b[mid])
l = mid + 1;
else if (x < b[mid])
r = mid - 1;
else
return mid;
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n;
for (int i = 0; i < n; i++) {
cin >> a[i];
b[i] = a[i];
}
sort(b, b + n);
m = unique(b, b + n) - b;
for (int i = 0; i < n; i++) {
a[i] = bs(a[i], 0, m - 1) + 1;
}
ll ans = 0;
for (int i = n - 1; i >= 0; i--) {
update(a[i]);
ans += sum(a[i] - 1);
}
cout << ans << endl;
}

UPD:树状数组上二分

树状数组既然本质是线段树,那当然就可以在树上进行二分,具体来说,即在log(n)log(n)的复杂度下找出第一个前缀和大于某个值的下标,
不过,虽然说是二分,但我个人感觉从倍增的角度理解更容易,就类似于树上倍增求lca一样,我们只要从二进制高位到低位一位一位的枚举,小于等于就加,而由于树状数组的特性,我们可以很容易找出一段长度为2的次幂的区间的和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 不要问我为什么函数叫这名,问就是jiangly这么叫的,而我正好取名困难
int select(T sum)
{
int pos = 0;
T cur{}; // 当前的累计和
// 从高位向低位枚举
for (int i = 1 << std::__lg(n); i > 0; i >>= 1) {
if (pos + i <= n && sum >= cur + tr[pos + i]) { // tr[pos + i]维护的就是从pos+1开始的长度为i的区间
cur += tr[pos + i];
pos |= i;
}
}
return pos + 1; // pos其实是前缀和不大于sum的最大下标,加1就是大于,可根据实际需求返回对应值
}

既然有了二分,那我们就可以干一些更有趣的事,比如用树状数组实现平衡树。
当然,说是平衡树,其实就是模仿权值线段树实现了平衡树的功能而已,因此需要先离散化才能使用,功能比较局限。
洛谷P3369

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <bits/stdc++.h>
#define lowbit(x) ((x) & -(x))
using std::cin, std::cout, std::string;
using ll = long long;
using ull = unsigned long long;
using pii = std::pair<int, int>;
using pll = std::pair<ll, ll>;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
template <typename T>
struct Fenwick {
int n;
std::vector<T> tr;
Fenwick(int len) : tr(len + 2), n{len} {}
void update(int i, T x)
{
for (; i <= n; i += lowbit(i))
tr[i] += x;
}
T query(int i)
{
T res{};
for (; i > 0; i -= lowbit(i))
res += tr[i];
return res;
}
T query(int l, int r)
{
return query(r) - query(l - 1);
}
int select(T sum)
{
int pos = 0;
T cur{};
for (int i = 1 << std::__lg(n); i > 0; i >>= 1) {
if (pos + i <= n && sum >= cur + tr[pos + i]) {
cur += tr[pos + i];
pos |= i;
}
}
return pos;
}
int kth(int k) // 第k个数
{
return select(k - 1) + 1;
}
};
int main()
{
std::ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
std::vector<std::pair<int, int>> ask(n);
std::vector<int> num;
num.reserve(n + 1);
for (auto &[op, x] : ask) {
cin >> op >> x;
if (op != 4)
num.push_back(x);
}
num.push_back(-1e9);
std::ranges::sort(num);
num.erase(std::unique(num.begin(), num.end()), num.end());

Fenwick<int> tr(num.size() + 1);

for (auto [op, x] : ask) {
x = op != 4 ? std::ranges::lower_bound(num, x) - num.begin() : x;
switch (op) {
case 1: // 插入
tr.update(x, 1);
break;
case 2: // 删除
tr.update(x, -1);
break;
case 3: // 排名
std::cout << tr.query(x - 1) + 1 << '\n';
break;
case 4: // 第k个数
std::cout << num[tr.kth(x)] << '\n';
break;
case 5: // 前驱
std::cout << num[tr.kth(tr.query(x - 1))] << '\n';
break;
case 6: // 后继
std::cout << num[tr.kth(tr.query(x) + 1)] << '\n';
break;
}
}
}

结尾

终于肝完了……😵敲了两天,累死我了……


img