树链剖分详解

0.模板题在此

点击此处前往洛谷模板题(【模板】树链剖分)

1.处理的问题

  • 将一棵树从$x$节点到$y$节点的最短路径上每个节点权值增加$v$
  • 求一棵树从$x​$节点到$y​$节点的最短路径上每个节点的权值和
  • 将一棵树以$x$节点为根的子树内的每个节点权值增加$v​$
  • 求一棵树以$x​$节点为根的子树内的每个节点的权值和

2.相关概念

  • 重儿子:对于每一个非叶子节点,它的儿子中,以某一节点为根的子树的节点数最大的那个儿子,叫做该节点的重儿子
  • 轻儿子:对于每一个非叶子节点,它的儿子中,非重儿子的其他儿子,叫做该节点的轻儿子
  • 叶子节点没有重儿子也没有轻儿子【因为他没有儿子(ˉ▽ˉ;)…
  • 重边:对于每一个非叶子节点,连接它和它的重儿子的边,叫做重边
  • 轻边:一棵树中,除去重边,剩下的边,叫做轻边
  • 重链:相邻的重边连接起来的,连接一串重儿子的链,叫做重链
    • 对于每一个叶子节点,如果它是轻儿子,则有一条以它为起点的长度为1的重链
    • 每一条重链以轻儿子为起点(根节点可视为轻儿子)

【以上非常抽象,接下来看图】

图解

这样好理解多了吧(十分担忧哪里画错了,如果有请评论告诉我)

3.两个$dfs$

是两个dfs​,不是一个dfs

咳咳……这两个$dfs​$大概是树剖的精髓所在了吧。

$dfs1()$

这一个$dfs$要处理以下几件事情:

  • 记录每一个节点的深度$dep[x]​$
  • 记录每一个节点的父亲$f[x]$
  • 记录每一个非叶子节点的子树大小$size[x]$
  • 记录每一个非叶子节点的重儿子编号$son[x]$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
inline void dfs1(int fa,int x,int depth)
{
dep[x]=depth;//记录每一个节点的深度dep[x]
f[x]=fa;//记录每一个节点的父亲f[x]
size[x]=1;//初始化每一个节点的子树大小size[x],先把自己加上去
int maxs=-1;//maxs记录以重儿子为根的子树大小,此处初始化便于比较更新
for (register int i=head[x];i;i=nxt[i])
{
if (endd[i]==fa)
continue;//如果是父亲就continue
dfs1(x,endd[i],depth+1);//递归dfs其儿子
size[x]+=size[endd[i]];//把以这个儿子为根的子树大小加上
if (size[endd[i]]>maxs)//比较此儿子为根的子树大小与maxs并更新
{
son[x]=endd[i];
maxs=size[endd[i]];
}
}
return;
}

$dfs2()$

这个$dfs$要处理以下几件事情:

  • 赋予每个节点新编号$id[x]$
  • 把每个节点的点权赋值到新编号上$nw[x]$
  • 记录每个节点所在的重链的顶端$top[x]​$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
inline void dfs2(int x,int topf)//topf是当前这条重链的顶端
{
id[x]=++num;//赋予每个节点新编号id[x]
nw[num]=w[x];//把每个节点的点权赋值到新编号上nw[x]
top[x]=topf;//记录每个节点所在的重链的顶端top[x]
if (!son[x])//如果这个节点没有儿子,则返回
return;
dfs2(son[x],topf);//先处理重儿子,把这个儿子接到以topf为开头的重链上
for (register int i=head[x];i;i=nxt[i])//遍历新儿子并处理
{
if (endd[i]==f[x]||endd[i]==son[x])
continue;//如果是父节点或重儿子则continue
dfs2(endd[i],endd[i]);//每一个轻儿子都有一条以它开头的重链
}
return;
}

4.区间处理

预处理完了以后,我们要干正事了。首先我们会发现以下几点:

  • $dfs2$中我们是延一条重链,先处理重儿子再处理轻儿子的,所以每一条重链中,节点的新编号是连续的
  • 因为是$dfs$处理的,所以,每一个子树中的新编号也是连续的

【借助下面这张新编号图更好理解】

图解

(我真的是被我自己的手写数字给丑到了……)

这样有什么用呢?

我们考虑我们要处理的问题:

