1use std::collections::VecDeque;
2
3use doubling::{Doubling, Transition};
4
5#[derive(Debug, Clone)]
24pub struct LowestCommonAncestor {
25 n: usize,
26 doubling: Doubling<()>,
27 depth: Vec<usize>,
28}
29
30impl LowestCommonAncestor {
31 pub fn new(n: usize, root: usize, edges: &[(usize, usize)]) -> Self {
33 assert!(root < n);
34 let mut g = vec![vec![]; n];
35 for &(u, v) in edges {
36 assert!(u < n);
37 assert!(v < n);
38 g[u].push(v);
39 g[v].push(u);
40 }
41
42 let mut depth = vec![0; n];
43 let mut parent = vec![None; n];
44 let mut que = VecDeque::new();
45 depth[root] = 0;
46 parent[root] = None;
47 que.push_back((root, None));
48 while let Some((curr, prev)) = que.pop_front() {
49 for &next in &g[curr] {
50 if prev.is_some_and(|prev| prev == next) {
51 continue;
52 }
53 depth[next] = depth[curr] + 1;
54 parent[next] = Some(curr);
55 que.push_back((next, Some(curr)));
56 }
57 }
58
59 let sentinel = n;
60 let doubling = Doubling::new(n + 1, (n - 1).max(1), |i| {
61 if i < n {
62 let next = parent[i].unwrap_or(sentinel);
63 Transition::new(next, ())
64 } else {
65 Transition::new(sentinel, ())
66 }
67 });
68
69 Self { n, doubling, depth }
70 }
71
72 pub fn get(&self, u: usize, v: usize) -> usize {
74 assert!(u < self.n);
75 assert!(v < self.n);
76
77 if self.n == 1 {
78 assert_eq!(u, 0);
79 assert_eq!(v, 0);
80 return 0;
81 }
82
83 let (u, v) = if self.depth[u] >= self.depth[v] {
84 (u, v)
85 } else {
86 (v, u)
87 };
88 assert!(self.depth[u] >= self.depth[v]);
89
90 let u = self
91 .doubling
92 .fold(u, self.depth[u] - self.depth[v], u, |_, t| t.next);
93
94 assert_eq!(self.depth[u], self.depth[v]);
95
96 if u == v {
97 return u;
98 }
99
100 let (mut u, mut v) = (u, v);
101 let log = self.n.ilog2() as usize + usize::from(!self.n.is_power_of_two());
102 for k in (0..log).rev() {
103 let au = self.doubling.get(u, k).next;
104 let av = self.doubling.get(v, k).next;
105 if au != av {
106 u = au;
107 v = av;
108 }
109 }
110
111 let lca = self.doubling.get(u, 0).next;
112 assert_ne!(lca, self.n);
113
114 lca
115 }
116
117 pub fn get_dist(&self, u: usize, v: usize) -> usize {
119 let w = self.get(u, v);
120 self.depth[u] + self.depth[v] - self.depth[w] * 2
121 }
122
123 pub fn depth(&self, u: usize) -> usize {
125 self.depth[u]
126 }
127
128 pub fn kth_parent(&self, u: usize, k: usize) -> Option<usize> {
130 assert!(u < self.n);
131 if k > self.depth[u] {
132 return None;
133 }
134
135 let result = self.doubling.fold(u, k, u, |_, t| t.next);
136 if result == self.n { None } else { Some(result) }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use crate::LowestCommonAncestor;
144
145 #[test]
146 fn single_node_test() {
147 let lca = LowestCommonAncestor::new(1, 0, &[]);
148 assert_eq!(lca.get(0, 0), 0);
149 }
150
151 #[test]
152 fn test_kth_parent() {
153 let lca = LowestCommonAncestor::new(5, 0, &[(0, 1), (1, 2), (2, 3), (3, 4)]);
154
155 assert_eq!(lca.kth_parent(0, 0), Some(0));
156 assert_eq!(lca.kth_parent(0, 1), None);
157
158 assert_eq!(lca.kth_parent(1, 0), Some(1));
159 assert_eq!(lca.kth_parent(1, 1), Some(0));
160 assert_eq!(lca.kth_parent(1, 2), None);
161
162 assert_eq!(lca.kth_parent(2, 0), Some(2));
163 assert_eq!(lca.kth_parent(2, 1), Some(1));
164 assert_eq!(lca.kth_parent(2, 2), Some(0));
165 assert_eq!(lca.kth_parent(2, 3), None);
166
167 assert_eq!(lca.kth_parent(3, 0), Some(3));
168 assert_eq!(lca.kth_parent(3, 1), Some(2));
169 assert_eq!(lca.kth_parent(3, 2), Some(1));
170 assert_eq!(lca.kth_parent(3, 3), Some(0));
171 assert_eq!(lca.kth_parent(3, 4), None);
172
173 assert_eq!(lca.kth_parent(4, 0), Some(4));
174 assert_eq!(lca.kth_parent(4, 1), Some(3));
175 assert_eq!(lca.kth_parent(4, 2), Some(2));
176 assert_eq!(lca.kth_parent(4, 3), Some(1));
177 assert_eq!(lca.kth_parent(4, 4), Some(0));
178 assert_eq!(lca.kth_parent(4, 5), None);
179 }
180}