lowest_common_ancestor/
lib.rs

1use std::collections::VecDeque;
2
3/// 根付き木の LCA です。
4///
5/// # Examples
6/// ```
7/// use lowest_common_ancestor::LowestCommonAncestor;
8///
9/// // 0 -- 2 -- 4
10/// // |    |
11/// // 1    3
12///
13/// let lca = LowestCommonAncestor::new(5, 0, &[(0, 1), (0, 2), (2, 3), (2, 4)]);
14/// assert_eq!(lca.get(0, 1), 0);
15/// assert_eq!(lca.get(0, 4), 0);
16/// assert_eq!(lca.get(1, 1), 1);
17/// assert_eq!(lca.get(1, 2), 0);
18/// assert_eq!(lca.get(2, 3), 2);
19/// assert_eq!(lca.get(3, 4), 2);
20/// ```
21pub struct LowestCommonAncestor {
22    n: usize,
23    ancestor: Vec<Vec<usize>>,
24    depth: Vec<usize>,
25}
26
27const ILLEGAL: usize = std::usize::MAX;
28
29impl LowestCommonAncestor {
30    /// 頂点数 `n`, 根 `root`, 木をなす無向辺の集合 `edges` を渡します。
31    pub fn new(n: usize, root: usize, edges: &[(usize, usize)]) -> Self {
32        assert!(root < n);
33        let mut g = vec![vec![]; n];
34        for &(u, v) in edges {
35            g[u].push(v);
36            g[v].push(u);
37        }
38        let mut depth = vec![0; n];
39        let mut parent = vec![ILLEGAL; n];
40        let mut que = VecDeque::new();
41        depth[root] = 0;
42        que.push_back((root, ILLEGAL));
43        while let Some((curr, prev)) = que.pop_front() {
44            for &next in &g[curr] {
45                if next != prev {
46                    depth[next] = depth[curr] + 1;
47                    parent[next] = curr;
48                    que.push_back((next, curr));
49                }
50            }
51        }
52        let table_size = if n == 1 {
53            1
54        } else {
55            // log2(n) の切り上げ
56            n.ilog2() as usize + usize::from(!n.is_power_of_two())
57        };
58        let mut ancestor = vec![vec![ILLEGAL; n]; table_size];
59        ancestor[0] = parent;
60        for i in 1..table_size {
61            ancestor[i] = (0..n)
62                .map(|v| {
63                    if ancestor[i - 1][v] == ILLEGAL {
64                        ILLEGAL
65                    } else {
66                        ancestor[i - 1][ancestor[i - 1][v]]
67                    }
68                })
69                .collect();
70        }
71        Self { n, ancestor, depth }
72    }
73
74    /// `u` と `v` の LCA を返します。
75    pub fn get(&self, u: usize, v: usize) -> usize {
76        assert!(u < self.n);
77        assert!(v < self.n);
78        let (mut u, mut v) = if self.depth[u] >= self.depth[v] {
79            (u, v)
80        } else {
81            (v, u)
82        };
83        assert!(self.depth[u] >= self.depth[v]);
84        let depth_diff = self.depth[u] - self.depth[v];
85        for i in 0..self.ancestor.len() {
86            if depth_diff >> i & 1 == 1 {
87                u = self.ancestor[i][u];
88            }
89        }
90        if u == v {
91            return u;
92        }
93        for i in (0..self.ancestor.len()).rev() {
94            if self.ancestor[i][u] != self.ancestor[i][v] {
95                u = self.ancestor[i][u];
96                v = self.ancestor[i][v];
97            }
98        }
99        let lca = self.ancestor[0][u];
100        assert_ne!(lca, ILLEGAL);
101        lca
102    }
103
104    /// `u` と `v` の距離 (頂点間にある辺の数) を返します。
105    pub fn get_dist(&self, u: usize, v: usize) -> usize {
106        let w = self.get(u, v);
107        self.depth[u] + self.depth[v] - self.depth[w] * 2
108    }
109
110    /// 頂点 `u` の深さを返します。
111    pub fn depth(&self, u: usize) -> usize {
112        self.depth[u]
113    }
114
115    /// 頂点 `u` から根の方向に `k` 本の辺を登って着く頂点を返します。
116    pub fn kth_parent(&self, u: usize, k: usize) -> Option<usize> {
117        assert!(u < self.n);
118        if k >= self.n - 1 {
119            return None;
120        }
121        let mut u = u;
122        for i in 0..self.ancestor.len() {
123            if self.depth[k] >> i & 1 == 1 {
124                u = self.ancestor[i][u];
125                if u == ILLEGAL {
126                    return None;
127                }
128            }
129        }
130        Some(u)
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use crate::LowestCommonAncestor;
137
138    #[test]
139    fn single_node_test() {
140        let lca = LowestCommonAncestor::new(1, 0, &[]);
141        assert_eq!(lca.get(0, 0), 0);
142    }
143}