Skip to main content

union_find/
lib.rs

1/// Union Find はグラフの連結成分を管理します。
2#[derive(Clone, Debug)]
3pub struct UnionFind {
4    nodes: Vec<NodeKind>,
5    groups: usize,
6}
7
8#[derive(Clone, Debug)]
9pub struct UniteResult {
10    pub new_root: usize,
11    pub child: usize,
12}
13
14#[derive(Clone, Copy, Debug)]
15enum NodeKind {
16    Root { size: usize },
17    Child { parent: usize },
18}
19
20impl UnionFind {
21    /// 頂点数を `n` として初期化します。
22    pub fn new(n: usize) -> Self {
23        Self {
24            nodes: vec![NodeKind::Root { size: 1 }; n],
25            groups: n,
26        }
27    }
28
29    /// 頂点 `i` の属する連結成分の代表元を返します。
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use union_find::UnionFind;
35    /// let mut uf = UnionFind::new(6);
36    /// uf.unite(0, 1);
37    /// uf.unite(1, 2);
38    /// uf.unite(3, 4);
39    ///
40    /// // [(0, 1, 2), (3, 4), (5)]
41    /// assert_eq!(uf.find(0), uf.find(0));
42    /// assert_eq!(uf.find(0), uf.find(1));
43    /// assert_eq!(uf.find(1), uf.find(2));
44    /// assert_eq!(uf.find(0), uf.find(2));
45    /// assert_eq!(uf.find(3), uf.find(4));
46    ///
47    /// assert_ne!(uf.find(0), uf.find(3));
48    /// assert_ne!(uf.find(0), uf.find(5));
49    /// ```
50    pub fn find(&mut self, i: usize) -> usize {
51        assert!(i < self.nodes.len());
52
53        match self.nodes[i] {
54            NodeKind::Root { .. } => i,
55            NodeKind::Child { parent } => {
56                let root = self.find(parent);
57                if root == parent {
58                    // noop
59                } else {
60                    // 経路圧縮
61                    self.nodes[i] = NodeKind::Child { parent: root };
62                }
63                root
64            }
65        }
66    }
67
68    /// 頂点 `i` の属する連結成分と頂点 `j` の属する連結成分をつなげます。
69    ///
70    /// もともと `i` と `j` が同じ連結成分だった場合は `None` を返します。
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use union_find::UnionFind;
76    /// let mut uf = UnionFind::new(6);
77    /// assert!(uf.unite(0, 1).is_some());
78    /// assert!(uf.unite(1, 2).is_some());
79    /// assert!(uf.unite(3, 4).is_some());
80    ///
81    /// // [(0, 1, 2), (3, 4), (5)]
82    /// assert!(uf.unite(0, 2).is_none());
83    /// assert!(uf.unite(3, 3).is_none());
84    ///
85    /// assert!(uf.unite(4, 5).is_some());
86    /// ```
87    pub fn unite(&mut self, i: usize, j: usize) -> Option<UniteResult> {
88        let i = self.find(i);
89        let j = self.find(j);
90        if i == j {
91            return None;
92        }
93
94        match (self.nodes[i], self.nodes[j]) {
95            (NodeKind::Root { size: i_size }, NodeKind::Root { size: j_size }) => {
96                self.groups -= 1;
97                let total = i_size + j_size;
98                // マージテク
99                if i_size >= j_size {
100                    self.nodes[j] = NodeKind::Child { parent: i };
101                    self.nodes[i] = NodeKind::Root { size: total };
102                    Some(UniteResult {
103                        new_root: i,
104                        child: j,
105                    })
106                } else {
107                    self.nodes[i] = NodeKind::Child { parent: j };
108                    self.nodes[j] = NodeKind::Root { size: total };
109                    Some(UniteResult {
110                        new_root: j,
111                        child: i,
112                    })
113                }
114            }
115            _ => unreachable!(),
116        }
117    }
118
119    /// 頂点 `i` の属する連結成分のサイズ (頂点数) を返します。
120    ///
121    /// # Examples
122    ///
123    /// ```
124    /// use union_find::UnionFind;
125    /// let mut uf = UnionFind::new(6);
126    /// uf.unite(0, 1);
127    /// uf.unite(1, 2);
128    /// uf.unite(3, 4);
129    ///
130    /// // [(0, 1, 2), (3, 4), (5)]
131    /// assert_eq!(uf.size(0), 3);
132    /// assert_eq!(uf.size(1), 3);
133    /// assert_eq!(uf.size(2), 3);
134    /// assert_eq!(uf.size(3), 2);
135    /// assert_eq!(uf.size(4), 2);
136    /// assert_eq!(uf.size(5), 1);
137    /// ```
138    pub fn size(&mut self, i: usize) -> usize {
139        let root = self.find(i);
140        match self.nodes[root] {
141            NodeKind::Root { size } => size,
142            _ => unreachable!(),
143        }
144    }
145
146    /// 頂点 `i` と頂点 `j` が同じ連結成分に属するかどうかを返します。
147    ///  
148    /// # Examples
149    ///
150    /// ```
151    /// use union_find::UnionFind;
152    /// let mut uf = UnionFind::new(6);
153    /// assert!(uf.same(0, 0));
154    /// assert!(uf.same(3, 3));
155    /// assert!(uf.same(5, 5));
156    ///
157    /// uf.unite(0, 1);
158    /// uf.unite(1, 2);
159    /// uf.unite(3, 4);
160    ///
161    /// // [(0, 1, 2), (3, 4), (5)]
162    /// assert!(uf.same(0, 1));
163    /// assert!(uf.same(1, 2));
164    /// assert!(uf.same(0, 2));
165    /// assert!(uf.same(3, 4));
166    /// ```
167    pub fn same(&mut self, i: usize, j: usize) -> bool {
168        self.find(i) == self.find(j)
169    }
170
171    /// 連結成分数を返します。
172    ///
173    /// # Examples
174    ///
175    /// ```
176    /// use union_find::UnionFind;
177    /// let mut uf = UnionFind::new(6);
178    /// uf.unite(0, 1);
179    /// uf.unite(1, 2);
180    /// uf.unite(3, 4);
181    ///
182    /// // [(0, 1, 2), (3, 4), (5)]
183    /// assert_eq!(uf.count_groups(), 3);
184    /// ```
185    pub fn count_groups(&self) -> usize {
186        self.groups
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use ::proptest::{collection, prelude::*};
193
194    use super::*;
195
196    #[test]
197    fn test_basic() {
198        // 0 -- 1 -- 2
199        // 3 -- 4
200        let mut uf = UnionFind::new(6);
201        assert!(uf.unite(0, 1).is_some());
202        assert!(uf.unite(1, 2).is_some());
203        assert!(uf.unite(3, 4).is_some());
204
205        assert!(uf.same(0, 1));
206        assert!(uf.same(1, 2));
207        assert!(uf.same(0, 2));
208        assert!(uf.same(3, 4));
209        assert!(!uf.same(0, 3));
210
211        assert_eq!(uf.size(0), 3);
212        assert_eq!(uf.size(3), 2);
213        assert_eq!(uf.size(5), 1);
214        assert_eq!(uf.count_groups(), 3);
215    }
216
217    prop_compose! {
218        fn uf_operations()(n in 1_usize..=20)
219                         (n in Just(n),
220                          operations in collection::vec((0..n, 0..n), 0..=50))
221                         -> (usize, Vec<(usize, usize)>) {
222            (n, operations)
223        }
224    }
225
226    proptest! {
227        #[test]
228        fn unite_makes_same((n, operations) in uf_operations()) {
229            let mut uf = UnionFind::new(n);
230            for (i, j) in operations {
231                uf.unite(i, j);
232                prop_assert!(uf.same(i, j));
233            }
234        }
235
236        #[test]
237        fn same_is_transitive((n, operations) in uf_operations()) {
238            let mut uf = UnionFind::new(n);
239            for (i, j) in operations {
240                uf.unite(i, j);
241            }
242
243            for i in 0..n {
244                for j in 0..n {
245                    for k in 0..n {
246                        if uf.same(i, j) && uf.same(j, k) {
247                            prop_assert!(uf.same(i, k));
248                        }
249                    }
250                }
251            }
252        }
253
254        #[test]
255        fn size_sum_equals_n((n, operations) in uf_operations()) {
256            let mut uf = UnionFind::new(n);
257            for (i, j) in operations {
258                uf.unite(i, j);
259            }
260
261            let mut roots = std::collections::HashSet::new();
262            for i in 0..n {
263                roots.insert(uf.find(i));
264            }
265
266            let total_size: usize = roots.iter().map(|&r| uf.size(r)).sum();
267            prop_assert_eq!(total_size, n);
268        }
269    }
270}