Skip to main content

doubling/
lib.rs

1/// ダブリング
2///
3/// # Examples
4///
5/// ```
6/// use doubling::{Doubling, Transition, Value};
7///
8/// #[derive(Debug, PartialEq)]
9/// struct Sum(i64);
10///
11/// impl Value for Sum {
12///     fn op(&self, other: &Self) -> Self {
13///         Sum(self.0 + other.0)
14///     }
15/// }
16///
17/// struct E {
18///     to: usize,
19///     value: i64,
20/// }
21///
22/// // 0, 1, 2, 0, 1, 2, ...
23/// let n = 3;
24/// let to = vec![
25///     E { to: 1, value: 1 },
26///     E { to: 2, value: 10 },
27///     E { to: 0, value: 100 },
28/// ];
29/// let doubling = Doubling::new(n, 100, |i| {
30///     let e = &to[i];
31///     Transition::new(e.to, Sum(e.value))
32/// });
33///
34/// assert_eq!(
35///     doubling.fold(0, 4, Sum(0), |acc, t| Sum(acc.0 + t.value.0)),
36///     // 0 -> 1 -> 2 -> 0 -> 1
37///     Sum(1 + 10 + 100 + 1)
38/// );
39/// ```
40#[derive(Debug, Clone)]
41pub struct Doubling<V> {
42    transitions: Vec<Transition<V>>,
43    n_state: usize,
44    max_steps: usize,
45    log2_max_steps: usize,
46}
47
48#[derive(Debug, Clone)]
49pub struct Transition<V> {
50    pub next: usize,
51    pub value: V,
52}
53
54impl<V> Transition<V> {
55    pub fn new(next: usize, value: V) -> Self {
56        Self { next, value }
57    }
58}
59
60pub trait Value {
61    fn op(&self, other: &Self) -> Self;
62}
63
64impl Value for () {
65    fn op(&self, _other: &Self) -> Self {}
66}
67
68impl<V> Doubling<V>
69where
70    V: Value,
71{
72    /// ダブリングのテーブルを構築します。
73    ///
74    /// `step1(i)`は状態`i`から1回の遷移における
75    ///
76    /// - 遷移先の状態
77    /// - その遷移にともなう値
78    ///
79    /// を返す関数。
80    pub fn new<F>(n_state: usize, max_steps: usize, step1: F) -> Self
81    where
82        F: Fn(usize) -> Transition<V>,
83    {
84        assert!(max_steps > 0);
85
86        let log2_max_steps = max_steps.ilog2() as usize;
87
88        let mut transitions = Vec::with_capacity(n_state * (log2_max_steps + 1));
89        for i in 0..n_state {
90            let t = step1(i);
91
92            assert!(t.next < n_state);
93
94            transitions.push(t);
95        }
96
97        for k in 1..=log2_max_steps {
98            let offset = n_state * (k - 1);
99            for i in 0..n_state {
100                let t1 = &transitions[offset + i];
101                let t2 = &transitions[offset + t1.next];
102                transitions.push(Transition {
103                    next: t2.next,
104                    value: t1.value.op(&t2.value),
105                });
106            }
107        }
108
109        Self {
110            transitions,
111            n_state,
112            max_steps,
113            log2_max_steps,
114        }
115    }
116
117    /// 状態`start`から長さ`pow(2, k)`の遷移を返します。
118    pub fn get(&self, start: usize, k: usize) -> &Transition<V> {
119        assert!(start < self.n_state);
120        assert!(k <= self.log2_max_steps);
121
122        let offset = self.n_state * k;
123        &self.transitions[offset + start]
124    }
125
126    /// 状態`start`から`step`回の遷移、初期値`init`から始めて`f`で畳みこんだ結果を返します。
127    pub fn fold<A, F>(&self, start: usize, step: usize, init: A, mut f: F) -> A
128    where
129        F: FnMut(A, &Transition<V>) -> A,
130    {
131        assert!(start < self.n_state);
132        assert!(step <= self.max_steps);
133
134        let mut i = start;
135        let mut acc = init;
136        for k in 0..=self.log2_max_steps {
137            if step >> k & 1 == 1 {
138                let offset = self.n_state * k;
139                let t = &self.transitions[offset + i];
140                (i, acc) = (t.next, f(acc, t));
141            }
142        }
143
144        acc
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use ::proptest::{collection, prelude::*};
151
152    use super::*;
153
154    #[derive(Debug, PartialEq)]
155    struct Sum(i64);
156
157    impl Value for Sum {
158        fn op(&self, other: &Self) -> Self {
159            Sum(self.0 + other.0)
160        }
161    }
162
163    #[test]
164    fn test_cycle() {
165        struct E {
166            to: usize,
167            value: i64,
168        }
169
170        // 0, 1, 2, 0, 1, 2, ...
171        let n = 3;
172        let to = vec![
173            E { to: 1, value: 1 },
174            E { to: 2, value: 10 },
175            E { to: 0, value: 100 },
176        ];
177        let doubling = Doubling::new(n, 100, |i| {
178            let e = &to[i];
179            Transition::new(e.to, Sum(e.value))
180        });
181
182        assert_eq!(
183            doubling.fold(0, 0, Sum(0), |acc, t| Sum(acc.0 + t.value.0)),
184            Sum(0)
185        );
186        assert_eq!(
187            doubling.fold(0, 1, Sum(0), |acc, t| Sum(acc.0 + t.value.0)),
188            Sum(1)
189        );
190        assert_eq!(
191            doubling.fold(0, 2, Sum(0), |acc, t| Sum(acc.0 + t.value.0)),
192            Sum(1 + 10)
193        );
194        assert_eq!(
195            doubling.fold(0, 3, Sum(0), |acc, t| Sum(acc.0 + t.value.0)),
196            Sum(1 + 10 + 100)
197        );
198        assert_eq!(
199            doubling.fold(0, 4, Sum(0), |acc, t| Sum(acc.0 + t.value.0)),
200            Sum(1 + 10 + 100 + 1)
201        );
202    }
203
204    #[test]
205    fn test_get() {
206        let n = 3;
207        let to = vec![1, 2, 0];
208        let doubling = Doubling::new(n, 100, |i| Transition::new(to[i], Sum(1)));
209
210        let t = doubling.get(0, 0);
211        assert_eq!(t.value, Sum(1));
212
213        let t = doubling.get(0, 1);
214        assert_eq!(t.value, Sum(2));
215
216        let t = doubling.get(0, 2);
217        assert_eq!(t.value, Sum(4));
218
219        let t = doubling.get(0, 3);
220        assert_eq!(t.value, Sum(8));
221    }
222
223    impl Value for String {
224        fn op(&self, other: &Self) -> Self {
225            format!("{}{}", self, other)
226        }
227    }
228
229    proptest! {
230        #[test]
231        fn test_fold_associativity(
232            (n_state, max_steps, nexts, values, start, step1, step2) in (1_usize..=10, 1_usize..=100)
233                .prop_flat_map(|(n_state, max_steps)| {
234                    (
235                        Just(n_state),
236                        Just(max_steps),
237                        collection::vec(0..n_state, n_state),
238                        collection::vec(proptest::char::range('a', 'z'), n_state),
239                    )
240                })
241                .prop_flat_map(|(n_state, max_steps, nexts, values)| {
242                    (
243                        Just(n_state),
244                        Just(max_steps),
245                        Just(nexts),
246                        Just(values),
247                        0..n_state,
248                        0..=max_steps,
249                    )
250                })
251                .prop_flat_map(|(n_state, max_steps, nexts, values, start, step1)| {
252                    (
253                        Just(n_state),
254                        Just(max_steps),
255                        Just(nexts),
256                        Just(values),
257                        Just(start),
258                        Just(step1),
259                        0..=(max_steps - step1),
260                    )
261                })
262        ) {
263            let doubling = Doubling::new(n_state, max_steps, |i| {
264                Transition::new(nexts[i], values[i].to_string())
265            });
266
267            #[derive(Debug, Clone, PartialEq)]
268            struct Acc {
269                value: String,
270                state: usize,
271            }
272
273            let init = Acc {
274                value: String::new(),
275                state: start,
276            };
277            let f = |acc: Acc, t: &Transition<String>| Acc {
278                value: format!("{}{}", acc.value, t.value),
279                state: t.next,
280            };
281
282            let combined = doubling.fold(start, step1 + step2, init.clone(), f);
283
284            let intermediate = doubling.fold(start, step1, init.clone(), f);
285            let split = doubling.fold(intermediate.state, step2, intermediate.clone(), f);
286
287            prop_assert_eq!(combined.value, split.value);
288        }
289    }
290}