diff options
Diffstat (limited to 'third_party/rust/petgraph/src/astar.rs')
-rw-r--r-- | third_party/rust/petgraph/src/astar.rs | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/third_party/rust/petgraph/src/astar.rs b/third_party/rust/petgraph/src/astar.rs new file mode 100644 index 0000000000..56b3e2eb0c --- /dev/null +++ b/third_party/rust/petgraph/src/astar.rs @@ -0,0 +1,175 @@ +use std::collections::hash_map::Entry::{Occupied, Vacant}; +use std::collections::{BinaryHeap, HashMap}; + +use std::hash::Hash; + +use super::visit::{EdgeRef, GraphBase, IntoEdges, VisitMap, Visitable}; +use crate::scored::MinScored; + +use crate::algo::Measure; + +/// \[Generic\] A* shortest path algorithm. +/// +/// Computes the shortest path from `start` to `finish`, including the total path cost. +/// +/// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the +/// given node is the finish node. +/// +/// The function `edge_cost` should return the cost for a particular edge. Edge costs must be +/// non-negative. +/// +/// The function `estimate_cost` should return the estimated cost to the finish for a particular +/// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that +/// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs +/// must also be non-negative. +/// +/// The graph should be `Visitable` and implement `IntoEdges`. +/// +/// # Example +/// ``` +/// use petgraph::Graph; +/// use petgraph::algo::astar; +/// +/// let mut g = Graph::new(); +/// let a = g.add_node((0., 0.)); +/// let b = g.add_node((2., 0.)); +/// let c = g.add_node((1., 1.)); +/// let d = g.add_node((0., 2.)); +/// let e = g.add_node((3., 3.)); +/// let f = g.add_node((4., 2.)); +/// g.extend_with_edges(&[ +/// (a, b, 2), +/// (a, d, 4), +/// (b, c, 1), +/// (b, f, 7), +/// (c, e, 5), +/// (e, f, 1), +/// (d, e, 1), +/// ]); +/// +/// // Graph represented with the weight of each edge +/// // Edges with '*' are part of the optimal path. +/// // +/// // 2 1 +/// // a ----- b ----- c +/// // | 4* | 7 | +/// // d f | 5 +/// // | 1* | 1* | +/// // \------ e ------/ +/// +/// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0); +/// assert_eq!(path, Some((6, vec![a, d, e, f]))); +/// ``` +/// +/// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was +/// found. +pub fn astar<G, F, H, K, IsGoal>( + graph: G, + start: G::NodeId, + mut is_goal: IsGoal, + mut edge_cost: F, + mut estimate_cost: H, +) -> Option<(K, Vec<G::NodeId>)> +where + G: IntoEdges + Visitable, + IsGoal: FnMut(G::NodeId) -> bool, + G::NodeId: Eq + Hash, + F: FnMut(G::EdgeRef) -> K, + H: FnMut(G::NodeId) -> K, + K: Measure + Copy, +{ + let mut visited = graph.visit_map(); + let mut visit_next = BinaryHeap::new(); + let mut scores = HashMap::new(); + let mut path_tracker = PathTracker::<G>::new(); + + let zero_score = K::default(); + scores.insert(start, zero_score); + visit_next.push(MinScored(estimate_cost(start), start)); + + while let Some(MinScored(_, node)) = visit_next.pop() { + if is_goal(node) { + let path = path_tracker.reconstruct_path_to(node); + let cost = scores[&node]; + return Some((cost, path)); + } + + // Don't visit the same node several times, as the first time it was visited it was using + // the shortest available path. + if !visited.visit(node) { + continue; + } + + // This lookup can be unwrapped without fear of panic since the node was necessarily scored + // before adding him to `visit_next`. + let node_score = scores[&node]; + + for edge in graph.edges(node) { + let next = edge.target(); + if visited.is_visited(&next) { + continue; + } + + let mut next_score = node_score + edge_cost(edge); + + match scores.entry(next) { + Occupied(ent) => { + let old_score = *ent.get(); + if next_score < old_score { + *ent.into_mut() = next_score; + path_tracker.set_predecessor(next, node); + } else { + next_score = old_score; + } + } + Vacant(ent) => { + ent.insert(next_score); + path_tracker.set_predecessor(next, node); + } + } + + let next_estimate_score = next_score + estimate_cost(next); + visit_next.push(MinScored(next_estimate_score, next)); + } + } + + None +} + +struct PathTracker<G> +where + G: GraphBase, + G::NodeId: Eq + Hash, +{ + came_from: HashMap<G::NodeId, G::NodeId>, +} + +impl<G> PathTracker<G> +where + G: GraphBase, + G::NodeId: Eq + Hash, +{ + fn new() -> PathTracker<G> { + PathTracker { + came_from: HashMap::new(), + } + } + + fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) { + self.came_from.insert(node, previous); + } + + fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> { + let mut path = vec![last]; + + let mut current = last; + while let Some(&previous) = self.came_from.get(¤t) { + path.push(previous); + current = previous; + } + + path.reverse(); + + path + } +} |