树状数组学习笔记

前言

树状数组或二叉索引树(Binary Indexed Tree),又以其发明者命名为 Fenwick 树

它可以以 O(logn)O(\log n) 的时间得到任意前缀和 i=1jA[i],1<=j<=N\sum_{i=1}^j A[i], 1 <= j <= N,并同时支持在 O(logn)O(\log n) 时间内支持动态单点值的修改。空间复杂度 O(n)O(n)

使用场景

树状数组可以高效地实现如下两个操作:

  1. 数组前缀和的查询
  2. 单点更新

对于上面两个问题,如果我们不使用任何数据结构,仅依靠定义,「数组前缀和的查询」 的时间复杂度是 O(n)O(n),「单点更新」 的时间复杂度是 O(1)O(1)

利用数组实现前缀和,每次查询前缀和时间复杂度就变成了 O(1)O(1), 但是对于频繁更新的数组,每次重新计算前缀和,时间复杂度为 O(n)O(n)

树状数组简介

树状数组名字虽然又有树,又有数组,但是它实际上物理形式还是数组,不过每个节点的含义是树的关系。

如上图所示,以一个有 8 个元素的数组 A 为例, 在数组 A 之上建立一个数组 T, 数组 T 也就是树状数组。

节点意义

树状数组的下标从 1 开始计数。

在树状数组 T 中,所有的奇数下标的节点的含义是叶子节点,表示单点,它存的值是原数组相同下标存的值。

所有的偶数下标的节点均是父节点。父节点内存的是区间和,这个区间的左边界是该父节点最左边叶子节点对应的下标,右边界就是自己的下标。

索引 i 树状数组 T 来自数组 A 元素的个数
1 T1=A1T1 = A1 1
2 T2=T1+A2=A1+A2T2 = T1 + A2 = A1 + A2 2
3 T3=A3T3 = A3 1
4 T4=T2+T3+A4=A1+A2+A3+A4T4 = T2 + T3 + A4 = A1 + A2 + A3 + A4 4
5 T5=A5T5 = A5 1
6 T6=T5+A6=A5+A6T6 = T5 + A6 =A5 + A6 2
7 T7=A7T7 = A7 1
8 T8=T4+T6+T7+A8=A1+A2+A3+A4+A5+A6+A7+A8T8 = T4 + T6 + T7 + A8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8 8

数组 T 的索引与数组 A 的索引的关系

伟大的计算机科学家注意到上表中标注了「数组 T 中的元素来自数组 A 的元素个数」,它们的规律如下:

将数组 TT 的索引 ii 表示成二进制,从右向左数,遇到 1 则停止,数出 0 的个数记为 kk,则计算 2k2^k 就是「数组 T 中的元素来自数组 A 的个数」

示例

例: 当 i=5i=5 时,计算 kk

分析:因为 5 的二进制表示是 0000 0101,从右边向左边数,第 1 个是 1 ,因此 0 的个数是 0,此时 k=0k=0

因此我们可以得到如下结果:

索引 i i 的二进制表示 k 2^k 树状数组 T
1 0000 0001 0 1 T1=A1T1 = A1
2 0000 0010 1 2 T2=A1+A2T2 = A1 + A2
3 0000 0011 0 1 T3=A3T3 = A3
4 0000 0100 2 4 T4=A1+A2+A3+A4T4 = A1 + A2 + A3 + A4
5 0000 0101 0 1 T5=A5T5 = A5
6 0000 0110 1 2 T6=A5+A6T6 = A5 + A6
7 0000 0111 0 1 T7=A7T7 = A7
8 0000 1000 3 8 T8=A1+A2+A3+A4+A5+A6+A7+A8T8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8

通过 lowbit 高效计算 2^k

1
2
def lowbit(x: int) -> int:
return x & (-x)

通过 lowbit 函数, 我们可以很快计算得到 i 转换成二进制以后,末尾最后一个 1 代表的数值, 即 2k2^k

示例

例: 计算 lowbit(6)

x=6=(00000110)2x = 6 = (00000110)_2

x=x=(11111010)2-x = x_{补} = (11111010)_2

lowbit(x)=(00000110)2&(11111010)2=(00000010)2=2lowbit(x) = (00000110)_2 \& (11111010)_2 = (00000010)_2 = 2

单点更新

树状数组上的父子的下标满足 parent=son+2kparent = son + 2^k 关系,所以可以通过这个公式从叶子结点不断往上递归,直到访问到最大节点值为止,祖先结点最多为 log(n)\log(n) 个。

