union_find/lib.rs
1/// Union Find はグラフの連結成分を管理します。
2#[derive(Clone, Debug)]
3pub struct UnionFind {
4 nodes: Vec<NodeKind>,
5 groups: usize,
6}
7
8#[derive(Clone, Debug)]
9pub struct UniteResult {
10 pub new_root: usize,
11 pub child: usize,
12}
13
14#[derive(Clone, Copy, Debug)]
15enum NodeKind {
16 Root { size: usize },
17 Child { parent: usize },
18}
19
20impl UnionFind {
21 /// 頂点数を `n` として初期化します。
22 pub fn new(n: usize) -> Self {
23 Self {
24 nodes: vec![NodeKind::Root { size: 1 }; n],
25 groups: n,
26 }
27 }
28
29 /// 頂点 `i` の属する連結成分の代表元を返します。
30 ///
31 /// # Examples
32 ///
33 /// ```
34 /// use union_find::UnionFind;
35 /// let mut uf = UnionFind::new(6);
36 /// uf.unite(0, 1);
37 /// uf.unite(1, 2);
38 /// uf.unite(3, 4);
39 ///
40 /// // [(0, 1, 2), (3, 4), (5)]
41 /// assert_eq!(uf.find(0), uf.find(0));
42 /// assert_eq!(uf.find(0), uf.find(1));
43 /// assert_eq!(uf.find(1), uf.find(2));
44 /// assert_eq!(uf.find(0), uf.find(2));
45 /// assert_eq!(uf.find(3), uf.find(4));
46 ///
47 /// assert_ne!(uf.find(0), uf.find(3));
48 /// assert_ne!(uf.find(0), uf.find(5));
49 /// ```
50 pub fn find(&mut self, i: usize) -> usize {
51 assert!(i < self.nodes.len());
52
53 match self.nodes[i] {
54 NodeKind::Root { .. } => i,
55 NodeKind::Child { parent } => {
56 let root = self.find(parent);
57 if root == parent {
58 // noop
59 } else {
60 // 経路圧縮
61 self.nodes[i] = NodeKind::Child { parent: root };
62 }
63 root
64 }
65 }
66 }
67
68 /// 頂点 `i` の属する連結成分と頂点 `j` の属する連結成分をつなげます。
69 ///
70 /// もともと `i` と `j` が同じ連結成分だった場合は `None` を返します。
71 ///
72 /// # Examples
73 ///
74 /// ```
75 /// use union_find::UnionFind;
76 /// let mut uf = UnionFind::new(6);
77 /// assert!(uf.unite(0, 1).is_some());
78 /// assert!(uf.unite(1, 2).is_some());
79 /// assert!(uf.unite(3, 4).is_some());
80 ///
81 /// // [(0, 1, 2), (3, 4), (5)]
82 /// assert!(uf.unite(0, 2).is_none());
83 /// assert!(uf.unite(3, 3).is_none());
84 ///
85 /// assert!(uf.unite(4, 5).is_some());
86 /// ```
87 pub fn unite(&mut self, i: usize, j: usize) -> Option<UniteResult> {
88 let i = self.find(i);
89 let j = self.find(j);
90 if i == j {
91 return None;
92 }
93
94 match (self.nodes[i], self.nodes[j]) {
95 (NodeKind::Root { size: i_size }, NodeKind::Root { size: j_size }) => {
96 self.groups -= 1;
97 let total = i_size + j_size;
98 // マージテク
99 if i_size >= j_size {
100 self.nodes[j] = NodeKind::Child { parent: i };
101 self.nodes[i] = NodeKind::Root { size: total };
102 Some(UniteResult {
103 new_root: i,
104 child: j,
105 })
106 } else {
107 self.nodes[i] = NodeKind::Child { parent: j };
108 self.nodes[j] = NodeKind::Root { size: total };
109 Some(UniteResult {
110 new_root: j,
111 child: i,
112 })
113 }
114 }
115 _ => unreachable!(),
116 }
117 }
118
119 /// 頂点 `i` の属する連結成分のサイズ (頂点数) を返します。
120 ///
121 /// # Examples
122 ///
123 /// ```
124 /// use union_find::UnionFind;
125 /// let mut uf = UnionFind::new(6);
126 /// uf.unite(0, 1);
127 /// uf.unite(1, 2);
128 /// uf.unite(3, 4);
129 ///
130 /// // [(0, 1, 2), (3, 4), (5)]
131 /// assert_eq!(uf.size(0), 3);
132 /// assert_eq!(uf.size(1), 3);
133 /// assert_eq!(uf.size(2), 3);
134 /// assert_eq!(uf.size(3), 2);
135 /// assert_eq!(uf.size(4), 2);
136 /// assert_eq!(uf.size(5), 1);
137 /// ```
138 pub fn size(&mut self, i: usize) -> usize {
139 let root = self.find(i);
140 match self.nodes[root] {
141 NodeKind::Root { size } => size,
142 _ => unreachable!(),
143 }
144 }
145
146 /// 頂点 `i` と頂点 `j` が同じ連結成分に属するかどうかを返します。
147 ///
148 /// # Examples
149 ///
150 /// ```
151 /// use union_find::UnionFind;
152 /// let mut uf = UnionFind::new(6);
153 /// assert!(uf.same(0, 0));
154 /// assert!(uf.same(3, 3));
155 /// assert!(uf.same(5, 5));
156 ///
157 /// uf.unite(0, 1);
158 /// uf.unite(1, 2);
159 /// uf.unite(3, 4);
160 ///
161 /// // [(0, 1, 2), (3, 4), (5)]
162 /// assert!(uf.same(0, 1));
163 /// assert!(uf.same(1, 2));
164 /// assert!(uf.same(0, 2));
165 /// assert!(uf.same(3, 4));
166 /// ```
167 pub fn same(&mut self, i: usize, j: usize) -> bool {
168 self.find(i) == self.find(j)
169 }
170
171 /// 連結成分数を返します。
172 ///
173 /// # Examples
174 ///
175 /// ```
176 /// use union_find::UnionFind;
177 /// let mut uf = UnionFind::new(6);
178 /// uf.unite(0, 1);
179 /// uf.unite(1, 2);
180 /// uf.unite(3, 4);
181 ///
182 /// // [(0, 1, 2), (3, 4), (5)]
183 /// assert_eq!(uf.count_groups(), 3);
184 /// ```
185 pub fn count_groups(&self) -> usize {
186 self.groups
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use ::proptest::{collection, prelude::*};
193
194 use super::*;
195
196 #[test]
197 fn test_basic() {
198 // 0 -- 1 -- 2
199 // 3 -- 4
200 let mut uf = UnionFind::new(6);
201 assert!(uf.unite(0, 1).is_some());
202 assert!(uf.unite(1, 2).is_some());
203 assert!(uf.unite(3, 4).is_some());
204
205 assert!(uf.same(0, 1));
206 assert!(uf.same(1, 2));
207 assert!(uf.same(0, 2));
208 assert!(uf.same(3, 4));
209 assert!(!uf.same(0, 3));
210
211 assert_eq!(uf.size(0), 3);
212 assert_eq!(uf.size(3), 2);
213 assert_eq!(uf.size(5), 1);
214 assert_eq!(uf.count_groups(), 3);
215 }
216
217 prop_compose! {
218 fn uf_operations()(n in 1_usize..=20)
219 (n in Just(n),
220 operations in collection::vec((0..n, 0..n), 0..=50))
221 -> (usize, Vec<(usize, usize)>) {
222 (n, operations)
223 }
224 }
225
226 proptest! {
227 #[test]
228 fn unite_makes_same((n, operations) in uf_operations()) {
229 let mut uf = UnionFind::new(n);
230 for (i, j) in operations {
231 uf.unite(i, j);
232 prop_assert!(uf.same(i, j));
233 }
234 }
235
236 #[test]
237 fn same_is_transitive((n, operations) in uf_operations()) {
238 let mut uf = UnionFind::new(n);
239 for (i, j) in operations {
240 uf.unite(i, j);
241 }
242
243 for i in 0..n {
244 for j in 0..n {
245 for k in 0..n {
246 if uf.same(i, j) && uf.same(j, k) {
247 prop_assert!(uf.same(i, k));
248 }
249 }
250 }
251 }
252 }
253
254 #[test]
255 fn size_sum_equals_n((n, operations) in uf_operations()) {
256 let mut uf = UnionFind::new(n);
257 for (i, j) in operations {
258 uf.unite(i, j);
259 }
260
261 let mut roots = std::collections::HashSet::new();
262 for i in 0..n {
263 roots.insert(uf.find(i));
264 }
265
266 let total_size: usize = roots.iter().map(|&r| uf.size(r)).sum();
267 prop_assert_eq!(total_size, n);
268 }
269 }
270}