rolling_hash/
lib.rs

1use std::{iter::FromIterator, ops};
2
3const MASK30: u64 = (1 << 30) - 1;
4const MASK31: u64 = (1 << 31) - 1;
5const MOD: u64 = (1 << 61) - 1;
6const MASK61: u64 = (1 << 61) - 1;
7const POSITIVIZER: u64 = MOD * 4;
8const BASE: u64 = 1_000_000_000 + 9;
9
10/// Rolling Hash です。O(文字列長) の前計算をしたうえで、部分文字列のハッシュ値を O(1) で計算します。
11///
12/// [実装の参考資料](https://qiita.com/keymoon/items/11fac5627672a6d6a9f6)
13#[derive(Debug, Clone)]
14pub struct RollingHash {
15    xs: Vec<u64>,
16    hashes: Vec<u64>,
17    pows: Vec<u64>,
18}
19
20impl<T> FromIterator<T> for RollingHash
21where
22    T: Into<u64>,
23{
24    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
25        let xs = iter.into_iter().map(|x| x.into()).collect::<Vec<_>>();
26        Self::new(&xs)
27    }
28}
29
30impl RollingHash {
31    pub fn new(xs: &[u64]) -> Self {
32        let n = xs.len();
33        let xs = xs.to_vec();
34        let mut hashes = vec![0; n + 1];
35        let mut pows = vec![1; n + 1];
36        for (i, &x) in xs.iter().enumerate() {
37            // hashes[i + 1] = hashes[i] * BASE + x
38            hashes[i + 1] = calc_mod(mul(hashes[i], BASE) + x);
39            // pows[i + 1] = pows[i] * BASE
40            pows[i + 1] = calc_mod(mul(pows[i], BASE));
41        }
42        Self { xs, hashes, pows }
43    }
44
45    pub fn len(&self) -> usize {
46        self.xs.len()
47    }
48
49    pub fn is_empty(&self) -> bool {
50        self.xs.is_empty()
51    }
52
53    pub fn at(&self, i: usize) -> u64 {
54        assert!(i < self.len());
55        self.xs[i]
56    }
57
58    /// 部分文字列のハッシュ値を返します。
59    pub fn hash(&self, range: ops::Range<usize>) -> u64 {
60        let l = range.start;
61        let r = range.end;
62        assert!(l <= r);
63        assert!(r <= self.hashes.len());
64        // hashes[r] - hashes[l] * pows[r - l]
65        // = (xs[0] * BASE ^ (r - 1) + xs[1] * BASE ^ (r - 2) + ... + xs[r - 1])
66        //   - (xs[0] * BASE ^ (l - 1) + xs[1] * BASE ^ (l - 2) + ... + xs[l - 1]) * BASE ^ (r - l)
67        // = xs[l] * BASE ^ (r - l - 1) + xs[l + 1] * BASE ^ (r - l - 2) + ... + xs[r - 1]
68        calc_mod(self.hashes[r] + POSITIVIZER - mul(self.hashes[l], self.pows[r - l]))
69    }
70
71    /// self が other の部分文字列かどうかを返します。
72    ///
73    /// O(other.len())
74    ///
75    /// # Examples
76    /// ```
77    /// use rolling_hash::RollingHash;
78    /// let rh1 = RollingHash::from_iter("abcd".bytes());
79    /// let rh2 = RollingHash::from_iter("xxabcdyy".bytes());
80    /// assert!(rh1.is_substring(&rh2));
81    /// ```
82    // 出現位置をすべて返すようにしたほうがいいかも
83    pub fn is_substring(&self, other: &Self) -> bool {
84        for j in 0..other.len() {
85            if j + self.len() > other.len() {
86                break;
87            }
88            if self.hash(0..self.len()) == other.hash(j..(j + self.len())) {
89                return true;
90            }
91        }
92        false
93    }
94}
95
96fn mul(a: u64, b: u64) -> u64 {
97    let au = a >> 31;
98    let ad = a & MASK31;
99    let bu = b >> 31;
100    let bd = b & MASK31;
101    let mid = ad * bu + au * bd;
102    let midu = mid >> 30;
103    let midd = mid & MASK30;
104    au * bu * 2 + midu + (midd << 31) + ad * bd
105}
106
107fn calc_mod(x: u64) -> u64 {
108    let xu = x >> 61;
109    let xd = x & MASK61;
110    let mut res = xu + xd;
111    if res >= MOD {
112        res -= MOD;
113    }
114    res
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_hash() {
123        let rh1 = RollingHash::from_iter("abcd".bytes());
124        let rh2 = RollingHash::from_iter("xxbcyy".bytes());
125        assert_eq!(
126            rh1.hash(1..3), // a"bc"d
127            rh2.hash(2..4), // xx"bc"yy
128        );
129    }
130
131    #[test]
132    fn test_is_substring() {
133        let rh1 = RollingHash::from_iter("xyz".bytes());
134        let rh2 = RollingHash::from_iter("abcxyz".bytes());
135        assert!(rh1.is_substring(&rh2));
136    }
137}