dijkstra/
lib.rs

1use std::cmp::Reverse;
2use std::collections::BinaryHeap;
3use std::fmt::Debug;
4use std::ops::Add;
5
6/// グラフの辺を表すトレイトです。
7pub trait Edge<T> {
8    fn from(&self) -> usize;
9    fn to(&self) -> usize;
10    /// 始点から [`from`] までの距離 `d` を受け取り、この辺を辿って [`to`] へ行く最短距離を求めます。[`dijkstra`] が正しく動くように、この関数は次の[条件](https://fetburner.hatenablog.com/entry/2021/02/28/200020)を満たように実装してください。
11    ///
12    /// - `dist(d)` は `d` 以上である
13    /// - `dist(d)` は `d` について (広義) 単調増加である
14    ///
15    /// 使用例: [ABC192E](https://atcoder.jp/contests/abc192/submissions/26105492)
16    ///
17    /// [`from`]: trait.Edge.html#tymethod.from
18    /// [`to`]: trait.Edge.html#tymethod.to
19    /// [`dijkstra`]: fn.dijkstra.html
20    fn dist(&self, d: &T) -> T;
21}
22
23/// 長さが定数の辺です。
24pub struct ConstEdge<T> {
25    from: usize,
26    to: usize,
27    cost: T,
28}
29
30impl<T> ConstEdge<T> {
31    pub fn new(from: usize, to: usize, cost: T) -> Self {
32        Self { from, to, cost }
33    }
34}
35
36impl<T> Edge<T> for ConstEdge<T>
37where
38    T: Copy + Add<Output = T>,
39{
40    fn from(&self) -> usize {
41        self.from
42    }
43    fn to(&self) -> usize {
44        self.to
45    }
46    fn dist(&self, d: &T) -> T {
47        *d + self.cost
48    }
49}
50
51/// `dijkstra` はあるひとつの頂点から全ての頂点への最短距離を計算します。
52///
53/// 返り値 `(dist, prev)` はそれぞれ以下です。
54///
55/// - `dist[t]`: `s` から `t` までの最短距離
56/// - `prev[t]`: `s` を根とする最短経路木における `t` の親頂点
57///
58/// `prev` をゴールの頂点からたどることで、最短経路を復元できます。
59///
60/// `s` から `t` への経路が存在しない場合 `dist[t]`、`prev[t]` は `None` です。
61///
62/// # Examples
63/// ```
64/// use dijkstra::{Edge, ConstEdge, dijkstra};
65/// let edges = vec![
66///     ConstEdge::new(0, 1, 1),
67///     ConstEdge::new(0, 2, 1),
68///     ConstEdge::new(1, 2, 1),
69///     ConstEdge::new(2, 3, 1),
70/// ];
71/// //
72/// //     0 -----> 1 -----> 2 -----> 3
73/// //     |                 ^
74/// //     |                 |
75/// //     +-----------------+
76/// //
77/// let (dist, prev) = dijkstra(4, &edges, 0);
78/// assert_eq!(dist[0], Some(0));
79/// assert_eq!(dist[1], Some(1));
80/// assert_eq!(dist[2], Some(1));
81/// assert_eq!(dist[3], Some(2));
82/// assert_eq!(prev[0], None);
83/// assert_eq!(prev[1], Some(0));
84/// assert_eq!(prev[2], Some(0));
85/// assert_eq!(prev[3], Some(2));
86/// ```
87pub fn dijkstra<E, T>(n: usize, edges: &[E], s: usize) -> (Vec<Option<T>>, Vec<Option<usize>>)
88where
89    E: Edge<T>,
90    T: Clone + Default + Ord + Debug,
91{
92    let mut adj = vec![vec![]; n];
93    for e in edges {
94        adj[e.from()].push(e);
95    }
96    let mut dist = vec![Option::<T>::None; n];
97    let mut heap = BinaryHeap::new();
98    let mut prev = vec![None; n];
99    dist[s] = Some(T::default());
100    heap.push((Reverse(T::default()), s));
101    while let Some((Reverse(d), v)) = heap.pop() {
102        #[allow(clippy::comparison_chain)]
103        match &dist[v] {
104            Some(dv) => {
105                if dv < &d {
106                    continue;
107                } else if dv > &d {
108                    unreachable!();
109                } else {
110                    assert_eq!(dv, &d);
111                }
112            }
113            None => unreachable!(),
114        }
115        for e in &adj[v] {
116            let next_d = e.dist(&d);
117            let to = e.to();
118            match &dist[to] {
119                Some(dt) if dt <= &next_d => {
120                    continue;
121                }
122                _ => {
123                    dist[to] = Some(next_d.clone());
124                    prev[to] = Some(v);
125                    heap.push((Reverse(next_d), to));
126                }
127            }
128        }
129    }
130    (dist, prev)
131}
132
133#[cfg(test)]
134mod tests {
135    use crate::{ConstEdge, dijkstra};
136    use rand::distributions::Uniform;
137    use rand::prelude::*;
138
139    #[allow(clippy::many_single_char_names)]
140    fn generate(n: usize, m: usize) -> Vec<(usize, usize, u64)> {
141        let nodes = Uniform::from(0..n);
142        let costs = Uniform::from(0..=1_000_000_000);
143        let mut rng = thread_rng();
144        (0..m)
145            .map(|_| {
146                let a = nodes.sample(&mut rng);
147                let b = nodes.sample(&mut rng);
148                let c = costs.sample(&mut rng);
149                (a, b, c)
150            })
151            .take(m)
152            .collect()
153    }
154
155    const INF: u64 = std::u64::MAX;
156
157    fn floyd_warshall(n: usize, edges: &Vec<(usize, usize, u64)>) -> Vec<u64> {
158        let mut d = vec![vec![INF; n]; n];
159        for i in 0..n {
160            d[i][i] = 0;
161        }
162        for &(a, b, c) in edges {
163            d[a][b] = d[a][b].min(c);
164        }
165        for k in 0..n {
166            for i in 0..n {
167                for j in 0..n {
168                    d[i][j] = d[i][j].min(d[i][k].saturating_add(d[k][j]));
169                }
170            }
171        }
172        d[0].clone()
173    }
174
175    #[test]
176    fn random_test() {
177        for n in 1..=10 {
178            for m in 0..=n * n {
179                let edges = generate(n, m);
180                let dd = floyd_warshall(n, &edges);
181                let edges = edges
182                    .into_iter()
183                    .map(|(a, b, c)| ConstEdge::new(a, b, c))
184                    .collect::<Vec<_>>();
185                let (d, _) = dijkstra(n, &edges, 0);
186                for v in 0..n {
187                    assert_eq!(d[v].unwrap_or(INF), dd[v]);
188                }
189            }
190        }
191    }
192}