treap/
lib.rs

1use std::{
2    cmp::{self, Ordering},
3    fmt,
4    marker::PhantomData,
5};
6
7use rand::{RngCore, SeedableRng, rngs::StdRng};
8
9struct Node<T> {
10    x: T,
11    priority: u64,
12    left: Option<Box<Node<T>>>,
13    right: Option<Box<Node<T>>>,
14    size: usize,
15}
16
17pub struct Treap<T, R> {
18    n: usize,
19    root: Option<Box<Node<T>>>,
20    rng: R,
21}
22
23impl<T, R> Treap<T, R> {
24    pub fn new(rng: R) -> Self {
25        Self {
26            n: 0,
27            root: None,
28            rng,
29        }
30    }
31
32    pub fn len(&self) -> usize {
33        self.n
34    }
35
36    pub fn is_empty(&self) -> bool {
37        self.n == 0
38    }
39
40    fn new_node(x: T, priority: u64) -> Box<Node<T>> {
41        Box::new(Node {
42            x,
43            priority,
44            left: None,
45            right: None,
46            size: 1,
47        })
48    }
49
50    fn rotate_right(mut root: Box<Node<T>>) -> Box<Node<T>> {
51        //         root                    left
52        //         |                       |
53        //     +---+---+               +---+---+
54        //     |       |               |       |
55        //    left     c       ->      a      root
56        //     |                              |
57        // +---+---+                      +---+---+
58        // |       |                      |       |
59        // a       b                      b       c
60        let mut left = root.left.take().unwrap();
61        let b = left.right.take();
62        root.left = b;
63
64        root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
65        left.size = 1 + Self::node_size(&left.left) + root.size;
66
67        left.right = Some(root);
68        left
69    }
70
71    fn rotate_left(mut root: Box<Node<T>>) -> Box<Node<T>> {
72        //      root                        right
73        //      |                           |
74        //  +---+---+                   +---+---+
75        //  |       |                   |       |
76        //  a      right        ->     root      c
77        //          |                   |
78        //      +---+---+           +---+---+
79        //      |       |           |       |
80        //      b       c           a       b
81        let mut right = root.right.take().unwrap();
82        let b = right.left.take();
83        root.right = b;
84
85        root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
86        right.size = 1 + root.size + Self::node_size(&right.right);
87
88        right.left = Some(root);
89        right
90    }
91
92    fn node_size(node: &Option<Box<Node<T>>>) -> usize {
93        node.as_ref().map_or(0, |n| n.size)
94    }
95
96    pub fn into_sorted_vec(mut self) -> Vec<T> {
97        fn collect<T>(node: Option<Box<Node<T>>>, acc: &mut Vec<T>) {
98            if let Some(node) = node {
99                collect(node.left, acc);
100                acc.push(node.x);
101                collect(node.right, acc);
102            }
103        }
104
105        let mut result = Vec::with_capacity(self.n);
106        collect(self.root.take(), &mut result);
107        self.n = 0;
108        result
109    }
110}
111
112impl<T, R> Treap<T, R>
113where
114    R: RngCore,
115{
116    fn gen_priority(&mut self) -> u64 {
117        self.rng.next_u64()
118    }
119}
120
121impl<T, R> Treap<T, R>
122where
123    T: cmp::Ord,
124{
125    fn find_last(&self, x: &T) -> Option<&Node<T>> {
126        let mut current = &self.root;
127        let mut last = Option::<&Node<T>>::None;
128
129        while let Some(node) = current {
130            last = Some(node);
131            match x.cmp(&node.x) {
132                Ordering::Less => current = &node.left,
133                Ordering::Greater => current = &node.right,
134                Ordering::Equal => return Some(node),
135            }
136        }
137
138        last
139    }
140
141    /// 集合にxが含まれるかを返す。
142    pub fn contains(&self, x: &T) -> bool {
143        self.find_last(x).is_some_and(|node| x.eq(&node.x))
144    }
145
146    /// xを削除する。集合にxが含まれていた場合trueを返す。
147    pub fn remove(&mut self, x: &T) -> bool {
148        let root = self.root.take();
149        let mut removed = false;
150        self.root = Self::remove_recursive(root, x, &mut removed);
151        if removed {
152            self.n -= 1;
153        }
154        removed
155    }
156
157    fn remove_recursive(
158        root: Option<Box<Node<T>>>,
159        x: &T,
160        removed: &mut bool,
161    ) -> Option<Box<Node<T>>> {
162        let mut root = root?;
163
164        match x.cmp(&root.x) {
165            Ordering::Less => {
166                root.left = Self::remove_recursive(root.left.take(), x, removed);
167                if *removed {
168                    root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
169                }
170                Some(root)
171            }
172            Ordering::Greater => {
173                root.right = Self::remove_recursive(root.right.take(), x, removed);
174                if *removed {
175                    root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
176                }
177                Some(root)
178            }
179            Ordering::Equal => {
180                *removed = true;
181                Self::remove_node(root)
182            }
183        }
184    }
185
186    fn remove_node(mut node: Box<Node<T>>) -> Option<Box<Node<T>>> {
187        match (&node.left, &node.right) {
188            (None, None) => None,
189            (None, Some(_)) => node.right.take(),
190            (Some(_), None) => node.left.take(),
191            (Some(left), Some(right)) => {
192                if left.priority > right.priority {
193                    let mut new_root = Self::rotate_right(node);
194                    new_root.right = Self::remove_node(new_root.right.take().unwrap());
195                    new_root.size =
196                        1 + Self::node_size(&new_root.left) + Self::node_size(&new_root.right);
197                    Some(new_root)
198                } else {
199                    let mut new_root = Self::rotate_left(node);
200                    new_root.left = Self::remove_node(new_root.left.take().unwrap());
201                    new_root.size =
202                        1 + Self::node_size(&new_root.left) + Self::node_size(&new_root.right);
203                    Some(new_root)
204                }
205            }
206        }
207    }
208
209    /// x以下の最大の要素を返す
210    pub fn le(&self, x: &T) -> Option<&T> {
211        let mut current = &self.root;
212        let mut result = None;
213
214        while let Some(node) = current {
215            match x.cmp(&node.x) {
216                Ordering::Less => current = &node.left,
217                Ordering::Greater => {
218                    result = Some(&node.x);
219                    current = &node.right;
220                }
221                Ordering::Equal => return Some(&node.x),
222            }
223        }
224
225        result
226    }
227
228    /// x以上の最小の要素を返す
229    pub fn ge(&self, x: &T) -> Option<&T> {
230        let mut current = &self.root;
231        let mut result = None;
232
233        while let Some(node) = current {
234            match x.cmp(&node.x) {
235                Ordering::Less => {
236                    result = Some(&node.x);
237                    current = &node.left;
238                }
239                Ordering::Greater => current = &node.right,
240                Ordering::Equal => return Some(&node.x),
241            }
242        }
243
244        result
245    }
246
247    /// 0-indexedでn番目の要素を返す
248    pub fn nth(&self, n: usize) -> Option<&T> {
249        if n >= self.len() {
250            return None;
251        }
252
253        let mut current = &self.root;
254        let mut n = n;
255
256        while let Some(node) = current {
257            let left_size = Self::node_size(&node.left);
258            match n.cmp(&left_size) {
259                Ordering::Less => current = &node.left,
260                Ordering::Equal => return Some(&node.x),
261                Ordering::Greater => {
262                    n -= 1 + left_size;
263                    current = &node.right;
264                }
265            }
266        }
267
268        unreachable!()
269    }
270
271    /// xより小さい要素の個数を返す
272    /// 集合がxを含む場合Ok, xを含まない場合Err
273    pub fn position(&self, x: &T) -> Result<usize, usize> {
274        let mut current = &self.root;
275        let mut count = 0;
276        let mut hit = false;
277
278        while let Some(node) = current {
279            match x.cmp(&node.x) {
280                Ordering::Less => current = &node.left,
281                Ordering::Equal => {
282                    hit = true;
283                    current = &node.left;
284                }
285                Ordering::Greater => {
286                    count += 1 + Self::node_size(&node.left);
287                    current = &node.right;
288                }
289            }
290        }
291
292        if hit { Ok(count) } else { Err(count) }
293    }
294}
295
296impl<T, R> Treap<T, R>
297where
298    T: cmp::Ord,
299    R: RngCore,
300{
301    /// xを追加する。集合にxが含まれていなかった場合trueを返す。
302    pub fn insert(&mut self, x: T) -> bool {
303        let root = self.root.take();
304        let mut inserted = false;
305        self.root = self.insert_recursive(root, x, &mut inserted);
306        if inserted {
307            self.n += 1;
308        }
309        inserted
310    }
311
312    fn insert_recursive(
313        &mut self,
314        root: Option<Box<Node<T>>>,
315        x: T,
316        inserted: &mut bool,
317    ) -> Option<Box<Node<T>>> {
318        let mut root = match root {
319            Some(root) => root,
320            None => {
321                *inserted = true;
322                return Some(Self::new_node(x, self.gen_priority()));
323            }
324        };
325
326        match x.cmp(&root.x) {
327            Ordering::Less => {
328                root.left = self.insert_recursive(root.left.take(), x, inserted);
329                if *inserted {
330                    root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
331
332                    if let Some(left) = &root.left
333                        && left.priority > root.priority
334                    {
335                        return Some(Self::rotate_right(root));
336                    }
337                }
338                Some(root)
339            }
340            Ordering::Greater => {
341                root.right = self.insert_recursive(root.right.take(), x, inserted);
342                if *inserted {
343                    root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
344
345                    if let Some(right) = &root.right
346                        && right.priority > root.priority
347                    {
348                        return Some(Self::rotate_left(root));
349                    }
350                }
351                Some(root)
352            }
353            Ordering::Equal => Some(root),
354        }
355    }
356}
357
358impl<T> Default for Treap<T, StdRng> {
359    fn default() -> Self {
360        Self::new(StdRng::seed_from_u64(12233344455555))
361    }
362}
363
364impl<T, R> fmt::Debug for Treap<T, R>
365where
366    T: fmt::Debug,
367{
368    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369        f.debug_list().entries(self.iter()).finish()
370    }
371}
372
373pub struct Iter<'a, T> {
374    stack: Vec<&'a Node<T>>,
375    _phantom: PhantomData<&'a T>,
376}
377
378impl<'a, T> Iter<'a, T> {
379    fn new(root: &'a Option<Box<Node<T>>>) -> Self {
380        let mut iter = Self {
381            stack: Vec::new(),
382            _phantom: PhantomData,
383        };
384        iter.push_left_path(root);
385        iter
386    }
387
388    fn push_left_path(&mut self, mut node: &'a Option<Box<Node<T>>>) {
389        while let Some(n) = node {
390            self.stack.push(n);
391            node = &n.left;
392        }
393    }
394}
395
396impl<'a, T> Iterator for Iter<'a, T> {
397    type Item = &'a T;
398
399    fn next(&mut self) -> Option<Self::Item> {
400        let node = self.stack.pop()?;
401        let result = &node.x;
402        self.push_left_path(&node.right);
403        Some(result)
404    }
405}
406
407impl<T, R> Treap<T, R> {
408    pub fn iter(&self) -> Iter<'_, T> {
409        Iter::new(&self.root)
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use crate::Treap;
416
417    #[test]
418    fn test_treap_insert() {
419        let mut treap = Treap::default();
420        assert_eq!(treap.insert(42), true);
421        assert_eq!(treap.insert(42), false);
422    }
423
424    #[test]
425    fn test_treap_remove() {
426        let mut treap = Treap::default();
427        treap.insert(42);
428        assert_eq!(treap.remove(&41), false);
429        assert_eq!(treap.remove(&42), true);
430        assert_eq!(treap.remove(&42), false);
431    }
432
433    #[test]
434    fn test_treap_contains() {
435        let mut treap = Treap::default();
436        treap.insert(42);
437        assert_eq!(treap.contains(&42), true);
438        assert_eq!(treap.contains(&24), false);
439    }
440
441    #[test]
442    fn test_treap_le() {
443        let mut treap = Treap::default();
444        treap.insert(42);
445        assert_eq!(treap.le(&41), None);
446        assert_eq!(treap.le(&42), Some(&42));
447        assert_eq!(treap.le(&43), Some(&42));
448    }
449
450    #[test]
451    fn test_treap_ge() {
452        let mut treap = Treap::default();
453        treap.insert(42);
454        assert_eq!(treap.ge(&41), Some(&42));
455        assert_eq!(treap.ge(&42), Some(&42));
456        assert_eq!(treap.ge(&43), None);
457    }
458
459    #[test]
460    fn test_treap_nth() {
461        let mut treap = Treap::default();
462        treap.insert(1);
463        treap.insert(2);
464        treap.insert(4);
465        treap.insert(8);
466        assert_eq!(treap.nth(0), Some(&1));
467        assert_eq!(treap.nth(1), Some(&2));
468        assert_eq!(treap.nth(2), Some(&4));
469        assert_eq!(treap.nth(3), Some(&8));
470        assert_eq!(treap.nth(4), None);
471    }
472
473    #[test]
474    fn test_treap_position() {
475        let mut treap = Treap::default();
476        treap.insert(1);
477        treap.insert(2);
478        treap.insert(4);
479        treap.insert(8);
480        assert_eq!(treap.position(&0), Err(0));
481        assert_eq!(treap.position(&1), Ok(0));
482        assert_eq!(treap.position(&2), Ok(1));
483        assert_eq!(treap.position(&3), Err(2));
484        assert_eq!(treap.position(&4), Ok(2));
485        assert_eq!(treap.position(&5), Err(3));
486        assert_eq!(treap.position(&6), Err(3));
487        assert_eq!(treap.position(&7), Err(3));
488        assert_eq!(treap.position(&8), Ok(3));
489        assert_eq!(treap.position(&9), Err(4));
490    }
491
492    #[test]
493    fn test_treap_iter() {
494        let mut treap = Treap::default();
495        treap.insert(3);
496        treap.insert(1);
497        treap.insert(4);
498        treap.insert(5);
499        treap.insert(9);
500        treap.insert(2);
501
502        let values: Vec<_> = treap.iter().collect();
503        assert_eq!(values, vec![&1, &2, &3, &4, &5, &9]);
504    }
505
506    #[test]
507    fn test_treap_into_sorted_vec() {
508        let mut treap = Treap::default();
509        treap.insert(3);
510        treap.insert(1);
511        treap.insert(4);
512        treap.insert(5);
513        treap.insert(9);
514        treap.insert(2);
515
516        assert_eq!(treap.into_sorted_vec(), vec![1, 2, 3, 4, 5, 9]);
517    }
518}