factorials/lib.rs
1/// 階乗とその乗法逆元、そして二項係数を扱います。
2pub struct Factorial {
3 factorial: Vec<u64>,
4 inversion_of_factorial: Vec<u64>,
5 modulo: u64,
6}
7
8impl Factorial {
9 /// `1` 以上 `size` 未満の `n` について、`n` の階乗 (mod `modulo`) と、その乗法逆元を O(`size`) 時間で計算します。[参考](https://drken1215.hatenablog.com/entry/2018/06/08/210000)
10 ///
11 /// 逆元を正しく計算するためには
12 ///
13 /// - `modulo` が素数
14 /// - `modulo >= size`
15 ///
16 /// である必要があります。
17 ///
18 /// # Examples
19 ///
20 /// ```
21 /// use factorials::Factorial;
22 ///
23 /// let modulo = 1_000_000_000 + 7;
24 /// let f = Factorial::new(100, modulo);
25 /// for i in 1..100 {
26 /// assert_eq!(f.factorial(i) * f.inversion(i) % modulo, 1);
27 /// }
28 /// ```
29 ///
30 /// # Panics
31 ///
32 /// `modulo` が `size` より小さい場合パニックです。
33 ///
34 /// ```should_panic
35 /// use factorials::Factorial;
36 ///
37 /// let size = 100;
38 /// let modulo = 97;
39 /// Factorial::new(size, modulo);
40 /// ```
41 pub fn new(size: usize, modulo: u64) -> Self {
42 assert!(modulo >= size as u64);
43 let mut fac = vec![0; size];
44 let mut inv = vec![0; size];
45 let mut inv_of_fac = vec![0; size];
46 fac[0] = 1;
47 fac[1] = 1;
48 inv[1] = 1;
49 inv_of_fac[0] = 1;
50 inv_of_fac[1] = 1;
51 for i in 2..size {
52 let i_u64 = i as u64;
53 fac[i] = fac[i - 1] * i_u64 % modulo;
54 inv[i] = ((modulo - inv[(modulo as usize) % i]) * (modulo / i_u64)).rem_euclid(modulo);
55 inv_of_fac[i] = inv_of_fac[i - 1] * inv[i] % modulo;
56 }
57 Self {
58 factorial: fac,
59 inversion_of_factorial: inv_of_fac,
60 modulo,
61 }
62 }
63
64 /// `modulo` が素数でない場合パニックです。素数判定に O(sqrt(`modulo`)) 時間かかります。
65 ///
66 /// # Panics
67 ///
68 /// ```should_panic
69 /// use factorials::Factorial;
70 ///
71 /// let modulo = 42;
72 /// Factorial::new_checking_modulo_prime(10, 42);
73 /// ```
74 pub fn new_checking_modulo_prime(size: usize, modulo: u64) -> Self {
75 assert!(
76 (2..modulo)
77 .take_while(|&x| x * x <= modulo)
78 .all(|x| modulo % x != 0)
79 );
80 Self::new(size, modulo)
81 }
82
83 pub fn factorial(&self, n: usize) -> u64 {
84 assert!(n < self.factorial.len());
85 self.factorial[n]
86 }
87
88 pub fn inversion(&self, n: usize) -> u64 {
89 assert!(n < self.inversion_of_factorial.len());
90 self.inversion_of_factorial[n]
91 }
92
93 /// 二項係数を返します。
94 ///
95 /// # Examples
96 ///
97 /// ```
98 /// use factorials::Factorial;
99 ///
100 /// let f = Factorial::new_checking_modulo_prime(5, 107);
101 /// assert_eq!(f.binomial(4, 0), 1);
102 /// assert_eq!(f.binomial(4, 1), 4);
103 /// assert_eq!(f.binomial(4, 2), 6);
104 /// assert_eq!(f.binomial(4, 3), 4);
105 /// assert_eq!(f.binomial(4, 4), 1);
106 /// ```
107 ///
108 /// # Panics
109 ///
110 /// 以下の少なくともひとつが成り立つ場合パニックです。
111 ///
112 /// - `n` が構築時の `size` 以上
113 /// - `k` が構築時の `size` 以上
114 /// - `n` が `k` より小さい
115 ///
116 /// ```should_panic
117 /// use factorials::Factorial;
118 ///
119 /// let f = Factorial::new_checking_modulo_prime(5, 107);
120 /// f.binomial(3, 4); // n < k
121 /// ```
122 pub fn binomial(&self, n: usize, k: usize) -> u64 {
123 assert!(n < self.factorial.len());
124 assert!(k < self.inversion_of_factorial.len());
125 assert!(n >= k);
126 self.factorial(n) * self.inversion(k) % self.modulo * self.inversion(n - k) % self.modulo
127 }
128
129 /// [`binomial`] とほとんど同じですが `n` が `k` より小さいときパニックせずに `0` を返します。
130 ///
131 /// ```
132 /// use factorials::Factorial;
133 ///
134 /// let f = Factorial::new_checking_modulo_prime(5, 107);
135 /// assert_eq!(f.binomial_or_zero(3, 4), 0);
136 /// ```
137 ///
138 /// [`binomial`]: struct.Factorial.html#method.binomial
139 pub fn binomial_or_zero(&self, n: usize, k: usize) -> u64 {
140 assert!(n < self.factorial.len());
141 assert!(k < self.inversion_of_factorial.len());
142 if n < k {
143 return 0;
144 }
145 self.binomial(n, k)
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::Factorial;
152 #[test]
153 fn test_mod_is_103() {
154 let p = 103;
155 let f = Factorial::new(100, p);
156 for i in 1..100 {
157 assert_eq!(f.factorial(i) * f.inversion(i) % p, 1);
158 }
159 }
160
161 #[test]
162 fn test_binomial() {
163 let f = Factorial::new(6, 1_000_000_000 + 7);
164 let b: Vec<Vec<u64>> = (0..6)
165 .map(|n| (0..6).map(|k| f.binomial_or_zero(n, k)).collect())
166 .collect();
167 assert_eq!(
168 b,
169 vec![
170 vec![1, 0, 0, 0, 0, 0],
171 vec![1, 1, 0, 0, 0, 0],
172 vec![1, 2, 1, 0, 0, 0],
173 vec![1, 3, 3, 1, 0, 0],
174 vec![1, 4, 6, 4, 1, 0],
175 vec![1, 5, 10, 10, 5, 1],
176 ]
177 )
178 }
179}