基本概念
  1. 最近公共祖先:在一棵有根树中,一对结点\((x, y)\)的最近公共祖先(\(LCA\),\(Lowest \ Common \ Ancestor\))定义为它们深度最大的\(x, y\)共同的祖先结点,即距离\(x\)和\(y\)最近的公共祖先结点。
算法思路

原始算法

显然,我们可以从结点\(x\)开始向上回溯,直到遇到根结点,并将沿途经过的结点都打上标记。接着,我们再从结点\(y\)开始向上回溯,遇到的第一个被打过标记的结点就是\(x\)和\(y\)的最近公共祖先。

时间复杂度\(O(n)\)。

简单优化

分别从\(x\)和\(y\)开始回溯显然会浪费很多时间。我们可以这样优化:事先预处理好每个结点的深度。假设\(x\)是给出的顶点中深度较大的顶点,那么,我们从结点\(x\)开始不断向上回溯,直到其深度与给出的另一个结点\(y\)相同为止。

接着,我们令\(x\)和\(y\)一起向上回溯。当\(x\)和\(y\)第一次相等的时候,说明此时已经回溯到了\(x\)和\(y\)的最近公共祖先,直接返回即可。

时间复杂度\(O(n)\),常数较原始算法有所缩小。

倍增优化

单词询问\(O(n)\)的时间复杂度仍然不可接受,考虑单次询问\(O(logn)\)的做法。

我们在原来的算法中,每次都只向上回溯单个结点,这也就是时间复杂度的瓶颈所在。要想时间复杂度优化到对数阶,必须优化回溯的步数。通过\(log\)可知, 可以考虑用倍增的思想优化。

假设我们尝试向上回溯\(2^i\)步。如果回溯后会访问到深度比根结点更小的结点,也就是溢出,那么考虑回溯\(2^{i - 1}\)步,再不行就回溯\(2^{i - 2}\)步……否则回溯\(2^i\)步。从大到小枚举\(i\),一定可以保证正确性(可以用二进制表示任何数),并且每次走的步数一定是可行范围内最大的步数。由此,我们在对齐\(x, y\)和回溯的时候可以用倍增的思想优化,从而将单次询问的时间复杂度优化到\(O(logn)\)。

使用倍增优化需要预处理好两个信息:结点的深度和结点祖先。设\(d_i\)为结点\(i\)的深度,显然\(d_i = d_{f_i} + 1\)。设\(f_{i, j}\)为结点\(i\)的第\(2^j\)辈祖先的编号,则\(f_{i, j} = f_{f_{i, j - 1, j - 1}}\)(先回溯\(2^{j - 1}\)步,到达第\(2^{j - 1}\)辈祖先,再回溯\(2^{j - 1}\)步,相当于回溯\(2 \times 2^{j - 1} = 2^j\)步,到达第\(2^j\)辈祖先)。

预处理时间复杂度为\(O(nlogn)\),单次查询时间复杂度为\(O(logn)\)。假如题目仅查询一次,则非倍增优化的算法更优;否则,倍增优化可以大大提高时间效率。

模板

​题目链接​

参考代码如下:

#include <cstdio>
using namespace std;
#define maxn 500005
#define maxm 1000005

struct node
{
int to, nxt;
}edge[maxm];

int n, m, s, cnt;
int head[maxn], f[maxn][40], d[maxn];

void add_edge(int u, int v)
{
cnt++;
edge[cnt].to = v;
edge[cnt].nxt = head[u];
head[u] = cnt;
}

void dfs(int u)
{
d[u] = d[f[u][0]] + 1;
for (int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].to;
if (v != f[u][0])
{
f[v][0] = u;
dfs(v);
}
}
}

int lca(int x, int y)
{
if (d[x] < d[y])
{
int temp = x;
x = y;
y = temp;
}
int k = 0;
while (1 << (k + 1) <= d[x])
k++;
for (int i = k; i >= 0; i--)
if (d[f[x][i]] >= d[y])
x = f[x][i];
if (x == y)
return x;
for (int i = k; i >= 0; i--)
{
if (f[x][i] != f[y][i])
{
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}

int main()
{
int u, v;
scanf("%d%d%d", &n, &m, &s);
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d", &u, &v);
add_edge(u, v);
add_edge(v, u);
}
dfs(s);
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i <= n; i++)
f[i][j] = f[f[i][j - 1]][j - 1];
for (int i = 1; i <= m; i++)
{
scanf("%d%d", &u, &v);
printf("%d\n", lca(u, v));
}
return 0;
}


例题选讲

树上距离

​例题链接​

给定树上的两个结点\(x, y\),试求\(x\)到\(y\)的路径长度。

