1use std::ops::{Add, Range, Sub};
2
3pub struct CumulativeSum2D<T> {
32 h: usize,
33 w: usize,
34 cum_sum: Vec<Vec<T>>,
35}
36
37impl<T> CumulativeSum2D<T>
38where
39 T: Clone + Copy + Default + Add<Output = T> + Sub<Output = T>,
40{
41 pub fn new(grid: &[Vec<T>]) -> Self {
42 let h = grid.len();
43 assert!(h >= 1);
44 let w = grid[0].len();
45 for row in grid {
46 assert_eq!(row.len(), w);
47 }
48 let mut cum_sum = grid.to_vec();
49 #[allow(clippy::needless_range_loop)]
50 for i in 0..h {
51 for j in 1..w {
52 cum_sum[i][j] = cum_sum[i][j] + cum_sum[i][j - 1];
53 }
54 }
55 for j in 0..w {
56 for i in 1..h {
57 cum_sum[i][j] = cum_sum[i - 1][j] + cum_sum[i][j];
58 }
59 }
60 Self { h, w, cum_sum }
61 }
62
63 pub fn sum(&self, y_range: Range<usize>, x_range: Range<usize>) -> T {
64 let (y_start, y_end) = (y_range.start, y_range.end);
65 let (x_start, x_end) = (x_range.start, x_range.end);
66 if y_start >= y_end || x_start >= x_end {
67 return T::default();
68 }
69 assert!(y_end <= self.h);
70 assert!(x_end <= self.w);
71 let sum = self.cum_sum[y_end - 1][x_end - 1];
72 if y_start >= 1 && x_start >= 1 {
73 return sum + self.cum_sum[y_start - 1][x_start - 1]
74 - self.cum_sum[y_start - 1][x_end - 1]
75 - self.cum_sum[y_end - 1][x_start - 1];
76 }
77 if y_start >= 1 {
78 assert_eq!(x_start, 0);
79 return sum - self.cum_sum[y_start - 1][x_end - 1];
80 }
81 if x_start >= 1 {
82 assert_eq!(y_start, 0);
83 return sum - self.cum_sum[y_end - 1][x_start - 1];
84 }
85 sum
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use crate::CumulativeSum2D;
92
93 #[test]
94 fn test() {
95 let grid: Vec<Vec<u32>> = vec![
96 vec![3, 1, 4, 1, 5],
97 vec![9, 2, 6, 5, 3],
98 vec![5, 8, 9, 7, 9],
99 vec![3, 2, 3, 8, 4],
100 ];
101 let cum_sum = CumulativeSum2D::new(&grid);
102 for y_start in 0..=4 {
103 for y_end in 0..=4 {
104 for x_start in 0..=5 {
105 for x_end in 0..=5 {
106 let mut expected = 0;
107 for y in y_start..y_end {
108 for x in x_start..x_end {
109 expected += grid[y][x];
110 }
111 }
112 let actual = cum_sum.sum(y_start..y_end, x_start..x_end);
113 assert_eq!(expected, actual);
114 }
115 }
116 }
117 }
118 }
119}