/** * * Copied and substantially modified from petgraph's scored.rs and algo/astar.rs. * * Copyright (c) 2015 **/ use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::{BinaryHeap, HashMap, VecDeque}; use std::hash::Hash; use petgraph::algo::Measure; use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences, IntoEdges}; use thiserror::Error; use std::cmp::Ordering; use crate::step::Step; #[derive(Copy, Clone, Debug)] pub struct MinScored(pub K, pub T); impl PartialEq for MinScored { #[inline] fn eq(&self, other: &MinScored) -> bool { self.cmp(other) == Ordering::Equal } } impl Eq for MinScored {} impl PartialOrd for MinScored { #[inline] fn partial_cmp(&self, other: &MinScored) -> Option { Some(self.cmp(other)) } } impl Ord for MinScored { #[inline] fn cmp(&self, other: &MinScored) -> Ordering { let a = &self.0; let b = &other.0; if a == b { Ordering::Equal } else if a < b { Ordering::Greater } else if a > b { Ordering::Less } else if a.ne(a) && b.ne(b) { // these are the NaN cases Ordering::Equal } else if a.ne(a) { // Order NaN less, so that it is last in the MinScore order Ordering::Less } else { Ordering::Greater } } } #[derive(Debug)] pub struct PathTracker where G: GraphBase, G::NodeId: Eq + Hash, { came_from: HashMap, } impl PathTracker where G: GraphBase, G::NodeId: Eq + Hash, { fn new() -> PathTracker { PathTracker { came_from: HashMap::new(), } } fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) { self.came_from.insert(node, previous); } pub fn reconstruct_path_to(&self, last: G::NodeId) -> Vec { let mut path = vec![last]; let mut current = last; while let Some(&previous) = self.came_from.get(¤t) { path.push(previous); current = previous; } path.reverse(); path } } pub trait AstarStrategy where G: GraphBase, G::NodeId: Eq + Hash, for<'a> &'a G: IntoEdges + MakeEdgeRef, K: Measure + Copy, { fn is_goal(&mut self, graph: &G, node: G::NodeId, tracker: &PathTracker) -> Option; fn place_probe<'a>( &mut self, graph: &'a G, edge: <&'a G as IntoEdgeReferences>::EdgeRef, ) -> Option; fn remove_probe<'a>(&mut self, graph: &'a G); fn estimate_cost(&mut self, graph: &G, node: G::NodeId) -> K; } pub trait MakeEdgeRef: IntoEdgeReferences { fn edge_ref(&self, edge_id: Self::EdgeId) -> Self::EdgeRef; } pub struct Astar where G: GraphBase, G::NodeId: Eq + Hash, for<'a> &'a G: IntoEdges + MakeEdgeRef, K: Measure + Copy, { pub graph: G, pub visit_next: BinaryHeap>, pub scores: HashMap, pub estimate_scores: HashMap, pub path_tracker: PathTracker, pub maybe_curr_node: Option, // FIXME: To work around edge references borrowing from the graph we collect then reiterate over tem. pub edge_ids: VecDeque, // TODO: Rewrite this to be a well-designed state machine. pub is_probing: bool, } #[derive(Error, Debug, Clone)] pub enum AstarError { #[error("A* search found no path")] NotFound, } #[derive(Debug)] pub enum AstarStatus where G: GraphBase, G::NodeId: Eq + Hash, for<'a> &'a G: IntoEdges + MakeEdgeRef, K: Measure + Copy, { Probing, Probed, Visited, Finished(K, Vec, R), } impl TryInto<(K, Vec, R)> for AstarStatus where G: GraphBase, G::NodeId: Eq + Hash, for<'a> &'a G: IntoEdges + MakeEdgeRef, K: Measure + Copy, { type Error = (); fn try_into(self) -> Result<(K, Vec, R), ()> { match self { AstarStatus::Probing => Err(()), AstarStatus::Probed => Err(()), AstarStatus::Visited => Err(()), AstarStatus::Finished(cost, path, result) => Ok((cost, path, result)), } } } impl Astar where G: GraphBase, G::NodeId: Eq + Hash, for<'a> &'a G: IntoEdges + MakeEdgeRef, K: Measure + Copy, { pub fn new(graph: G, start: G::NodeId, strategy: &mut impl AstarStrategy) -> Self { let mut this = Self { graph, visit_next: BinaryHeap::new(), scores: HashMap::new(), estimate_scores: HashMap::new(), path_tracker: PathTracker::::new(), maybe_curr_node: None, edge_ids: VecDeque::new(), is_probing: false, }; let zero_score = K::default(); this.scores.insert(start, zero_score); this.visit_next.push(MinScored( strategy.estimate_cost(&&this.graph, start), start, )); this } } impl> Step, AstarError, (K, Vec, R)> for Astar where G: GraphBase, G::NodeId: Eq + Hash, for<'a> &'a G: IntoEdges + MakeEdgeRef, K: Measure + Copy, { fn step(&mut self, strategy: &mut S) -> Result, AstarError> { if let Some(curr_node) = self.maybe_curr_node { if self.is_probing { strategy.remove_probe(&self.graph); self.is_probing = false; } if let Some(edge_id) = self.edge_ids.pop_front() { // This lookup can be unwrapped without fear of panic since the node was // necessarily scored before adding it to `visit_next`. let node_score = self.scores[&curr_node]; let edge = (&self.graph).edge_ref(edge_id); if let Some(edge_cost) = strategy.place_probe(&self.graph, edge) { let next = edge.target(); let next_score = node_score + edge_cost; match self.scores.entry(next) { Occupied(mut entry) => { // No need to add neighbors that we have already reached through a // shorter path than now. if *entry.get() <= next_score { return Ok(AstarStatus::Probed); } entry.insert(next_score); } Vacant(entry) => { entry.insert(next_score); } } self.path_tracker.set_predecessor(next, curr_node); let next_estimate_score = next_score + strategy.estimate_cost(&self.graph, next); self.visit_next.push(MinScored(next_estimate_score, next)); self.is_probing = true; return Ok(AstarStatus::Probing); } return Ok(AstarStatus::Probed); } self.maybe_curr_node = None; } let Some(MinScored(estimate_score, node)) = self.visit_next.pop() else { return Err(AstarError::NotFound); }; if let Some(result) = strategy.is_goal(&self.graph, node, &self.path_tracker) { let path = self.path_tracker.reconstruct_path_to(node); let cost = self.scores[&node]; return Ok(AstarStatus::Finished(cost, path, result)); } match self.estimate_scores.entry(node) { Occupied(mut entry) => { // If the node has already been visited with an equal or lower score than // now, then we do not need to re-visit it. if *entry.get() <= estimate_score { return Ok(AstarStatus::Visited); } entry.insert(estimate_score); } Vacant(entry) => { entry.insert(estimate_score); } } self.maybe_curr_node = Some(node); self.edge_ids = self.graph.edges(node).map(|edge| edge.id()).collect(); Ok(AstarStatus::Visited) } }