前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >数据结构--树--线段树(Segment Tree)

数据结构--树--线段树(Segment Tree)

作者头像
Michael阿明
发布2020-07-13 16:19:16
1.1K0
发布2020-07-13 16:19:16
举报

上图 from 熊掌搜索

  • 类似数据结构:树状数组

1. 概念

线段树是一种二叉树,是用来表示一个区间的树:

  • 常常用来查询区间的:和、最小值、最大值
  • 树结点中存放不是普通二叉树的值,其结点结构如下
代码语言:javascript
复制
class TreeNode
{
public:
    int sum;//区间和
    int MAX;//区间最大的
    int MIN;//区间最小的
    int start, end;//区间左右端点
    TreeNode *left, *right;//左右节点
    TreeNode(int s, int e, int v):start(s),end(e),sum(v)
    {
        left = right = NULL;
        MAX = v;
        MIN = v;
    }
};

2. 建树

  • 传入数组,及其左右极限端点
  • 自底向上建树
代码语言:javascript
复制
	TreeNode* build(vector<int>& A, int L, int R)
    {
        if(L > R)
            return NULL;
        TreeNode* rt = new TreeNode(L,R,A[L]);
        if(L == R)
            return rt;
        int mid = L+((R-L)>>1);//对半分开
        rt->left = build(A,L,mid);
        rt->right = build(A,mid+1,R);
        rt->sum = 0;
        if(rt->left)
        {
            rt->sum += rt->left->sum;
            rt->MAX = max(rt->MAX, rt->left->MAX);
            rt->MIN = min(rt->MIN, rt->left->MIN);
        }
        if(rt->right)
        {
            rt->sum += rt->right->sum;
            rt->MAX = max(rt->MAX, rt->right->MAX);
            rt->MIN = min(rt->MIN, rt->right->MIN);
        }
        return rt;
    }

3. 查询

  • 时间复杂度:
O(\log n)
代码语言:javascript
复制
	vector<int> query(TreeNode *rt, int s, int e)//查询区间的sum,min,max
    {
        if(s > rt->end || e < rt->start)
            return {0, INT_MAX, INT_MIN};//没有交集
        if(s <= rt->start && rt->end <= e)
            return {rt->sum, rt->MIN, rt->MAX};//完全包含区间,取其值
        //不完全包含,左右查找
        vector<int> l = query(rt->left, s, e);
        vector<int> r = query(rt->right,s, e);
        //汇总信息
        vector<int> summary(3);
        summary[0] = l[0] + r[0];
        summary[1] = min(l[1], r[1]);
        summary[2] = max(l[2], r[2]);
        return summary;
    }

4. 修改

  • 时间复杂度:
O(\log n)
代码语言:javascript
复制
	void modify(TreeNode *rt, int id, int val)
    {
        if(rt->start == rt->end)
        {	//叶子节点
            rt->sum = val;//和为自身
            rt->MAX = val;
            rt->MIN = val;
            data[id] = val;
            return;
        }
        int mid = (rt->start + rt->end)/2;
        if(id > mid)
            modify(rt->right, id, val);
        else
            modify(rt->left, id, val);
        root->sum = 0;
        if(rt->left)
        {
            rt->sum += rt->left->sum;
            rt->MAX = max(rt->MAX, rt->left->MAX);
            rt->MIN = min(rt->MIN, rt->left->MIN);
        }
        if(rt->right)
        {
            rt->sum += rt->right->sum;
            rt->MAX = max(rt->MAX, rt->right->MAX);
            rt->MIN = min(rt->MIN, rt->right->MIN);
        }
    }

5. 完整代码及测试

代码语言:javascript
复制
/**
 * @description: 线段树
 * @author: michael ming
 * @date: 2020/3/13 0:21
 * @modified by:
 * @Website: https://michael.blog.csdn.net/
 */
#include<vector>
#include<iostream>
#include<climits>
using namespace std;
class TreeNode
{
public:
    int sum;//区间和
    int MAX;//区间最大的
    int MIN;//区间最小的
    int start, end;//区间左右端点
    TreeNode *left, *right;//左右节点
    TreeNode(int s, int e, int v):start(s),end(e),sum(v)
    {
        left = right = NULL;
        MAX = v;
        MIN = v;
    }
};
class SegmentTree
{
public:
    TreeNode* root;
    vector<int> data;
    SegmentTree(vector<int>& A)
    {
        root = build(A, 0, A.size()-1);
        data = A;
    }
    ~SegmentTree()
    {
        destroy(root);
    }

