(有关线段树的定义来自LintCode网站的相关题目)
线段树是一棵二叉树,他的每个节点包含了两个额外的属性start
和end
用于表示该节点所代表的区间。start和end都是整数,并按照如下的方式赋值:
build
方法所给出。start=A.left, end=(A.left + A.right) / 2
。start=(A.left + A.right) / 2 + 1, end=A.right
。线段树(又称区间树), 是一种高级数据结构,他可以支持这样的一些操作:
样例
比如给定start=1, end=6
,对应的线段树为:
[1, 6]
/ \
[1, 3] [4, 6]
/ \ / \
[1, 2] [3,3] [4, 5] [6,6]
/ \ / \
[1,1] [2,2] [4,4] [5,5]
和之前的树形数据结构类似,我们需要先定义一个代表节点的数据结构(参考了LintCode的题目要求):
class SegmentTreeNode:
def __init__(self, start, end):
self.start, self.end = start, end
self.left, self.right = None, None
线段树的构造其实是一个很简单的递归过程。
class Solution:
# @param start, end: Denote an segment / interval
# @return: The root of Segment Tree
def build(self, start, end):
# write your code here
if start > end: return None
root = SegmentTreeNode(start, end)
root.left = None if start == end else self.build(start, (start + end) // 2)
root.right = None if start == end else self.build((start + end) // 2 + 1, end)
return root
纯粹的线段树并不能应用于太多的实际问题,一般来说线段树的节点除了start和end之外,还会有一个额外的属性值,我们以最大线段树为例,最大线段树的每一个节点还有一个代表区间中最大值的max属性,很显然:
node.max = max(node.left.max, node.right.max)
在构造的时候设置max属性为0就可以了
class SegmentTreeNode:
def __init__(self, start, end):
self.start, self.end, self.max = start, end, 0
self.left, self.right = None, None
class Solution:
def build(self, start, end):
if start > end: return None
root = SegmentTreeNode(start, end)
root.left = None if start == end else self.build(start, (start + end) // 2)
root.right = None if start == end else self.build((start + end) // 2 + 1, end)
root.max = 0
return root
线段树的修改方法modify
,接受三个参数root
、index
和value
。该方法将root为根的线段树中 [start,end] = [index,index] 的节点修改为了新的value,并确保在修改后,线段树的每个节点的max属性仍然具有正确的值。
样例
对于线段树:
[1, 4, max=3]
/ \
[1, 2, max=2] [3, 4, max=3]
/ \ / \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=3]
如果调用 modify(root, 2, 4)
, 返回:
[1, 4, max=4]
/ \
[1, 2, max=4] [3, 4, max=3]
/ \ / \
[1, 1, max=2], [2, 2, max=4], [3, 3, max=0], [4, 4, max=3]
或 调用 modify(root, 4, 0)
, 返回:
[1, 4, max=2]
/ \
[1, 2, max=2] [3, 4, max=0]
/ \ / \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=0]
算法并不是很复杂
代码实现如下:
class Solution:
def modify(self, root, index, value):
if root and root.start <= index <= root.end:
root.max = max(root.max, value)
self.modify(root.left, index, value)
self.modify(root.right, index, value)
其实在修改线段树的时候同时修改了受影响的区间节点,所以查询的时候只要直接找到满足条件的区间节点,返回要查询的属性值就可以了,值得注意的是如果查询的区间跨越了两个区间节点,要分成两个区间进行查询,这里以在最大线段树中查询一个区间的最大值为例:
class Solution:
def query(self, root, start, end):
mid = (root.start + root.end) // 2
if start == root.start and end == root.end:
return root.max
elif end <= mid:
return self.query(root.left, start, end)
elif start >= mid + 1:
return self.query(root.right, start, end)
else:
return max(self.query(root.left, start, mid),
self.query(root.right, mid + 1, end))
例如在如下的线段树中查询[2,3]区间的最大值:
[1, 4, max=4]
/ \
[1, 2, max=4] [3, 4, max=3]
/ \ / \
[1, 1, max=2], [2, 2, max=4], [3, 3, max=0], [4, 4, max=3]
因为[2,3]跨越了[1,4]的两个子区间[1,2]和[3,4],所以在[1,2]中寻找[2,2]的最大值为4,在[3,4]中寻找[3,3]的最大值为0,所以[2,3]的最大值为4.
完整的代码实现如下:
class SegmentTreeNode:
def __init__(self, start, end):
self.start, self.end, self.max = start, end, 0
self.left, self.right = None, None
class Solution:
def build(self, start, end):
if start > end: return None
root = SegmentTreeNode(start, end)
root.left = None if start == end else self.build(start, (start + end) // 2)
root.right = None if start == end else self.build((start + end) // 2 + 1, end)
root.max = 0
return root
def modify(self, root, index, value):
if root and root.start <= index <= root.end:
root.max = max(root.max, value)
self.modify(root.left, index, value)
self.modify(root.right, index, value)
def query(self, root, start, end):
mid = (root.start + root.end) // 2
if start == root.start and end == root.end:
return root.max
elif end <= mid:
return self.query(root.left, start, end)
elif start >= mid + 1:
return self.query(root.right, start, end)
else:
return max(self.query(root.left, start, mid),
self.query(root.right, mid + 1, end))
我们可以测试一下我们的代码:
线段树一般应用在基于区间的多次查询上,例如区间求和,只要修改一下修改和查询时的逻辑就可以了,节点增加一个属性total,修改的时候把包含index的区间total值加上value,查询的时候跨区间查询的话返回两个子区间的total值之和,代码实现如下:
class SegmentTreeNode:
def __init__(self, start, end):
self.start, self.end, self.total = start, end, 0
self.left, self.right = None, None
class Solution:
def build(self, start, end):
if start > end: return None
root = SegmentTreeNode(start, end)
root.left = None if start == end else self.build(start, (start + end) // 2)
root.right = None if start == end else self.build((start + end) // 2 + 1, end)
root.total = 0
return root
def modify(self, root, index, value):
if root and root.start <= index <= root.end:
root.total += value
self.modify(root.left, index, value)
self.modify(root.right, index, value)
def query(self, root, start, end):
mid = (root.start + root.end) // 2
if start == root.start and end == root.end:
return root.total
elif end <= mid:
return self.query(root.left, start, end)
elif start >= mid + 1:
return self.query(root.right, start, end)
else:
return self.query(root.left, start, mid) + self.query(root.right, mid + 1, end)
简单的测试一下我们的代码:
线段树是用于解决在区间内多次查询的很好的方案,实际应用很多,希望大家可以理解并掌握其用法。
最后祝大家享受生活,享受代码。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。