dijkstra/
lib.rs

1use std::cmp::Reverse;
2use std::collections::BinaryHeap;
3use std::fmt::Debug;
4use std::ops::Add;
5
6/// グラフの辺を表すトレイトです。
7pub trait Edge<T> {
8    fn from(&self) -> usize;
9    fn to(&self) -> usize;
10    /// 始点から [`from`] までの距離 `d` を受け取り、この辺を辿って [`to`] へ行く最短距離を求めます。[`dijkstra`] が正しく動くように、この関数は次の条件を満たように実装してください。[参考情報](https://fetburner.hatenablog.com/entry/2021/02/28/200020)。
11    ///
12    /// - `dist(d)` は `d` 以上である
13    /// - `dist(d)` は `d` について (広義) 単調増加である
14    ///
15    /// 使用例は [ABC192E](https://atcoder.jp/contests/abc192/submissions/26105492) をどうぞ。
16    ///
17    /// [`from`]: trait.Edge.html#tymethod.from
18    /// [`to`]: trait.Edge.html#tymethod.to
19    /// [`dijkstra`]: fn.dijkstra.html
20    fn dist(&self, d: T) -> T;
21}
22
23/// 長さが定数の辺です。
24#[derive(Copy, Clone)]
25pub struct ConstEdge<T> {
26    from: usize,
27    to: usize,
28    cost: T,
29}
30
31impl<T> ConstEdge<T> {
32    pub fn new(from: usize, to: usize, cost: T) -> Self {
33        Self { from, to, cost }
34    }
35}
36
37impl<T> Edge<T> for ConstEdge<T>
38where
39    T: Copy + Add<Output = T>,
40{
41    fn from(&self) -> usize {
42        self.from
43    }
44    fn to(&self) -> usize {
45        self.to
46    }
47    fn dist(&self, d: T) -> T {
48        d + self.cost
49    }
50}
51
52/// `dijkstra` はあるひとつの頂点から全ての頂点への最短距離を計算します。
53///
54/// 返り値 `(d, prev)` はそれぞれ以下です。
55///
56/// - `d[t]`: `s` から `t` までの最短距離
57/// - `prev[t]`: `s` を根とする最短経路木における `t` の親頂点
58///
59/// `prev` をゴールの頂点からたどることで、最短経路を復元できます。
60///
61/// `s` から `t` への経路が存在しない場合 `d[t]`、`prev[t]` は `None` です。
62///
63/// # Examples
64/// ```
65/// use dijkstra::{Edge, ConstEdge, dijkstra};
66/// let edges = vec![
67///     ConstEdge::new(0, 1, 1),
68///     ConstEdge::new(0, 2, 1),
69///     ConstEdge::new(1, 2, 1),
70///     ConstEdge::new(2, 3, 1),
71/// ];
72/// //
73/// //     0 -----> 1 -----> 2 -----> 3
74/// //     |                 ^
75/// //     |                 |
76/// //     +-----------------+
77/// //
78/// let (d, prev) = dijkstra(4, &edges, 0);
79/// assert_eq!(d[0], Some(0));
80/// assert_eq!(d[1], Some(1));
81/// assert_eq!(d[2], Some(1));
82/// assert_eq!(d[3], Some(2));
83/// assert_eq!(prev[0], None);
84/// assert_eq!(prev[1], Some(0));
85/// assert_eq!(prev[2], Some(0));
86/// assert_eq!(prev[3], Some(2));
87/// ```
88pub fn dijkstra<E, T>(n: usize, edges: &[E], s: usize) -> (Vec<Option<T>>, Vec<Option<usize>>)
89where
90    E: Edge<T> + Clone,
91    T: Copy + Add<Output = T> + Default + Ord + Debug,
92{
93    let mut adj = vec![vec![]; n];
94    for e in edges {
95        adj[e.from()].push(e);
96    }
97    let mut dist = vec![None; n];
98    let mut heap = BinaryHeap::new();
99    let mut prev = vec![None; n];
100    dist[s] = Some(T::default());
101    heap.push((Reverse(T::default()), s));
102    while let Some((Reverse(d), v)) = heap.pop() {
103        #[allow(clippy::comparison_chain)]
104        match dist[v] {
105            Some(dv) => {
106                if dv < d {
107                    continue;
108                } else if dv > d {
109                    unreachable!();
110                } else {
111                    assert_eq!(dv, d);
112                }
113            }
114            None => unreachable!(),
115        }
116        for e in &adj[v] {
117            let next_d = e.dist(d);
118            let to = e.to();
119            match dist[to] {
120                Some(dt) if dt <= next_d => {
121                    continue;
122                }
123                _ => {
124                    dist[to] = Some(next_d);
125                    prev[to] = Some(v);
126                    heap.push((Reverse(next_d), to));
127                }
128            }
129        }
130    }
131    (dist, prev)
132}
133
134#[cfg(test)]
135mod tests {
136    use crate::{dijkstra, ConstEdge};
137    use rand::distributions::Uniform;
138    use rand::prelude::*;
139
140    #[allow(clippy::many_single_char_names)]
141    fn generate(n: usize, m: usize) -> Vec<(usize, usize, u64)> {
142        let nodes = Uniform::from(0..n);
143        let costs = Uniform::from(0..=1_000_000_000);
144        let mut rng = thread_rng();
145        (0..m)
146            .map(|_| {
147                let a = nodes.sample(&mut rng);
148                let b = nodes.sample(&mut rng);
149                let c = costs.sample(&mut rng);
150                (a, b, c)
151            })
152            .take(m)
153            .collect()
154    }
155
156    const INF: u64 = std::u64::MAX;
157
158    fn floyd_warshall(n: usize, edges: &Vec<(usize, usize, u64)>) -> Vec<u64> {
159        let mut d = vec![vec![INF; n]; n];
160        for i in 0..n {
161            d[i][i] = 0;
162        }
163        for &(a, b, c) in edges {
164            d[a][b] = d[a][b].min(c);
165        }
166        for k in 0..n {
167            for i in 0..n {
168                for j in 0..n {
169                    d[i][j] = d[i][j].min(d[i][k].saturating_add(d[k][j]));
170                }
171            }
172        }
173        d[0].clone()
174    }
175
176    #[test]
177    fn random_test() {
178        for n in 1..=10 {
179            for m in 0..=n * n {
180                let edges = generate(n, m);
181                let dd = floyd_warshall(n, &edges);
182                let edges = edges
183                    .into_iter()
184                    .map(|(a, b, c)| ConstEdge::new(a, b, c))
185                    .collect::<Vec<_>>();
186                let (d, _) = dijkstra(n, &edges, 0);
187                for v in 0..n {
188                    assert_eq!(d[v].unwrap_or(INF), dd[v]);
189                }
190            }
191        }
192    }
193}