    void destroy(TreeNode* rt)
    {
        if(!rt) return;
        destroy(rt->left);
        destroy(rt->right);
        delete rt;
    }

    TreeNode* build(vector<int>& A, int L, int R)
    {
        if(L > R)
            return NULL;
        TreeNode* rt = new TreeNode(L,R,A[L]);
        if(L == R)
            return rt;
        int mid = L+((R-L)>>1);//对半分开
        rt->left = build(A,L,mid);
        rt->right = build(A,mid+1,R);
        rt->sum = 0;
        if(rt->left)
        {
            rt->sum += rt->left->sum;
            rt->MAX = max(rt->MAX, rt->left->MAX);
            rt->MIN = min(rt->MIN, rt->left->MIN);
        }
        if(rt->right)
        {
            rt->sum += rt->right->sum;
            rt->MAX = max(rt->MAX, rt->right->MAX);
            rt->MIN = min(rt->MIN, rt->right->MIN);
        }
        return rt;
    }

    vector<int> query(TreeNode *rt, int s, int e)//查询区间的sum,min,max
    {
        if(s > rt->end || e < rt->start)
            return {0, INT_MAX, INT_MIN};//没有交集
        if(s <= rt->start && rt->end <= e)
            return {rt->sum, rt->MIN, rt->MAX};//完全包含区间,取其值
        //不完全包含,左右查找
        vector<int> l = query(rt->left, s, e);
        vector<int> r = query(rt->right,s, e);
        //汇总信息
        vector<int> summary(3);
        summary[0] = l[0] + r[0];
        summary[1] = min(l[1], r[1]);
        summary[2] = max(l[2], r[2]);
        return summary;
    }

    void modify(TreeNode *rt, int id, int val)
    {
        if(rt->start == rt->end)
        {	//叶子节点
            rt->sum = val;//和为自身
            rt->MAX = val;
            rt->MIN = val;
            data[id] = val;
            return;
        }
        int mid = (rt->start + rt->end)/2;
        if(id > mid)
            modify(rt->right, id, val);
        else
            modify(rt->left, id, val);
        root->sum = 0;
        if(rt->left)
        {
            rt->sum += rt->left->sum;
            rt->MAX = max(rt->MAX, rt->left->MAX);
            rt->MIN = min(rt->MIN, rt->left->MIN);
        }
        if(rt->right)
        {
            rt->sum += rt->right->sum;
            rt->MAX = max(rt->MAX, rt->right->MAX);
            rt->MIN = min(rt->MIN, rt->right->MIN);
        }
    }
};
//-------------test---------------------
void printVec(vector<int> &a)
{
    for(auto& ai : a)
        cout << ai << " ";
    cout << endl;
}

int main()
{
    vector<int> v = {1,2,7,8,5};
    printVec(v);

    cout << "建立线段树" << endl;
    SegmentTree sgtree(v);
    printVec(sgtree.data);

    cout << "查询区间的sum,MIN,MAX" << endl;
    vector<int> qy_res = sgtree.query(sgtree.root,1,3);
    printVec(qy_res);

    cout << "修改某位置的值" << endl;
    sgtree.modify(sgtree.root,1,100);
    printVec(sgtree.data);

    cout << "查询区间的sum,MIN,MAX" << endl;
    qy_res = sgtree.query(sgtree.root,1,3);
    printVec(qy_res);
    return 0;
}

运行结果:valgrind ./a.out

代码语言:javascript
复制
==16895== Memcheck, a memory error detector
==16895== Copyright (C) 2002-2017, and GNU GPL'd, by Julian Seward et al.
==16895== Using Valgrind-3.14.0 and LibVEX; rerun with -h for copyright info
==16895== Command: ./a.out
==16895== 
1 2 7 8 5 
建立线段树
1 2 7 8 5 
查询区间的sum,MIN,MAX
17 2 8 
修改某位置的值
1 100 7 8 5 
查询区间的sum,MIN,MAX
115 7 100 
==16895== 
==16895== HEAP SUMMARY:
==16895==     in use at exit: 0 bytes in 0 blocks
==16895==   total heap usage: 29 allocs, 29 frees, 616 bytes allocated
==16895== 
==16895== All heap blocks were freed -- no leaks are possible
==16895== 
==16895== For counts of detected and suppressed errors, rerun with: -v
==16895== ERROR SUMMARY: 0 errors from 0 contexts (suppressed: 0 from 0)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-03-13 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 概念
  • 2. 建树
  • 3. 查询
  • 4. 修改
  • 5. 完整代码及测试
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档