1use std::fmt::Debug;
13use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
14
15use ext_gcd::ext_gcd;
16
17#[derive(Debug, Clone, Copy)]
18pub struct ModInt<const M: i64>(i64);
19
20impl<const M: i64> ModInt<M> {
21 pub fn new(x: i64) -> Self {
23 if 0 <= x && x < M {
24 Self::new_raw(x)
25 } else {
26 Self::new_raw(x.rem_euclid(M))
27 }
28 }
29
30 fn new_raw(x: i64) -> Self {
31 debug_assert!(0 <= x && x < M);
32 Self(x)
33 }
34
35 pub fn val(self) -> i64 {
43 self.0
44 }
45
46 pub fn modulo() -> i64 {
55 M
56 }
57
58 pub fn pow(self, exp: u32) -> Self {
69 let mut res = 1;
70 let mut base = self.0;
71 let mut exp = exp;
72 while exp > 0 {
73 if exp & 1 == 1 {
74 res *= base;
75 res %= M;
76 }
77 base *= base;
78 base %= M;
79 exp >>= 1;
80 }
81 Self::new_raw(res)
82 }
83
84 pub fn inv(self) -> Self {
105 assert_ne!(self.0, 0, "Don't divide by zero!");
106 let (x, _, g) = ext_gcd(self.0, M);
107 assert_eq!(g, 1, "{} is not prime!", M);
108 Self::new(x)
109 }
110}
111
112impl<const M: i64, T: Into<ModInt<M>>> AddAssign<T> for ModInt<M> {
113 fn add_assign(&mut self, rhs: T) {
114 self.0 += rhs.into().0;
115 debug_assert!(0 <= self.0 && self.0 <= (M - 1) * 2);
116 if self.0 >= M {
117 self.0 -= M;
118 }
119 }
120}
121
122impl<const M: i64, T: Into<ModInt<M>>> Add<T> for ModInt<M> {
123 type Output = ModInt<M>;
124 fn add(self, rhs: T) -> Self::Output {
125 let mut result = self;
126 result += rhs.into();
127 result
128 }
129}
130
131impl<const M: i64, T: Into<ModInt<M>>> SubAssign<T> for ModInt<M> {
132 fn sub_assign(&mut self, rhs: T) {
133 self.0 -= rhs.into().0;
134 debug_assert!(-(M - 1) <= self.0 && self.0 < M);
135 if self.0 < 0 {
136 self.0 += M;
137 }
138 }
139}
140
141impl<const M: i64, T: Into<ModInt<M>>> Sub<T> for ModInt<M> {
142 type Output = ModInt<M>;
143 fn sub(self, rhs: T) -> Self::Output {
144 let mut result = self;
145 result -= rhs.into();
146 result
147 }
148}
149
150impl<const M: i64, T: Into<ModInt<M>>> MulAssign<T> for ModInt<M> {
151 fn mul_assign(&mut self, rhs: T) {
152 self.0 *= rhs.into().0;
153 if self.0 >= M {
154 self.0 %= M;
155 }
156 }
157}
158
159impl<const M: i64, T: Into<ModInt<M>>> Mul<T> for ModInt<M> {
160 type Output = ModInt<M>;
161 fn mul(self, rhs: T) -> Self::Output {
162 let mut result = self;
163 result *= rhs.into();
164 result
165 }
166}
167
168#[allow(clippy::suspicious_op_assign_impl)]
169impl<const M: i64, T: Into<ModInt<M>>> DivAssign<T> for ModInt<M> {
170 fn div_assign(&mut self, rhs: T) {
171 *self *= rhs.into().inv();
172 }
173}
174
175impl<const M: i64, T: Into<ModInt<M>>> Div<T> for ModInt<M> {
176 type Output = ModInt<M>;
177 fn div(self, rhs: T) -> Self::Output {
178 let mut result = self;
179 result /= rhs.into();
180 result
181 }
182}
183
184macro_rules! impl_from_int {
185 ($($t:ty),+) => {
186 $(
187 impl<const M: i64> From<$t> for ModInt<M> {
188 fn from(x: $t) -> Self {
189 Self::new(i64::from(x))
190 }
191 }
192 )+
193 };
194}
195
196impl_from_int!(i8, i16, i32, i64, u8, u16, u32);
197
198macro_rules! impl_from_large_int {
199 ($($t:ty),+) => {
200 $(
201 impl<const M: i64> From<$t> for ModInt<M> {
202 fn from(x: $t) -> Self {
203 Self::new((x % (M as $t)) as i64)
204 }
205 }
206 )+
207 };
208}
209
210impl_from_large_int!(u64, usize, isize);
211
212pub type ModInt1000000007 = ModInt<1_000_000_007>;
213pub type ModInt998244353 = ModInt<998_244_353>;
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn ops_test() {
221 type Mint = ModInt<19>;
222 for a in 0..50 {
223 for b in 0..50 {
224 assert_eq!((Mint::new(a) + Mint::new(b)).val(), (a + b) % 19);
226 let mut sum = Mint::new(a);
228 sum += b;
229 assert_eq!(sum.val(), (a + b) % 19);
230
231 assert_eq!((Mint::new(a) - Mint::new(b)).val(), (a - b).rem_euclid(19));
233 let mut diff = Mint::new(a);
235 diff -= b;
236 assert_eq!(diff.val(), (a - b).rem_euclid(19));
237
238 assert_eq!((Mint::new(a) * Mint::new(b)).val(), a * b % 19);
240 let mut prod = Mint::new(a);
242 prod *= b;
243 assert_eq!(prod.val(), a * b % 19);
244
245 if b % 19 != 0 {
246 let expect = (0..19).find(|&x| a % 19 == b * x % 19).unwrap();
247 assert_eq!((Mint::new(a) / Mint::new(b)).val(), expect);
249 let mut frac = Mint::new(a);
251 frac /= b;
252 assert_eq!(frac.val(), expect);
253 }
254 }
255 }
256 }
257}