1.处理任意两节点之间的路径

设两节点之间,所在的重链的顶端更深的点为节点$x$,处理过程如下

  • 处理$x$节点到$x$节点所在的重链的顶端这一段区间
  • 把$x$节点跳到$x$所在的重链的顶端这一节点的父节点

重复执行这一步骤,直到$x$节点与$y$节点在同一条重链上,然后加上这段区间的点权和即可。

在处理这个问题的时候,我们会发现,因为我们前面注意到的点,每一条重链中,节点编号是连续的,所以可以考虑使用线段树进行连续编号计算区间和或处理区间加问题。

将一棵树从$x$节点到$y$节点的最短路径上每个节点权值增加$v$的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
inline void updrange(int x,int y,int v)
{
v%=p;
while (top[x]!=top[y])//当x节点和y节点不在同一条重链时,重复跳跃x点
{
if (dep[top[x]]<dep[top[y]])//“不妨设x点为两点中所在重链顶端深度深的那一个节点”
swap(x,y);
Add(1,1,n,id[top[x]],id[x],v);//线段树的区间加操作,把x点到x点所在重链顶端这一区间进行区间加
x=f[top[x]];//x跳到x所在重链顶端这一节点的父节点
}
if (dep[x]>dep[y])//“不妨设x点为深度较深的那一点”
swap(x,y);
Add(1,1,n,id[x],id[y],v);//此时x点和y点在同一重链上,使用线段树进行区间加
return;
}

求一棵树从$x$节点到$y$节点的最短路径上每个节点的权值和的代码如下,基本同理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
inline int qrange(int x,int y)
{
int ans=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]])
swap(x,y);
ans=(ans+query(1,1,n,id[top[x]],id[x]))%p;
x=f[top[x]];
}
if (dep[x]>dep[y])
swap(x,y);
ans=(ans+query(1,1,n,id[x],id[y]))%p;
return ans;
}

每次的复杂度为$O(log^2n)$

2.处理以任意节点为根的子树

设此节点为$x$点。联想前面发现的,一棵子树中节点编号连续,还是可以使用线段树进行区间处理

将一棵树以$x$节点为根的子树内的每个节点权值增加$v$的代码如下:

1
2
3
4
5
inline void updson(int x,int v)
{
Add(1,1,n,id[x],id[x]+size[x]-1,v%p);//以节点x为根的子树编号区间为[id[x],id[x]+size[x]-1]
return;
}

求一棵树以$x$节点为根的子树内的每个节点的权值和的代码如下,同理:

1
2
3
4
inline int qson(int x)
{
return query(1,1,n,id[x],id[x]+size[x]-1);
}

5.代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define MAXN 100005
#define MAXM 200005
#define MAXT 4*MAXN
typedef long long LL;
using namespace std;

//----------------------------------------------------------------------------------

int n,m,r,p;
int w[MAXN],nw[MAXN],cnt=0,endd[MAXM],nxt[MAXM],head[MAXN];
int dep[MAXN],f[MAXN],size[MAXN],son[MAXN],id[MAXN],num=0,top[MAXN];

//----------------------------------------------------------------------------------

inline void getnum(int &num)
{
num=0;
char c=getchar();
while (!isdigit(c))
c=getchar();
while (isdigit(c))
{
num=(num<<3)+(num<<1)+(c&15);
c=getchar();
}
return;
}

//---------------------------------------------------------------------------------

struct SegTreeNode
{
int val,add_tag;
}segtree[MAXT];

inline void build_seg(int root,int l,int r)
{
segtree[root].add_tag=0;
if (l==r)
segtree[root].val=nw[l]%p;
else
{
int mid=(l+r)>>1;
build_seg(root<<1,l,mid);
build_seg(root<<1|1,mid+1,r);
segtree[root].val=(segtree[root<<1].val%p+segtree[root<<1|1].val%p)%p;
}
return;
}

inline void segtree_add(int root,int l,int r,int v)
{
segtree[root].add_tag+=v;
segtree[root].val+=(r-l+1)*v;
segtree[root].val%=p;
segtree[root].add_tag%=p;
return;
}

