首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >P3384 树链剖分 点操作

P3384 树链剖分 点操作

作者头像
用户2965768
发布2019-08-29 10:00:07
3020
发布2019-08-29 10:00:07
举报
文章被收录于专栏:wymwym

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

输入格式

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

输出格式

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模

树状数组和线段树版本

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 100003;

vector<int> adj[N];
int n,m,r,p,cnt;
// 重儿子编号 深度 父结点编号 子树大小 
int son[N],depth[N],fa[N],size[N];
//结点u 第几个被遍历 所在重链顶点 点权值 
int id[N],top[N],w[N];
ll c1[N],c2[N];
//树状数组 
inline int lowbit(int x){
	return x&(-x); 
}
inline void add(int l,int r,int x){
	x%=p;
	int ad1 = (ll)(l-1)*x%p;
	int ad2 = (ll)r*x%p;
	for(int t = l;t<=n;t+=lowbit(t)){
		c1[t] = (c1[t]+x)%p;
		c2[t] = (c2[t]+ad1)%p; 
	}
	for(int t = r+1;t<=n;t+=lowbit(t)){
		c1[t] = (c1[t] - x)%p;
		c1[t] = (c1[t] + p)%p;
		c2[t] = (c2[t] - ad2)%p;
		c2[t] = (c2[t] + p)%p;	
	}
}
inline int qwq(int i){
	int res = 0;
	for(int t=i;t>0;t-=lowbit(t)){
		res = (res+(ll)i*c1[t]%p)%p;
		res = (res - c2[t])%p;
		res = (res+p)%p;
	}
	return res;
}

inline int query(int l,int r){
	int res = (qwq(r) - qwq(l-1))%p;
	return (res+p)%p;
}
//以上是树状数组 
void dfs1(int u,int f){//u 为当前结点 f为父结点 
	fa[u] = f;
	size[u] = 1;//子树大小要算上子树的根结点 也就是u 
	depth[u] = depth[f] + 1;//比父亲深度大1 
	int v,t = -1,l = adj[u].size();
	for(int i=0;i<l;++i){//遍历连接u的点v 
		v = adj[u][i];
		if(v==f)continue;
		dfs1(v,u);
		size[u]+=size[v];
		if(size[v]>t){
		//如果这个子树大小比已找到的还大,那就更新已找到的 
			t = size[v];
			son[u] = v;
		}	
	}
}
void dfs2(int u,int f){// f为u所在重链的顶端 
	top[u] = f;
	id[u] = ++cnt;
	if(w[u]!=0)
		add(id[u],id[u],w[u]);//树状数组维护区间和  
	if(son[u]==0)return ;//重儿子编号为0意味着没有儿子 返回 
	dfs2(son[u],f);//先从重儿子dfs 这样可以使得一条重链上的id连续 
	int v,l = adj[u].size();
	for(int i=0;i<l;++i){
		v = adj[u][i];
		if(v==son[u]||v==fa[u])continue;
		dfs2(v,v);//由于是轻儿子 所以其所在重链顶端结点是自己 
	}
}

int queryPath(int u,int v){
	int res = 0;
	while(top[u]!=top[v]){
		if(depth[top[u]]<depth[top[v]])
			swap(u,v);//深度大的优先跳 保证能跳到一条重链上 
		res = (res + query(id[top[u]],id[u]))%p;
		u = fa[top[u]];
	}
	if(depth[u]>depth[v]) swap(u,v);
	res = (res + query(id[u],id[v]))%p;
	return res; 
}
int addPath(int u,int v,int k){
	k%=p;
	while(top[u]!=top[v]){
		if(depth[top[u]]<depth[top[v]])
			swap(u,v);
		add(id[top[u]],id[u],k);
		u = fa[top[u]];
	}
	if(depth[u]>depth[v]) swap(u,v);
	add(id[u],id[v],k);
}
int querySon(int u){
	//id[u]到id[u]+size[u]-1 的所有子结点id ,下同 
	return query(id[u],id[u]+size[u] - 1);
}
void addSon(int u,int k){
	k%=p;
	add(id[u],id[u]+size[u]-1,k);
}
inline int lca(int u,int v){//求lca
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])
            swap(u,v);
        u = fa[top[u]];    
    }
    if(depth[u]<depth[v]) return u;
    return v;
}
inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

