fenwick_tree/
lib.rs

1use std::ops::{Bound, RangeBounds};
2
3/// Fenwick Tree (Binary Indexed Tree) [http://hos.ac/slides/20140319_bit.pdf](http://hos.ac/slides/20140319_bit.pdf)
4///
5/// # Examples
6/// ```
7/// use fenwick_tree::FenwickTree;
8/// let mut ft = FenwickTree::new(5, 0);
9/// ft.add(0, 1);
10/// ft.add(2, 10);
11/// ft.add(4, 100);
12/// // [1, 0, 10, 0, 100]
13/// assert_eq!(ft.sum(0..1), 1);
14/// assert_eq!(ft.sum(0..2), 1);
15/// assert_eq!(ft.sum(0..3), 11);
16/// assert_eq!(ft.sum(2..4), 10);
17/// assert_eq!(ft.sum(2..5), 110);
18/// assert_eq!(ft.sum(0..5), 111);
19/// ```
20#[derive(Clone, Debug)]
21pub struct FenwickTree<T> {
22    n: usize,
23    e: T,
24    dat: Vec<T>,
25}
26
27impl<T> FenwickTree<T>
28where
29    T: Copy,
30    T: std::ops::AddAssign,
31    T: std::ops::SubAssign,
32{
33    pub fn new(n: usize, e: T) -> Self {
34        Self {
35            n,
36            e,
37            dat: vec![e; n + 1],
38        }
39    }
40    // 0-indexed
41    // a[k] += x
42    pub fn add(&mut self, k: usize, x: T) {
43        assert!(k < self.n);
44        let mut k = k + 1;
45        while k <= self.n {
46            self.dat[k] += x;
47            k += 1 << k.trailing_zeros();
48        }
49    }
50    // 1-indexed
51    // a[1] + a[2] + ... + a[r]
52    fn _sum(&self, r: usize) -> T {
53        assert!(r <= self.n);
54        let mut result = self.e;
55        let mut k = r;
56        while k >= 1 {
57            result += self.dat[k];
58            k -= 1 << k.trailing_zeros();
59        }
60        result
61    }
62    // 0-indexed
63    pub fn sum(&self, range: impl RangeBounds<usize>) -> T {
64        let start = match range.start_bound() {
65            Bound::Included(&start) => start,
66            Bound::Excluded(&start) => start + 1,
67            Bound::Unbounded => 0,
68        };
69        let end = match range.end_bound() {
70            Bound::Included(&end) => end + 1,
71            Bound::Excluded(&end) => end,
72            Bound::Unbounded => self.n,
73        };
74        assert!(end <= self.n);
75        let mut result = self._sum(end);
76        result -= self._sum(start);
77        result
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::FenwickTree;
84    use rand::prelude::*;
85
86    #[test]
87    fn test() {
88        let mut rng = thread_rng();
89        for n in 1..=20 {
90            let mut a = vec![0; n];
91            let mut ft = FenwickTree::new(n, 0);
92            for _ in 0..100 {
93                let i = rng.gen_range(0, n);
94                let x = rng.gen_range(-100, 100);
95                a[i] += x;
96                ft.add(i, x);
97                for (l, r) in (0..n).zip(1..=n) {
98                    if l <= r {
99                        assert_eq!(a[l..r].iter().sum::<i32>(), ft.sum(l..r))
100                    }
101                }
102            }
103        }
104    }
105
106    #[test]
107    fn test_single() {
108        let mut f = FenwickTree::new(1, 0);
109        f.add(0, 123);
110        assert_eq!(f.sum(0..1), 123);
111    }
112}