inline void pushdown(int root,int l,int r)
{
int mid=(l+r)>>1;
if (segtree[root].add_tag!=0)
{
segtree_add(root<<1,l,mid,segtree[root].add_tag);
segtree_add(root<<1|1,mid+1,r,segtree[root].add_tag);
segtree[root].add_tag=0;
}
return;
}

inline void Add(int root,int l,int r,int x,int y,int v)
{
if (l==x&&y==r)
{
segtree_add(root,l,r,v);
return;
}
int mid=(l+r)>>1;
pushdown(root,l,r);
if (y<=mid)
Add(root<<1,l,mid,x,y,v);
else
if (x>mid)
Add(root<<1|1,mid+1,r,x,y,v);
else
{
Add(root<<1,l,mid,x,mid,v);
Add(root<<1|1,mid+1,r,mid+1,y,v);
}
segtree[root].val=(segtree[root<<1].val+segtree[root<<1|1].val)%p;
return;
}

inline LL query(int root,int l,int r,int ql,int qr)
{
if(l==ql&&r==qr)
return segtree[root].val%p;
pushdown(root,l,r);
int mid=(l+r)>>1;
if (qr<=mid)
return query(root<<1,l,mid,ql,qr)%p;
if (ql>mid)
return query(root<<1|1,mid+1,r,ql,qr)%p;
return (query(root<<1,l,mid,ql,mid)%p+query(root<<1|1,mid+1,r,mid+1,qr)%p)%p;
}

//---------------------------------------------------------------------------------

inline void add_edge(int u,int v)
{
++cnt;
endd[cnt]=v,nxt[cnt]=head[u];
head[u]=cnt;
}

//---------------------------------------------------------------------------------

inline void dfs1(int fa,int x,int depth)
{
dep[x]=depth;
f[x]=fa;
size[x]=1;
int maxs=-1;
for (register int i=head[x];i;i=nxt[i])
{
if (endd[i]==fa)
continue;
dfs1(x,endd[i],depth+1);
size[x]+=size[endd[i]];
if (size[endd[i]]>maxs)
{
son[x]=endd[i];
maxs=size[endd[i]];
}
}
return;
}

inline void dfs2(int x,int topf)
{
id[x]=++num;
nw[num]=w[x];
top[x]=topf;
if (!son[x])
return;
dfs2(son[x],topf);
for (register int i=head[x];i;i=nxt[i])
{
if (endd[i]==f[x]||endd[i]==son[x])
continue;
dfs2(endd[i],endd[i]);
}
return;
}

//---------------------------------------------------------------------------------

inline int qrange(int x,int y)
{
int ans=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]])
swap(x,y);
ans=(ans+query(1,1,n,id[top[x]],id[x]))%p;
x=f[top[x]];
}
if (dep[x]>dep[y])
swap(x,y);
ans=(ans+query(1,1,n,id[x],id[y]))%p;
return ans;
}

inline void updrange(int x,int y,int v)
{
v%=p;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]])
swap(x,y);
Add(1,1,n,id[top[x]],id[x],v);
x=f[top[x]];
}
if (dep[x]>dep[y])
swap(x,y);
Add(1,1,n,id[x],id[y],v);
return;
}

inline int qson(int x)
{
return query(1,1,n,id[x],id[x]+size[x]-1);
}

inline void updson(int x,int v)
{
Add(1,1,n,id[x],id[x]+size[x]-1,v%p);
return;
}

//---------------------------------------------------------------------------------

int main()
{
getnum(n),getnum(m),getnum(r),getnum(p);
for (register int i=1;i<=n;++i)
getnum(w[i]);
for (register int i=1;i<n;++i)
{
int a,b;
getnum(a),getnum(b);
add_edge(a,b);
add_edge(b,a);
}
dfs1(0,r,1);
dfs2(r,r);
build_seg(1,1,n);
while (m--)
{
int ord,x,y,z;
getnum(ord);
if (ord==1)
{
getnum(x),getnum(y),getnum(z);
updrange(x,y,z);
}
else
if (ord==2)
{
getnum(x),getnum(y);
printf("%d\n",qrange(x,y));
}
else
if (ord==3)
{
getnum(x),getnum(z);
updson(x,z);
}
else
if (ord==4)
{
getnum(x);
printf("%d\n",qson(x));
}
}
return 0;
}
-------------本文结束,感谢阅读-------------
0%