mod_int/
lib.rs

1//! `ModInt` は整数の四則演算を mod `p` で行う構造体です。
2//!
3//! ```
4//! use mod_int::ModInt1000000007;
5//! let p = 1000000007_i64;
6//! let (a, b, c) = (1000000001, 1000000005, 100000006);
7//! let x = (123 * a % p * b % p - c).rem_euclid(p);
8//! let y = ModInt1000000007::new(123) * a * b - c;
9//! assert_eq!(x, y.val());
10//! ```
11
12use std::fmt::Debug;
13use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
14
15use ext_gcd::ext_gcd;
16
17#[derive(Debug, Clone, Copy)]
18pub struct ModInt<const M: i64>(i64);
19
20impl<const M: i64> ModInt<M> {
21    /// 整数を `0 <= x < modulo` に正規化してインスタンスを作ります。
22    pub fn new(x: i64) -> Self {
23        if 0 <= x && x < M {
24            Self::new_raw(x)
25        } else {
26            Self::new_raw(x.rem_euclid(M))
27        }
28    }
29
30    fn new_raw(x: i64) -> Self {
31        debug_assert!(0 <= x && x < M);
32        Self(x)
33    }
34
35    /// `ModInt` に格納されている値を返します。
36    ///
37    /// # Examples
38    /// ```
39    /// use mod_int::ModInt1000000007;
40    /// assert_eq!(ModInt1000000007::new(123).val(), 123);
41    /// ```
42    pub fn val(self) -> i64 {
43        self.0
44    }
45
46    /// 法を返します。
47    ///
48    /// # Examples
49    /// ```
50    /// use mod_int::{ModInt1000000007, ModInt998244353};
51    /// assert_eq!(ModInt1000000007::modulo(), 1000000007);
52    /// assert_eq!(ModInt998244353::modulo(), 998244353);
53    /// ```
54    pub fn modulo() -> i64 {
55        M
56    }
57
58    /// 二分累乗法で `x^exp % p` を計算します。
59    ///
60    /// # Examples
61    /// ```
62    /// use mod_int::ModInt1000000007;
63    /// use std::iter::repeat;
64    /// let (x, exp, p) = (123, 100_u32, 1000000007);
65    /// let y = repeat(x).take(exp as usize).fold(1, |acc, x| acc * x % p);
66    /// assert_eq!(y, ModInt1000000007::new(x).pow(exp).val());
67    /// ```
68    pub fn pow(self, exp: u32) -> Self {
69        let mut res = 1;
70        let mut base = self.0;
71        let mut exp = exp;
72        while exp > 0 {
73            if exp & 1 == 1 {
74                res *= base;
75                res %= M;
76            }
77            base *= base;
78            base %= M;
79            exp >>= 1;
80        }
81        Self::new_raw(res)
82    }
83
84    /// `x * y % p = 1` となる `y` を返します。
85    ///
86    /// # Examples
87    /// ```
88    /// use mod_int::ModInt1000000007;
89    /// let (x, p) = (2, ModInt1000000007::modulo());
90    /// let y = ModInt1000000007::new(x).inv().val();
91    /// assert_eq!(x * y % p, 1);
92    /// ```
93    ///
94    /// ```should_panic
95    /// use mod_int::ModInt1000000007;
96    /// ModInt1000000007::new(0).inv(); // panic
97    /// ```
98    ///
99    /// ```should_panic
100    /// use mod_int::ModInt;
101    /// // 6 * n % 10 : 0, 6, 2, 8, 4, 0, 6, 2, 8, 4
102    /// ModInt::<10>::new(6).inv(); // panic
103    /// ```
104    pub fn inv(self) -> Self {
105        assert_ne!(self.0, 0, "Don't divide by zero!");
106        let (x, _, g) = ext_gcd(self.0, M);
107        assert_eq!(g, 1, "{} is not prime!", M);
108        Self::new(x)
109    }
110}
111
112impl<const M: i64, T: Into<ModInt<M>>> AddAssign<T> for ModInt<M> {
113    fn add_assign(&mut self, rhs: T) {
114        self.0 += rhs.into().0;
115        debug_assert!(0 <= self.0 && self.0 <= (M - 1) * 2);
116        if self.0 >= M {
117            self.0 -= M;
118        }
119    }
120}
121
122impl<const M: i64, T: Into<ModInt<M>>> Add<T> for ModInt<M> {
123    type Output = ModInt<M>;
124    fn add(self, rhs: T) -> Self::Output {
125        let mut result = self;
126        result += rhs.into();
127        result
128    }
129}
130
131impl<const M: i64, T: Into<ModInt<M>>> SubAssign<T> for ModInt<M> {
132    fn sub_assign(&mut self, rhs: T) {
133        self.0 -= rhs.into().0;
134        debug_assert!(-(M - 1) <= self.0 && self.0 < M);
135        if self.0 < 0 {
136            self.0 += M;
137        }
138    }
139}
140
141impl<const M: i64, T: Into<ModInt<M>>> Sub<T> for ModInt<M> {
142    type Output = ModInt<M>;
143    fn sub(self, rhs: T) -> Self::Output {
144        let mut result = self;
145        result -= rhs.into();
146        result
147    }
148}
149
150impl<const M: i64, T: Into<ModInt<M>>> MulAssign<T> for ModInt<M> {
151    fn mul_assign(&mut self, rhs: T) {
152        self.0 *= rhs.into().0;
153        if self.0 >= M {
154            self.0 %= M;
155        }
156    }
157}
158
159impl<const M: i64, T: Into<ModInt<M>>> Mul<T> for ModInt<M> {
160    type Output = ModInt<M>;
161    fn mul(self, rhs: T) -> Self::Output {
162        let mut result = self;
163        result *= rhs.into();
164        result
165    }
166}
167
168#[allow(clippy::suspicious_op_assign_impl)]
169impl<const M: i64, T: Into<ModInt<M>>> DivAssign<T> for ModInt<M> {
170    fn div_assign(&mut self, rhs: T) {
171        *self *= rhs.into().inv();
172    }
173}
174
175impl<const M: i64, T: Into<ModInt<M>>> Div<T> for ModInt<M> {
176    type Output = ModInt<M>;
177    fn div(self, rhs: T) -> Self::Output {
178        let mut result = self;
179        result /= rhs.into();
180        result
181    }
182}
183
184macro_rules! impl_from_int {
185    ($($t:ty),+) => {
186        $(
187            impl<const M: i64> From<$t> for ModInt<M> {
188                fn from(x: $t) -> Self {
189                    Self::new(i64::from(x))
190                }
191            }
192        )+
193    };
194}
195
196impl_from_int!(i8, i16, i32, i64, u8, u16, u32);
197
198macro_rules! impl_from_large_int {
199    ($($t:ty),+) => {
200        $(
201            impl<const M: i64> From<$t> for ModInt<M> {
202                fn from(x: $t) -> Self {
203                    Self::new((x % (M as $t)) as i64)
204                }
205            }
206        )+
207    };
208}
209
210impl_from_large_int!(u64, usize, isize);
211
212pub type ModInt1000000007 = ModInt<1_000_000_007>;
213pub type ModInt998244353 = ModInt<998_244_353>;
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn ops_test() {
221        type Mint = ModInt<19>;
222        for a in 0..50 {
223            for b in 0..50 {
224                // add
225                assert_eq!((Mint::new(a) + Mint::new(b)).val(), (a + b) % 19);
226                // add assign
227                let mut sum = Mint::new(a);
228                sum += b;
229                assert_eq!(sum.val(), (a + b) % 19);
230
231                // sub
232                assert_eq!((Mint::new(a) - Mint::new(b)).val(), (a - b).rem_euclid(19));
233                // sub assign
234                let mut diff = Mint::new(a);
235                diff -= b;
236                assert_eq!(diff.val(), (a - b).rem_euclid(19));
237
238                // mul
239                assert_eq!((Mint::new(a) * Mint::new(b)).val(), a * b % 19);
240                // mul assign
241                let mut prod = Mint::new(a);
242                prod *= b;
243                assert_eq!(prod.val(), a * b % 19);
244
245                if b % 19 != 0 {
246                    let expect = (0..19).find(|&x| a % 19 == b * x % 19).unwrap();
247                    // div
248                    assert_eq!((Mint::new(a) / Mint::new(b)).val(), expect);
249                    // div assign
250                    let mut frac = Mint::new(a);
251                    frac /= b;
252                    assert_eq!(frac.val(), expect);
253                }
254            }
255        }
256    }
257}