1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
use std::fmt;
use std::ops::{Bound, Index, RangeBounds};

/// __注意⚠__ この実装は遅いので time limit の厳しい問題には代わりに ACL のセグメントツリーを使うこと。
///
/// セグメントツリーです。
#[derive(Clone)]
pub struct SegmentTree<T, F> {
    original_n: usize,
    n: usize,
    dat: Vec<T>,
    e: T,
    multiply: F,
}

// https://hcpc-hokudai.github.io/archive/structure_segtree_001.pdf
impl<T, F> SegmentTree<T, F>
where
    T: Clone,
    F: Fn(&T, &T) -> T,
{
    /// 長さ `n` の列を初期値 `e` で初期化します。
    ///
    /// `multiply` は fold に使う二項演算です。
    pub fn new(n: usize, e: T, multiply: F) -> Self {
        let original_n = n;
        let n = n.next_power_of_two();
        Self {
            original_n,
            n,
            dat: vec![e.clone(); n * 2], // dat[0] is unused
            e,
            multiply,
        }
    }

    /// 列の `i` 番目の要素を取得します。
    pub fn get(&self, i: usize) -> &T {
        assert!(i < self.original_n);
        &self.dat[i + self.n]
    }

    /// 列の `i` 番目の要素を `x` で更新します。
    pub fn set(&mut self, i: usize, x: T) {
        self.update(i, |_| x);
    }

    /// 列の `i` 番目の要素を `f` で更新します。
    pub fn update<U>(&mut self, i: usize, f: U)
    where
        U: FnOnce(&T) -> T,
    {
        assert!(i < self.original_n);
        let mut k = i + self.n;
        self.dat[k] = f(&self.dat[k]);
        while k > 1 {
            k >>= 1;
            self.dat[k] = (self.multiply)(&self.dat[k << 1], &self.dat[k << 1 | 1]);
        }
    }

    /// `range` が `l..r` として、`multiply(l番目の要素, multiply(..., multiply(r-2番目の要素, r-1番目の要素)))` の値を返します。
    pub fn fold(&self, range: impl RangeBounds<usize>) -> T {
        let start = match range.start_bound() {
            Bound::Included(&start) => start,
            Bound::Excluded(&start) => start + 1,
            Bound::Unbounded => 0,
        };
        let end = match range.end_bound() {
            Bound::Included(&end) => end + 1,
            Bound::Excluded(&end) => end,
            Bound::Unbounded => self.original_n,
        };
        assert!(start <= end && end <= self.original_n);
        self._fold(start, end)
    }

    fn _fold(&self, mut l: usize, mut r: usize) -> T {
        let mut acc_l = self.e.clone();
        let mut acc_r = self.e.clone();
        l += self.n;
        r += self.n;
        while l < r {
            if l & 1 == 1 {
                // 右の子だったらいま足しておかないといけない
                // 左の子だったら祖先のどれかで足されるのでよい
                acc_l = (self.multiply)(&acc_l, &self.dat[l]);
                l += 1;
            }
            if r & 1 == 1 {
                // r が exclusive であることに注意する
                r -= 1;
                acc_r = (self.multiply)(&self.dat[r], &acc_r);
            }
            l >>= 1;
            r >>= 1;
        }
        (self.multiply)(&acc_l, &acc_r)
    }
}

impl<T, F> Index<usize> for SegmentTree<T, F>
where
    T: Clone,
    F: Fn(&T, &T) -> T,
{
    type Output = T;

    fn index(&self, index: usize) -> &Self::Output {
        self.get(index)
    }
}

impl<T, F> fmt::Debug for SegmentTree<T, F>
where
    T: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{:?}", &self.dat[self.n..])
    }
}

#[cfg(test)]
mod tests {
    use crate::SegmentTree;

    #[test]
    fn test() {
        let s = "abcdefgh";
        let mut seg = SegmentTree::new(s.len(), String::new(), |a, b| format!("{a}{b}"));
        for (i, c) in s.chars().enumerate() {
            seg.set(i, c.to_string());
        }

        for i in 0..s.len() {
            assert_eq!(s[..i], seg.fold(..i));
            assert_eq!(s[i..], seg.fold(i..));
        }

        for i in 0..s.len() {
            for j in i..s.len() {
                assert_eq!(s[i..j], seg.fold(i..j));
                if j + 1 < s.len() {
                    assert_eq!(s[i..=j], seg.fold(i..=j));
                }
            }
        }
    }

    #[test]
    fn single_element() {
        let mut seg = SegmentTree::new(1, 0, |a, b| a + b);
        assert_eq!(seg[0], 0);
        seg.set(0, 42);
        assert_eq!(seg[0], 42);
    }
}