设\(d_{i}\)为从根结点到结点\(i\)的路径长度。设\(k\)为\(lca(x, y)\)。显然,从\(x\)到\(y\)的路径长度为\(d_{x} + d_{y} - 2 \times d_{k}\)。如下图,\(d_{x} + d_{y} - 2 \times d_{k} = (A + C) + (B + C) - 2 \times C = A + B\),也就是\(x\)到\(y\)的距离。

最近公共祖先_i++

参考代码如下:

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define maxn 40005
#define maxm 80005

struct node
{
int to, nxt, w;
}edge[maxm];

int t, n, q, cnt;
int head[maxn], d[maxn], dis[maxn], f[maxn][20];

void init()
{
memset(head, 0, sizeof(head));
memset(d, 0, sizeof(d));
memset(dis, 0, sizeof(dis));
memset(f, 0, sizeof(f));
cnt = 0;
}

void add_edge(int u, int v, int w)
{
cnt++;
edge[cnt].to = v;
edge[cnt].w = w;
edge[cnt].nxt = head[u];
head[u] = cnt;
}

void dfs(int u)
{
d[u] = d[f[u][0]] + 1;
for (int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].to;
if (v != f[u][0])
{
f[v][0] = u;
dis[v] = dis[u] + edge[i].w;
dfs(v);
}
}
}

int lca(int x, int y)
{
if (d[x] < d[y])
swap(x, y);
int k = 0;
while ((1 << (k + 1)) <= n)
k++;
for (int i = k; i >= 0; i--)
if (d[f[x][i]] >= d[y])
x = f[x][i];
if (x == y)
return x;
for (int i = k; i >= 0; i--)
{
if (f[x][i] != f[y][i])
{
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}

int main()
{
int u, v, w, x;
scanf("%d", &t);
while (t--)
{
init();
scanf("%d%d", &n, &q);
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d%d", &u, &v, &w);
add_edge(u, v, w);
add_edge(v, u, w);
}
dfs(1);
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i <= n; i++)
f[i][j] = f[f[i][j - 1]][j - 1];
for (int i = 1; i <= q; i++)
{
scanf("%d%d", &u, &v);
x = lca(u, v);
printf("%d\n", dis[u] + dis[v] - 2 * dis[x]);
}
}
return 0;
}


树上三点最短距离

​例题链接​

给定一棵有根树 \(T\) ,试在 \(T\) 中选出一个结点 \(k\) ,使得 \(k\) 到树上的三个点 \(x, y, z\) 的距离之和最小。

显然,设 \(a = lca(x, y), b = lca(y, z), c = lca(x, z)\),则 \(a, b, c\) 中一定有两个值相等,且最终结构一定如下图(结点不一定是父子关系):

最近公共祖先_图论_02

易知 我们一定会选择 \(a\) 点,设 \(d_{i}\) 为结点 \(i\) 到根结点的距离,则最终答案为 \(d_x + d_y + d_z - d_a - 2 \times d_b\) 。编号与另外两个结点不相同的结点即为我们选择的结点。 具体读者自证不难。

参考代码如下:

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define maxn 50005
#define maxm 100005

struct node
{
int to, nxt, w;
}edge[maxm];

int n, q, cnt;
int head[maxn], d[maxn], f[maxn][20];
long long dis[maxn];

void init()
{
memset(head, 0, sizeof(head));
memset(d, 0, sizeof(d));
memset(f, 0, sizeof(f));
memset(dis, 0, sizeof(dis));
cnt = 0;
}

void add_edge(int u, int v, int w)
{
cnt++;
edge[cnt].to = v;
edge[cnt].nxt = head[u];
edge[cnt].w = w;
head[u] = cnt;
}

void dfs(int u)
{
d[u] = d[f[u][0]] + 1;
for (int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].to;
if (v != f[u][0])
{
f[v][0] = u;
dis[v] = dis[u] + edge[i].w;
dfs(v);
}
}
}

int lca(int x, int y)
{
if (d[x] < d[y])
swap(x, y);
int k = 0;
while ((1 << (k + 1)) <= n)
k++;
for (int i = k; i >= 0; i--)
if (d[f[x][i]] >= d[y])
x = f[x][i];
if (x == y)
return x;
for (int i = k; i >= 0; i--)
{
if (f[x][i] != f[y][i])
{
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}

int main()
{
int u, v, w;
int x, y, z, a, b, c;
while (scanf("%d", &n) != EOF)
{
init();
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d%d", &u, &v, &w);
u++, v++;
add_edge(u, v, w);
add_edge(v, u, w);
}
dfs(1);
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i <= n; i++)
f[i][j] = f[f[i][j - 1]][j - 1];
scanf("%d", &q);
for (int i = 1; i <= q; i++)
{
scanf("%d%d%d", &x, &y, &z);
x++, y++, z++;
a = lca(x, y);
b = lca(y, z);
c = lca(x, z);
printf("%lld\n", dis[x] + dis[y] + dis[z] - dis[a] - dis[b] - dis[c]);
}
printf("\n");
}
return 0;
}