1use std::{
2 cmp::{self, Ordering},
3 fmt,
4};
5
6#[derive(Clone)]
7struct Node<T> {
8 x: T,
9 height: i32,
10 left: Option<Box<Node<T>>>,
11 right: Option<Box<Node<T>>>,
12 size: usize,
13}
14
15#[derive(Clone)]
16pub struct AvlTree<T> {
17 n: usize,
18 root: Option<Box<Node<T>>>,
19}
20
21impl<T> AvlTree<T> {
22 pub fn new() -> Self {
23 Self { n: 0, root: None }
24 }
25
26 pub fn len(&self) -> usize {
27 self.n
28 }
29
30 pub fn is_empty(&self) -> bool {
31 self.n == 0
32 }
33
34 fn new_node(x: T) -> Box<Node<T>> {
35 Box::new(Node {
36 x,
37 height: 1,
38 left: None,
39 right: None,
40 size: 1,
41 })
42 }
43
44 fn node_height(node: &Option<Box<Node<T>>>) -> i32 {
45 node.as_ref().map_or(0, |n| n.height)
46 }
47
48 fn node_size(node: &Option<Box<Node<T>>>) -> usize {
49 node.as_ref().map_or(0, |n| n.size)
50 }
51
52 fn balance_factor(node: &Node<T>) -> i32 {
53 Self::node_height(&node.left) - Self::node_height(&node.right)
54 }
55
56 fn update_height_and_size(node: &mut Node<T>) {
57 node.height = 1 + Self::node_height(&node.left).max(Self::node_height(&node.right));
58 node.size = 1 + Self::node_size(&node.left) + Self::node_size(&node.right);
59 }
60
61 fn rotate_right(mut root: Box<Node<T>>) -> Box<Node<T>> {
62 let mut left = root.left.take().unwrap();
72 let b = left.right.take();
73
74 root.left = b;
75 Self::update_height_and_size(&mut root);
76
77 left.right = Some(root);
78 Self::update_height_and_size(&mut left);
79
80 left
81 }
82
83 fn rotate_left(mut root: Box<Node<T>>) -> Box<Node<T>> {
84 let mut right = root.right.take().unwrap();
94 let b = right.left.take();
95
96 root.right = b;
97 Self::update_height_and_size(&mut root);
98
99 right.left = Some(root);
100 Self::update_height_and_size(&mut right);
101
102 right
103 }
104
105 fn rebalance(mut node: Box<Node<T>>) -> Box<Node<T>> {
106 Self::update_height_and_size(&mut node);
107
108 let balance = Self::balance_factor(&node);
109
110 if balance > 1 {
112 if let Some(left) = node.left.take() {
114 if Self::balance_factor(&left) < 0 {
115 node.left = Some(Self::rotate_left(left));
117 } else {
118 node.left = Some(left);
120 }
121 }
122 return Self::rotate_right(node);
123 }
124
125 if balance < -1 {
127 if let Some(right) = node.right.take() {
129 if Self::balance_factor(&right) > 0 {
130 node.right = Some(Self::rotate_right(right));
132 } else {
133 node.right = Some(right);
135 }
136 }
137 return Self::rotate_left(node);
138 }
139
140 node
141 }
142
143 pub fn into_sorted_vec(mut self) -> Vec<T> {
144 fn collect<T>(node: Option<Box<Node<T>>>, acc: &mut Vec<T>) {
145 if let Some(node) = node {
146 collect(node.left, acc);
147 acc.push(node.x);
148 collect(node.right, acc);
149 }
150 }
151
152 let mut result = Vec::with_capacity(self.n);
153 collect(self.root.take(), &mut result);
154 self.n = 0;
155 result
156 }
157}
158
159impl<T> AvlTree<T>
160where
161 T: cmp::Ord,
162{
163 fn find_last(&self, x: &T) -> Option<&Node<T>> {
164 let mut current = &self.root;
165 let mut last = Option::<&Node<T>>::None;
166
167 while let Some(node) = current {
168 last = Some(node);
169 match x.cmp(&node.x) {
170 Ordering::Less => current = &node.left,
171 Ordering::Greater => current = &node.right,
172 Ordering::Equal => return Some(node),
173 }
174 }
175
176 last
177 }
178
179 pub fn contains(&self, x: &T) -> bool {
181 self.find_last(x).is_some_and(|node| x.eq(&node.x))
182 }
183
184 pub fn insert(&mut self, x: T) -> bool {
186 let root = self.root.take();
187 let mut inserted = false;
188 self.root = Self::insert_recursive(root, x, &mut inserted);
189 if inserted {
190 self.n += 1;
191 }
192 inserted
193 }
194
195 fn insert_recursive(
196 root: Option<Box<Node<T>>>,
197 x: T,
198 inserted: &mut bool,
199 ) -> Option<Box<Node<T>>> {
200 let mut root = match root {
201 Some(root) => root,
202 None => {
203 *inserted = true;
204 return Some(Self::new_node(x));
205 }
206 };
207
208 match x.cmp(&root.x) {
209 Ordering::Less => {
210 root.left = Self::insert_recursive(root.left.take(), x, inserted);
211 }
212 Ordering::Greater => {
213 root.right = Self::insert_recursive(root.right.take(), x, inserted);
214 }
215 Ordering::Equal => return Some(root),
216 }
217
218 if *inserted {
219 Some(Self::rebalance(root))
220 } else {
221 Some(root)
222 }
223 }
224
225 pub fn remove(&mut self, x: &T) -> bool {
227 let root = self.root.take();
228 let mut removed = false;
229 self.root = Self::remove_recursive(root, x, &mut removed);
230 if removed {
231 self.n -= 1;
232 }
233 removed
234 }
235
236 fn remove_recursive(
237 root: Option<Box<Node<T>>>,
238 x: &T,
239 removed: &mut bool,
240 ) -> Option<Box<Node<T>>> {
241 let mut root = root?;
242
243 match x.cmp(&root.x) {
244 Ordering::Less => {
245 root.left = Self::remove_recursive(root.left.take(), x, removed);
246 }
247 Ordering::Greater => {
248 root.right = Self::remove_recursive(root.right.take(), x, removed);
249 }
250 Ordering::Equal => {
251 *removed = true;
252 return Self::remove_node(root);
253 }
254 }
255
256 if *removed {
257 Some(Self::rebalance(root))
258 } else {
259 Some(root)
260 }
261 }
262
263 fn remove_node(mut node: Box<Node<T>>) -> Option<Box<Node<T>>> {
264 match (node.left.take(), node.right.take()) {
265 (None, None) => None,
266 (None, Some(right)) => Some(right),
267 (Some(left), None) => Some(left),
268 (Some(left), Some(right)) => {
269 node.left = Some(left);
270 let (successor_value, new_right) = Self::extract_min(right);
271 node.x = successor_value;
272 node.right = new_right;
273 Some(Self::rebalance(node))
274 }
275 }
276 }
277
278 fn extract_min(mut node: Box<Node<T>>) -> (T, Option<Box<Node<T>>>) {
280 match node.left.take() {
281 None => (node.x, node.right.take()),
282 Some(left) => {
283 let (min_value, new_left) = Self::extract_min(left);
284 node.left = new_left;
285 (min_value, Some(Self::rebalance(node)))
286 }
287 }
288 }
289
290 pub fn le(&self, x: &T) -> Option<&T> {
292 let mut current = &self.root;
293 let mut result = None;
294
295 while let Some(node) = current {
296 match x.cmp(&node.x) {
297 Ordering::Less => current = &node.left,
298 Ordering::Greater => {
299 result = Some(&node.x);
300 current = &node.right;
301 }
302 Ordering::Equal => return Some(&node.x),
303 }
304 }
305
306 result
307 }
308
309 pub fn ge(&self, x: &T) -> Option<&T> {
311 let mut current = &self.root;
312 let mut result = None;
313
314 while let Some(node) = current {
315 match x.cmp(&node.x) {
316 Ordering::Less => {
317 result = Some(&node.x);
318 current = &node.left;
319 }
320 Ordering::Greater => current = &node.right,
321 Ordering::Equal => return Some(&node.x),
322 }
323 }
324
325 result
326 }
327
328 pub fn nth(&self, n: usize) -> Option<&T> {
330 if n >= self.len() {
331 return None;
332 }
333
334 let mut current = &self.root;
335 let mut n = n;
336
337 while let Some(node) = current {
338 let left_size = Self::node_size(&node.left);
339 match n.cmp(&left_size) {
340 Ordering::Less => current = &node.left,
341 Ordering::Equal => return Some(&node.x),
342 Ordering::Greater => {
343 n -= 1 + left_size;
344 current = &node.right;
345 }
346 }
347 }
348
349 unreachable!()
350 }
351
352 pub fn position(&self, x: &T) -> Result<usize, usize> {
355 let mut current = &self.root;
356 let mut count = 0;
357 let mut hit = false;
358
359 while let Some(node) = current {
360 match x.cmp(&node.x) {
361 Ordering::Less => current = &node.left,
362 Ordering::Equal => {
363 hit = true;
364 current = &node.left;
365 }
366 Ordering::Greater => {
367 count += 1 + Self::node_size(&node.left);
368 current = &node.right;
369 }
370 }
371 }
372
373 if hit { Ok(count) } else { Err(count) }
374 }
375}
376
377impl<T> Default for AvlTree<T> {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383impl<T> fmt::Debug for AvlTree<T>
384where
385 T: fmt::Debug,
386{
387 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388 f.debug_list().entries(self.iter()).finish()
389 }
390}
391
392pub struct Iter<'a, T> {
393 stack: Vec<&'a Node<T>>,
394}
395
396impl<'a, T> Iter<'a, T> {
397 fn new(root: &'a Option<Box<Node<T>>>) -> Self {
398 let mut iter = Self { stack: Vec::new() };
399 iter.push_left_path(root);
400 iter
401 }
402
403 fn push_left_path(&mut self, mut node: &'a Option<Box<Node<T>>>) {
404 while let Some(n) = node {
405 self.stack.push(n);
406 node = &n.left;
407 }
408 }
409}
410
411impl<'a, T> Iterator for Iter<'a, T> {
412 type Item = &'a T;
413
414 fn next(&mut self) -> Option<Self::Item> {
415 let node = self.stack.pop()?;
416 let result = &node.x;
417 self.push_left_path(&node.right);
418 Some(result)
419 }
420}
421
422impl<T> AvlTree<T> {
423 pub fn iter(&self) -> Iter<'_, T> {
424 Iter::new(&self.root)
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use crate::{AvlTree, Node};
431
432 #[test]
433 fn test_avl_insert() {
434 let mut avl = AvlTree::default();
435 assert_eq!(avl.insert(42), true);
436 assert_eq!(avl.insert(42), false);
437 }
438
439 #[test]
440 fn test_avl_remove() {
441 let mut avl = AvlTree::default();
442 avl.insert(42);
443 assert_eq!(avl.remove(&41), false);
444 assert_eq!(avl.remove(&42), true);
445 assert_eq!(avl.remove(&42), false);
446 }
447
448 #[test]
449 fn test_avl_contains() {
450 let mut avl = AvlTree::default();
451 avl.insert(42);
452 assert_eq!(avl.contains(&42), true);
453 assert_eq!(avl.contains(&24), false);
454 }
455
456 #[test]
457 fn test_avl_le() {
458 let mut avl = AvlTree::default();
459 avl.insert(42);
460 assert_eq!(avl.le(&41), None);
461 assert_eq!(avl.le(&42), Some(&42));
462 assert_eq!(avl.le(&43), Some(&42));
463 }
464
465 #[test]
466 fn test_avl_ge() {
467 let mut avl = AvlTree::default();
468 avl.insert(42);
469 assert_eq!(avl.ge(&41), Some(&42));
470 assert_eq!(avl.ge(&42), Some(&42));
471 assert_eq!(avl.ge(&43), None);
472 }
473
474 #[test]
475 fn test_avl_nth() {
476 let mut avl = AvlTree::default();
477 avl.insert(1);
478 avl.insert(2);
479 avl.insert(4);
480 avl.insert(8);
481 assert_eq!(avl.nth(0), Some(&1));
482 assert_eq!(avl.nth(1), Some(&2));
483 assert_eq!(avl.nth(2), Some(&4));
484 assert_eq!(avl.nth(3), Some(&8));
485 assert_eq!(avl.nth(4), None);
486 }
487
488 #[test]
489 fn test_avl_position() {
490 let mut avl = AvlTree::default();
491 avl.insert(1);
492 avl.insert(2);
493 avl.insert(4);
494 avl.insert(8);
495 assert_eq!(avl.position(&0), Err(0));
496 assert_eq!(avl.position(&1), Ok(0));
497 assert_eq!(avl.position(&2), Ok(1));
498 assert_eq!(avl.position(&3), Err(2));
499 assert_eq!(avl.position(&4), Ok(2));
500 assert_eq!(avl.position(&5), Err(3));
501 assert_eq!(avl.position(&6), Err(3));
502 assert_eq!(avl.position(&7), Err(3));
503 assert_eq!(avl.position(&8), Ok(3));
504 assert_eq!(avl.position(&9), Err(4));
505 }
506
507 #[test]
508 fn test_avl_iter() {
509 let mut avl = AvlTree::default();
510 avl.insert(3);
511 avl.insert(1);
512 avl.insert(4);
513 avl.insert(5);
514 avl.insert(9);
515 avl.insert(2);
516
517 let values: Vec<_> = avl.iter().collect();
518 assert_eq!(values, vec![&1, &2, &3, &4, &5, &9]);
519 }
520
521 #[test]
522 fn test_avl_into_sorted_vec() {
523 let mut avl = AvlTree::default();
524 avl.insert(3);
525 avl.insert(1);
526 avl.insert(4);
527 avl.insert(5);
528 avl.insert(9);
529 avl.insert(2);
530
531 assert_eq!(avl.into_sorted_vec(), vec![1, 2, 3, 4, 5, 9]);
532 }
533
534 #[test]
535 fn test_avl_balance() {
536 fn assert_all<T>(node: &Option<Box<Node<T>>>) {
537 if let Some(node) = node {
538 assert_all(&node.left);
539 assert!(AvlTree::balance_factor(node).abs() <= 1);
540 assert_all(&node.right);
541 };
542 }
543
544 let mut avl = AvlTree::default();
545 for x in 0..1000 {
546 avl.insert(x);
547 assert_all(&avl.root);
548 }
549 }
550}