前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >线段树模板

线段树模板

作者头像
EmoryHuang
发布2022-10-31 16:34:33
3080
发布2022-10-31 16:34:33
举报
文章被收录于专栏:EmoryHuang's Blog

线段树模板

线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。

线段树可以在

的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树 + Lazy(数组)

代码语言:javascript
复制
class SegmentTree:
    def __init__(self, nums) -> None:
        self.n = len(nums)
        self.nums = nums
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)
        self.build(1, self.n, 1)

    def build(self, start, end, idx):
        # 对 [start, end] 区间建立线段树,当前根的编号为 idx
        if start == end:
            self.tree[idx] = self.nums[start - 1]
            return
        mid = start + ((end - start) >> 1)
        # 递归对左右区间建树
        self.build(start, mid, idx << 1)
        self.build(mid + 1, end, idx << 1 | 1)
        # 合并左右区间的结果
        self.pushup(idx)

    def query(self, start, end, idx, left, right):
        # [s, t] 为当前节点包含的区间, 当前根的编号为 idx
        # 查询 [left, right] 区间的结果

        # 当前区间为询问区间的子集时直接返回当前区间的和
        if left <= start and right >= end:
            return self.tree[idx]
        mid, sum = start + ((end - start) >> 1), 0
        self.pushdown(idx, mid - start + 1, end - mid)
        # 如果询问区间在左区间内,则递归查询左区间
        if left <= mid:
            sum += self.query(start, mid, idx << 1, left, right)
        # 如果询问区间在右区间内,则递归查询右区间
        if right > mid:
            sum += self.query(mid + 1, end, idx << 1 | 1, left, right)
        return sum

    def update(self, start, end, idx, left, right, val):
        # [s, t] 为当前节点包含的区间, 当前根的编号为 idx
        # 更新 [left, right] 区间的结果, 区间加上值 val

        # 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改
        if left <= start and right >= end:
            self.tree[idx] += (end - start + 1) * val
            self.lazy[idx] += val
            return
        mid = start + ((end - start) >> 1)
        self.pushdown(idx, mid - start + 1, end - mid)
        # 如果修改区间在左区间内,则递归更新左区间
        if left <= mid:
            self.update(start, mid, idx << 1, left, right, val)
        # 如果修改区间在右区间内,则递归更新右区间
        if right > mid:
            self.update(mid + 1, end, idx << 1 | 1, left, right, val)
        # 合并左右区间的结果
        self.pushup(idx)

    def pushup(self, idx):
        # 从儿子节点更新当前节点
        self.tree[idx] = self.tree[idx << 1] + self.tree[idx << 1 | 1]

    def pushdown(self, idx, ln, rn):
        # 当前根的编号为 idx, ln, rn 分别表示左右子树的节点数量
        # 从父节点更新当前节点, 下放懒惰标记
        if self.lazy[idx] != 0:
            # 更新当前节点两个子节点的值
            self.tree[idx << 1] += self.lazy[idx] * ln
            self.tree[idx << 1 | 1] += self.lazy[idx] * rn
            # 将标记下传给子节点
            self.lazy[idx << 1] += self.lazy[idx]
            self.lazy[idx << 1 | 1] += self.lazy[idx]
            # 清空当前节点的标记
            self.lazy[idx] = 0

线段树 + Lazy + 动态开点(类)

代码语言:javascript
复制
class SegmentTree:
    class Node:
        def __init__(self):
            self.left = None
            self.right = None
            self.val = 0
            self.lazy = 0

    def __init__(self) -> None:
        self.root = self.Node()

    @staticmethod
    def query(start: int, end: int, node: Node, left: int, right: int) -> int:
        # [s, t] 为当前节点包含的区间, 当前根为 node
        # 查询 [left, right] 区间的结果

        # 当前区间为询问区间的子集时直接返回当前区间的和
        if left <= start and right >= end:
            return node.val
        mid, sum = start + ((end - start) >> 1), 0
        SegmentTree.pushdown(node, mid - start + 1, end - mid)
        if left <= mid:
            sum += SegmentTree.query(start, mid, node.left, left, right)
        if right > mid:
            sum += SegmentTree.query(mid + 1, end, node.right, left, right)
        return sum

    @staticmethod
    def update(start: int, end: int, node: Node, left: int, right: int, val: int) -> None:
        # [s, t] 为当前节点包含的区间, 当前根为 node
        # 更新 [left, right] 区间值为 val

        # 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改
        if left <= start and right >= end:
            node.val += val * (end - start + 1)
            node.lazy += val
            return
        mid = start + ((end - start) >> 1)
        SegmentTree.pushdown(node, mid - start + 1, end - mid)
        if left <= mid:
            SegmentTree.update(start, mid, node.left, left, right, val)
        if right > mid:
            SegmentTree.update(mid + 1, end, node.right, left, right, val)
        SegmentTree.pushup(node)

    @staticmethod
    def pushup(node: Node):
        node.val = node.left.val + node.right.val

    @staticmethod
    def pushdown(node: Node, ln: int, rn: int):
        if node.left is None:
            node.left = SegmentTree.Node()
        if node.right is None:
            node.right = SegmentTree.Node()
        if node.lazy:
            # 更新当前节点两个子节点的值
            node.left.val += node.lazy * ln
            node.right.val += node.lazy * rn
            # 将标记下传给子节点
            node.left.lazy += node.lazy
            node.right.lazy += node.lazy
            # 清空当前节点的标记
            node.lazy = 0

参考资料

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-05-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 线段树模板
    • 线段树 + Lazy(数组)
      • 线段树 + Lazy + 动态开点(类)
        • 参考资料
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档