给定一个森林,初始形态是一棵以 1 为根的树,要求进行以下操作。
对于 100\% 的数据,满足 1<=N<=50,000;1<=M<=50,000;1<=a_i<=10^6;1<=D<=100;1<=U,V<=N
对于有 Link 与 Cut 的动态树问题,考虑使用 LCT。
设这条路径是 a_1,a_2,a_3,\dots,a_{sz},其中 sz 是路径的长度。
对于每个点的期望,我们考虑路径上的每个点会被计算几次。
路径上的第 i 个点权值为 a_i,显然当1\leq l\leq i,i\leq r \leq sz时会被计算到。(l,r表示路径上选择的两个点)
那么这个点会被计算到 i\times (sz-i+1) 次,所以产生的贡献就是 i\times (sz-i+1)\times a_i。
故易知期望值就是
Q:为什么是 C_{sz+1}^2 而不是 C_{sz}^2 呢?
A:因为任选两个点可以是相同。
然后让我们来考虑维护这个又臭又长的答案。
首先考虑分母 C_{sz+1}^2,这个特别好维护,只需要维护下每个点的 sz 即可,最后 split 下就好了。
然后我们考虑分子如何维护,也就是说如何从左子树与右子树来得到这个答案。
我们设左子树的路径是 a_1,a_2,\dots,a_{mid-1},路径长度是 mid-1。
然后,当前点的权值是 a_{mid}。
同理,右子树的路径是 a_{mid+1},a_{mid+2},\dots,a_{sz},路径长度是 sz-mid
那么,当前点的应得答案就是 \sum\limits_{i=1}^{sz}i\times(sz-i+1)\times a_i
左子树的答案就是 \sum\limits_{i=1}^{mid-1}i\times(mid-i)\times a_i当前点的贡献是 a_{mid}\times mid\times (sz-mid+1)
右子树的答案就是 \sum\limits_{i=mid+1}^{sz-mid}(i-mid)\times (sz-i+1)\times a_i
然后你就会发现,如果把当前点应得的答案拆分下:
然后将其与左右子树的答案作差可得:
此时,为了更佳的观赏效果,我们重新设一下。
设左子树的路径是 b_1,b_2,\dots,b_{sz_b},其中 sz_b 表示左子树的路径长度。
同理,我们设右子树的路径是 c_1,c_2,\dots,c_{sz_c},其中 sz_c 表示右子树的路径长度。
最后,设我们这个点的权值为 a。
那么其实差值就是:
所以,我们就可以维护下 lsum=\sum\limits_{i=1}^{sz}i\times a_i,rsum=\sum\limits_{i=1}^{sz} (sz-i+1)\times a_i 以快速 pushup 答案。
那么问题来了,我们如何维护 lsum 与 rsum 呢?
很显然,(其中 s 表示所有 a_i 的和)
简单的来说,思路就是从左右儿子继承并注意前面的系数即可。
好了,至此,我们已经完成了如何从左右儿子得到自己,现在我们要考虑操作 3 带来的影响。
为了方便阅读,这里再次给出操作 3 的概括:把x 到 y 的路径上每个点权值加上 d。
那么其实就是:
然后就完事了
#include<bits/stdc++.h>
#define LL long long
using namespace std;
inline int read(){int res=0,f=1;char ch=getchar();while(!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();while(isdigit(ch)) res=(res<<3)+(res<<1)+ch-'0',ch=getchar();return res*f;}
inline void write(LL x){if(x<0) putchar('-'),x=-x;if(x<10) putchar(x+'0');else write(x/10),putchar(x%10+'0');}
struct node{bool tag;LL v,s,ch[2],lsum,rsum,ans,add,sz,fa;}tr[50010];
int n,m,stk[50010],top,op,x,y;
LL a[50010],d,Ans,Div,g;
inline bool isroot(int x){return (!x)(tr[tr[x].fa].ch[0]!=x&&tr[tr[x].fa].ch[1]!=x);}
inline void flip(int x){swap(tr[x].lsum,tr[x].rsum),swap(tr[x].ch[0],tr[x].ch[1]),tr[x].tag^=1;}
inline void add(int x,int v){
tr[x].v+=v,tr[x].add+=v,tr[x].s+=tr[x].sz*v;
tr[x].lsum+=tr[x].sz*(tr[x].sz+1)*v>>1;
tr[x].rsum+=tr[x].sz*(tr[x].sz+1)*v>>1;
tr[x].ans+=tr[x].sz*(tr[x].sz+1)*(tr[x].sz+2)/6*v;
}
inline void pushdown(int x){
if(tr[x].tag){if(tr[x].ch[0]) flip(tr[x].ch[0]);if(tr[x].ch[1]) flip(tr[x].ch[1]);tr[x].tag=0;}
if(tr[x].add){if(tr[x].ch[0]) add(tr[x].ch[0],tr[x].add);if(tr[x].ch[1]) add(tr[x].ch[1],tr[x].add);tr[x].add=0;}
}
inline void pushup(int x){
tr[x].s=tr[tr[x].ch[0]].s+tr[tr[x].ch[1]].s+tr[x].v;
tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+1;
tr[x].lsum=tr[tr[x].ch[0]].lsum+tr[x].v*(tr[tr[x].ch[0]].sz+1)+tr[tr[x].ch[1]].lsum+tr[tr[x].ch[1]].s*(tr[tr[x].ch[0]].sz+1);
tr[x].rsum=tr[tr[x].ch[1]].rsum+tr[x].v*(tr[tr[x].ch[1]].sz+1)+tr[tr[x].ch[0]].rsum+tr[tr[x].ch[0]].s*(tr[tr[x].ch[1]].sz+1);
tr[x].ans=tr[tr[x].ch[0]].ans+tr[tr[x].ch[1]].ans+tr[x].v*(tr[tr[x].ch[0]].sz+1)*(tr[tr[x].ch[1]].sz+1)+tr[tr[x].ch[0]].lsum*(tr[tr[x].ch[1]].sz+1)+tr[tr[x].ch[1]].rsum*(tr[tr[x].ch[0]].sz+1);
}
inline void rotate(int x){
int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,v=tr[x].ch[!k];
if(!isroot(y)) tr[z].ch[tr[z].ch[1]==y]=x;
tr[x].ch[!k]=y,tr[y].ch[k]=v;
if(v) tr[v].fa=y;
tr[y].fa=x,tr[x].fa=z;
pushup(y),pushup(x);
}
inline void splay(int x){
int y=x,z=0;top=0;stk[++top]=y;
while(!isroot(y)) stk[++top]=y=tr[y].fa;
while(top) pushdown(stk[top--]);
while(!isroot(x)){
y=tr[x].fa,z=tr[y].fa;
if(!isroot(y)) rotate((tr[y].ch[0]==x)^(tr[z].ch[0]==y)?x:y);
rotate(x);
}
pushup(x);
}
inline void access(int x){for(int y=0;x;x=tr[y=x].fa) splay(x),tr[x].ch[1]=y,pushup(x);}
inline void makeroot(int x){access(x),splay(x),flip(x);}
inline int findroot(int x){access(x),splay(x);while(tr[x].ch[0]) x=tr[x].ch[0];return x;}
inline void split(int x,int y){makeroot(x),access(y),splay(x);}
inline void link(int x,int y){makeroot(x);makeroot(y);tr[y].fa=x;}
inline void cut(int x,int y){split(x,y);tr[y].fa=tr[x].ch[1]=0;}
inline LL gcd(LL a,LL b){return !b?a:gcd(b,a%b);}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++) tr[i].lsum=tr[i].rsum=tr[i].v=tr[i].s=tr[i].ans=a[i]=read(),tr[i].sz=1;
for(int x,y,i=1;i<n;i++) x=read(),y=read(),link(x,y);
for(int i=1;i<=m;i++){
op=read(),x=read(),y=read();
if(op==1){if(findroot(x)==findroot(y)) cut(x,y);}
else if(op==2){if(findroot(x)!=findroot(y)) link(x,y);}
else if(op==3){d=read();if(findroot(x)==findroot(y)) split(x,y),add(x,d);}
else if(op==4){if(findroot(x)!=findroot(y)){puts("-1");continue ;}split(x,y),Ans=tr[x].ans,Div=tr[x].sz*(tr[x].sz+1)>>1,g=gcd(Ans,Div),write(Ans/g),putchar('/'),write(Div/g),putchar('\n');}
}
}