1use std::fmt;
2use std::ops::{Bound, Index, RangeBounds};
3
4#[derive(Clone)]
8pub struct SegmentTree<T, F> {
9 original_n: usize,
10 n: usize,
11 dat: Vec<T>,
12 e: T,
13 multiply: F,
14}
15
16impl<T, F> SegmentTree<T, F>
18where
19 T: Clone,
20 F: Fn(&T, &T) -> T,
21{
22 pub fn new(n: usize, e: T, multiply: F) -> Self {
26 let original_n = n;
27 let n = n.next_power_of_two();
28 Self {
29 original_n,
30 n,
31 dat: vec![e.clone(); n * 2], e,
33 multiply,
34 }
35 }
36
37 pub fn get(&self, i: usize) -> &T {
39 assert!(i < self.original_n);
40 &self.dat[i + self.n]
41 }
42
43 pub fn set(&mut self, i: usize, x: T) {
45 self.update(i, |_| x);
46 }
47
48 pub fn update<U>(&mut self, i: usize, f: U)
50 where
51 U: FnOnce(&T) -> T,
52 {
53 assert!(i < self.original_n);
54 let mut k = i + self.n;
55 self.dat[k] = f(&self.dat[k]);
56 while k > 1 {
57 k >>= 1;
58 self.dat[k] = (self.multiply)(&self.dat[k << 1], &self.dat[k << 1 | 1]);
59 }
60 }
61
62 pub fn fold(&self, range: impl RangeBounds<usize>) -> T {
64 let start = match range.start_bound() {
65 Bound::Included(&start) => start,
66 Bound::Excluded(&start) => start + 1,
67 Bound::Unbounded => 0,
68 };
69 let end = match range.end_bound() {
70 Bound::Included(&end) => end + 1,
71 Bound::Excluded(&end) => end,
72 Bound::Unbounded => self.original_n,
73 };
74 assert!(start <= end && end <= self.original_n);
75 self._fold(start, end)
76 }
77
78 pub fn max_right<P>(&self, l: usize, f: P) -> usize
84 where
85 P: Fn(&T) -> bool,
86 {
87 assert!(l <= self.original_n);
88 assert!(f(&self.e), "f(e) must be true");
89
90 if l == self.original_n {
91 return self.original_n;
92 }
93
94 let mut l = l + self.n;
95 let mut sum = self.e.clone();
96
97 loop {
98 while l % 2 == 0 {
100 l >>= 1;
101 }
102
103 let new_sum = (self.multiply)(&sum, &self.dat[l]);
104 if !f(&new_sum) {
105 while l < self.n {
106 l <<= 1;
107 let new_sum = (self.multiply)(&sum, &self.dat[l]);
108 if f(&new_sum) {
109 sum = new_sum;
110 l += 1;
111 }
112 }
113 return l - self.n;
114 }
115
116 sum = new_sum;
117 l += 1;
118
119 if (l & (l.wrapping_neg())) == l {
120 break;
121 }
122 }
123
124 self.original_n
125 }
126
127 pub fn min_left<P>(&self, r: usize, f: P) -> usize
133 where
134 P: Fn(&T) -> bool,
135 {
136 assert!(r <= self.original_n);
137 assert!(f(&self.e), "f(e) must be true");
138
139 if r == 0 {
140 return 0;
141 }
142
143 let mut r = r + self.n;
144 let mut sum = self.e.clone();
145
146 loop {
147 r -= 1;
148 while r > 1 && r % 2 == 1 {
149 r >>= 1;
150 }
151
152 let new_sum = (self.multiply)(&self.dat[r], &sum);
153 if !f(&new_sum) {
154 while r < self.n {
155 r = r * 2 + 1;
156 let new_sum = (self.multiply)(&self.dat[r], &sum);
157 if f(&new_sum) {
158 sum = new_sum;
159 r -= 1;
160 }
161 }
162 return r + 1 - self.n;
163 }
164
165 sum = new_sum;
166
167 if (r & (r.wrapping_neg())) == r {
168 break;
169 }
170 }
171
172 0
173 }
174
175 fn _fold(&self, mut l: usize, mut r: usize) -> T {
176 let mut acc_l = self.e.clone();
177 let mut acc_r = self.e.clone();
178 l += self.n;
179 r += self.n;
180 while l < r {
181 if l & 1 == 1 {
182 acc_l = (self.multiply)(&acc_l, &self.dat[l]);
185 l += 1;
186 }
187 if r & 1 == 1 {
188 r -= 1;
190 acc_r = (self.multiply)(&self.dat[r], &acc_r);
191 }
192 l >>= 1;
193 r >>= 1;
194 }
195 (self.multiply)(&acc_l, &acc_r)
196 }
197}
198
199impl<T, F> Index<usize> for SegmentTree<T, F>
200where
201 T: Clone,
202 F: Fn(&T, &T) -> T,
203{
204 type Output = T;
205
206 fn index(&self, index: usize) -> &Self::Output {
207 self.get(index)
208 }
209}
210
211impl<T, F> fmt::Debug for SegmentTree<T, F>
212where
213 T: fmt::Debug,
214{
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 write!(f, "{:?}", &self.dat[self.n..])
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use crate::SegmentTree;
223
224 #[test]
225 fn test() {
226 let s = "abcdefgh";
227 let mut seg = SegmentTree::new(s.len(), String::new(), |a, b| format!("{a}{b}"));
228 for (i, c) in s.chars().enumerate() {
229 seg.set(i, c.to_string());
230 }
231
232 for i in 0..s.len() {
233 assert_eq!(s[..i], seg.fold(..i));
234 assert_eq!(s[i..], seg.fold(i..));
235 }
236
237 for i in 0..s.len() {
238 for j in i..s.len() {
239 assert_eq!(s[i..j], seg.fold(i..j));
240 if j + 1 < s.len() {
241 assert_eq!(s[i..=j], seg.fold(i..=j));
242 }
243 }
244 }
245 }
246
247 #[test]
248 fn single_element() {
249 let mut seg = SegmentTree::new(1, 0, |a, b| a + b);
250 assert_eq!(seg[0], 0);
251 seg.set(0, 42);
252 assert_eq!(seg[0], 42);
253 }
254
255 #[test]
256 fn test_max_right() {
257 let n = 9;
258 let mut seg = SegmentTree::new(n, 0, |a, b| a + b);
259 let values = vec![3, 1, 4, 1, 5, 9, 2, 6, 5];
260 for (i, &v) in values.iter().enumerate() {
261 seg.set(i, v);
262 }
263
264 assert_eq!(seg.max_right(0, |&sum| sum < 9), 3); assert_eq!(seg.max_right(0, |&sum| sum <= 9), 4); assert_eq!(seg.max_right(1, |&sum| sum < 11), 4); assert_eq!(seg.max_right(1, |&sum| sum <= 11), 5); assert_eq!(seg.max_right(2, |&sum| sum < 4), 2);
272 assert_eq!(seg.max_right(2, |&sum| sum <= 4), 3);
273 assert_eq!(seg.max_right(2, |&sum| sum <= 100), n);
274
275 assert_eq!(seg.max_right(n, |&sum| sum <= 0), n);
276 assert_eq!(seg.max_right(n, |&sum| sum <= 100), n);
277 }
278
279 #[test]
280 fn test_min_left() {
281 let n = 9;
282 let mut seg = SegmentTree::new(n, 0, |a, b| a + b);
283 let values = vec![3, 1, 4, 1, 5, 9, 2, 6, 5];
284 for (i, &v) in values.iter().enumerate() {
285 seg.set(i, v);
286 }
287
288 assert_eq!(seg.min_left(n, |&sum| sum <= 22), 5); assert_eq!(seg.min_left(n, |&sum| sum < 22), 6); assert_eq!(seg.min_left(n - 1, |&sum| sum <= 27), 2); assert_eq!(seg.min_left(n - 1, |&sum| sum < 27), 3); assert_eq!(seg.min_left(n - 1, |&sum| sum < 100), 0);
295
296 assert_eq!(seg.min_left(0, |&sum| sum <= 0), 0);
297 assert_eq!(seg.min_left(0, |&sum| sum <= 100), 0);
298 }
299}