Skip to main content

lowest_common_ancestor/
lib.rs

1use std::collections::VecDeque;
2
3use doubling::{Doubling, Transition};
4
5/// 根付き木の LCA です。
6///
7/// # Examples
8/// ```
9/// use lowest_common_ancestor::LowestCommonAncestor;
10///
11/// // 0 -- 2 -- 4
12/// // |    |
13/// // 1    3
14///
15/// let lca = LowestCommonAncestor::new(5, 0, &[(0, 1), (0, 2), (2, 3), (2, 4)]);
16/// assert_eq!(lca.get(0, 1), 0);
17/// assert_eq!(lca.get(0, 4), 0);
18/// assert_eq!(lca.get(1, 1), 1);
19/// assert_eq!(lca.get(1, 2), 0);
20/// assert_eq!(lca.get(2, 3), 2);
21/// assert_eq!(lca.get(3, 4), 2);
22/// ```
23#[derive(Debug, Clone)]
24pub struct LowestCommonAncestor {
25    n: usize,
26    doubling: Doubling<()>,
27    depth: Vec<usize>,
28}
29
30impl LowestCommonAncestor {
31    /// 頂点数 `n`, 根 `root`, 木をなす無向辺の集合 `edges` を渡します。
32    pub fn new(n: usize, root: usize, edges: &[(usize, usize)]) -> Self {
33        assert!(root < n);
34        let mut g = vec![vec![]; n];
35        for &(u, v) in edges {
36            assert!(u < n);
37            assert!(v < n);
38            g[u].push(v);
39            g[v].push(u);
40        }
41
42        let mut depth = vec![0; n];
43        let mut parent = vec![None; n];
44        let mut que = VecDeque::new();
45        depth[root] = 0;
46        parent[root] = None;
47        que.push_back((root, None));
48        while let Some((curr, prev)) = que.pop_front() {
49            for &next in &g[curr] {
50                if prev.is_some_and(|prev| prev == next) {
51                    continue;
52                }
53                depth[next] = depth[curr] + 1;
54                parent[next] = Some(curr);
55                que.push_back((next, Some(curr)));
56            }
57        }
58
59        let sentinel = n;
60        let doubling = Doubling::new(n + 1, (n - 1).max(1), |i| {
61            if i < n {
62                let next = parent[i].unwrap_or(sentinel);
63                Transition::new(next, ())
64            } else {
65                Transition::new(sentinel, ())
66            }
67        });
68
69        Self { n, doubling, depth }
70    }
71
72    /// `u` と `v` の LCA を返します。
73    pub fn get(&self, u: usize, v: usize) -> usize {
74        assert!(u < self.n);
75        assert!(v < self.n);
76
77        if self.n == 1 {
78            assert_eq!(u, 0);
79            assert_eq!(v, 0);
80            return 0;
81        }
82
83        let (u, v) = if self.depth[u] >= self.depth[v] {
84            (u, v)
85        } else {
86            (v, u)
87        };
88        assert!(self.depth[u] >= self.depth[v]);
89
90        let u = self
91            .doubling
92            .fold(u, self.depth[u] - self.depth[v], u, |_, t| t.next);
93
94        assert_eq!(self.depth[u], self.depth[v]);
95
96        if u == v {
97            return u;
98        }
99
100        let (mut u, mut v) = (u, v);
101        let log = self.n.ilog2() as usize + usize::from(!self.n.is_power_of_two());
102        for k in (0..log).rev() {
103            let au = self.doubling.get(u, k).next;
104            let av = self.doubling.get(v, k).next;
105            if au != av {
106                u = au;
107                v = av;
108            }
109        }
110
111        let lca = self.doubling.get(u, 0).next;
112        assert_ne!(lca, self.n);
113
114        lca
115    }
116
117    /// `u` と `v` の距離 (頂点間にある辺の数) を返します。
118    pub fn get_dist(&self, u: usize, v: usize) -> usize {
119        let w = self.get(u, v);
120        self.depth[u] + self.depth[v] - self.depth[w] * 2
121    }
122
123    /// 頂点 `u` の深さを返します。
124    pub fn depth(&self, u: usize) -> usize {
125        self.depth[u]
126    }
127
128    /// 頂点 `u` から根の方向に `k` 本の辺を登って着く頂点を返します。
129    pub fn kth_parent(&self, u: usize, k: usize) -> Option<usize> {
130        assert!(u < self.n);
131        if k > self.depth[u] {
132            return None;
133        }
134
135        let result = self.doubling.fold(u, k, u, |_, t| t.next);
136        // n is sentinel
137        if result == self.n { None } else { Some(result) }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use crate::LowestCommonAncestor;
144
145    #[test]
146    fn single_node_test() {
147        let lca = LowestCommonAncestor::new(1, 0, &[]);
148        assert_eq!(lca.get(0, 0), 0);
149    }
150
151    #[test]
152    fn test_kth_parent() {
153        let lca = LowestCommonAncestor::new(5, 0, &[(0, 1), (1, 2), (2, 3), (3, 4)]);
154
155        assert_eq!(lca.kth_parent(0, 0), Some(0));
156        assert_eq!(lca.kth_parent(0, 1), None);
157
158        assert_eq!(lca.kth_parent(1, 0), Some(1));
159        assert_eq!(lca.kth_parent(1, 1), Some(0));
160        assert_eq!(lca.kth_parent(1, 2), None);
161
162        assert_eq!(lca.kth_parent(2, 0), Some(2));
163        assert_eq!(lca.kth_parent(2, 1), Some(1));
164        assert_eq!(lca.kth_parent(2, 2), Some(0));
165        assert_eq!(lca.kth_parent(2, 3), None);
166
167        assert_eq!(lca.kth_parent(3, 0), Some(3));
168        assert_eq!(lca.kth_parent(3, 1), Some(2));
169        assert_eq!(lca.kth_parent(3, 2), Some(1));
170        assert_eq!(lca.kth_parent(3, 3), Some(0));
171        assert_eq!(lca.kth_parent(3, 4), None);
172
173        assert_eq!(lca.kth_parent(4, 0), Some(4));
174        assert_eq!(lca.kth_parent(4, 1), Some(3));
175        assert_eq!(lca.kth_parent(4, 2), Some(2));
176        assert_eq!(lca.kth_parent(4, 3), Some(1));
177        assert_eq!(lca.kth_parent(4, 4), Some(0));
178        assert_eq!(lca.kth_parent(4, 5), None);
179    }
180}