use std::fmt;
use std::ops::{Bound, Index, RangeBounds};
#[derive(Clone)]
pub struct SegmentTree<T, F> {
original_n: usize,
n: usize,
dat: Vec<T>,
e: T,
multiply: F,
}
impl<T, F> SegmentTree<T, F>
where
T: Clone,
F: Fn(&T, &T) -> T,
{
pub fn new(n: usize, e: T, multiply: F) -> Self {
let original_n = n;
let n = n.next_power_of_two();
Self {
original_n,
n,
dat: vec![e.clone(); n * 2], e,
multiply,
}
}
pub fn get(&self, i: usize) -> &T {
assert!(i < self.original_n);
&self.dat[i + self.n]
}
pub fn set(&mut self, i: usize, x: T) {
self.update(i, |_| x);
}
pub fn update<U>(&mut self, i: usize, f: U)
where
U: FnOnce(&T) -> T,
{
assert!(i < self.original_n);
let mut k = i + self.n;
self.dat[k] = f(&self.dat[k]);
while k > 1 {
k >>= 1;
self.dat[k] = (self.multiply)(&self.dat[k << 1], &self.dat[k << 1 | 1]);
}
}
pub fn fold(&self, range: impl RangeBounds<usize>) -> T {
let start = match range.start_bound() {
Bound::Included(&start) => start,
Bound::Excluded(&start) => start + 1,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(&end) => end + 1,
Bound::Excluded(&end) => end,
Bound::Unbounded => self.original_n,
};
assert!(start <= end && end <= self.original_n);
self._fold(start, end)
}
fn _fold(&self, mut l: usize, mut r: usize) -> T {
let mut acc_l = self.e.clone();
let mut acc_r = self.e.clone();
l += self.n;
r += self.n;
while l < r {
if l & 1 == 1 {
acc_l = (self.multiply)(&acc_l, &self.dat[l]);
l += 1;
}
if r & 1 == 1 {
r -= 1;
acc_r = (self.multiply)(&self.dat[r], &acc_r);
}
l >>= 1;
r >>= 1;
}
(self.multiply)(&acc_l, &acc_r)
}
}
impl<T, F> Index<usize> for SegmentTree<T, F>
where
T: Clone,
F: Fn(&T, &T) -> T,
{
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
self.get(index)
}
}
impl<T, F> fmt::Debug for SegmentTree<T, F>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", &self.dat[self.n..])
}
}
#[cfg(test)]
mod tests {
use crate::SegmentTree;
#[test]
fn test() {
let s = "abcdefgh";
let mut seg = SegmentTree::new(s.len(), String::new(), |a, b| format!("{a}{b}"));
for (i, c) in s.chars().enumerate() {
seg.set(i, c.to_string());
}
for i in 0..s.len() {
assert_eq!(s[..i], seg.fold(..i));
assert_eq!(s[i..], seg.fold(i..));
}
for i in 0..s.len() {
for j in i..s.len() {
assert_eq!(s[i..j], seg.fold(i..j));
if j + 1 < s.len() {
assert_eq!(s[i..=j], seg.fold(i..=j));
}
}
}
}
#[test]
fn single_element() {
let mut seg = SegmentTree::new(1, 0, |a, b| a + b);
assert_eq!(seg[0], 0);
seg.set(0, 42);
assert_eq!(seg[0], 42);
}
}