我们在刷题时经常会遇到和区间有关的题目,最典型的是区间求和,给你一个数组,求出从i到j这一段子区间的所有元素的和。比如leetcode 307.区域检索-数组可修改
这种问题很容易想到要用前缀和来做。
但是,如果我们的数组是变化的呢?
这时候前缀和就不怎么高效了,这时候需要考虑另外2种结构:树状数组和线段树。
线段树是一种非常灵活的数据结构,能够处理各种区间问题,那么,让我们一起来看看线段树到底是怎样的吧。
既然是处理区间,线段树需要将这个数组作为输入,如下图所示,数组转换为线段树
细心的童鞋可能已经发现了,线段树的叶子节点从左到右刚好是原数组。除了叶子节点,所有节点的值都是左右子树的和。根节点的值就是所有叶子节点的和。
规约是一种很常见的解决问题的方法,通过将一个问题转化为另外一个问题,进而能够更加清晰简易的解决。线段树也可以规约为如下问题:
如何将数组构造一个叶子节点是数组,父节点是左右孩子节点和的树?
如果你把LeetCode上树的题目都刷了,那么这个问题相当简单,无非就是后序遍历自底向上创建树。代码如下:
TreeNode * buildTree(vector<int>& nums, int l, int r)
{
if(l == r){
// 叶子节点
return new TreeNode(l, r, nums[l]);
}
// 后序构造
auto left = buildTree(nums, l, (l+r)/2);
auto right = buildTree(nums, (l+r)/2+1, r);
// 非叶子节点
return new TreeNode(l, r, left->v+right->v, left, right);
}
碰到叶子节点就特殊处理,每次进行二分操作,构造完左右孩子再构造自己。注意二分操作一般都是向下取整。
完成第一步构建,我们就掌握了线段树90%的内容,接下来让我们看看更新操作。如下代码:
void updateNode(TreeNode * n, int pos, int val){
if(!n)return;
if(n->l == n->r && n->l == pos){
n->v = val;
return;
}
int m = (n->l+n->r)/2;
// 后序遍历
if(pos <= m){
updateNode(n->left, pos, val);
}else{
updateNode(n->right, pos, val);
}
n->v = n->left->v + n->right->v;
}
更新操作首先是找到元素,然后更新父节点,依然是后序遍历。
然后是查询操作,查询操作要分几种情况,区间ij可能刚好在左右子树的一边,另外一种情况是跨越左右子树,一部分在左子树,一部分在右子树。这时候要拆分求和。如下代码:
int query(TreeNode * n, int l, int r){
if(!n)return 0;
if(n->l == l && n->r == r){
return n->v;
}
int m = (n->l + n->r)/2;
// 左右
if(r <= m){
return query(n->left, l, r);
}else if(l > m){
return query(n->right, l, r);
}
// 跨区间
return query(n->left, l, m) + query(n->right, m+1, r);
}
本篇文章带你了解了什么是线段树,然后重点是如何构造,对于查询和更新操作,也做了简单介绍。线段树还可以增减元素,以及更新区间每个元素的值等等。后面做题时再深入介绍。
这一期让大家久等了,制作过程有点坎坷,动画视频用了将近一周的业余时间才制作完成,主要是因为第一次做树方面的动画,碰到一些不好处理的地方,比如树形结构的位置绘制,另外一点是线段树本身代码量比较多,还有一点是制作过程中产生了一点对计算机理论的困惑,花了一些时间研究了下理论计算机知识,正在入坑ing,后面再和大家聊聊图灵祖师爷的理论计算基础。
LeetCode 307题是一个线段树的模板题,也被划分到中等难度的题目,线段树的其他题目都是难题,可见线段树真的有点难啃。题解源码:
class NumArray {
public:
struct TreeNode{
int l, r, v;
TreeNode * left = nullptr;
TreeNode * right = nullptr;
TreeNode(int l, int r, int v):l(l), r(r), v(v){
}
TreeNode(int l, int r, int v, TreeNode * left, TreeNode * right):l(l), r(r), v(v),
left(left), right(right)
{
}
};
TreeNode * root;
TreeNode * buildTree(vector<int>& nums, int l, int r)
{
if(l == r){
// 叶子节点
return new TreeNode(l, r, nums[l]);
}
// 后序构造
auto left = buildTree(nums, l, (l+r)/2);
auto right = buildTree(nums, (l+r)/2+1, r);
// 非叶子节点
return new TreeNode(l, r, left->v+right->v, left, right);
}
void dumpInternal(TreeNode * n, int d){
if(!n)return;
for(int i=0; i<d; i++)printf("--");
printf("%d(%d %d)\n", n->v, n->l, n->r);
dumpInternal(n->left, d+1);
dumpInternal(n->right, d+1);
}
void updateNode(TreeNode * n, int pos, int val){
if(!n)return;
if(n->l == n->r && n->l == pos){
n->v = val;
return;
}
int m = (n->l+n->r)/2;
// 后序遍历
if(pos <= m){
updateNode(n->left, pos, val);
}else{
updateNode(n->right, pos, val);
}
n->v = n->left->v + n->right->v;
}
int query(TreeNode * n, int l, int r){
if(!n)return 0;
if(n->l == l && n->r == r){
return n->v;
}
int m = (n->l + n->r)/2;
// 左右
if(r <= m){
return query(n->left, l, r);
}else if(l > m){
return query(n->right, l, r);
}
// 跨区间
return query(n->left, l, m) + query(n->right, m+1, r);
}
void dump(TreeNode * n){
// dumpInternal(n, 1);
}
NumArray(vector<int>& nums)
{
root = buildTree(nums, 0, nums.size()-1);
dump(root);
}
void update(int pos, int val)
{
updateNode(root, pos, val);
dump(root);
}
int sumRange(int l, int r)
{
return query(root, l, r);
}
};