[奇怪但有用的数据结构]线段树

线段树

(有关线段树的定义来自LintCode网站的相关题目

描述

线段树是一棵二叉树,他的每个节点包含了两个额外的属性startend用于表示该节点所代表的区间。start和end都是整数,并按照如下的方式赋值:

  • 根节点的 startendbuild 方法所给出。
  • 对于节点 A 的左儿子,有 start=A.left, end=(A.left + A.right) / 2
  • 对于节点 A 的右儿子,有 start=(A.left + A.right) / 2 + 1, end=A.right
  • 如果 start 等于 end, 那么该节点是叶子节点,不再有左右儿子。

说明

线段树(又称区间树), 是一种高级数据结构,他可以支持这样的一些操作:

  • 查找给定的点包含在了哪些区间内
  • 查找给定的区间包含了哪些点

样例

比如给定start=1, end=6,对应的线段树为:

               [1,  6]
             /        \
      [1,  3]           [4,  6]
      /     \           /     \
   [1, 2]  [3,3]     [4, 5]   [6,6]
   /    \           /     \
[1,1]   [2,2]     [4,4]   [5,5]

线段树的构造

和之前的树形数据结构类似,我们需要先定义一个代表节点的数据结构(参考了LintCode的题目要求):

class SegmentTreeNode:
    def __init__(self, start, end):
        self.start, self.end = start, end
        self.left, self.right = None, None

线段树的构造其实是一个很简单的递归过程。

class Solution:
    # @param start, end: Denote an segment / interval
    # @return: The root of Segment Tree
    def build(self, start, end):
        # write your code here
        if start > end: return None
        root = SegmentTreeNode(start, end)
        root.left = None if start == end else self.build(start, (start + end) // 2)
        root.right = None if start == end else self.build((start + end) // 2 + 1, end)
        return root
        

线段树的修改

最大线段树

纯粹的线段树并不能应用于太多的实际问题,一般来说线段树的节点除了start和end之外,还会有一个额外的属性值,我们以最大线段树为例,最大线段树的每一个节点还有一个代表区间中最大值的max属性,很显然:

node.max = max(node.left.max, node.right.max)

在构造的时候设置max属性为0就可以了

class SegmentTreeNode:
    def __init__(self, start, end):
        self.start, self.end, self.max = start, end, 0
        self.left, self.right = None, None
class Solution:
    def build(self, start, end):
        if start > end: return None
        root = SegmentTreeNode(start, end)
        root.left = None if start == end else self.build(start, (start + end) // 2)
        root.right = None if start == end else self.build((start + end) // 2 + 1, end)
        root.max = 0
        return root

线段树的修改

线段树的修改方法modify,接受三个参数rootindexvalue。该方法将root为根的线段树中 [start,end] = [index,index] 的节点修改为了新的value,并确保在修改后,线段树的每个节点的max属性仍然具有正确的值。

样例

对于线段树:

                      [1, 4, max=3]
                    /                \
        [1, 2, max=2]                [3, 4, max=3]
       /              \             /             \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=3]

如果调用 modify(root, 2, 4), 返回:

                      [1, 4, max=4]
                    /                \
        [1, 2, max=4]                [3, 4, max=3]
       /              \             /             \
[1, 1, max=2], [2, 2, max=4], [3, 3, max=0], [4, 4, max=3]

调用 modify(root, 4, 0), 返回:

                      [1, 4, max=2]
                    /                \
        [1, 2, max=2]                [3, 4, max=0]
       /              \             /             \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=0]

算法并不是很复杂

  1. 如果root节点为空直接返回
  2. 判断index是否在root区间内
  3. 是的话根据情况修改root区间的max值为value,再对root的左右节点分别进行modify操作,否的话直接返回

代码实现如下:

class Solution:
    def modify(self, root, index, value):
        if root and root.start <= index <= root.end:
            root.max = max(root.max, value)
            self.modify(root.left, index, value)
            self.modify(root.right, index, value)

线段树的查询

其实在修改线段树的时候同时修改了受影响的区间节点,所以查询的时候只要直接找到满足条件的区间节点,返回要查询的属性值就可以了,值得注意的是如果查询的区间跨越了两个区间节点,要分成两个区间进行查询,这里以在最大线段树中查询一个区间的最大值为例:

class Solution:    
     def query(self, root, start, end):
        mid = (root.start + root.end) // 2
        if start == root.start and end == root.end:
            return root.max
        elif end <= mid:
            return self.query(root.left, start, end)
        elif start >= mid + 1:
            return self.query(root.right, start, end)
        else:
            return max(self.query(root.left, start, mid),
                       self.query(root.right, mid + 1, end))

例如在如下的线段树中查询[2,3]区间的最大值:

                      [1, 4, max=4]
                    /                \
        [1, 2, max=4]                [3, 4, max=3]
       /              \             /             \
[1, 1, max=2], [2, 2, max=4], [3, 3, max=0], [4, 4, max=3]

因为[2,3]跨越了[1,4]的两个子区间[1,2]和[3,4],所以在[1,2]中寻找[2,2]的最大值为4,在[3,4]中寻找[3,3]的最大值为0,所以[2,3]的最大值为4.

完整的代码实现如下:

class SegmentTreeNode:
    def __init__(self, start, end):
        self.start, self.end, self.max = start, end, 0
        self.left, self.right = None, None
class Solution:
    def build(self, start, end):
        if start > end: return None
        root = SegmentTreeNode(start, end)
        root.left = None if start == end else self.build(start, (start + end) // 2)
        root.right = None if start == end else self.build((start + end) // 2 + 1, end)
        root.max = 0
        return root
    def modify(self, root, index, value):
        if root and root.start <= index <= root.end:
            root.max = max(root.max, value)
            self.modify(root.left, index, value)
            self.modify(root.right, index, value)
    def query(self, root, start, end):
        mid = (root.start + root.end) // 2
        if start == root.start and end == root.end:
            return root.max
        elif end <= mid:
            return self.query(root.left, start, end)
        elif start >= mid + 1:
            return self.query(root.right, start, end)
        else:
            return max(self.query(root.left, start, mid),
                       self.query(root.right, mid + 1, end))

我们可以测试一下我们的代码:

线段树的其他应用

线段树一般应用在基于区间的多次查询上,例如区间求和,只要修改一下修改和查询时的逻辑就可以了,节点增加一个属性total,修改的时候把包含index的区间total值加上value,查询的时候跨区间查询的话返回两个子区间的total值之和,代码实现如下:

class SegmentTreeNode:
    def __init__(self, start, end):
        self.start, self.end, self.total = start, end, 0
        self.left, self.right = None, None
class Solution:
    def build(self, start, end):
        if start > end: return None
        root = SegmentTreeNode(start, end)
        root.left = None if start == end else self.build(start, (start + end) // 2)
        root.right = None if start == end else self.build((start + end) // 2 + 1, end)
        root.total = 0
        return root
    def modify(self, root, index, value):
        if root and root.start <= index <= root.end:
            root.total += value
            self.modify(root.left, index, value)
            self.modify(root.right, index, value)
    def query(self, root, start, end):
        mid = (root.start + root.end) // 2
        if start == root.start and end == root.end:
            return root.total
        elif end <= mid:
            return self.query(root.left, start, end)
        elif start >= mid + 1:
            return self.query(root.right, start, end)
        else:
            return self.query(root.left, start, mid) + self.query(root.right, mid + 1, end)

简单的测试一下我们的代码:

结语

线段树是用于解决在区间内多次查询的很好的方案,实际应用很多,希望大家可以理解并掌握其用法。

最后祝大家享受生活,享受代码。

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏xx_Cc的学习总结专栏

iOS-RunTime,不再只是听说

2777
来自专栏工科狗和生物喵

【计算机本科补全计划】《C++ Primer》:数组全解!!

正文之前 其实我的《C++ Primer》 已经看到第五章了,但是因为码字比较费时间,所以暂时没有迅速更新实在是对不住,但是没办法, 总不能一天拿出五六个小时来...

34010
来自专栏小樱的经验随笔

COGS 1299. bplusa【听说比a+b还要水的大水题???】

1299. bplusa ☆   输入文件:bplusa.in   输出文件:bplusa.out 评测插件 时间限制:1 s   内存限制:128 MB ...

2708
来自专栏猿人谷

浅谈C/C++中的指针和数组(一)

                                                       浅谈C/C++中的指针和数组(一)       指...

1965
来自专栏deed博客

day02笔记

1252
来自专栏编程

Python的三个问题

第一,以下程序的执行结果是什么? deffoo(a=[]):a.append(1)printafoo()foo() 第二,以下程序的执行结果是什么? deffo...

1759
来自专栏java架构师

【SQL Server】系统学习之三:逻辑查询处理阶段-六段式

一、From阶段 针对连接说明: 1、笛卡尔积 2、on筛选器 插播:unknown=not unknuwn 缺失的值; 筛选器(on where having...

32511
来自专栏窗户

用awk写递归

看到自己很多年前写的一篇帖子,觉得有些意义,转录过来,稍加修改。 awk是一种脚本语言,语法接近C语言,我比较喜欢用,gawk甚至可以支持tcp/ip,用起来非...

2117
来自专栏程序员宝库

精心收集的 48 个 JavaScript 代码片段,仅需 30 秒就可理解

该项目来自于 Github 用户 Chalarangelo,目前已在 Github 上获得了 5000 多Star,精心收集了多达 48 个有用的 JavaSc...

34612
来自专栏小樱的经验随笔

线段树入门总结

线段树的入门级 总结       线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。       对于...

3546

扫码关注云+社区