lowest_common_ancestor/
lib.rs1use std::collections::VecDeque;
2
3pub struct LowestCommonAncestor {
22 n: usize,
23 ancestor: Vec<Vec<usize>>,
24 depth: Vec<usize>,
25}
26
27const ILLEGAL: usize = std::usize::MAX;
28
29impl LowestCommonAncestor {
30 pub fn new(n: usize, root: usize, edges: &[(usize, usize)]) -> Self {
32 assert!(root < n);
33 let mut g = vec![vec![]; n];
34 for &(u, v) in edges {
35 g[u].push(v);
36 g[v].push(u);
37 }
38 let mut depth = vec![0; n];
39 let mut parent = vec![ILLEGAL; n];
40 let mut que = VecDeque::new();
41 depth[root] = 0;
42 que.push_back((root, ILLEGAL));
43 while let Some((curr, prev)) = que.pop_front() {
44 for &next in &g[curr] {
45 if next != prev {
46 depth[next] = depth[curr] + 1;
47 parent[next] = curr;
48 que.push_back((next, curr));
49 }
50 }
51 }
52 let table_size = if n == 1 {
53 1
54 } else {
55 n.ilog2() as usize + usize::from(!n.is_power_of_two())
57 };
58 let mut ancestor = vec![vec![ILLEGAL; n]; table_size];
59 ancestor[0] = parent;
60 for i in 1..table_size {
61 ancestor[i] = (0..n)
62 .map(|v| {
63 if ancestor[i - 1][v] == ILLEGAL {
64 ILLEGAL
65 } else {
66 ancestor[i - 1][ancestor[i - 1][v]]
67 }
68 })
69 .collect();
70 }
71 Self { n, ancestor, depth }
72 }
73
74 pub fn get(&self, u: usize, v: usize) -> usize {
76 assert!(u < self.n);
77 assert!(v < self.n);
78 let (mut u, mut v) = if self.depth[u] >= self.depth[v] {
79 (u, v)
80 } else {
81 (v, u)
82 };
83 assert!(self.depth[u] >= self.depth[v]);
84 let depth_diff = self.depth[u] - self.depth[v];
85 for i in 0..self.ancestor.len() {
86 if depth_diff >> i & 1 == 1 {
87 u = self.ancestor[i][u];
88 }
89 }
90 if u == v {
91 return u;
92 }
93 for i in (0..self.ancestor.len()).rev() {
94 if self.ancestor[i][u] != self.ancestor[i][v] {
95 u = self.ancestor[i][u];
96 v = self.ancestor[i][v];
97 }
98 }
99 let lca = self.ancestor[0][u];
100 assert_ne!(lca, ILLEGAL);
101 lca
102 }
103
104 pub fn get_dist(&self, u: usize, v: usize) -> usize {
106 let w = self.get(u, v);
107 self.depth[u] + self.depth[v] - self.depth[w] * 2
108 }
109
110 pub fn depth(&self, u: usize) -> usize {
112 self.depth[u]
113 }
114
115 pub fn kth_parent(&self, u: usize, k: usize) -> Option<usize> {
117 assert!(u < self.n);
118 if k >= self.n - 1 {
119 return None;
120 }
121 let mut u = u;
122 for i in 0..self.ancestor.len() {
123 if self.depth[k] >> i & 1 == 1 {
124 u = self.ancestor[i][u];
125 if u == ILLEGAL {
126 return None;
127 }
128 }
129 }
130 Some(u)
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use crate::LowestCommonAncestor;
137
138 #[test]
139 fn single_node_test() {
140 let lca = LowestCommonAncestor::new(1, 0, &[]);
141 assert_eq!(lca.get(0, 0), 0);
142 }
143}