前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >线段树模板

线段树模板

原创
作者头像
Unlezhou
发布2023-07-25 19:50:39
2540
发布2023-07-25 19:50:39
举报
文章被收录于专栏:Unclezhou's Blog

概述:

线段树是算法竞赛中常用的数据结构(虽然考场中很少用,毕竟调起来麻烦,区间求和用树状树组还是更加方便代码也短)。

线段树可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。简略的描述一下算法思路,线段树是一个二叉树,树的每一个节点存储的都是一个区间内的值(根据具体的题目而定),每个父结点的值由两个子结点的值决定。

但是普通的二分思想并不能体现线段树的精髓所在,线段树的精髓就在于它的懒标记,具体往下看。

算法的实现:

//建议初学者先看无懒标记版,在最下面。

这里以洛谷P3372的区间求和为例

个人习惯:
代码语言:javascript
复制
#define pl tr<<1 //左儿子
#define pr tr<<1|1 //右儿子
建树(build)
代码语言:javascript
复制
struct segmentTree{
	int l,r; //查询的区间范围
	long long sum ,lz; //区间和,懒标记
}t[N<<2];//要开4*N的大小

void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[l];return;} //如果区间内只有一个树,则赋值,返回
	int mid=(l+r)>>1;
	build(l,mid,pl); //建左区间
	build(mid+1,r,pr); //建右区间
	pushup(tr); //关键操作,通过最下层来更新到上层
}
上放(pushup)
代码语言:javascript
复制
void pushup(int tr){
	t[tr].sum=t[pl].sum+t[pr].sum; //由两个子结点的值更新父结点的值
}
下放(pushdown)

懒标记解释:带有懒标记的值是已经处理完成的确认的值。

代码语言:javascript
复制
void pushdown(int tr){
	if(t[tr].lz){
		t[pl].sum+=t[tr].lz*(t[pl].r-t[pl].l+1);//左儿子的值加上懒标记的值*区间内数的个数
		t[pr].sum+=t[tr].lz*(t[pr].r-t[pr].l+1);//右儿子的值加上懒标记的值*区间内树的个数
		t[pl].lz+=t[tr].lz;//懒标记下放
		t[pr].lz+=t[tr].lz;//懒标记下放
		t[tr].lz=0;//将父结点的懒标记清零
	}
}
更新(update)

update中的pushup()是我当时学习该算法时的没理解的一个地方,并不是直接更新每个结点的值,而是最后通过pushup()来更新父结点

代码语言:javascript
复制
void update(int l,int r,int tr,int num){
	if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum+=num*(t[tr].r-t[tr].l+1);t[tr].lz+=num;return;}
	pushdown(tr);//上一行是指如果该区间在查询区间内,则更新该区间值即懒标记,并且返回。(因为有懒标记),如果不包含则懒标记下放
	int mid=(t[tr].l+t[tr].r)>>1;//二分
	if(l<=mid) update(l,r,pl,num); //如果左儿子一部分在查询区间内,更新左儿子
	if(mid<r) update(l,r,pr,num); //如果右儿子一部分在查询区间内,更新右儿子
	pushup(tr);//关键的一步
}
查询(query)
代码语言:javascript
复制
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum; 
	pushdown(tr);
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) ans+=query(l,r,pl);
	if(mid<r) ans+=query(l,r,pr);
	return ans;
}

例题与示例程序:

1.区间求和

洛谷P3372

代码语言:javascript
复制
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
#define pl tr<<1
#define pr tr<<1|1
using namespace std;
const int N=1e5+10;
int n,m,a[100010],x,y,k,q;
struct segmentTree{
	int l,r,lz;
	long long sum;
}t[N<<2];
void pushup(int tr){
	t[tr].sum=t[pl].sum+t[pr].sum;
}
void pushdown(int tr){
	if(t[tr].lz){
		t[pl].sum+=t[tr].lz*(t[pl].r-t[pl].l+1);
		t[pr].sum+=t[tr].lz*(t[pr].r-t[pr].l+1);
		t[pl].lz+=t[tr].lz;
		t[pr].lz+=t[tr].lz;
		t[tr].lz=0;
	}
}
void build(int l,int r,int tr){
	t[tr].l=l,t[tr].r=r;
	if(l==r){t[tr].sum=a[r];return;}
	int mid=(l+r)>>1;
	build(l,mid,pl);
	build(mid+1,r,pr);
	pushup(tr);
}
void update(int l,int r,int tr,int num){
	if(l<=t[tr].l&&t[tr].r<=r) {t[tr].sum+=num*(t[tr].r-t[tr].l+1);t[tr].lz+=num;return;}
	pushdown(tr);
	int mid=(t[tr].r+t[tr].l)>>1;
	if(l<=mid)update(l,r,pl,num);
	if(mid<r)update(l,r,pr,num);
	pushup(tr);
}
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
	pushdown(tr);
	int mid=(t[tr].r+t[tr].l)>>1;
	if(l<=mid) ans+=query(l,r,pl);
	if(mid<r) ans+=query(l,r,pr);
	return ans;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
		build(1,n,1);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&q,&x,&y);
		if(q==1){
			scanf("%d",&k);
			update(x,y,1,k);
		}
		else{
			printf("%lld\n",query(x,y,1));
		}
	}
    return 0;
}
2.区间求乘积

