segment_tree/
lib.rs

1use std::fmt;
2use std::ops::{Bound, Index, RangeBounds};
3
4/// __注意⚠__ この実装は遅いので time limit の厳しい問題には代わりに ACL のセグメントツリーを使うこと。
5///
6/// セグメントツリーです。
7#[derive(Clone)]
8pub struct SegmentTree<T, F> {
9    original_n: usize,
10    n: usize,
11    dat: Vec<T>,
12    e: T,
13    multiply: F,
14}
15
16// https://hcpc-hokudai.github.io/archive/structure_segtree_001.pdf
17impl<T, F> SegmentTree<T, F>
18where
19    T: Clone,
20    F: Fn(&T, &T) -> T,
21{
22    /// 長さ `n` の列を初期値 `e` で初期化します。
23    ///
24    /// `multiply` は fold に使う二項演算です。
25    pub fn new(n: usize, e: T, multiply: F) -> Self {
26        let original_n = n;
27        let n = n.next_power_of_two();
28        Self {
29            original_n,
30            n,
31            dat: vec![e.clone(); n * 2], // dat[0] is unused
32            e,
33            multiply,
34        }
35    }
36
37    /// 列の `i` 番目の要素を取得します。
38    pub fn get(&self, i: usize) -> &T {
39        assert!(i < self.original_n);
40        &self.dat[i + self.n]
41    }
42
43    /// 列の `i` 番目の要素を `x` で更新します。
44    pub fn set(&mut self, i: usize, x: T) {
45        self.update(i, |_| x);
46    }
47
48    /// 列の `i` 番目の要素を `f` で更新します。
49    pub fn update<U>(&mut self, i: usize, f: U)
50    where
51        U: FnOnce(&T) -> T,
52    {
53        assert!(i < self.original_n);
54        let mut k = i + self.n;
55        self.dat[k] = f(&self.dat[k]);
56        while k > 1 {
57            k >>= 1;
58            self.dat[k] = (self.multiply)(&self.dat[k << 1], &self.dat[k << 1 | 1]);
59        }
60    }
61
62    /// `range` が `l..r` として、`multiply(l番目の要素, multiply(..., multiply(r-2番目の要素, r-1番目の要素)))` の値を返します。
63    pub fn fold(&self, range: impl RangeBounds<usize>) -> T {
64        let start = match range.start_bound() {
65            Bound::Included(&start) => start,
66            Bound::Excluded(&start) => start + 1,
67            Bound::Unbounded => 0,
68        };
69        let end = match range.end_bound() {
70            Bound::Included(&end) => end + 1,
71            Bound::Excluded(&end) => end,
72            Bound::Unbounded => self.original_n,
73        };
74        assert!(start <= end && end <= self.original_n);
75        self._fold(start, end)
76    }
77
78    /// `f(fold(l..r)) = true` となる最大の `r` を返します。
79    ///
80    /// # Panics
81    ///
82    /// if `f(e) = false`
83    pub fn max_right<P>(&self, l: usize, f: P) -> usize
84    where
85        P: Fn(&T) -> bool,
86    {
87        assert!(l <= self.original_n);
88        assert!(f(&self.e), "f(e) must be true");
89
90        if l == self.original_n {
91            return self.original_n;
92        }
93
94        let mut l = l + self.n;
95        let mut sum = self.e.clone();
96
97        loop {
98            // l を含む区間の右端まで進む
99            while l % 2 == 0 {
100                l >>= 1;
101            }
102
103            let new_sum = (self.multiply)(&sum, &self.dat[l]);
104            if !f(&new_sum) {
105                while l < self.n {
106                    l <<= 1;
107                    let new_sum = (self.multiply)(&sum, &self.dat[l]);
108                    if f(&new_sum) {
109                        sum = new_sum;
110                        l += 1;
111                    }
112                }
113                return l - self.n;
114            }
115
116            sum = new_sum;
117            l += 1;
118
119            if (l & (l.wrapping_neg())) == l {
120                break;
121            }
122        }
123
124        self.original_n
125    }
126
127    /// `f(fold(l..r)) = true` となる最小の `l` を返します。
128    ///
129    /// # Panics
130    ///
131    /// if `f(e) = false`
132    pub fn min_left<P>(&self, r: usize, f: P) -> usize
133    where
134        P: Fn(&T) -> bool,
135    {
136        assert!(r <= self.original_n);
137        assert!(f(&self.e), "f(e) must be true");
138
139        if r == 0 {
140            return 0;
141        }
142
143        let mut r = r + self.n;
144        let mut sum = self.e.clone();
145
146        loop {
147            r -= 1;
148            while r > 1 && r % 2 == 1 {
149                r >>= 1;
150            }
151
152            let new_sum = (self.multiply)(&self.dat[r], &sum);
153            if !f(&new_sum) {
154                while r < self.n {
155                    r = r * 2 + 1;
156                    let new_sum = (self.multiply)(&self.dat[r], &sum);
157                    if f(&new_sum) {
158                        sum = new_sum;
159                        r -= 1;
160                    }
161                }
162                return r + 1 - self.n;
163            }
164
165            sum = new_sum;
166
167            if (r & (r.wrapping_neg())) == r {
168                break;
169            }
170        }
171
172        0
173    }
174
175    fn _fold(&self, mut l: usize, mut r: usize) -> T {
176        let mut acc_l = self.e.clone();
177        let mut acc_r = self.e.clone();
178        l += self.n;
179        r += self.n;
180        while l < r {
181            if l & 1 == 1 {
182                // 右の子だったらいま足しておかないといけない
183                // 左の子だったら祖先のどれかで足されるのでよい
184                acc_l = (self.multiply)(&acc_l, &self.dat[l]);
185                l += 1;
186            }
187            if r & 1 == 1 {
188                // r が exclusive であることに注意する
189                r -= 1;
190                acc_r = (self.multiply)(&self.dat[r], &acc_r);
191            }
192            l >>= 1;
193            r >>= 1;
194        }
195        (self.multiply)(&acc_l, &acc_r)
196    }
197}
198
199impl<T, F> Index<usize> for SegmentTree<T, F>
200where
201    T: Clone,
202    F: Fn(&T, &T) -> T,
203{
204    type Output = T;
205
206    fn index(&self, index: usize) -> &Self::Output {
207        self.get(index)
208    }
209}
210
211impl<T, F> fmt::Debug for SegmentTree<T, F>
212where
213    T: fmt::Debug,
214{
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        write!(f, "{:?}", &self.dat[self.n..])
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use crate::SegmentTree;
223
224    #[test]
225    fn test() {
226        let s = "abcdefgh";
227        let mut seg = SegmentTree::new(s.len(), String::new(), |a, b| format!("{a}{b}"));
228        for (i, c) in s.chars().enumerate() {
229            seg.set(i, c.to_string());
230        }
231
232        for i in 0..s.len() {
233            assert_eq!(s[..i], seg.fold(..i));
234            assert_eq!(s[i..], seg.fold(i..));
235        }
236
237        for i in 0..s.len() {
238            for j in i..s.len() {
239                assert_eq!(s[i..j], seg.fold(i..j));
240                if j + 1 < s.len() {
241                    assert_eq!(s[i..=j], seg.fold(i..=j));
242                }
243            }
244        }
245    }
246
247    #[test]
248    fn single_element() {
249        let mut seg = SegmentTree::new(1, 0, |a, b| a + b);
250        assert_eq!(seg[0], 0);
251        seg.set(0, 42);
252        assert_eq!(seg[0], 42);
253    }
254
255    #[test]
256    fn test_max_right() {
257        let n = 9;
258        let mut seg = SegmentTree::new(n, 0, |a, b| a + b);
259        let values = vec![3, 1, 4, 1, 5, 9, 2, 6, 5];
260        for (i, &v) in values.iter().enumerate() {
261            seg.set(i, v);
262        }
263
264        // 区間和
265        assert_eq!(seg.max_right(0, |&sum| sum < 9), 3); // 3 + 1 + 4 = 8
266        assert_eq!(seg.max_right(0, |&sum| sum <= 9), 4); // 3 + 1 + 4 + 1 = 9
267
268        assert_eq!(seg.max_right(1, |&sum| sum < 11), 4); // 1 + 4 + 1 = 6
269        assert_eq!(seg.max_right(1, |&sum| sum <= 11), 5); // 1 + 4 + 1 + 5 = 11
270
271        assert_eq!(seg.max_right(2, |&sum| sum < 4), 2);
272        assert_eq!(seg.max_right(2, |&sum| sum <= 4), 3);
273        assert_eq!(seg.max_right(2, |&sum| sum <= 100), n);
274
275        assert_eq!(seg.max_right(n, |&sum| sum <= 0), n);
276        assert_eq!(seg.max_right(n, |&sum| sum <= 100), n);
277    }
278
279    #[test]
280    fn test_min_left() {
281        let n = 9;
282        let mut seg = SegmentTree::new(n, 0, |a, b| a + b);
283        let values = vec![3, 1, 4, 1, 5, 9, 2, 6, 5];
284        for (i, &v) in values.iter().enumerate() {
285            seg.set(i, v);
286        }
287
288        // 区間和
289        assert_eq!(seg.min_left(n, |&sum| sum <= 22), 5); // 9 + 2 + 6 + 5 = 22
290        assert_eq!(seg.min_left(n, |&sum| sum < 22), 6); // 2 + 6 + 5 = 13
291
292        assert_eq!(seg.min_left(n - 1, |&sum| sum <= 27), 2); // 4 + 1 + 5 + 9 + 2 + 6 = 27
293        assert_eq!(seg.min_left(n - 1, |&sum| sum < 27), 3); // 1 + 5 + 9 + 2 + 6 = 23
294        assert_eq!(seg.min_left(n - 1, |&sum| sum < 100), 0);
295
296        assert_eq!(seg.min_left(0, |&sum| sum <= 0), 0);
297        assert_eq!(seg.min_left(0, |&sum| sum <= 100), 0);
298    }
299}