rolling_hash/
lib.rs

1use std::{iter::FromIterator, ops};
2
3const MASK30: u64 = (1 << 30) - 1;
4const MASK31: u64 = (1 << 31) - 1;
5const MOD: u64 = (1 << 61) - 1;
6const MASK61: u64 = (1 << 61) - 1;
7const POSITIVIZER: u64 = MOD * 4;
8const DEFAULT_BASE: u64 = 1_000_000_000 + 9;
9
10/// Rolling Hash です。O(文字列長) の前計算をしたうえで、部分文字列のハッシュ値を O(1) で計算します。
11///
12/// [実装の参考資料](https://qiita.com/keymoon/items/11fac5627672a6d6a9f6)
13#[derive(Debug, Clone)]
14pub struct RollingHash<const BASE: u64> {
15    xs: Vec<u64>,
16    hashes: Vec<u64>,
17    pows: Vec<u64>,
18}
19
20impl<const BASE: u64> RollingHash<BASE> {
21    pub fn new(xs: &[u64]) -> Self {
22        let n = xs.len();
23        let xs = xs.to_vec();
24        let mut hashes = vec![0; n + 1];
25        let mut pows = vec![1; n + 1];
26        for (i, &x) in xs.iter().enumerate() {
27            // hashes[i + 1] = hashes[i] * BASE + x
28            hashes[i + 1] = calc_mod(mul(hashes[i], BASE) + x);
29            // pows[i + 1] = pows[i] * BASE
30            pows[i + 1] = calc_mod(mul(pows[i], BASE));
31        }
32        Self { xs, hashes, pows }
33    }
34
35    pub fn len(&self) -> usize {
36        self.xs.len()
37    }
38
39    pub fn is_empty(&self) -> bool {
40        self.xs.is_empty()
41    }
42
43    pub fn base(&self) -> u64 {
44        BASE
45    }
46
47    pub fn at(&self, i: usize) -> u64 {
48        assert!(i < self.len());
49        self.xs[i]
50    }
51
52    /// 部分文字列のハッシュ値を返します。
53    pub fn hash(&self, range: ops::Range<usize>) -> u64 {
54        let l = range.start;
55        let r = range.end;
56        assert!(l <= r);
57        assert!(r <= self.hashes.len());
58        // hashes[r] - hashes[l] * pows[r - l]
59        // = (xs[0] * BASE ^ (r - 1) + xs[1] * BASE ^ (r - 2) + ... + xs[r - 1])
60        //   - (xs[0] * BASE ^ (l - 1) + xs[1] * BASE ^ (l - 2) + ... + xs[l - 1]) * BASE ^ (r - l)
61        // = xs[l] * BASE ^ (r - l - 1) + xs[l + 1] * BASE ^ (r - l - 2) + ... + xs[r - 1]
62        calc_mod(self.hashes[r] + POSITIVIZER - mul(self.hashes[l], self.pows[r - l]))
63    }
64
65    pub fn substring(&self, range: ops::Range<usize>) -> Substring<BASE> {
66        let len = range.end - range.start;
67        let hash = self.hash(range);
68        Substring::new(hash, len)
69    }
70
71    pub fn position(&self, sub: &Substring<BASE>) -> Option<usize> {
72        if sub.len > self.len() {
73            return None;
74        }
75        (0..=self.len() - sub.len).find(|&i| self.hash(i..(i + sub.len)) == sub.hash)
76    }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub struct Substring<const BASE: u64> {
81    hash: u64,
82    len: usize,
83}
84
85impl<const BASE: u64> Substring<BASE> {
86    pub fn new(hash: u64, len: usize) -> Self {
87        Self { hash, len }
88    }
89
90    pub fn hash(&self) -> u64 {
91        self.hash
92    }
93
94    pub fn len(&self) -> usize {
95        self.len
96    }
97
98    pub fn is_empty(&self) -> bool {
99        self.len == 0
100    }
101}
102
103impl<T> FromIterator<T> for RollingHash<DEFAULT_BASE>
104where
105    T: Into<u64>,
106{
107    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
108        let xs = iter.into_iter().map(|x| x.into()).collect::<Vec<_>>();
109        Self::new(&xs)
110    }
111}
112
113fn mul(a: u64, b: u64) -> u64 {
114    let au = a >> 31;
115    let ad = a & MASK31;
116    let bu = b >> 31;
117    let bd = b & MASK31;
118    let mid = ad * bu + au * bd;
119    let midu = mid >> 30;
120    let midd = mid & MASK30;
121    au * bu * 2 + midu + (midd << 31) + ad * bd
122}
123
124fn calc_mod(x: u64) -> u64 {
125    let xu = x >> 61;
126    let xd = x & MASK61;
127    let mut res = xu + xd;
128    if res >= MOD {
129        res -= MOD;
130    }
131    res
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_hash() {
140        let rh1 = RollingHash::from_iter("abcd".bytes());
141        let rh2 = RollingHash::from_iter("xxbcyy".bytes());
142        assert_eq!(
143            rh1.hash(1..3), // a"bc"d
144            rh2.hash(2..4), // xx"bc"yy
145        );
146    }
147
148    #[test]
149    fn test_with_base() {
150        let rh1 = RollingHash::<1_000_000_007>::new(&[1, 2, 3]);
151        let rh2 = RollingHash::<998_244_353>::new(&[1, 2, 3]);
152
153        assert_ne!(rh1.hash(0..3), rh2.hash(0..3));
154    }
155
156    #[test]
157    fn test_position() {
158        let rh = RollingHash::from_iter("abcabc".bytes());
159        let sub = rh.substring(1..4); // "bca"
160        assert_eq!(rh.position(&sub), Some(1));
161    }
162}