洛谷P3373

代码语言:javascript
复制
#include <iostream>
#include <stdio.h>
#include <algorithm>
#define pl tr<<1
#define pr tr<<1|1

using namespace std;
const int N=1e5+10;

int n,m,p,x,y,k,q;
int a[N];
struct segmentTree{
	int l,r;
	long long sum,add=0,mul=1;//add=加,mul=乘
}t[N<<2];
void pushup(int tr){
	t[tr].sum=(t[pl].sum+t[pr].sum)%p;
}
void pushdown(int tr){
	t[pl].sum=(t[tr].add*(t[pl].r-t[pl].l+1)%p+(t[pl].sum*t[tr].mul)%p)%p;
	t[pr].sum=(t[tr].add*(t[pr].r-t[pr].l+1)%p+(t[pr].sum*t[tr].mul)%p)%p;
	t[pl].add=(t[tr].mul*t[pl].add%p+t[tr].add)%p;
	t[pl].mul=t[tr].mul*t[pl].mul%p;
	t[pr].add=(t[tr].mul*t[pr].add%p+t[tr].add)%p;
	t[pr].mul=t[tr].mul*t[pr].mul%p;
	t[tr].add=0;t[tr].mul=1;
}
void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[l];return;}
	else{
		int mid=(l+r)>>1;
		build(l,mid,pl);
		build(mid+1,r,pr);
		pushup(tr);
	}
}
void update1(int l,int r,int tr,int k){//add
	if(l<=t[tr].l&&t[tr].r<=r){
		t[tr].sum=(t[tr].sum+k*(t[tr].r-t[tr].l+1)%p)%p;
		t[tr].add=(t[tr].add+k%p)%p;
		return;
	}
	pushdown(tr);
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) update1(l,r,pl,k);
	if(mid<r) update1(l,r,pr,k);
	pushup(tr);
}
void update2(int l,int r,int tr,int k){//mul
	if(l<=t[tr].l&&t[tr].r<=r){
		t[tr].sum=(t[tr].sum*k)%p;
		t[tr].add=(t[tr].add*k)%p;
		t[tr].mul=(t[tr].mul*k)%p;
		return;
	}
	pushdown(tr);
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) update2(l,r,pl,k);
	if(mid<r) update2(l,r,pr,k);
	pushup(tr);
}
long long query(int l,int r,int tr){
	long long ans=0;
	if(l<=t[tr].l&&t[tr].r<=r) return t[tr].sum;
	int mid=(t[tr].l+t[tr].r)>>1;
	pushdown(tr);
	if(l<=mid) ans+=query(l,r,pl);
	if(mid<r) ans+=query(l,r,pr);
	return ans%p;
}
int main(){
	scanf("%d%d%d",&n,&m,&p);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	build(1,n,1);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&q,&x,&y);
		if(q==1){
			scanf("%d",&k);
			update2(x,y,1,k);
		}
		else if(q==2){
			scanf("%d",&k);
			update1(x,y,1,k);
		}
		else {
			printf("%lld\n",query(x,y,1));
		}
	}
    return 0;
}

无懒标记版本:

代码语言:c++
复制
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>

using namespace std;
const int N=1e5+10;
int n,m,q,x,y,k;
int a[N];
struct segmenttree{
	int l,r,sum;
}t[N<<2];
void pushup(int tr){
	t[tr].sum=t[tr<<1].sum+t[tr<<1|1].sum;
}
void build(int l,int r,int tr){
	t[tr].l=l;t[tr].r=r;
	if(l==r) {t[tr].sum=a[r];return;}
	int mid=(l+r)>>1;
	build(l,mid,tr<<1);
	build(mid+1,r,tr<<1|1);
	pushup(tr);
}
void update(int l,int r,int tr,int num){
	int mid=(t[tr].l+t[tr].r)>>1;
	if(t[tr].l==t[tr].r) {
		t[tr].sum+=num;return;
	}
	if(l<=mid)update(l,r,tr<<1,num);
	if(mid<r)update(l,r,tr<<1|1,num);
	pushup(tr);
}
int query(int l,int r,int tr){
	int ans=0;
	if(t[tr].l>=l&&t[tr].r<=r) {return t[tr].sum;}
	int mid=(t[tr].l+t[tr].r)>>1;
	if(l<=mid) ans+=query(l,r,tr<<1);
	if(mid<r) ans+=query(l,r,tr<<1|1);
	return ans;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	build(1,n,1);
	for(int i=1;i<=m;i++){
		scanf("%d",&q);
		if(q==1){
			scanf("%d%d%d",&x,&y,&k);
			update(x,y,1,k);
		}
		else {
			scanf("%d%d",&x,&y);
			cout<<query(x,y,1)<<endl;
		}
	}
}

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 概述:
  • 算法的实现:
    • 个人习惯:
      • 建树(build)
        • 上放(pushup)
          • 下放(pushdown)
            • 更新(update)
              • 查询(query)
              • 例题与示例程序:
                • 1.区间求和
                  • 2.区间求乘积
                  • 无懒标记版本:
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档