avl_tree/
lib.rs

1use std::{
2    cmp::{self, Ordering},
3    fmt,
4};
5
6#[derive(Clone)]
7struct Node<T> {
8    x: T,
9    height: i32,
10    left: Option<Box<Node<T>>>,
11    right: Option<Box<Node<T>>>,
12    size: usize,
13}
14
15#[derive(Clone)]
16pub struct AvlTree<T> {
17    n: usize,
18    root: Option<Box<Node<T>>>,
19}
20
21impl<T> AvlTree<T> {
22    pub fn new() -> Self {
23        Self { n: 0, root: None }
24    }
25
26    pub fn len(&self) -> usize {
27        self.n
28    }
29
30    pub fn is_empty(&self) -> bool {
31        self.n == 0
32    }
33
34    fn new_node(x: T) -> Box<Node<T>> {
35        Box::new(Node {
36            x,
37            height: 1,
38            left: None,
39            right: None,
40            size: 1,
41        })
42    }
43
44    fn node_height(node: &Option<Box<Node<T>>>) -> i32 {
45        node.as_ref().map_or(0, |n| n.height)
46    }
47
48    fn node_size(node: &Option<Box<Node<T>>>) -> usize {
49        node.as_ref().map_or(0, |n| n.size)
50    }
51
52    fn balance_factor(node: &Node<T>) -> i32 {
53        Self::node_height(&node.left) - Self::node_height(&node.right)
54    }
55
56    fn update_height_and_size(node: &mut Node<T>) {
57        node.height = 1 + Self::node_height(&node.left).max(Self::node_height(&node.right));
58        node.size = 1 + Self::node_size(&node.left) + Self::node_size(&node.right);
59    }
60
61    fn rotate_right(mut root: Box<Node<T>>) -> Box<Node<T>> {
62        //         root                    left
63        //         |                       |
64        //     +---+---+               +---+---+
65        //     |       |               |       |
66        //    left     c       ->      a      root
67        //     |                              |
68        // +---+---+                      +---+---+
69        // |       |                      |       |
70        // a       b                      b       c
71        let mut left = root.left.take().unwrap();
72        let b = left.right.take();
73
74        root.left = b;
75        Self::update_height_and_size(&mut root);
76
77        left.right = Some(root);
78        Self::update_height_and_size(&mut left);
79
80        left
81    }
82
83    fn rotate_left(mut root: Box<Node<T>>) -> Box<Node<T>> {
84        //      root                        right
85        //      |                           |
86        //  +---+---+                   +---+---+
87        //  |       |                   |       |
88        //  a      right        ->     root      c
89        //          |                   |
90        //      +---+---+           +---+---+
91        //      |       |           |       |
92        //      b       c           a       b
93        let mut right = root.right.take().unwrap();
94        let b = right.left.take();
95
96        root.right = b;
97        Self::update_height_and_size(&mut root);
98
99        right.left = Some(root);
100        Self::update_height_and_size(&mut right);
101
102        right
103    }
104
105    fn rebalance(mut node: Box<Node<T>>) -> Box<Node<T>> {
106        Self::update_height_and_size(&mut node);
107
108        let balance = Self::balance_factor(&node);
109
110        // Left Heavy
111        if balance > 1 {
112            // Left-Right case: check if we need double rotation
113            if let Some(left) = node.left.take() {
114                if Self::balance_factor(&left) < 0 {
115                    // Left-Right case: rotate left child left first
116                    node.left = Some(Self::rotate_left(left));
117                } else {
118                    // Left-Left case: put it back
119                    node.left = Some(left);
120                }
121            }
122            return Self::rotate_right(node);
123        }
124
125        // Right Heavy
126        if balance < -1 {
127            // Right-Left case: check if we need double rotation
128            if let Some(right) = node.right.take() {
129                if Self::balance_factor(&right) > 0 {
130                    // Right-Left case: rotate right child right first
131                    node.right = Some(Self::rotate_right(right));
132                } else {
133                    // Right-Right case: put it back
134                    node.right = Some(right);
135                }
136            }
137            return Self::rotate_left(node);
138        }
139
140        node
141    }
142
143    pub fn into_sorted_vec(mut self) -> Vec<T> {
144        fn collect<T>(node: Option<Box<Node<T>>>, acc: &mut Vec<T>) {
145            if let Some(node) = node {
146                collect(node.left, acc);
147                acc.push(node.x);
148                collect(node.right, acc);
149            }
150        }
151
152        let mut result = Vec::with_capacity(self.n);
153        collect(self.root.take(), &mut result);
154        self.n = 0;
155        result
156    }
157}
158
159impl<T> AvlTree<T>
160where
161    T: cmp::Ord,
162{
163    fn find_last(&self, x: &T) -> Option<&Node<T>> {
164        let mut current = &self.root;
165        let mut last = Option::<&Node<T>>::None;
166
167        while let Some(node) = current {
168            last = Some(node);
169            match x.cmp(&node.x) {
170                Ordering::Less => current = &node.left,
171                Ordering::Greater => current = &node.right,
172                Ordering::Equal => return Some(node),
173            }
174        }
175
176        last
177    }
178
179    /// 集合にxが含まれるかを返す。
180    pub fn contains(&self, x: &T) -> bool {
181        self.find_last(x).is_some_and(|node| x.eq(&node.x))
182    }
183
184    /// xを追加する。集合にxが含まれていなかった場合trueを返す。
185    pub fn insert(&mut self, x: T) -> bool {
186        let root = self.root.take();
187        let mut inserted = false;
188        self.root = Self::insert_recursive(root, x, &mut inserted);
189        if inserted {
190            self.n += 1;
191        }
192        inserted
193    }
194
195    fn insert_recursive(
196        root: Option<Box<Node<T>>>,
197        x: T,
198        inserted: &mut bool,
199    ) -> Option<Box<Node<T>>> {
200        let mut root = match root {
201            Some(root) => root,
202            None => {
203                *inserted = true;
204                return Some(Self::new_node(x));
205            }
206        };
207
208        match x.cmp(&root.x) {
209            Ordering::Less => {
210                root.left = Self::insert_recursive(root.left.take(), x, inserted);
211            }
212            Ordering::Greater => {
213                root.right = Self::insert_recursive(root.right.take(), x, inserted);
214            }
215            Ordering::Equal => return Some(root),
216        }
217
218        if *inserted {
219            Some(Self::rebalance(root))
220        } else {
221            Some(root)
222        }
223    }
224
225    /// xを削除する。集合にxが含まれていた場合trueを返す。
226    pub fn remove(&mut self, x: &T) -> bool {
227        let root = self.root.take();
228        let mut removed = false;
229        self.root = Self::remove_recursive(root, x, &mut removed);
230        if removed {
231            self.n -= 1;
232        }
233        removed
234    }
235
236    fn remove_recursive(
237        root: Option<Box<Node<T>>>,
238        x: &T,
239        removed: &mut bool,
240    ) -> Option<Box<Node<T>>> {
241        let mut root = root?;
242
243        match x.cmp(&root.x) {
244            Ordering::Less => {
245                root.left = Self::remove_recursive(root.left.take(), x, removed);
246            }
247            Ordering::Greater => {
248                root.right = Self::remove_recursive(root.right.take(), x, removed);
249            }
250            Ordering::Equal => {
251                *removed = true;
252                return Self::remove_node(root);
253            }
254        }
255
256        if *removed {
257            Some(Self::rebalance(root))
258        } else {
259            Some(root)
260        }
261    }
262
263    fn remove_node(mut node: Box<Node<T>>) -> Option<Box<Node<T>>> {
264        match (node.left.take(), node.right.take()) {
265            (None, None) => None,
266            (None, Some(right)) => Some(right),
267            (Some(left), None) => Some(left),
268            (Some(left), Some(right)) => {
269                node.left = Some(left);
270                let (successor_value, new_right) = Self::extract_min(right);
271                node.x = successor_value;
272                node.right = new_right;
273                Some(Self::rebalance(node))
274            }
275        }
276    }
277
278    // Extract the minimum value from a subtree and return (value, remaining_tree)
279    fn extract_min(mut node: Box<Node<T>>) -> (T, Option<Box<Node<T>>>) {
280        match node.left.take() {
281            None => (node.x, node.right.take()),
282            Some(left) => {
283                let (min_value, new_left) = Self::extract_min(left);
284                node.left = new_left;
285                (min_value, Some(Self::rebalance(node)))
286            }
287        }
288    }
289
290    /// x以下の最大の要素を返す
291    pub fn le(&self, x: &T) -> Option<&T> {
292        let mut current = &self.root;
293        let mut result = None;
294
295        while let Some(node) = current {
296            match x.cmp(&node.x) {
297                Ordering::Less => current = &node.left,
298                Ordering::Greater => {
299                    result = Some(&node.x);
300                    current = &node.right;
301                }
302                Ordering::Equal => return Some(&node.x),
303            }
304        }
305
306        result
307    }
308
309    /// x以上の最小の要素を返す
310    pub fn ge(&self, x: &T) -> Option<&T> {
311        let mut current = &self.root;
312        let mut result = None;
313
314        while let Some(node) = current {
315            match x.cmp(&node.x) {
316                Ordering::Less => {
317                    result = Some(&node.x);
318                    current = &node.left;
319                }
320                Ordering::Greater => current = &node.right,
321                Ordering::Equal => return Some(&node.x),
322            }
323        }
324
325        result
326    }
327
328    /// 0-indexedでn番目の要素を返す
329    pub fn nth(&self, n: usize) -> Option<&T> {
330        if n >= self.len() {
331            return None;
332        }
333
334        let mut current = &self.root;
335        let mut n = n;
336
337        while let Some(node) = current {
338            let left_size = Self::node_size(&node.left);
339            match n.cmp(&left_size) {
340                Ordering::Less => current = &node.left,
341                Ordering::Equal => return Some(&node.x),
342                Ordering::Greater => {
343                    n -= 1 + left_size;
344                    current = &node.right;
345                }
346            }
347        }
348
349        unreachable!()
350    }
351
352    /// xより小さい要素の個数を返す
353    /// 集合がxを含む場合Ok, xを含まない場合Err
354    pub fn position(&self, x: &T) -> Result<usize, usize> {
355        let mut current = &self.root;
356        let mut count = 0;
357        let mut hit = false;
358
359        while let Some(node) = current {
360            match x.cmp(&node.x) {
361                Ordering::Less => current = &node.left,
362                Ordering::Equal => {
363                    hit = true;
364                    current = &node.left;
365                }
366                Ordering::Greater => {
367                    count += 1 + Self::node_size(&node.left);
368                    current = &node.right;
369                }
370            }
371        }
372
373        if hit { Ok(count) } else { Err(count) }
374    }
375}
376
377impl<T> Default for AvlTree<T> {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383impl<T> fmt::Debug for AvlTree<T>
384where
385    T: fmt::Debug,
386{
387    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388        f.debug_list().entries(self.iter()).finish()
389    }
390}
391
392pub struct Iter<'a, T> {
393    stack: Vec<&'a Node<T>>,
394}
395
396impl<'a, T> Iter<'a, T> {
397    fn new(root: &'a Option<Box<Node<T>>>) -> Self {
398        let mut iter = Self { stack: Vec::new() };
399        iter.push_left_path(root);
400        iter
401    }
402
403    fn push_left_path(&mut self, mut node: &'a Option<Box<Node<T>>>) {
404        while let Some(n) = node {
405            self.stack.push(n);
406            node = &n.left;
407        }
408    }
409}
410
411impl<'a, T> Iterator for Iter<'a, T> {
412    type Item = &'a T;
413
414    fn next(&mut self) -> Option<Self::Item> {
415        let node = self.stack.pop()?;
416        let result = &node.x;
417        self.push_left_path(&node.right);
418        Some(result)
419    }
420}
421
422impl<T> AvlTree<T> {
423    pub fn iter(&self) -> Iter<'_, T> {
424        Iter::new(&self.root)
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use crate::{AvlTree, Node};
431
432    #[test]
433    fn test_avl_insert() {
434        let mut avl = AvlTree::default();
435        assert_eq!(avl.insert(42), true);
436        assert_eq!(avl.insert(42), false);
437    }
438
439    #[test]
440    fn test_avl_remove() {
441        let mut avl = AvlTree::default();
442        avl.insert(42);
443        assert_eq!(avl.remove(&41), false);
444        assert_eq!(avl.remove(&42), true);
445        assert_eq!(avl.remove(&42), false);
446    }
447
448    #[test]
449    fn test_avl_contains() {
450        let mut avl = AvlTree::default();
451        avl.insert(42);
452        assert_eq!(avl.contains(&42), true);
453        assert_eq!(avl.contains(&24), false);
454    }
455
456    #[test]
457    fn test_avl_le() {
458        let mut avl = AvlTree::default();
459        avl.insert(42);
460        assert_eq!(avl.le(&41), None);
461        assert_eq!(avl.le(&42), Some(&42));
462        assert_eq!(avl.le(&43), Some(&42));
463    }
464
465    #[test]
466    fn test_avl_ge() {
467        let mut avl = AvlTree::default();
468        avl.insert(42);
469        assert_eq!(avl.ge(&41), Some(&42));
470        assert_eq!(avl.ge(&42), Some(&42));
471        assert_eq!(avl.ge(&43), None);
472    }
473
474    #[test]
475    fn test_avl_nth() {
476        let mut avl = AvlTree::default();
477        avl.insert(1);
478        avl.insert(2);
479        avl.insert(4);
480        avl.insert(8);
481        assert_eq!(avl.nth(0), Some(&1));
482        assert_eq!(avl.nth(1), Some(&2));
483        assert_eq!(avl.nth(2), Some(&4));
484        assert_eq!(avl.nth(3), Some(&8));
485        assert_eq!(avl.nth(4), None);
486    }
487
488    #[test]
489    fn test_avl_position() {
490        let mut avl = AvlTree::default();
491        avl.insert(1);
492        avl.insert(2);
493        avl.insert(4);
494        avl.insert(8);
495        assert_eq!(avl.position(&0), Err(0));
496        assert_eq!(avl.position(&1), Ok(0));
497        assert_eq!(avl.position(&2), Ok(1));
498        assert_eq!(avl.position(&3), Err(2));
499        assert_eq!(avl.position(&4), Ok(2));
500        assert_eq!(avl.position(&5), Err(3));
501        assert_eq!(avl.position(&6), Err(3));
502        assert_eq!(avl.position(&7), Err(3));
503        assert_eq!(avl.position(&8), Ok(3));
504        assert_eq!(avl.position(&9), Err(4));
505    }
506
507    #[test]
508    fn test_avl_iter() {
509        let mut avl = AvlTree::default();
510        avl.insert(3);
511        avl.insert(1);
512        avl.insert(4);
513        avl.insert(5);
514        avl.insert(9);
515        avl.insert(2);
516
517        let values: Vec<_> = avl.iter().collect();
518        assert_eq!(values, vec![&1, &2, &3, &4, &5, &9]);
519    }
520
521    #[test]
522    fn test_avl_into_sorted_vec() {
523        let mut avl = AvlTree::default();
524        avl.insert(3);
525        avl.insert(1);
526        avl.insert(4);
527        avl.insert(5);
528        avl.insert(9);
529        avl.insert(2);
530
531        assert_eq!(avl.into_sorted_vec(), vec![1, 2, 3, 4, 5, 9]);
532    }
533
534    #[test]
535    fn test_avl_balance() {
536        fn assert_all<T>(node: &Option<Box<Node<T>>>) {
537            if let Some(node) = node {
538                assert_all(&node.left);
539                assert!(AvlTree::balance_factor(node).abs() <= 1);
540                assert_all(&node.right);
541            };
542        }
543
544        let mut avl = AvlTree::default();
545        for x in 0..1000 {
546            avl.insert(x);
547            assert_all(&avl.root);
548        }
549    }
550}