int main(){
    int u,v;
    read(n),read(m),read(r),read(p);
    for(int i=1;i<=n;++i)
        read(w[i]);
    for(int i=1;i<n;++i){
        read(u),read(v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }    
    dfs1(r,0);
    dfs2(r,r);
    int ans,op,x,y,z;
    while(m--){
        read(op),read(x);
        if(op==1){
            read(y),read(z);
            addPath(x,y,z);
            continue;
        }
        if(op==2){
            read(y);
            ans = queryPath(x,y);
            print(ans);
            putchar('\n');
            continue;
        }
        if(op==3){
            read(z);
            addSon(x,z);
            continue;
        }
        ans = querySon(x);
        print(ans);
        putchar('\n');
    }
    return 0;
}
/*by SilverN*/
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<cmath>
#define LL long long
using namespace std;
const int mxn=100010;
int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
//读入优化 
struct edge{int v,nxt;}e[mxn<<1];
int hd[mxn],mct=0;
void add_edge(int u,int v){e[++mct].v=v;e[mct].nxt=hd[u];hd[u]=mct;}
//邻接表 
struct node{
	//结点u的 父亲 重儿子  
    int fa,son;
    int size,dep,top;//子树大小 深度 重链顶点
    int w,e;//编号 对应线段树的结尾 
}tr[mxn];
//树剖结点 
struct segtree{
    LL smm,mk;
}st[mxn<<2];
int sz=0;
//线段树 
int n,M,rt,mod;
int w[mxn];//初始值 
//
void DFS1(int u){
    tr[u].size=1;tr[u].son=0;
    for(int i=hd[u];i;i=e[i].nxt){
        int v=e[i].v;
        if(v==tr[u].fa)continue;
        tr[v].fa=u;
        tr[v].dep=tr[u].dep+1;
        DFS1(v);
        tr[u].size+=tr[v].size;
        if(tr[v].size>tr[tr[u].son].size)
            tr[u].son=v;
    }
    return;
}
void DFS2(int u,int top){//当前点,当前链的顶点 
    tr[u].top=top;
    tr[u].w=++sz;//把树边挂上线段树 
    if(tr[u].son){
        DFS2(tr[u].son,top);//扩展搭建重链 
        for(int i=hd[u];i;i=e[i].nxt){
            int v=e[i].v;
            if(v!=tr[u].fa && v!=tr[u].son)
                DFS2(v,v);//搭建轻链 
        }
    }
    tr[u].e=sz;//当前点对应的线段树结尾
    return; 
}
void update(int L,int R,LL w,int l,int r,int k){//区间加值
     if(L<=l && r<=R){
         st[k].mk+=w;
         st[k].smm+=(r-l+1)*w;
         st[k].smm%=mod;
         return;
    }
    int mid=(l+r)>>1;
    if(st[k].mk){
        st[k<<1].mk+=st[k].mk;
        st[k<<1].smm+=st[k].mk*(mid-l+1);
        st[k<<1].smm%=mod;
        st[k<<1|1].mk+=st[k].mk;
        st[k<<1|1].smm+=st[k].mk*(r-mid);
        st[k<<1|1].smm%=mod;
        st[k].mk=0;
    }
    if(L<=mid)update(L,R,w,l,mid,k<<1);
    if(R>mid)update(L,R,w,mid+1,r,k<<1|1);
    st[k].smm=(st[k<<1].smm+st[k<<1|1].smm)%mod;
    return;
}
int query(int L,int R,int l,int r,int k){
    if(L<=l && r<=R)return st[k].smm;
    int mid=(l+r)>>1;
    if(st[k].mk){
        st[k<<1].mk+=st[k].mk;
        st[k<<1].smm+=st[k].mk*(mid-l+1);
        st[k<<1].smm%=mod;
        st[k<<1|1].mk+=st[k].mk;
        st[k<<1|1].smm+=st[k].mk*(r-mid);
        st[k<<1|1].smm%=mod;
        st[k].mk=0;
    }
    LL res=0;
    if(L<=mid)res=(res+query(L,R,l,mid,k<<1))%mod;
    if(R>mid)res=(res+query(L,R,mid+1,r,k<<1|1))%mod;
    return res%mod;
}
// 表示求树从x到y结点最短路径上所有节点的值之和 
int find(int x,int y){
    int f1=tr[x].top,f2=tr[y].top;
    int ans=0;
    while(f1!=f2){
        if(tr[f1].dep<tr[f2].dep){
            ans+=query(tr[f2].w,tr[y].w,1,n,1);
            y=tr[f2].fa;
            f2=tr[y].top;
        }
        else{
            ans+=query(tr[f1].w,tr[x].w,1,n,1);
            x=tr[f1].fa;
            f1=tr[x].top;
        }
        ans%=mod;
    }
    if(tr[x].dep<tr[y].dep)return ans+query(tr[x].w,tr[y].w,1,n,1);
    return ans+query(tr[y].w,tr[x].w,1,n,1);
}
void add(int x,int y,int k){//x到y的路径加k 
    int f1=tr[x].top,f2=tr[y].top;
    while(f1!=f2){
        if(tr[f1].dep<tr[f2].dep){
            update(tr[f2].w,tr[y].w,k,1,n,1);
            y=tr[f2].fa;
            f2=tr[y].top;
        }
        else{
            update(tr[f1].w,tr[x].w,k,1,n,1);
            x=tr[f1].fa;
            f1=tr[x].top;
        }
    }
    if(tr[x].dep<tr[y].dep) update(tr[x].w,tr[y].w,k,1,n,1);
    else update(tr[y].w,tr[x].w,k,1,n,1);
    return;
}
//
int main(){
    n=read();M=read();rt=read();mod=read();
    int i,j;
    for(i=1;i<=n;i++){w[i]=read();}
    int x,y;
    for(i=1;i<n;i++){
        x=read();y=read();
        add_edge(x,y);
        add_edge(y,x);
    }
    sz=tr[0].size=tr[rt].dep=0;
    //
    DFS1(rt);
    DFS2(rt,rt);
    for(i=1;i<=n;i++)update(tr[i].w,tr[i].w,w[i],1,n,1);
    int op;
    for(i=1;i<=M;i++){
        op=read();x=read();
        switch(op){
            case 4:{
                printf("%d\n",query(tr[x].w,tr[x].e,1,n,1)%mod);
                break;
            }
            case 3:{
                y=read();
                update(tr[x].w,tr[x].e,y,1,n,1);
                break;
            }
            case 2:{
                y=read();
                printf("%d\n",find(x,y)%mod);
                break;
            }
            case 1:{
                y=read();j=read();
                add(x,y,j);
                break;
            }
        }
    }
    return 0;
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年08月27日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 输入格式
  • 输出格式
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档