【模板】倍增LCA

还是很好理解的做法

建好边之后,从根节点开始$dfs$,处理好每个点的深度和它往上跳$2^k$层的祖先。

1
2
3
4
5
6
7
8
9
10
11
inline void dfs(int u,int fa)
{
depth[u]=depth[fa]+1;
f[u][0]=fa;
for (register int i=1;(1<<i)<=depth[u];++i)
f[u][i]=f[f[u][i-1]][i-1];
for (register int i=head[u];i;i=nxt[i])
if (endd[i]!=fa)
dfs(endd[i],u);
return;
}

这些都处理完了,再就是$lca$了。假设要求的是$x$和$y$两个点的$lca$。先找出$x$和$y$之间深度较大的那个点,把它往上跳,跳到和另一个点深度相同。然后两个点一起往上跳。这个过程使用循环实现。循环变量$i$从$Log[depth[x]]$开始一直循环到$0$。每次判断$f[x][i]$和$f[y][i]$的关系,如果不相等那么就把$x$和$y$跳到$f[x][i]$和$f[y][i]$的位置。最后两者肯定跳到$lca$的下面那个点,于是返回$f[x][0]$。

1
2
3
4
5
6
7
8
9
10
11
12
13
inline int lca(int x,int y)
{
if (depth[x]<depth[y])
swap(x,y);
while (depth[x]>depth[y])
x=f[x][Log[depth[x]-depth[y]]];
if (x==y)
return x;
for (register int i=Log[depth[x]];~i;--i)
if (f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}

这里写的有点烂……画个图会很好理解

全代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define MAXN 500005
#define MAXM 1000005
using namespace std;
int n,m,s,a,b;
int cnt=0,head[MAXN],endd[MAXM],nxt[MAXM];
int depth[MAXN],f[MAXN][20];
int Log[MAXN];

inline void getnum(int &num)
{
num=0;
char c=getchar();
while (!isdigit(c))
c=getchar();
while (isdigit(c))
{
num=(num<<3)+(num<<1)+(c&15);
c=getchar();
}
return;
}

inline void add_edge(int x,int y)
{
++cnt;
endd[cnt]=y,nxt[cnt]=head[x];
head[x]=cnt;
return;
}

inline void dfs(int u,int fa)
{
depth[u]=depth[fa]+1;
f[u][0]=fa;
for (register int i=1;(1<<i)<=depth[u];++i)
f[u][i]=f[f[u][i-1]][i-1];
for (register int i=head[u];i;i=nxt[i])
if (endd[i]!=fa)
dfs(endd[i],u);
return;
}

inline int lca(int x,int y)
{
if (depth[x]<depth[y])
swap(x,y);
while (depth[x]>depth[y])
x=f[x][Log[depth[x]-depth[y]]];
if (x==y)
return x;
for (register int i=Log[depth[x]];~i;--i)
if (f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}

int main()
{
getnum(n),getnum(m),getnum(s);
for (register int i=1;i<n;++i)
{
getnum(a),getnum(b);
add_edge(a,b);
add_edge(b,a);
}
dfs(s,0);
Log[1]=0;
for (register int i=2;i<=n;++i)
Log[i]=Log[i/2]+1;
while (m--)
{
getnum(a),getnum(b);
printf("%d\n",lca(a,b));
}
return 0;
}
-------------本文结束,感谢阅读-------------
0%