feat(pie/PmgAstar): support threading a context along

This commit is contained in:
Ellen Emilia Anna Zscheile 2025-03-23 10:10:51 +01:00
parent 13f2400c45
commit 8e3be44e18
1 changed files with 103 additions and 61 deletions

View File

@ -18,7 +18,7 @@ use num_traits::float::TotalOrder;
/// A walk task /// A walk task
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Task<B: NavmeshBase> { pub struct Task<B: NavmeshBase, Ctx> {
/// index of current goal /// index of current goal
pub goal_idx: usize, pub goal_idx: usize,
@ -42,11 +42,14 @@ pub struct Task<B: NavmeshBase> {
/// the introduction position re: `selected_node` /// the introduction position re: `selected_node`
pub cur_intro: usize, pub cur_intro: usize,
/// associated context (ignored during comparisons)
pub context: Ctx,
} }
/// Results after a [`Task`] is done. /// Results after a [`Task`] is done.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TaskResult<B: NavmeshBase> { pub struct TaskResult<B: NavmeshBase, Ctx> {
/// index of current goal /// index of current goal
pub goal_idx: usize, pub goal_idx: usize,
@ -61,13 +64,32 @@ pub struct TaskResult<B: NavmeshBase> {
/// the introduction position re: `target` /// the introduction position re: `target`
pub cur_intro: usize, pub cur_intro: usize,
/// the associated context
pub context: Ctx,
}
pub trait EvaluateNavmesh<B: NavmeshBase, Ctx>:
Fn(NavmeshRef<B>, &Ctx, EdgeIndex<NavmeshIndex<B::PrimalNodeIndex>>) -> Option<(B::Scalar, Ctx)>
{
}
impl<B, Ctx, F> EvaluateNavmesh<B, Ctx> for F
where
B: NavmeshBase,
F: Fn(
NavmeshRef<B>,
&Ctx,
EdgeIndex<NavmeshIndex<B::PrimalNodeIndex>>,
) -> Option<(B::Scalar, Ctx)>,
{
} }
/// The main path search data structure /// The main path search data structure
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PmgAstar<B: NavmeshBase> { pub struct PmgAstar<B: NavmeshBase, Ctx> {
/// task queue, ordered by costs ascending /// task queue, ordered by costs ascending
pub queue: BinaryHeap<Task<B>>, pub queue: BinaryHeap<Task<B, Ctx>>,
// constant data // constant data
pub nodes: pub nodes:
@ -78,7 +100,7 @@ pub struct PmgAstar<B: NavmeshBase> {
pub goals: Box<[PreparedGoal<B>]>, pub goals: Box<[PreparedGoal<B>]>,
} }
impl<B: NavmeshBase> Task<B> impl<B: NavmeshBase, Ctx> Task<B, Ctx>
where where
B::Scalar: num_traits::Float, B::Scalar: num_traits::Float,
{ {
@ -91,7 +113,7 @@ where
} }
} }
impl<B: NavmeshBase> PartialEq for Task<B> impl<B: NavmeshBase, Ctx> PartialEq for Task<B, Ctx>
where where
B::PrimalNodeIndex: Ord, B::PrimalNodeIndex: Ord,
B::EtchedPath: PartialOrd, B::EtchedPath: PartialOrd,
@ -115,7 +137,7 @@ where
} }
} }
impl<B: NavmeshBase> Eq for Task<B> impl<B: NavmeshBase, Ctx> Eq for Task<B, Ctx>
where where
B::PrimalNodeIndex: Ord, B::PrimalNodeIndex: Ord,
B::EtchedPath: PartialOrd, B::EtchedPath: PartialOrd,
@ -125,7 +147,7 @@ where
} }
// tasks are ordered such that smaller costs and higher goal indices are ordered as being larger (better) // tasks are ordered such that smaller costs and higher goal indices are ordered as being larger (better)
impl<B: NavmeshBase> Ord for Task<B> impl<B: NavmeshBase, Ctx> Ord for Task<B, Ctx>
where where
B::PrimalNodeIndex: Ord, B::PrimalNodeIndex: Ord,
B::EtchedPath: PartialOrd, B::EtchedPath: PartialOrd,
@ -153,7 +175,7 @@ where
} }
} }
impl<B: NavmeshBase> PartialOrd for Task<B> impl<B: NavmeshBase, Ctx> PartialOrd for Task<B, Ctx>
where where
B::PrimalNodeIndex: Ord, B::PrimalNodeIndex: Ord,
B::EtchedPath: PartialOrd, B::EtchedPath: PartialOrd,
@ -165,7 +187,9 @@ where
} }
} }
impl<B: NavmeshBase<Scalar = Scalar>, Scalar: num_traits::Float + core::iter::Sum> PmgAstar<B> { impl<B: NavmeshBase<Scalar = Scalar>, Scalar: num_traits::Float + core::iter::Sum, Ctx>
PmgAstar<B, Ctx>
{
fn estimate_remaining_goals_costs(&self, start_goal_idx: usize) -> Scalar { fn estimate_remaining_goals_costs(&self, start_goal_idx: usize) -> Scalar {
self.goals self.goals
.get(start_goal_idx + 1..) .get(start_goal_idx + 1..)
@ -180,13 +204,14 @@ where
B::Scalar: num_traits::Float + core::iter::Sum, B::Scalar: num_traits::Float + core::iter::Sum,
{ {
/// start processing the goal /// start processing the goal
fn start_pmga<'a, F: Fn(NavmeshRef<B>) -> Option<B::Scalar>>( fn start_pmga<'a, Ctx, F: EvaluateNavmesh<B, Ctx>>(
&'a self, &'a self,
navmesh: NavmeshRef<'a, B>, navmesh: NavmeshRef<'a, B>,
goal_idx: usize, goal_idx: usize,
env: &'a PmgAstar<B>, env: &'a PmgAstar<B, Ctx>,
context: &'a Ctx,
evaluate_navmesh: &'a F, evaluate_navmesh: &'a F,
) -> Option<impl Iterator<Item = Task<B>> + 'a> { ) -> Option<impl Iterator<Item = Task<B, Ctx>> + 'a> {
let source = NavmeshIndex::Primal(self.source.clone()); let source = NavmeshIndex::Primal(self.source.clone());
let estimated_remaining_goals = env.estimate_remaining_goals_costs(goal_idx); let estimated_remaining_goals = env.estimate_remaining_goals_costs(goal_idx);
Some( Some(
@ -206,6 +231,7 @@ where
} }
}) })
.flat_map(move |(neigh, epi, edge_len)| { .flat_map(move |(neigh, epi, edge_len)| {
let eidx = EdgeIndex::from((source.clone(), neigh.clone()));
let source = source.clone(); let source = source.clone();
// A*-like remaining costs estimation // A*-like remaining costs estimation
let estimated_remaining = let estimated_remaining =
@ -220,7 +246,8 @@ where
navmesh.access_edge_paths_mut(epi).with_borrow_mut(|mut j| { navmesh.access_edge_paths_mut(epi).with_borrow_mut(|mut j| {
j.insert(i, RelaxedPath::Normal(self.label.clone())) j.insert(i, RelaxedPath::Normal(self.label.clone()))
}); });
evaluate_navmesh(navmesh.as_ref()).map(|costs| Task { evaluate_navmesh(navmesh.as_ref(), context, eidx.clone()).map(
|(costs, context)| Task {
goal_idx, goal_idx,
costs, costs,
estimated_remaining, estimated_remaining,
@ -229,27 +256,26 @@ where
selected_node: neigh.clone(), selected_node: neigh.clone(),
prev_node: source.clone(), prev_node: source.clone(),
cur_intro: edge_len - i, cur_intro: edge_len - i,
}) context,
},
)
}) })
}), }),
) )
} }
} }
impl<B: NavmeshBase> Task<B> impl<B: NavmeshBase, Ctx> Task<B, Ctx>
where where
B::EtchedPath: PartialOrd, B::EtchedPath: PartialOrd,
B::GapComment: Clone + PartialOrd, B::GapComment: Clone + PartialOrd,
B::Scalar: num_traits::Float + num_traits::float::TotalOrder, B::Scalar: num_traits::Float + num_traits::float::TotalOrder,
{ {
pub fn run<F>( pub fn run<F: EvaluateNavmesh<B, Ctx>>(
self, self,
env: &mut PmgAstar<B>, env: &mut PmgAstar<B, Ctx>,
evaluate_navmesh: F, evaluate_navmesh: F,
) -> ControlFlow<TaskResult<B>, (Self, Vec<NavmeshIndex<B::PrimalNodeIndex>>)> ) -> ControlFlow<TaskResult<B, Ctx>, (Self, Vec<NavmeshIndex<B::PrimalNodeIndex>>)> {
where
F: Fn(NavmeshRef<B>) -> Option<B::Scalar>,
{
if let NavmeshIndex::Primal(primal) = &self.selected_node { if let NavmeshIndex::Primal(primal) = &self.selected_node {
if env.goals[self.goal_idx].target.contains(primal) { if env.goals[self.goal_idx].target.contains(primal) {
let Self { let Self {
@ -261,6 +287,7 @@ where
prev_node, prev_node,
cur_intro, cur_intro,
selected_node: _, selected_node: _,
context,
} = self; } = self;
return ControlFlow::Break(TaskResult { return ControlFlow::Break(TaskResult {
goal_idx, goal_idx,
@ -268,6 +295,7 @@ where
edge_paths, edge_paths,
prev_node, prev_node,
cur_intro, cur_intro,
context,
}); });
} else { } else {
panic!("wrong primal node selected"); panic!("wrong primal node selected");
@ -278,14 +306,11 @@ where
} }
/// progress to the next step, splitting the task into new tasks (make sure to call `done` beforehand) /// progress to the next step, splitting the task into new tasks (make sure to call `done` beforehand)
fn progress<F>( fn progress<F: EvaluateNavmesh<B, Ctx>>(
&self, &self,
env: &mut PmgAstar<B>, env: &mut PmgAstar<B, Ctx>,
evaluate_navmesh: F, evaluate_navmesh: F,
) -> Vec<NavmeshIndex<B::PrimalNodeIndex>> ) -> Vec<NavmeshIndex<B::PrimalNodeIndex>> {
where
F: Fn(NavmeshRef<B>) -> Option<B::Scalar>,
{
let goal_idx = self.goal_idx; let goal_idx = self.goal_idx;
let navmesh = NavmeshRef { let navmesh = NavmeshRef {
nodes: &env.nodes, nodes: &env.nodes,
@ -320,6 +345,7 @@ where
edges: &env.edges, edges: &env.edges,
edge_paths: &mut edge_paths, edge_paths: &mut edge_paths,
}; };
let eidx = EdgeIndex::from((self.selected_node.clone(), neigh.clone()));
let cur_intro = navmesh let cur_intro = navmesh
.edge_data_mut(self.selected_node.clone(), neigh.clone()) .edge_data_mut(self.selected_node.clone(), neigh.clone())
.unwrap() .unwrap()
@ -331,7 +357,8 @@ where
x.len() - stop_data.insert_pos - 1 x.len() - stop_data.insert_pos - 1
}); });
ret.push(neigh.clone()); ret.push(neigh.clone());
evaluate_navmesh(navmesh.as_ref()).map(|costs| Task { evaluate_navmesh(navmesh.as_ref(), &self.context, eidx).map(|(costs, context)| {
Task {
goal_idx, goal_idx,
costs, costs,
estimated_remaining, estimated_remaining,
@ -340,6 +367,8 @@ where
selected_node: neigh.clone(), selected_node: neigh.clone(),
prev_node: self.selected_node.clone(), prev_node: self.selected_node.clone(),
cur_intro, cur_intro,
context,
}
}) })
})); }));
@ -348,8 +377,10 @@ where
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct IntermedResult<B: NavmeshBase> { pub struct IntermedResult<B: NavmeshBase, Ctx> {
pub edge_paths: Box<[EdgePaths<B::EtchedPath, B::GapComment>]>, pub edge_paths: Box<[EdgePaths<B::EtchedPath, B::GapComment>]>,
// TODO: maybe avoid these clones?
pub context: Ctx,
pub goal_idx: usize, pub goal_idx: usize,
pub forks: Vec<NavmeshIndex<B::PrimalNodeIndex>>, pub forks: Vec<NavmeshIndex<B::PrimalNodeIndex>>,
@ -358,7 +389,7 @@ pub struct IntermedResult<B: NavmeshBase> {
pub maybe_finished_goal: Option<B::Scalar>, pub maybe_finished_goal: Option<B::Scalar>,
} }
impl<B> PmgAstar<B> impl<B, Ctx> PmgAstar<B, Ctx>
where where
B: NavmeshBase, B: NavmeshBase,
B::EtchedPath: PartialOrd, B::EtchedPath: PartialOrd,
@ -370,14 +401,12 @@ where
+ num_traits::float::TotalOrder, + num_traits::float::TotalOrder,
{ {
/// * `evaluate_navmesh` calculates the exact cost of a given navmesh (lower cost is better) /// * `evaluate_navmesh` calculates the exact cost of a given navmesh (lower cost is better)
pub fn new<F>( pub fn new<F: EvaluateNavmesh<B, Ctx>>(
navmesh: &Navmesh<B>, navmesh: &Navmesh<B>,
goals: Vec<Goal<B::PrimalNodeIndex, B::EtchedPath>>, goals: Vec<Goal<B::PrimalNodeIndex, B::EtchedPath>>,
context: &Ctx,
evaluate_navmesh: F, evaluate_navmesh: F,
) -> Self ) -> Self {
where
F: Fn(NavmeshRef<B>) -> Option<B::Scalar>,
{
let mut this = Self { let mut this = Self {
queue: BinaryHeap::new(), queue: BinaryHeap::new(),
goals: goals goals: goals
@ -400,7 +429,7 @@ where
edge_paths: &navmesh.edge_paths, edge_paths: &navmesh.edge_paths,
}; };
let tmp = if let Some(iter) = let tmp = if let Some(iter) =
first_goal.start_pmga(navmesh, 0, &this, &evaluate_navmesh) first_goal.start_pmga(navmesh, 0, &this, context, &evaluate_navmesh)
{ {
iter.collect() iter.collect()
} else { } else {
@ -418,16 +447,19 @@ where
} }
/// run one step of the path-search /// run one step of the path-search
pub fn step<F>( pub fn step<F: EvaluateNavmesh<B, Ctx>>(
&mut self, &mut self,
evaluate_navmesh: F, evaluate_navmesh: F,
) -> ControlFlow< ) -> ControlFlow<
Option<(B::Scalar, Box<[EdgePaths<B::EtchedPath, B::GapComment>]>)>, Option<(
IntermedResult<B>, B::Scalar,
Box<[EdgePaths<B::EtchedPath, B::GapComment>]>,
Ctx,
)>,
IntermedResult<B, Ctx>,
> >
where where
B::PrimalNodeIndex: core::fmt::Debug, B::PrimalNodeIndex: core::fmt::Debug,
F: Fn(NavmeshRef<B>) -> Option<B::Scalar>,
{ {
let Some(task) = self.queue.pop() else { let Some(task) = self.queue.pop() else {
log::info!("found no complete result"); log::info!("found no complete result");
@ -450,7 +482,11 @@ where
edge_count, edge_count,
taskres.costs taskres.costs
); );
return ControlFlow::Break(Some((taskres.costs, taskres.edge_paths))); return ControlFlow::Break(Some((
taskres.costs,
taskres.edge_paths,
taskres.context,
)));
} }
Some(next_goal) => { Some(next_goal) => {
// prepare next goal // prepare next goal
@ -460,9 +496,13 @@ where
edge_count, edge_count,
taskres.costs, taskres.costs,
); );
let mut tmp = if let Some(iter) = let mut tmp = if let Some(iter) = next_goal.start_pmga(
next_goal.start_pmga(navmesh, next_goal_idx, self, &evaluate_navmesh) navmesh,
{ next_goal_idx,
self,
&taskres.context,
&evaluate_navmesh,
) {
iter.collect() iter.collect()
} else { } else {
BinaryHeap::new() BinaryHeap::new()
@ -473,6 +513,7 @@ where
goal_idx: taskres.goal_idx, goal_idx: taskres.goal_idx,
forks, forks,
edge_paths: taskres.edge_paths, edge_paths: taskres.edge_paths,
context: taskres.context,
selected_node: NavmeshIndex::Primal(next_goal.source.clone()), selected_node: NavmeshIndex::Primal(next_goal.source.clone()),
maybe_finished_goal: Some(taskres.costs), maybe_finished_goal: Some(taskres.costs),
} }
@ -485,6 +526,7 @@ where
goal_idx: task.goal_idx, goal_idx: task.goal_idx,
forks, forks,
edge_paths: task.edge_paths, edge_paths: task.edge_paths,
context: task.context,
selected_node: task.selected_node, selected_node: task.selected_node,
maybe_finished_goal: None, maybe_finished_goal: None,
} }