1use std::{iter::FromIterator, ops};
2
3const MASK30: u64 = (1 << 30) - 1;
4const MASK31: u64 = (1 << 31) - 1;
5const MOD: u64 = (1 << 61) - 1;
6const MASK61: u64 = (1 << 61) - 1;
7const POSITIVIZER: u64 = MOD * 4;
8const BASE: u64 = 1_000_000_000 + 9;
9
10#[derive(Debug, Clone)]
14pub struct RollingHash {
15 xs: Vec<u64>,
16 hashes: Vec<u64>,
17 pows: Vec<u64>,
18}
19
20impl<T> FromIterator<T> for RollingHash
21where
22 T: Into<u64>,
23{
24 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
25 let xs = iter.into_iter().map(|x| x.into()).collect::<Vec<_>>();
26 Self::new(&xs)
27 }
28}
29
30impl RollingHash {
31 pub fn new(xs: &[u64]) -> Self {
32 let n = xs.len();
33 let xs = xs.to_vec();
34 let mut hashes = vec![0; n + 1];
35 let mut pows = vec![1; n + 1];
36 for (i, &x) in xs.iter().enumerate() {
37 hashes[i + 1] = calc_mod(mul(hashes[i], BASE) + x);
39 pows[i + 1] = calc_mod(mul(pows[i], BASE));
41 }
42 Self { xs, hashes, pows }
43 }
44
45 pub fn len(&self) -> usize {
46 self.xs.len()
47 }
48
49 pub fn is_empty(&self) -> bool {
50 self.xs.is_empty()
51 }
52
53 pub fn at(&self, i: usize) -> u64 {
54 assert!(i < self.len());
55 self.xs[i]
56 }
57
58 pub fn hash(&self, range: ops::Range<usize>) -> u64 {
60 let l = range.start;
61 let r = range.end;
62 assert!(l <= r);
63 assert!(r <= self.hashes.len());
64 calc_mod(self.hashes[r] + POSITIVIZER - mul(self.hashes[l], self.pows[r - l]))
69 }
70
71 pub fn is_substring(&self, other: &Self) -> bool {
84 for j in 0..other.len() {
85 if j + self.len() > other.len() {
86 break;
87 }
88 if self.hash(0..self.len()) == other.hash(j..(j + self.len())) {
89 return true;
90 }
91 }
92 false
93 }
94}
95
96fn mul(a: u64, b: u64) -> u64 {
97 let au = a >> 31;
98 let ad = a & MASK31;
99 let bu = b >> 31;
100 let bd = b & MASK31;
101 let mid = ad * bu + au * bd;
102 let midu = mid >> 30;
103 let midd = mid & MASK30;
104 au * bu * 2 + midu + (midd << 31) + ad * bd
105}
106
107fn calc_mod(x: u64) -> u64 {
108 let xu = x >> 61;
109 let xd = x & MASK61;
110 let mut res = xu + xd;
111 if res >= MOD {
112 res -= MOD;
113 }
114 res
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_hash() {
123 let rh1 = RollingHash::from_iter("abcd".bytes());
124 let rh2 = RollingHash::from_iter("xxbcyy".bytes());
125 assert_eq!(
126 rh1.hash(1..3), rh2.hash(2..4), );
129 }
130
131 #[test]
132 fn test_is_substring() {
133 let rh1 = RollingHash::from_iter("xyz".bytes());
134 let rh2 = RollingHash::from_iter("abcxyz".bytes());
135 assert!(rh1.is_substring(&rh2));
136 }
137}