示例

例: 修改 A[3]A[3]​, 分析对数组 TT​ 产生的变化。

从图中我们可以看出 A[3]A[3] 的父结点以及祖先结点依次是 T[3]T[3]T[4]T[4]T[8]T[8] ,所以修改了 A[3]A[3] 以后 T[3]T[3]T[4]T[4]T[8]T[8] 的值也要修改。

对于 T[3]:3+lowbit(3)=4T[3]: 3 + lowbit(3) = 4, 4 为 T[3]T[3] 父节点 T[4]T[4] 的下标。
对于 T[4]:4+lowbit(4)=8T[4]: 4 + lowbit(4) = 8, 8 为 T[4]T[4] 父节点 T[8]T[8] 的下标。

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def update(index: int, delta: int) -> None:
'''单点更新:从下到上, 最多到 size, 可以取等

Args:
-------
index: int
数组下标
delta: int
变化值 = 更新以后的值 - 原始值
'''
while index <= size:
tree[index] += delta
index += lowbit(index)

def lowbit(x: int) -> int:
return x & (-x)

前缀和查询

树状数组中查询 [1, i] 区间内的和。按照节点的含义,可以得出下面的关系:

query(i)=A1+A2++Ai=A1+A2+Ai2k+Ai2k+1++Ai=A1+A2+Ai2k+Ti=query(i2k)+Ti=query(ilowbit(i))+Ti\begin{aligned} query(i) &= A_1 + A_2 + \cdots + A_i \\ &= A_1 + A_2 + A_{i-2^k} + A_{i-2^k+1} + \cdots + A_i \\ &= A_1 + A_2 + A_{i-2^k} + T_i \\ &= query(i-2^k) + T_i \\ &= query(i-lowbit(i)) + T_i \end{aligned}

ilowbit(i)i - lowbit(i)ii 的二进制中末尾的 1 去掉,最多有 log(i)\log(i) 个 1,所以查询操作最坏的时间复杂度是 O(logn)O(log n)

示例

例: 求前缀和(6)。

从图中我们可以看出 前缀和(6) = T[6] + T[4]

对于 T[6]:6lowbit(6)=4T[6]: 6 - lowbit(6) = 4, 4 为 T[6]T[6] 的上一个非叶子结点 T[4]T[4] 的下标。

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
def query(index: int) -> int:
'''查询前缀和:从上到下,最少到 1,可以取等

Args:
-------
index: int
前缀的最大索引,即查询区间 [0, index] 的所有元素之和
'''
res = 0
while index > 0:
res += tree[index]
index -= lowbit(index)
return res

树状数组的初始化

树状数组的初始化可以通过「单点更新」来实现:

1
2
3
4
5
6
class NumArray:
def __init__(self, nums: List[int]):
self.size = len(nums)
self.tree = [0] * (len(nums) + 1)
for i, num in enumerate(nums, 1):
self.update(i, num)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class FenwickTree:
def __init__(self, nums):
self.size = len(nums)
self.tree = [0] * (len(nums) + 1)
for i, num in enumerate(nums, 1):
self.update(i, num)

def lowbit(self, index):
return index & (-index)

# 单点更新:从下到上,最多到 size,可以取等
def update(self, index, delta):
while index <= self.size:
self.tree[index] += delta
index += self.lowbit(index)

# 区间查询:从上到下,最少到 1,可以取等
def query(self, index):
res = 0
while index > 0:
res += self.tree[index]
index -= self.lowbit(index)
return res

应用

分析: 该题只涉及「单点修改」和「区间求和」,属于「树状数组」的经典应用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class NumArray:
def __init__(self, nums: List[int]):
self.nums = nums
self.size = len(nums)
self.tree = [0] * (len(nums) + 1)
for i, num in enumerate(nums, 1):
self.insert(i, num)

def lowbit(self, x):
return x & (-x)

# 单点更新:从下到上,最多到 size,可以取等
def insert(self, index: int, val: int) -> None:
while index <= self.size:
self.tree[index] += val
index += self.lowbit(index)

# 区间查询:从上到下,最少到 1,可以取等
def query(self, index: int) -> int:
res = 0
while index > 0:
res += self.tree[index]
index -= self.lowbit(index)
return res

def update(self, index: int, val: int) -> None:
x = index + 1
while x <= self.size:
self.tree[x] += val - self.nums[index]
x += self.lowbit(x)
self.nums[index] = val

def sumRange(self, left: int, right: int) -> int:
return self.query(right + 1) - self.query(left)

参考资料