线段树模板

线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。

线段树可以在O(logN)O(\log{N})的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树 + Lazy(数组)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class SegmentTree:
def __init__(self, nums) -> None:
self.n = len(nums)
self.nums = nums
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n)
self.build(1, self.n, 1)

def build(self, start, end, idx):
# 对 [start, end] 区间建立线段树,当前根的编号为 idx
if start == end:
self.tree[idx] = self.nums[start - 1]
return
mid = start + ((end - start) >> 1)
# 递归对左右区间建树
self.build(start, mid, idx << 1)
self.build(mid + 1, end, idx << 1 | 1)
# 合并左右区间的结果
self.pushup(idx)

def query(self, start, end, idx, left, right):
# [s, t] 为当前节点包含的区间, 当前根的编号为 idx
# 查询 [left, right] 区间的结果

# 当前区间为询问区间的子集时直接返回当前区间的和
if left <= start and right >= end:
return self.tree[idx]
mid, sum = start + ((end - start) >> 1), 0
self.pushdown(idx, mid - start + 1, end - mid)
# 如果询问区间在左区间内,则递归查询左区间
if left <= mid:
sum += self.query(start, mid, idx << 1, left, right)
# 如果询问区间在右区间内,则递归查询右区间
if right > mid:
sum += self.query(mid + 1, end, idx << 1 | 1, left, right)
return sum

def update(self, start, end, idx, left, right, val):
# [s, t] 为当前节点包含的区间, 当前根的编号为 idx
# 更新 [left, right] 区间的结果, 区间加上值 val

# 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改
if left <= start and right >= end:
self.tree[idx] += (end - start + 1) * val
self.lazy[idx] += val
return
mid = start + ((end - start) >> 1)
self.pushdown(idx, mid - start + 1, end - mid)
# 如果修改区间在左区间内,则递归更新左区间
if left <= mid:
self.update(start, mid, idx << 1, left, right, val)
# 如果修改区间在右区间内,则递归更新右区间
if right > mid:
self.update(mid + 1, end, idx << 1 | 1, left, right, val)
# 合并左右区间的结果
self.pushup(idx)

def pushup(self, idx):
# 从儿子节点更新当前节点
self.tree[idx] = self.tree[idx << 1] + self.tree[idx << 1 | 1]

def pushdown(self, idx, ln, rn):
# 当前根的编号为 idx, ln, rn 分别表示左右子树的节点数量
# 从父节点更新当前节点, 下放懒惰标记
if self.lazy[idx] != 0:
# 更新当前节点两个子节点的值
self.tree[idx << 1] += self.lazy[idx] * ln
self.tree[idx << 1 | 1] += self.lazy[idx] * rn
# 将标记下传给子节点
self.lazy[idx << 1] += self.lazy[idx]
self.lazy[idx << 1 | 1] += self.lazy[idx]
# 清空当前节点的标记
self.lazy[idx] = 0

线段树 + Lazy + 动态开点(类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class SegmentTree:
class Node:
def __init__(self):
self.left = None
self.right = None
self.val = 0
self.lazy = 0

def __init__(self) -> None:
self.root = self.Node()

@staticmethod
def query(start: int, end: int, node: Node, left: int, right: int) -> int:
# [s, t] 为当前节点包含的区间, 当前根为 node
# 查询 [left, right] 区间的结果

# 当前区间为询问区间的子集时直接返回当前区间的和
if left <= start and right >= end:
return node.val
mid, sum = start + ((end - start) >> 1), 0
SegmentTree.pushdown(node, mid - start + 1, end - mid)
if left <= mid:
sum += SegmentTree.query(start, mid, node.left, left, right)
if right > mid:
sum += SegmentTree.query(mid + 1, end, node.right, left, right)
return sum

@staticmethod
def update(start: int, end: int, node: Node, left: int, right: int, val: int) -> None:
# [s, t] 为当前节点包含的区间, 当前根为 node
# 更新 [left, right] 区间值为 val

# 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改
if left <= start and right >= end:
node.val += val * (end - start + 1)
node.lazy += val
return
mid = start + ((end - start) >> 1)
SegmentTree.pushdown(node, mid - start + 1, end - mid)
if left <= mid:
SegmentTree.update(start, mid, node.left, left, right, val)
if right > mid:
SegmentTree.update(mid + 1, end, node.right, left, right, val)
SegmentTree.pushup(node)

@staticmethod
def pushup(node: Node):
node.val = node.left.val + node.right.val

@staticmethod
def pushdown(node: Node, ln: int, rn: int):
if node.left is None:
node.left = SegmentTree.Node()
if node.right is None:
node.right = SegmentTree.Node()
if node.lazy:
# 更新当前节点两个子节点的值
node.left.val += node.lazy * ln
node.right.val += node.lazy * rn
# 将标记下传给子节点
node.left.lazy += node.lazy
node.right.lazy += node.lazy
# 清空当前节点的标记
node.lazy = 0

参考资料