1use std::{
2 cmp::{self, Ordering},
3 fmt,
4 marker::PhantomData,
5};
6
7use rand::{RngCore, SeedableRng, rngs::StdRng};
8
9struct Node<T> {
10 x: T,
11 priority: u64,
12 left: Option<Box<Node<T>>>,
13 right: Option<Box<Node<T>>>,
14 size: usize,
15}
16
17pub struct Treap<T, R> {
18 n: usize,
19 root: Option<Box<Node<T>>>,
20 rng: R,
21}
22
23impl<T, R> Treap<T, R> {
24 pub fn new(rng: R) -> Self {
25 Self {
26 n: 0,
27 root: None,
28 rng,
29 }
30 }
31
32 pub fn len(&self) -> usize {
33 self.n
34 }
35
36 pub fn is_empty(&self) -> bool {
37 self.n == 0
38 }
39
40 fn new_node(x: T, priority: u64) -> Box<Node<T>> {
41 Box::new(Node {
42 x,
43 priority,
44 left: None,
45 right: None,
46 size: 1,
47 })
48 }
49
50 fn rotate_right(mut root: Box<Node<T>>) -> Box<Node<T>> {
51 let mut left = root.left.take().unwrap();
61 let b = left.right.take();
62 root.left = b;
63
64 root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
65 left.size = 1 + Self::node_size(&left.left) + root.size;
66
67 left.right = Some(root);
68 left
69 }
70
71 fn rotate_left(mut root: Box<Node<T>>) -> Box<Node<T>> {
72 let mut right = root.right.take().unwrap();
82 let b = right.left.take();
83 root.right = b;
84
85 root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
86 right.size = 1 + root.size + Self::node_size(&right.right);
87
88 right.left = Some(root);
89 right
90 }
91
92 fn node_size(node: &Option<Box<Node<T>>>) -> usize {
93 node.as_ref().map_or(0, |n| n.size)
94 }
95
96 pub fn into_sorted_vec(mut self) -> Vec<T> {
97 fn collect<T>(node: Option<Box<Node<T>>>, acc: &mut Vec<T>) {
98 if let Some(node) = node {
99 collect(node.left, acc);
100 acc.push(node.x);
101 collect(node.right, acc);
102 }
103 }
104
105 let mut result = Vec::with_capacity(self.n);
106 collect(self.root.take(), &mut result);
107 self.n = 0;
108 result
109 }
110}
111
112impl<T, R> Treap<T, R>
113where
114 R: RngCore,
115{
116 fn gen_priority(&mut self) -> u64 {
117 self.rng.next_u64()
118 }
119}
120
121impl<T, R> Treap<T, R>
122where
123 T: cmp::Ord,
124{
125 fn find_last(&self, x: &T) -> Option<&Node<T>> {
126 let mut current = &self.root;
127 let mut last = Option::<&Node<T>>::None;
128
129 while let Some(node) = current {
130 last = Some(node);
131 match x.cmp(&node.x) {
132 Ordering::Less => current = &node.left,
133 Ordering::Greater => current = &node.right,
134 Ordering::Equal => return Some(node),
135 }
136 }
137
138 last
139 }
140
141 pub fn contains(&self, x: &T) -> bool {
143 self.find_last(x).is_some_and(|node| x.eq(&node.x))
144 }
145
146 pub fn remove(&mut self, x: &T) -> bool {
148 let root = self.root.take();
149 let mut removed = false;
150 self.root = Self::remove_recursive(root, x, &mut removed);
151 if removed {
152 self.n -= 1;
153 }
154 removed
155 }
156
157 fn remove_recursive(
158 root: Option<Box<Node<T>>>,
159 x: &T,
160 removed: &mut bool,
161 ) -> Option<Box<Node<T>>> {
162 let mut root = root?;
163
164 match x.cmp(&root.x) {
165 Ordering::Less => {
166 root.left = Self::remove_recursive(root.left.take(), x, removed);
167 if *removed {
168 root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
169 }
170 Some(root)
171 }
172 Ordering::Greater => {
173 root.right = Self::remove_recursive(root.right.take(), x, removed);
174 if *removed {
175 root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
176 }
177 Some(root)
178 }
179 Ordering::Equal => {
180 *removed = true;
181 Self::remove_node(root)
182 }
183 }
184 }
185
186 fn remove_node(mut node: Box<Node<T>>) -> Option<Box<Node<T>>> {
187 match (&node.left, &node.right) {
188 (None, None) => None,
189 (None, Some(_)) => node.right.take(),
190 (Some(_), None) => node.left.take(),
191 (Some(left), Some(right)) => {
192 if left.priority > right.priority {
193 let mut new_root = Self::rotate_right(node);
194 new_root.right = Self::remove_node(new_root.right.take().unwrap());
195 new_root.size =
196 1 + Self::node_size(&new_root.left) + Self::node_size(&new_root.right);
197 Some(new_root)
198 } else {
199 let mut new_root = Self::rotate_left(node);
200 new_root.left = Self::remove_node(new_root.left.take().unwrap());
201 new_root.size =
202 1 + Self::node_size(&new_root.left) + Self::node_size(&new_root.right);
203 Some(new_root)
204 }
205 }
206 }
207 }
208
209 pub fn le(&self, x: &T) -> Option<&T> {
211 let mut current = &self.root;
212 let mut result = None;
213
214 while let Some(node) = current {
215 match x.cmp(&node.x) {
216 Ordering::Less => current = &node.left,
217 Ordering::Greater => {
218 result = Some(&node.x);
219 current = &node.right;
220 }
221 Ordering::Equal => return Some(&node.x),
222 }
223 }
224
225 result
226 }
227
228 pub fn ge(&self, x: &T) -> Option<&T> {
230 let mut current = &self.root;
231 let mut result = None;
232
233 while let Some(node) = current {
234 match x.cmp(&node.x) {
235 Ordering::Less => {
236 result = Some(&node.x);
237 current = &node.left;
238 }
239 Ordering::Greater => current = &node.right,
240 Ordering::Equal => return Some(&node.x),
241 }
242 }
243
244 result
245 }
246
247 pub fn nth(&self, n: usize) -> Option<&T> {
249 if n >= self.len() {
250 return None;
251 }
252
253 let mut current = &self.root;
254 let mut n = n;
255
256 while let Some(node) = current {
257 let left_size = Self::node_size(&node.left);
258 match n.cmp(&left_size) {
259 Ordering::Less => current = &node.left,
260 Ordering::Equal => return Some(&node.x),
261 Ordering::Greater => {
262 n -= 1 + left_size;
263 current = &node.right;
264 }
265 }
266 }
267
268 unreachable!()
269 }
270
271 pub fn position(&self, x: &T) -> Result<usize, usize> {
274 let mut current = &self.root;
275 let mut count = 0;
276 let mut hit = false;
277
278 while let Some(node) = current {
279 match x.cmp(&node.x) {
280 Ordering::Less => current = &node.left,
281 Ordering::Equal => {
282 hit = true;
283 current = &node.left;
284 }
285 Ordering::Greater => {
286 count += 1 + Self::node_size(&node.left);
287 current = &node.right;
288 }
289 }
290 }
291
292 if hit { Ok(count) } else { Err(count) }
293 }
294}
295
296impl<T, R> Treap<T, R>
297where
298 T: cmp::Ord,
299 R: RngCore,
300{
301 pub fn insert(&mut self, x: T) -> bool {
303 let root = self.root.take();
304 let mut inserted = false;
305 self.root = self.insert_recursive(root, x, &mut inserted);
306 if inserted {
307 self.n += 1;
308 }
309 inserted
310 }
311
312 fn insert_recursive(
313 &mut self,
314 root: Option<Box<Node<T>>>,
315 x: T,
316 inserted: &mut bool,
317 ) -> Option<Box<Node<T>>> {
318 let mut root = match root {
319 Some(root) => root,
320 None => {
321 *inserted = true;
322 return Some(Self::new_node(x, self.gen_priority()));
323 }
324 };
325
326 match x.cmp(&root.x) {
327 Ordering::Less => {
328 root.left = self.insert_recursive(root.left.take(), x, inserted);
329 if *inserted {
330 root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
331
332 if let Some(left) = &root.left
333 && left.priority > root.priority
334 {
335 return Some(Self::rotate_right(root));
336 }
337 }
338 Some(root)
339 }
340 Ordering::Greater => {
341 root.right = self.insert_recursive(root.right.take(), x, inserted);
342 if *inserted {
343 root.size = 1 + Self::node_size(&root.left) + Self::node_size(&root.right);
344
345 if let Some(right) = &root.right
346 && right.priority > root.priority
347 {
348 return Some(Self::rotate_left(root));
349 }
350 }
351 Some(root)
352 }
353 Ordering::Equal => Some(root),
354 }
355 }
356}
357
358impl<T> Default for Treap<T, StdRng> {
359 fn default() -> Self {
360 Self::new(StdRng::seed_from_u64(12233344455555))
361 }
362}
363
364impl<T, R> fmt::Debug for Treap<T, R>
365where
366 T: fmt::Debug,
367{
368 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369 f.debug_list().entries(self.iter()).finish()
370 }
371}
372
373pub struct Iter<'a, T> {
374 stack: Vec<&'a Node<T>>,
375 _phantom: PhantomData<&'a T>,
376}
377
378impl<'a, T> Iter<'a, T> {
379 fn new(root: &'a Option<Box<Node<T>>>) -> Self {
380 let mut iter = Self {
381 stack: Vec::new(),
382 _phantom: PhantomData,
383 };
384 iter.push_left_path(root);
385 iter
386 }
387
388 fn push_left_path(&mut self, mut node: &'a Option<Box<Node<T>>>) {
389 while let Some(n) = node {
390 self.stack.push(n);
391 node = &n.left;
392 }
393 }
394}
395
396impl<'a, T> Iterator for Iter<'a, T> {
397 type Item = &'a T;
398
399 fn next(&mut self) -> Option<Self::Item> {
400 let node = self.stack.pop()?;
401 let result = &node.x;
402 self.push_left_path(&node.right);
403 Some(result)
404 }
405}
406
407impl<T, R> Treap<T, R> {
408 pub fn iter(&self) -> Iter<'_, T> {
409 Iter::new(&self.root)
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use crate::Treap;
416
417 #[test]
418 fn test_treap_insert() {
419 let mut treap = Treap::default();
420 assert_eq!(treap.insert(42), true);
421 assert_eq!(treap.insert(42), false);
422 }
423
424 #[test]
425 fn test_treap_remove() {
426 let mut treap = Treap::default();
427 treap.insert(42);
428 assert_eq!(treap.remove(&41), false);
429 assert_eq!(treap.remove(&42), true);
430 assert_eq!(treap.remove(&42), false);
431 }
432
433 #[test]
434 fn test_treap_contains() {
435 let mut treap = Treap::default();
436 treap.insert(42);
437 assert_eq!(treap.contains(&42), true);
438 assert_eq!(treap.contains(&24), false);
439 }
440
441 #[test]
442 fn test_treap_le() {
443 let mut treap = Treap::default();
444 treap.insert(42);
445 assert_eq!(treap.le(&41), None);
446 assert_eq!(treap.le(&42), Some(&42));
447 assert_eq!(treap.le(&43), Some(&42));
448 }
449
450 #[test]
451 fn test_treap_ge() {
452 let mut treap = Treap::default();
453 treap.insert(42);
454 assert_eq!(treap.ge(&41), Some(&42));
455 assert_eq!(treap.ge(&42), Some(&42));
456 assert_eq!(treap.ge(&43), None);
457 }
458
459 #[test]
460 fn test_treap_nth() {
461 let mut treap = Treap::default();
462 treap.insert(1);
463 treap.insert(2);
464 treap.insert(4);
465 treap.insert(8);
466 assert_eq!(treap.nth(0), Some(&1));
467 assert_eq!(treap.nth(1), Some(&2));
468 assert_eq!(treap.nth(2), Some(&4));
469 assert_eq!(treap.nth(3), Some(&8));
470 assert_eq!(treap.nth(4), None);
471 }
472
473 #[test]
474 fn test_treap_position() {
475 let mut treap = Treap::default();
476 treap.insert(1);
477 treap.insert(2);
478 treap.insert(4);
479 treap.insert(8);
480 assert_eq!(treap.position(&0), Err(0));
481 assert_eq!(treap.position(&1), Ok(0));
482 assert_eq!(treap.position(&2), Ok(1));
483 assert_eq!(treap.position(&3), Err(2));
484 assert_eq!(treap.position(&4), Ok(2));
485 assert_eq!(treap.position(&5), Err(3));
486 assert_eq!(treap.position(&6), Err(3));
487 assert_eq!(treap.position(&7), Err(3));
488 assert_eq!(treap.position(&8), Ok(3));
489 assert_eq!(treap.position(&9), Err(4));
490 }
491
492 #[test]
493 fn test_treap_iter() {
494 let mut treap = Treap::default();
495 treap.insert(3);
496 treap.insert(1);
497 treap.insert(4);
498 treap.insert(5);
499 treap.insert(9);
500 treap.insert(2);
501
502 let values: Vec<_> = treap.iter().collect();
503 assert_eq!(values, vec![&1, &2, &3, &4, &5, &9]);
504 }
505
506 #[test]
507 fn test_treap_into_sorted_vec() {
508 let mut treap = Treap::default();
509 treap.insert(3);
510 treap.insert(1);
511 treap.insert(4);
512 treap.insert(5);
513 treap.insert(9);
514 treap.insert(2);
515
516 assert_eq!(treap.into_sorted_vec(), vec![1, 2, 3, 4, 5, 9]);
517 }
518}