题意:给出一棵树 op = 0 求u到v链上的和 op = 1 将u到v链上所有元素开根向下取整
树链剖分模板 + 势能线段树(势能线段树大佬博客)
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
const int MaxN = 1e5 + 5;
typedef long long LL;
struct edge{
int next,to;
}e[2 * MaxN];
struct Node{
LL sum;
int lazy,l,r;
}a[4 * MaxN];
int rt,n,m,r,cnt,head[MaxN],f[MaxN],d[MaxN],size[MaxN],son[MaxN],rk[MaxN],top[MaxN],id[MaxN];
LL v[MaxN];
void add(int x,int y){
e[++cnt].next = head[x];
e[cnt].to = y;
head[x] = cnt;
}
void dfs1(int u,int fa,int depth){ //当前节点、父节点、层次深度
f[u] = fa;
d[u] = depth;
size[u] = 1; //这个点本身size=1
for(int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v == fa) continue;
dfs1(v,u,depth + 1); //层次深度+1
size[u] += size[v]; //子节点的size已被处理,用它来更新父节点的size
if(size[v] > size[son[u]]) son[u] = v;//选取size最大的作为重儿子
}
}
void dfs2(int u,int t){ //当前节点、重链顶端
top[u] = t;
id[u] = ++cnt; //标记dfs序
rk[cnt] = u; //序号cnt对应节点u
if(!son[u]) return;
dfs2(son[u],t);
/*我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续,
一个点和它的重儿子处于同一条重链,所以重儿子所在重链的顶端还是t*/
for(int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v != son[u] && v != f[u]) dfs2(v,v);
//一个点位于轻链底端,那么它的top必然是它本身
}
}
void pushup(int k){
a[k].sum = a[k * 2].sum + a[k * 2 + 1].sum;
}
void build(int k,int l,int r){
a[k].l = l;
a[k].r = r;
if(l == r){
// a[k].sum = 1;
a[k].sum = v[rk[l]];//rk[l]--这个结点id对应的原来标号 v[]题目初始权值
return;
}
int mid = l + r >> 1;
build(k * 2,l,mid);
build(k * 2 + 1,mid + 1,r);
pushup(k);
}
//以x为根节点的所有子树之和 query(id[x],id[x] + size[x] - 1,k = 1);
LL query(int l,int r,int k){
if(a[k].l >= l && a[k].r <= r){
return a[k].sum;
}
// if(a[k].lazy) pushdown(k);
int mid = a[k].l+a[k].r >> 1;
LL tot = 0;
if(mid >= l) tot += query(l,r,k * 2);
if(mid < r) tot += query(l,r,k * 2 + 1);
return tot;
}
//树从x到y最短路径上的和
LL sum(int x,int y){
LL ret = 0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ret += query(id[top[x]],id[x],1);
x = f[top[x]];
}
if(id[x] > id[y]) swap(x,y);
return (ret + query(id[x],id[y],1));
}
//以x为根节点的所有子树+c update(id[x],id[x] + size[x] - 1,c,k = 1);
void update(int l,int r,int k){
if(a[k].l == a[k].r){
// a[k].lazy++;
a[k].sum = sqrt(a[k].sum);
return;
}
// if(a[k].lazy) pushdown(k);
int mid = a[k].l+a[k].r >> 1;
if(mid >= l && a[k * 2].sum > 1) update(l,r,2 * k);
if(mid < r && a[k * 2 + 1].sum > 1) update(l,r,2 * k + 1);
pushup(k);
}
//[x,y]+c updates(x,y,k);
void updates(int x,int y){
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x,y);
update(id[top[x]],id[x],1);//id[top[x]]~id[x]这一条重链+c
x = f[top[x]];
}
if(id[x] > id[y]) swap(x,y);
update(id[x],id[y],1);
}
void chan_p(int k,int l,int r,int x){
if(l == r){
a[k].sum = sqrt(a[k].sum);//1-0
return ;
}
// if(a[k].lazy) pushdown(k);
int mid = a[k].r + a[k].l >> 1;
if(x <= mid) chan_p(k * 2,l,mid,x);
else chan_p(k * 2 + 1,mid + 1,r,x);
pushup(k);
}
int main()
{
cnt = 0;
scanf("%d %d",&n,&m);//n points
for(int i = 1;i <= n; i++) scanf("%lld",&v[i]);
for(int i = 1;i < n; i++){
int x,y;
scanf("%d %d",&x,&y);
add(x,y);
add(y,x);
}
cnt = 0;
dfs1(1,0,1);//进入 dfs(root,0-fa,1-dep);
dfs2(1,1);//dfs2(root,root);
cnt = 0;
build(1,1,n);
for(int i = 1;i <= m; i++){
int op,x,y;
scanf("%d %d %d",&op,&x,&y);
if(op == 1){
printf("%lld\n",sum(x,y));
}
else{//单点修改
updates(x,y);
}
}
}