wifi-densepose/vendor/ruvector/examples/benchmarks/src/temporal.rs

2319 lines
85 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Temporal Reasoning Benchmark Framework
//!
//! Implements temporal constraint solving and benchmarking based on:
//! - TimePuzzles benchmark methodology
//! - Tool-augmented iterative temporal reasoning
//! - Calendar math and cross-cultural date systems
use anyhow::{anyhow, Result};
use chrono::{Datelike, NaiveDate, Weekday};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Temporal constraint types
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum TemporalConstraint {
/// Date is exactly this value
Exact(NaiveDate),
/// Date is after this date
After(NaiveDate),
/// Date is before this date
Before(NaiveDate),
/// Date is between two dates (inclusive)
Between(NaiveDate, NaiveDate),
/// Date is on a specific day of week
DayOfWeek(Weekday),
/// Date is N days after reference
DaysAfter(String, i64),
/// Date is N days before reference
DaysBefore(String, i64),
/// Date is in a specific month
InMonth(u32),
/// Date is in a specific year
InYear(i32),
/// Date is a specific day of month
DayOfMonth(u32),
/// Relative to a named event (e.g., "Easter", "Chinese New Year")
RelativeToEvent(String, i64),
}
/// A temporal puzzle with constraints
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TemporalPuzzle {
/// Unique puzzle ID
pub id: String,
/// Human-readable description
pub description: String,
/// Constraints that define the puzzle
pub constraints: Vec<TemporalConstraint>,
/// Named reference dates
pub references: HashMap<String, NaiveDate>,
/// Valid solution dates (for evaluation)
pub solutions: Vec<NaiveDate>,
/// Difficulty level (1-10)
pub difficulty: u8,
/// Tags for categorization
pub tags: Vec<String>,
/// Multi-dimensional difficulty vector (None = use scalar difficulty)
pub difficulty_vector: Option<crate::timepuzzles::DifficultyVector>,
}
impl TemporalPuzzle {
/// Create a new puzzle
pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
Self {
id: id.into(),
description: description.into(),
constraints: Vec::new(),
references: HashMap::new(),
solutions: Vec::new(),
difficulty: 5,
tags: Vec::new(),
difficulty_vector: None,
}
}
/// Add a constraint
pub fn with_constraint(mut self, constraint: TemporalConstraint) -> Self {
self.constraints.push(constraint);
self
}
/// Add a reference date
pub fn with_reference(mut self, name: impl Into<String>, date: NaiveDate) -> Self {
self.references.insert(name.into(), date);
self
}
/// Set solution dates
pub fn with_solutions(mut self, solutions: Vec<NaiveDate>) -> Self {
self.solutions = solutions;
self
}
/// Set difficulty
pub fn with_difficulty(mut self, difficulty: u8) -> Self {
self.difficulty = difficulty.min(10).max(1);
self
}
/// Add tags
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
/// Check if a date satisfies all constraints
pub fn check_date(&self, date: NaiveDate) -> Result<bool> {
for constraint in &self.constraints {
if !self.check_constraint(date, constraint)? {
return Ok(false);
}
}
Ok(true)
}
/// Check a single constraint
fn check_constraint(&self, date: NaiveDate, constraint: &TemporalConstraint) -> Result<bool> {
match constraint {
TemporalConstraint::Exact(d) => Ok(date == *d),
TemporalConstraint::After(d) => Ok(date > *d),
TemporalConstraint::Before(d) => Ok(date < *d),
TemporalConstraint::Between(start, end) => Ok(date >= *start && date <= *end),
TemporalConstraint::DayOfWeek(dow) => Ok(date.weekday() == *dow),
TemporalConstraint::DaysAfter(ref_name, days) => {
let ref_date = self
.references
.get(ref_name)
.ok_or_else(|| anyhow!("Unknown reference: {}", ref_name))?;
let target = *ref_date + chrono::Duration::days(*days);
Ok(date == target)
}
TemporalConstraint::DaysBefore(ref_name, days) => {
let ref_date = self
.references
.get(ref_name)
.ok_or_else(|| anyhow!("Unknown reference: {}", ref_name))?;
let target = *ref_date - chrono::Duration::days(*days);
Ok(date == target)
}
TemporalConstraint::InMonth(month) => Ok(date.month() == *month),
TemporalConstraint::InYear(year) => Ok(date.year() == *year),
TemporalConstraint::DayOfMonth(day) => Ok(date.day() == *day),
TemporalConstraint::RelativeToEvent(event_name, days) => {
// Look up event in references
let event_date = self
.references
.get(event_name)
.ok_or_else(|| anyhow!("Unknown event: {}", event_name))?;
let target = *event_date + chrono::Duration::days(*days);
Ok(date == target)
}
}
}
/// Solve the puzzle by searching date space
pub fn solve(&self, search_range: (NaiveDate, NaiveDate)) -> Result<Vec<NaiveDate>> {
let mut solutions = Vec::new();
let mut current = search_range.0;
while current <= search_range.1 {
if self.check_date(current)? {
solutions.push(current);
}
current = current.succ_opt().unwrap_or(current);
}
Ok(solutions)
}
}
/// Puzzle solver with tool augmentation
#[derive(Clone, Debug)]
pub struct TemporalSolver {
/// Enable calendar math tool
pub calendar_tool: bool,
/// Enable web search tool
pub web_search_tool: bool,
/// Maximum steps allowed
pub max_steps: usize,
/// Current step count
pub steps: usize,
/// Tool call count
pub tool_calls: usize,
/// Stop after finding the first valid solution (early termination)
pub stop_after_first: bool,
/// Skip to matching weekday (advance by 7 days instead of 1)
pub skip_weekday: Option<Weekday>,
/// Constraint propagation pre-pass mode (controlled by PolicyKernel)
pub prepass_mode: PrepassMode,
}
impl Default for TemporalSolver {
fn default() -> Self {
Self {
calendar_tool: true,
web_search_tool: false,
max_steps: 100,
steps: 0,
tool_calls: 0,
stop_after_first: false,
skip_weekday: None,
prepass_mode: PrepassMode::Off,
}
}
}
impl TemporalSolver {
/// Create solver with tools
pub fn with_tools(calendar: bool, web_search: bool) -> Self {
Self {
calendar_tool: calendar,
web_search_tool: web_search,
stop_after_first: false,
skip_weekday: None,
prepass_mode: PrepassMode::Off,
..Default::default()
}
}
/// Constraint propagation pre-pass: tighten the search range
/// using InMonth, DayOfMonth, and DayOfWeek constraints.
///
/// This is the key sublinear optimization. Instead of scanning
/// every day in the range, we compute valid date sets directly:
///
/// 1. InMonth(m) + InYear(y) → range shrinks to that month (≤31 days)
/// 2. DayOfMonth(d) + bounded range → jump directly to matching days
/// 3. DayOfWeek(w) already handled by skip_weekday, but propagation
/// can further restrict: e.g., Month(2) + DayOfWeek(Mon) in a year
/// has only 4-5 candidates.
///
/// Returns (tightened_start, tightened_end, direct_candidates).
/// If direct_candidates is non-empty, skip the scan entirely.
fn propagate_constraints(
&self,
puzzle: &TemporalPuzzle,
range_start: NaiveDate,
range_end: NaiveDate,
) -> (NaiveDate, NaiveDate, Vec<NaiveDate>) {
let mut start = range_start;
let mut end = range_end;
// Extract constraint features
let mut target_month: Option<u32> = None;
let mut target_dom: Option<u32> = None;
let mut target_dow: Option<Weekday> = None;
let mut target_year: Option<i32> = None;
for c in &puzzle.constraints {
match c {
TemporalConstraint::InMonth(m) => {
target_month = Some(*m);
}
TemporalConstraint::DayOfMonth(d) => {
target_dom = Some(*d);
}
TemporalConstraint::DayOfWeek(w) => {
target_dow = Some(*w);
}
TemporalConstraint::InYear(y) => {
target_year = Some(*y);
}
_ => {}
}
}
// Tighten by month + year
if let (Some(m), Some(y)) = (target_month, target_year) {
let month_start = NaiveDate::from_ymd_opt(y, m, 1);
let month_end = if m == 12 {
NaiveDate::from_ymd_opt(y, 12, 31)
} else {
NaiveDate::from_ymd_opt(y, m + 1, 1).and_then(|d| d.pred_opt())
};
if let (Some(ms), Some(me)) = (month_start, month_end) {
if ms > start {
start = ms;
}
if me < end {
end = me;
}
}
} else if let Some(m) = target_month {
// Month without year: tighten to first occurrence in range
let year = start.year();
if let Some(ms) = NaiveDate::from_ymd_opt(year, m, 1) {
if ms >= start && ms <= end {
start = ms;
// End of that month
let me = if m == 12 {
NaiveDate::from_ymd_opt(year, 12, 31)
} else {
NaiveDate::from_ymd_opt(year, m + 1, 1).and_then(|d| d.pred_opt())
};
if let Some(me) = me {
if me < end {
end = me;
}
}
}
}
}
// Direct solve: DayOfMonth within a tight range
if let Some(dom) = target_dom {
if (end - start).num_days() <= 366 {
let mut candidates = Vec::new();
let mut y = start.year();
let mut m = start.month();
loop {
if let Some(d) = NaiveDate::from_ymd_opt(y, m, dom) {
if d >= start && d <= end {
// Verify against ALL constraints before adding
if puzzle.check_date(d).unwrap_or(false) {
candidates.push(d);
}
}
if d > end {
break;
}
}
// Next month
m += 1;
if m > 12 {
m = 1;
y += 1;
}
if NaiveDate::from_ymd_opt(y, m, 1)
.map(|d| d > end)
.unwrap_or(true)
{
break;
}
}
if !candidates.is_empty() {
return (start, end, candidates);
}
}
}
// Direct solve: DayOfWeek within a tight range (≤60 days → ≤9 candidates)
// Only in Full mode — Light mode does InMonth/DayOfMonth only
if self.prepass_mode == PrepassMode::Full {
if let Some(dow) = target_dow {
let range_days = (end - start).num_days();
if range_days <= 60 && range_days >= 0 {
let mut candidates = Vec::new();
let mut d = start;
while d.weekday() != dow && d <= end {
d = d.succ_opt().unwrap_or(d);
}
while d <= end {
if puzzle.check_date(d).unwrap_or(false) {
candidates.push(d);
}
d = d + chrono::Duration::days(7);
}
if !candidates.is_empty() {
return (start, end, candidates);
}
}
}
}
(start, end, Vec::new())
}
/// Solve a puzzle with step tracking.
///
/// Three-phase solve:
/// 1. Constraint propagation: tighten range, attempt direct solve
/// 2. If direct candidates found: verify and return (sublinear)
/// 3. Otherwise: scan with optional weekday skip (linear/7x)
pub fn solve(&mut self, puzzle: &TemporalPuzzle) -> Result<SolverResult> {
self.steps = 0;
self.tool_calls = 0;
let start_time = std::time::Instant::now();
// Rewrite constraints to explicit dates if calendar tool enabled
let effective_puzzle = if self.calendar_tool {
self.tool_calls += 1;
self.rewrite_constraints(puzzle)?
} else {
puzzle.clone()
};
// Determine search range from effective (rewritten) constraints
let range = self.determine_search_range(&effective_puzzle)?;
// ─── Phase 1: Constraint propagation (if enabled) ────────────────
let (prop_start, prop_end, direct_candidates) = match self.prepass_mode {
PrepassMode::Off => (range.0, range.1, Vec::new()),
PrepassMode::Light | PrepassMode::Full => {
self.propagate_constraints(&effective_puzzle, range.0, range.1)
}
};
// ─── Phase 2: Direct solve (sublinear) ──────────────────────────
if !direct_candidates.is_empty() {
self.steps = direct_candidates.len();
self.tool_calls += 1; // propagation counts as a tool call
let latency = start_time.elapsed();
let correct = if puzzle.solutions.is_empty() {
true
} else {
puzzle
.solutions
.iter()
.all(|s| direct_candidates.contains(s) || *s < prop_start || *s > prop_end)
};
return Ok(SolverResult {
puzzle_id: puzzle.id.clone(),
solved: !direct_candidates.is_empty(),
correct,
solutions: direct_candidates,
steps: self.steps,
tool_calls: self.tool_calls,
latency_ms: latency.as_millis() as u64,
});
}
// ─── Phase 3: Scan (linear or weekday-skip) ─────────────────────
let mut found_solutions = Vec::new();
let mut current = prop_start; // Use propagated (tighter) range
// Advance to first matching weekday if skipping enabled
if let Some(target_dow) = self.skip_weekday {
while current.weekday() != target_dow && current <= prop_end {
current = current.succ_opt().unwrap_or(current);
}
}
while current <= prop_end && self.steps < self.max_steps {
self.steps += 1;
if effective_puzzle.check_date(current)? {
found_solutions.push(current);
if self.stop_after_first {
break;
}
}
if self.skip_weekday.is_some() {
current = current + chrono::Duration::days(7);
} else {
current = match current.succ_opt() {
Some(d) => d,
None => break,
};
}
}
let latency = start_time.elapsed();
// Check correctness
let correct = if puzzle.solutions.is_empty() {
true
} else {
puzzle
.solutions
.iter()
.all(|s| found_solutions.contains(s) || *s < prop_start || *s > prop_end)
};
Ok(SolverResult {
puzzle_id: puzzle.id.clone(),
solved: !found_solutions.is_empty(),
correct,
solutions: found_solutions,
steps: self.steps,
tool_calls: self.tool_calls,
latency_ms: latency.as_millis() as u64,
})
}
/// Determine search range from constraints
fn determine_search_range(&self, puzzle: &TemporalPuzzle) -> Result<(NaiveDate, NaiveDate)> {
let mut min_date = NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
let mut max_date = NaiveDate::from_ymd_opt(2100, 12, 31).unwrap();
for constraint in &puzzle.constraints {
match constraint {
TemporalConstraint::Exact(d) => {
min_date = *d;
max_date = *d;
}
TemporalConstraint::After(d) => {
if *d >= min_date {
min_date = d.succ_opt().unwrap_or(*d);
}
}
TemporalConstraint::Before(d) => {
if *d <= max_date {
max_date = d.pred_opt().unwrap_or(*d);
}
}
TemporalConstraint::Between(start, end) => {
if *start > min_date {
min_date = *start;
}
if *end < max_date {
max_date = *end;
}
}
TemporalConstraint::InYear(year) => {
let year_start = NaiveDate::from_ymd_opt(*year, 1, 1).unwrap_or(min_date);
let year_end = NaiveDate::from_ymd_opt(*year, 12, 31).unwrap_or(max_date);
if year_start > min_date {
min_date = year_start;
}
if year_end < max_date {
max_date = year_end;
}
}
_ => {}
}
}
Ok((min_date, max_date))
}
/// Rewrite relative constraints to explicit dates
fn rewrite_constraints(&self, puzzle: &TemporalPuzzle) -> Result<TemporalPuzzle> {
let mut new_puzzle = puzzle.clone();
let mut new_constraints = Vec::new();
for constraint in &puzzle.constraints {
match constraint {
TemporalConstraint::DaysAfter(ref_name, days) => {
if let Some(ref_date) = puzzle.references.get(ref_name) {
let target = *ref_date + chrono::Duration::days(*days);
new_constraints.push(TemporalConstraint::Exact(target));
} else {
new_constraints.push(constraint.clone());
}
}
TemporalConstraint::DaysBefore(ref_name, days) => {
if let Some(ref_date) = puzzle.references.get(ref_name) {
let target = *ref_date - chrono::Duration::days(*days);
new_constraints.push(TemporalConstraint::Exact(target));
} else {
new_constraints.push(constraint.clone());
}
}
TemporalConstraint::RelativeToEvent(event_name, days) => {
if let Some(event_date) = puzzle.references.get(event_name) {
let target = *event_date + chrono::Duration::days(*days);
new_constraints.push(TemporalConstraint::Exact(target));
} else {
new_constraints.push(constraint.clone());
}
}
_ => new_constraints.push(constraint.clone()),
}
}
new_puzzle.constraints = new_constraints;
Ok(new_puzzle)
}
}
/// Result from solving a puzzle
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SolverResult {
pub puzzle_id: String,
pub solved: bool,
pub correct: bool,
pub solutions: Vec<NaiveDate>,
pub steps: usize,
pub tool_calls: usize,
pub latency_ms: u64,
}
/// Benchmark configuration
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BenchmarkConfig {
/// Number of puzzles to run
pub num_puzzles: usize,
/// Difficulty range
pub difficulty_range: (u8, u8),
/// Enable calendar tool
pub calendar_tool: bool,
/// Enable web search tool
pub web_search_tool: bool,
/// Maximum steps per puzzle
pub max_steps: usize,
/// Constraint density (1-5)
pub constraint_density: u8,
}
impl Default for BenchmarkConfig {
fn default() -> Self {
Self {
num_puzzles: 50,
difficulty_range: (1, 10),
calendar_tool: true,
web_search_tool: false,
max_steps: 100,
constraint_density: 3,
}
}
}
/// Benchmark results
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BenchmarkResults {
pub config: BenchmarkConfig,
pub total_puzzles: usize,
pub solved_count: usize,
pub correct_count: usize,
pub accuracy: f64,
pub avg_steps: f64,
pub avg_tool_calls: f64,
pub avg_latency_ms: f64,
pub results: Vec<SolverResult>,
}
impl BenchmarkResults {
/// Create from individual results
pub fn from_results(config: BenchmarkConfig, results: Vec<SolverResult>) -> Self {
let total = results.len();
let solved = results.iter().filter(|r| r.solved).count();
let correct = results.iter().filter(|r| r.correct).count();
let avg_steps = results.iter().map(|r| r.steps as f64).sum::<f64>() / total as f64;
let avg_tools = results.iter().map(|r| r.tool_calls as f64).sum::<f64>() / total as f64;
let avg_latency = results.iter().map(|r| r.latency_ms as f64).sum::<f64>() / total as f64;
Self {
config,
total_puzzles: total,
solved_count: solved,
correct_count: correct,
accuracy: correct as f64 / total as f64,
avg_steps,
avg_tool_calls: avg_tools,
avg_latency_ms: avg_latency,
results,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_puzzle() {
let puzzle = TemporalPuzzle::new("test-1", "Find a date in January 2024")
.with_constraint(TemporalConstraint::InYear(2024))
.with_constraint(TemporalConstraint::InMonth(1))
.with_constraint(TemporalConstraint::DayOfMonth(15));
let expected = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
assert!(puzzle.check_date(expected).unwrap());
assert!(!puzzle
.check_date(NaiveDate::from_ymd_opt(2024, 2, 15).unwrap())
.unwrap());
}
#[test]
fn test_relative_constraint() {
let base = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
let puzzle = TemporalPuzzle::new("test-2", "Find a date 10 days after New Year")
.with_reference("new_year", base)
.with_constraint(TemporalConstraint::DaysAfter("new_year".to_string(), 10));
let expected = NaiveDate::from_ymd_opt(2024, 1, 11).unwrap();
assert!(puzzle.check_date(expected).unwrap());
}
#[test]
fn test_solver_with_rewriting() {
let base = NaiveDate::from_ymd_opt(2024, 6, 15).unwrap();
let puzzle = TemporalPuzzle::new("test-3", "Find date relative to event")
.with_reference("event", base)
.with_constraint(TemporalConstraint::DaysAfter("event".to_string(), 5))
.with_solutions(vec![NaiveDate::from_ymd_opt(2024, 6, 20).unwrap()]);
let mut solver = TemporalSolver::with_tools(true, false);
let result = solver.solve(&puzzle).unwrap();
assert!(result.solved);
assert!(result.correct);
assert_eq!(result.solutions.len(), 1);
}
}
// ============================================================================
// Adaptive Solver with ReasoningBank Learning
// ============================================================================
use crate::reasoning_bank::{ReasoningBank, Strategy, Trajectory, Verdict};
use crate::timepuzzles::DifficultyVector;
// ═══════════════════════════════════════════════════════════════════════════
// PolicyKernel — learned skip-mode selection
// ═══════════════════════════════════════════════════════════════════════════
/// Skip mode for the temporal solver scan loop.
/// All modes have access to all skip modes.
/// What differs is the *policy* that selects the mode.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum SkipMode {
/// Linear scan: check every date in range (1-day increments)
None,
/// Weekday skip: advance by 7 days when DayOfWeek constraint is present
Weekday,
/// Hybrid: weekday skip for initial scan, then full refinement pass
/// around candidates to catch near-misses under noise
Hybrid,
}
impl Default for SkipMode {
fn default() -> Self {
SkipMode::None
}
}
impl std::fmt::Display for SkipMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SkipMode::None => write!(f, "none"),
SkipMode::Weekday => write!(f, "weekday"),
SkipMode::Hybrid => write!(f, "hybrid"),
}
}
}
/// Context features for PolicyKernel decisions.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PolicyContext {
/// Number of dates in the posterior (search range)
pub posterior_range: usize,
/// Number of distractor constraints in the puzzle
pub distractor_count: usize,
/// Whether a DayOfWeek constraint is present
pub has_day_of_week: bool,
/// Whether noise was injected
pub noisy: bool,
/// Difficulty vector components
pub difficulty: DifficultyVector,
/// Recent false-hit density (rolling window)
pub recent_false_hit_rate: f64,
}
/// Outcome of a skip-mode decision for learning.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SkipOutcome {
/// The skip mode that was used
pub mode: SkipMode,
/// Whether the solve was correct
pub correct: bool,
/// Steps taken
pub steps: usize,
/// Whether this was an early commit that turned out wrong
pub early_commit_wrong: bool,
/// Initial candidate count (for normalized penalty)
pub initial_candidates: usize,
/// Remaining candidates at commit time (for normalized penalty)
pub remaining_at_commit: usize,
}
/// Per-context skip-mode statistics for learned policy.
///
/// Two-signal model for Thompson Sampling:
/// 1. **Safety posterior**: Beta(alpha_safety, beta_safety)
/// Updated by whether the commit was correct (not just solved).
/// Drives exploration toward safe arms.
/// 2. **Cost signal**: EMA of normalized step cost.
/// Captures efficiency without contaminating the safety posterior.
///
/// Final score = sample_safety - lambda * cost_ema
/// This separates "is it safe?" (explored by Thompson Sampling)
/// from "is it cheap?" (deterministic penalty).
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct SkipModeStats {
pub attempts: usize,
pub successes: usize,
pub total_steps: usize,
pub early_commit_wrongs: usize,
/// Accumulated normalized early-commit penalty (remaining/initial fractions)
pub early_commit_penalty_sum: f64,
/// Safety posterior alpha: correct commits
pub alpha_safety: f64,
/// Safety posterior beta: incorrect commits + early wrongs
pub beta_safety: f64,
/// Cost EMA: exponential moving average of normalized step cost
pub cost_ema: f64,
}
/// Lambda: weight of cost penalty in Thompson score.
/// Higher = more cost-sensitive, lower = more safety-focused.
const THOMPSON_LAMBDA: f64 = 0.3;
/// EMA decay factor for cost signal. 0.9 = slow decay, recent history matters.
const COST_EMA_ALPHA: f64 = 0.1;
impl SkipModeStats {
/// Composite reward for backward compatibility and diagnostics.
pub fn reward(&self) -> f64 {
if self.attempts == 0 {
return 0.5;
}
let accuracy = self.successes as f64 / self.attempts as f64;
let cost_bonus =
0.3 * (1.0 - (self.total_steps as f64 / self.attempts as f64) / 200.0).max(0.0);
let avg_penalty = self.early_commit_penalty_sum / self.attempts as f64;
let robustness_penalty = 0.2 * avg_penalty.min(1.0);
(accuracy * 0.5 + cost_bonus - robustness_penalty).max(0.0)
}
/// Safety Beta posterior parameters.
///
/// Prior: Beta(1, 1) = uniform.
/// alpha = safe commits (correct, no early-commit penalty)
/// beta = unsafe commits (wrong, or early-commit-wrong)
pub fn safety_beta(&self) -> (f64, f64) {
(self.alpha_safety + 1.0, self.beta_safety + 1.0)
}
/// Posterior variance of the safety Beta distribution.
/// High variance = high uncertainty = speculative dual-path trigger.
pub fn safety_variance(&self) -> f64 {
let (a, b) = self.safety_beta();
(a * b) / ((a + b).powi(2) * (a + b + 1.0))
}
/// Update safety posterior from an outcome.
pub fn update_safety(&mut self, correct: bool, early_commit_wrong: bool) {
if correct && !early_commit_wrong {
self.alpha_safety += 1.0;
} else {
self.beta_safety += 1.0;
if early_commit_wrong {
// Double penalty for early wrong commits: these are the dangerous ones
self.beta_safety += 0.5;
}
}
}
/// Update cost EMA from an outcome.
pub fn update_cost(&mut self, normalized_steps: f64) {
if self.attempts <= 1 {
self.cost_ema = normalized_steps;
} else {
self.cost_ema =
COST_EMA_ALPHA * normalized_steps + (1.0 - COST_EMA_ALPHA) * self.cost_ema;
}
}
}
/// Constraint propagation pre-pass mode.
///
/// Controls whether the solver runs arc-consistency before scanning.
/// Selectable by PolicyKernel — kept off by default to preserve
/// learning gradient. If prepass always wins, increase generator
/// ambiguity to restore gradient.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum PrepassMode {
/// No constraint propagation (default)
Off,
/// Cheap local pruning: InMonth+DayOfMonth only
Light,
/// Full arc consistency: InMonth+DayOfMonth+DayOfWeek
Full,
}
impl Default for PrepassMode {
fn default() -> Self {
PrepassMode::Off
}
}
impl std::fmt::Display for PrepassMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrepassMode::Off => write!(f, "off"),
PrepassMode::Light => write!(f, "light"),
PrepassMode::Full => write!(f, "full"),
}
}
}
/// Metrics from constraint propagation pre-pass.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PrepassMetrics {
/// Total pre-pass invocations
pub invocations: usize,
/// Total candidates pruned by pre-pass
pub pruned_candidates: usize,
/// Total steps the pre-pass itself took
pub prepass_steps: usize,
/// Estimated scan steps saved by pre-pass
pub scan_steps_saved: usize,
/// Number of direct solves (scan skipped entirely)
pub direct_solves: usize,
}
/// PolicyKernel: decides skip_mode based on context.
///
/// Three policy levels:
/// - **Fixed** (Mode A): deterministic heuristic based on posterior_range + distractor_count
/// - **Compiled** (Mode B): compiler-suggested skip_mode from CompiledSolveConfig
/// - **Learned** (Mode C): contextual stats drive selection, adapts from outcomes
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PolicyKernel {
/// Per-context bucket → per-skip-mode stats (for learned policy)
pub context_stats: HashMap<String, HashMap<String, SkipModeStats>>,
/// Early commit penalty accumulator
pub early_commit_penalties: f64,
/// Total early commits tracked
pub early_commits_total: usize,
/// Total early commits that were wrong
pub early_commits_wrong: usize,
/// Exploration rate (legacy, not used by Thompson Sampling)
pub epsilon: f64,
/// RNG state (seeded for deterministic Thompson Sampling)
rng_state: u64,
/// Constraint propagation pre-pass mode
pub prepass: PrepassMode,
/// Pre-pass metrics
pub prepass_metrics: PrepassMetrics,
/// Speculative dual-path attempts
pub speculative_attempts: usize,
/// Speculative dual-path wins (second arm was better)
pub speculative_arm2_wins: usize,
}
impl PolicyKernel {
pub fn new() -> Self {
Self {
epsilon: 0.15,
rng_state: 42,
..Default::default()
}
}
/// Fixed baseline policy (Mode A):
/// Uses risk_score = R - k*D where R=posterior_range, D=distractor_count.
///
/// Constants (fixed, not learned — Mode A is the control arm):
/// k = 30 (one distractor reduces effective range by ~30 days)
/// T = 140 (threshold: skip only when effective range justifies it)
///
/// Rationale: distractors make the search space noisier, so a rational
/// fixed agent should be *more cautious* (less likely to skip) when
/// distractors are present. This is the conservative-under-distractors
/// baseline that Mode C must learn to outperform.
///
/// Decision:
/// If no DayOfWeek: None (nothing to skip to)
/// Else risk_score = R - 30*D
/// risk_score >= 140 → Weekday (large range, few distractors)
/// risk_score < 140 → None (small range or distractor-heavy)
const BASELINE_K: usize = 30;
const BASELINE_T: usize = 140;
pub fn fixed_policy(ctx: &PolicyContext) -> SkipMode {
if !ctx.has_day_of_week {
return SkipMode::None;
}
let effective_range = ctx
.posterior_range
.saturating_sub(Self::BASELINE_K * ctx.distractor_count);
if effective_range >= Self::BASELINE_T {
SkipMode::Weekday
} else {
SkipMode::None
}
}
/// Compiled policy (Mode B):
/// Uses compiler-suggested skip_mode from CompiledSolveConfig.
/// Falls back to fixed policy if compiler has no suggestion.
pub fn compiled_policy(ctx: &PolicyContext, compiled_skip: Option<SkipMode>) -> SkipMode {
compiled_skip.unwrap_or_else(|| Self::fixed_policy(ctx))
}
/// Learned policy (Mode C):
/// Two-signal Thompson Sampling.
///
/// Signal 1 (safety): sample from Beta(alpha_safety, beta_safety)
/// - Naturally explores uncertain arms
/// - Converges as evidence accumulates
/// - O(√T) regret bound
///
/// Signal 2 (cost): deterministic EMA penalty
/// - No exploration needed (fully observed)
/// - Penalizes expensive arms
///
/// Score = safety_sample - lambda * cost_ema
///
/// When the top two arms are within delta AND uncertainty is high,
/// returns both arms for speculative dual-path execution.
pub fn learned_policy(&mut self, ctx: &PolicyContext) -> SkipMode {
if !ctx.has_day_of_week {
return SkipMode::None;
}
let bucket = Self::context_bucket(ctx);
let modes = ["none", "weekday", "hybrid"];
// Collect sampling params before borrowing self for sampling
let params: Vec<(SkipMode, f64, f64, f64)> = {
let stats_map = self.context_stats.entry(bucket).or_default();
modes
.iter()
.map(|mode_name| {
let stats = stats_map.get(*mode_name).cloned().unwrap_or_default();
let (alpha, beta) = stats.safety_beta();
let mode = match *mode_name {
"weekday" => SkipMode::Weekday,
"hybrid" => SkipMode::Hybrid,
_ => SkipMode::None,
};
(mode, alpha, beta, stats.cost_ema)
})
.collect()
};
// Sample and score (now safe to borrow self mutably for RNG)
let mut scored: Vec<(SkipMode, f64)> = params
.into_iter()
.map(|(mode, alpha, beta, cost_ema)| {
let safety_sample = self.sample_beta(alpha, beta);
let score = safety_sample - THOMPSON_LAMBDA * cost_ema;
(mode, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.first()
.map(|(m, _)| m.clone())
.unwrap_or(SkipMode::None)
}
/// Check if speculation is warranted for Mode C.
///
/// Returns Some((arm1, arm2)) if:
/// 1. Top two arms are within `delta` of each other, AND
/// 2. Safety variance of the top arm is above threshold
///
/// Otherwise returns None (single-path is sufficient).
pub fn should_speculate(&mut self, ctx: &PolicyContext) -> Option<(SkipMode, SkipMode)> {
if !ctx.has_day_of_week {
return None;
}
// Only speculate in medium/large range with distractors or noise
if ctx.posterior_range < 61 || (ctx.distractor_count == 0 && !ctx.noisy) {
return None;
}
let bucket = Self::context_bucket(ctx);
let modes = ["none", "weekday", "hybrid"];
// Collect params first to avoid double mutable borrow
let params: Vec<(SkipMode, f64, f64, f64, f64)> = {
let stats_map = self.context_stats.entry(bucket).or_default();
modes
.iter()
.map(|mode_name| {
let stats = stats_map.get(*mode_name).cloned().unwrap_or_default();
let (alpha, beta) = stats.safety_beta();
let variance = stats.safety_variance();
let mode = match *mode_name {
"weekday" => SkipMode::Weekday,
"hybrid" => SkipMode::Hybrid,
_ => SkipMode::None,
};
(mode, alpha, beta, stats.cost_ema, variance)
})
.collect()
};
// Now sample with self.sample_beta() — no conflicting borrow
let mut scored: Vec<(SkipMode, f64, f64)> = params
.into_iter()
.map(|(mode, alpha, beta, cost_ema, variance)| {
let safety_sample = self.sample_beta(alpha, beta);
let score = safety_sample - THOMPSON_LAMBDA * cost_ema;
(mode, score, variance)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if scored.len() >= 2 {
let (ref arm1, score1, var1) = scored[0];
let (ref arm2, score2, _) = scored[1];
let delta = 0.15;
let var_threshold = 0.02; // Beta(1,1) has var≈0.083, so 0.02 = moderate certainty
if (score1 - score2).abs() < delta && var1 > var_threshold {
return Some((arm1.clone(), arm2.clone()));
}
}
None
}
/// Sample from Beta(alpha, beta) using rejection sampling.
///
/// Uses Joehnk's algorithm for alpha,beta < 1 and
/// Cheng's BA algorithm for larger params.
/// Deterministic given internal rng_state.
fn sample_beta(&mut self, alpha: f64, beta: f64) -> f64 {
// For our use case, alpha and beta are typically 1..50
// Use the gamma ratio method: Beta(a,b) = X/(X+Y) where X~Gamma(a), Y~Gamma(b)
let x = self.sample_gamma(alpha);
let y = self.sample_gamma(beta);
if x + y == 0.0 {
return 0.5;
}
x / (x + y)
}
/// Sample from Gamma(shape, 1) using Marsaglia & Tsang's method.
fn sample_gamma(&mut self, shape: f64) -> f64 {
if shape < 1.0 {
// Boost: Gamma(shape) = Gamma(shape+1) * U^(1/shape)
let u = self.next_f64().max(1e-10);
return self.sample_gamma(shape + 1.0) * u.powf(1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let x = self.next_standard_normal();
let v = (1.0 + c * x).powi(3);
if v <= 0.0 {
continue;
}
let u = self.next_f64().max(1e-10);
// Squeeze test
if u < 1.0 - 0.0331 * x * x * x * x {
return d * v;
}
if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
return d * v;
}
}
}
/// Box-Muller standard normal sample.
fn next_standard_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-10);
let u2 = self.next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
/// Record the outcome of a skip-mode decision.
///
/// EarlyCommitPenalty is normalized:
/// penalty = (remaining_at_commit / initial_candidates) * PENALTY_SCALE
///
/// Committing at 5% of scan = cheap (penalty ≈ 0.05).
/// Committing at 90% of scan = expensive (penalty ≈ 0.90).
/// Only charged when the commit is *wrong*.
const PENALTY_SCALE: f64 = 1.0;
pub fn record_outcome(&mut self, ctx: &PolicyContext, outcome: &SkipOutcome) {
let bucket = Self::context_bucket(ctx);
let mode_name = outcome.mode.to_string();
let stats_map = self.context_stats.entry(bucket).or_default();
let stats = stats_map.entry(mode_name).or_default();
stats.attempts += 1;
stats.total_steps += outcome.steps;
if outcome.correct {
stats.successes += 1;
}
// Update two-signal model
// Signal 1: safety posterior
stats.update_safety(outcome.correct, outcome.early_commit_wrong);
// Signal 2: cost EMA (normalize steps to 0..1 range)
let normalized_cost = (outcome.steps as f64 / 200.0).min(1.0);
stats.update_cost(normalized_cost);
if outcome.early_commit_wrong {
stats.early_commit_wrongs += 1;
self.early_commits_wrong += 1;
// Normalized penalty: remaining/initial fraction
let penalty = if outcome.initial_candidates > 0 {
(outcome.remaining_at_commit as f64 / outcome.initial_candidates as f64)
* Self::PENALTY_SCALE
} else {
1.0 - (outcome.steps as f64 / 200.0).min(1.0)
};
self.early_commit_penalties += penalty;
stats.early_commit_penalty_sum += penalty;
}
self.early_commits_total += 1;
}
/// Early commit penalty rate.
pub fn early_commit_rate(&self) -> f64 {
if self.early_commits_total == 0 {
return 0.0;
}
self.early_commits_wrong as f64 / self.early_commits_total as f64
}
/// Build a context bucket key for stats grouping (public for witnesses).
pub fn context_bucket_static(ctx: &PolicyContext) -> String {
Self::context_bucket(ctx)
}
/// Build a context bucket key for stats grouping.
///
/// 3 range × 3 distractor × 2 noise = 18 buckets.
/// Fine enough for the bandit to learn per-context preferences,
/// coarse enough to accumulate enough samples per bucket.
fn context_bucket(ctx: &PolicyContext) -> String {
let range_bucket = match ctx.posterior_range {
0..=60 => "small",
61..=180 => "medium",
_ => "large",
};
let distractor_bucket = match ctx.distractor_count {
0 => "clean",
1 => "some",
_ => "heavy",
};
let noise_bucket = if ctx.noisy { "noisy" } else { "clean" };
format!("{}:{}:{}", range_bucket, distractor_bucket, noise_bucket)
}
fn next_f64(&mut self) -> f64 {
let mut x = self.rng_state.max(1);
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng_state = x;
(x as f64) / (u64::MAX as f64)
}
/// Print diagnostic summary.
pub fn print_diagnostics(&self) {
println!();
println!(" PolicyKernel Diagnostics (Thompson Sampling, two-signal)");
println!(
" Early commits: {}/{} wrong ({:.1}%)",
self.early_commits_wrong,
self.early_commits_total,
self.early_commit_rate() * 100.0
);
println!(" Accumulated penalty: {:.2}", self.early_commit_penalties);
println!(" Prepass mode: {}", self.prepass);
if self.prepass_metrics.invocations > 0 {
println!(" Prepass: {} invocations, {} direct solves, {} candidates pruned, {} scan steps saved",
self.prepass_metrics.invocations, self.prepass_metrics.direct_solves,
self.prepass_metrics.pruned_candidates, self.prepass_metrics.scan_steps_saved);
}
if self.speculative_attempts > 0 {
println!(
" Speculation: {} attempts, {} arm2 wins ({:.0}%)",
self.speculative_attempts,
self.speculative_arm2_wins,
self.speculative_arm2_wins as f64 / self.speculative_attempts as f64 * 100.0
);
}
println!(" Context buckets: {}", self.context_stats.len());
for (bucket, modes) in &self.context_stats {
println!(" {}", bucket);
for (mode, stats) in modes {
let (a, b) = stats.safety_beta();
println!(
" {:<8} n={:<4} safe=Beta({:.1},{:.1}) cost_ema={:.3} reward={:.3}",
mode,
stats.attempts,
a,
b,
stats.cost_ema,
stats.reward()
);
}
}
}
}
/// Adaptive temporal solver with learning capabilities
///
/// Uses ReasoningBank to:
/// - Track solution trajectories
/// - Learn from successes and failures
/// - Adapt strategy based on puzzle characteristics
/// - Achieve sublinear regret through experience
// ═══════════════════════════════════════════════════════════════════════════
// KnowledgeCompiler — constraint signature → compiled solve config
// ═══════════════════════════════════════════════════════════════════════════
/// Compiled solver configuration for a known constraint signature.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompiledSolveConfig {
/// Whether to use calendar rewriting
pub use_rewriting: bool,
/// Minimum steps that succeeded for this signature
pub max_steps: usize,
/// Average steps across all successes (for bounded trial budget)
pub avg_steps: f64,
/// Number of successful observations compiled
pub observations: usize,
/// Expected correctness
pub expected_correct: bool,
/// Stop after first solution (early termination for known single-solution puzzles)
pub stop_after_first: bool,
/// Hit count (how often this config was used and succeeded)
pub hit_count: usize,
/// Counterexample count (failures on this signature)
pub counterexample_count: usize,
/// Compiled skip mode suggestion (for Mode B policy)
pub compiled_skip_mode: SkipMode,
}
impl CompiledSolveConfig {
/// Confidence: Laplace-smoothed success rate.
pub fn confidence(&self) -> f64 {
let total = self.hit_count + self.counterexample_count;
if total == 0 {
return 0.5;
}
(self.hit_count as f64 + 1.0) / (total as f64 + 2.0)
}
/// Trial budget: bounded step limit for Strategy Zero.
/// Uses avg_steps * 2.0 as budget (enough headroom for variance),
/// with a floor of max_steps and a ceiling of 25% of external limit.
pub fn trial_budget(&self, external_limit: usize) -> usize {
let budget = if self.observations > 2 && self.avg_steps > 1.0 {
// Enough data: use 2x average steps for headroom
(self.avg_steps * 2.0) as usize
} else {
// Not enough data or trivially small: use max observed steps
self.max_steps.max(10)
};
budget.max(10).min(external_limit / 4)
}
}
/// KnowledgeCompiler: learns constraint-signature → optimal solve config.
/// Consulted as "Strategy Zero" before any other strategy runs.
///
/// Signature version: v1 (difficulty:sorted_constraints)
/// Change this when canonicalization rules change.
const COMPILER_SIG_VERSION: &str = "v1";
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct KnowledgeCompiler {
/// Compiled constraint signature → config
pub signature_cache: HashMap<String, CompiledSolveConfig>,
/// Cache hits
pub hits: usize,
/// Cache misses
pub misses: usize,
/// False hits (compiled config tried but solve was wrong)
pub false_hits: usize,
/// Steps saved by successful Strategy Zero (vs estimated fallback cost)
pub steps_saved: i64,
/// Confidence threshold for attempting Strategy Zero
pub confidence_threshold: f64,
}
impl KnowledgeCompiler {
pub fn new() -> Self {
Self {
confidence_threshold: 0.7,
..Default::default()
}
}
/// Build constraint signature from puzzle features.
/// Includes version prefix for cache safety across refactors.
pub fn signature(puzzle: &TemporalPuzzle) -> String {
let mut sig_parts: Vec<String> = puzzle
.constraints
.iter()
.map(|c| constraint_type_name(c))
.collect();
sig_parts.sort();
format!(
"{}:{}:{}",
COMPILER_SIG_VERSION,
puzzle.difficulty,
sig_parts.join(",")
)
}
/// Compile knowledge from a ReasoningBank's trajectories.
pub fn compile_from_bank(&mut self, bank: &ReasoningBank) {
for traj in &bank.trajectories {
let correct = traj
.verdict
.as_ref()
.map(|v| v.is_success())
.unwrap_or(false);
if !correct {
continue;
}
// Build signature from constraint types (versioned)
let mut sig_parts = traj.constraint_types.clone();
sig_parts.sort();
let sig = format!(
"{}:{}:{}",
COMPILER_SIG_VERSION,
traj.difficulty,
sig_parts.join(",")
);
if let Some(attempt) = traj.attempts.first() {
// Determine compiled skip mode from constraint types
let has_dow = traj.constraint_types.iter().any(|c| c == "DayOfWeek");
let compiled_skip = if has_dow {
SkipMode::Weekday
} else {
SkipMode::None
};
let entry = self
.signature_cache
.entry(sig)
.or_insert(CompiledSolveConfig {
use_rewriting: true,
max_steps: attempt.steps,
avg_steps: 0.0,
observations: 0,
expected_correct: true,
stop_after_first: true,
hit_count: 0,
counterexample_count: 0,
compiled_skip_mode: compiled_skip,
});
// Keep minimum steps that succeeded
entry.max_steps = entry.max_steps.min(attempt.steps);
// Running average of steps
let n = entry.observations as f64;
entry.avg_steps = (entry.avg_steps * n + attempt.steps as f64) / (n + 1.0);
entry.observations += 1;
// Compiled from successful trajectories → seed confidence
entry.hit_count = entry.observations;
}
}
}
/// Look up a compiled config for a puzzle. Returns None on cache miss.
pub fn lookup(&mut self, puzzle: &TemporalPuzzle) -> Option<&CompiledSolveConfig> {
let sig = Self::signature(puzzle);
if self.signature_cache.contains_key(&sig) {
self.hits += 1;
// Safe: we just checked containment
self.signature_cache.get(&sig)
} else {
self.misses += 1;
None
}
}
/// Record a counterexample: Strategy Zero failed on this signature.
/// Quarantine escalation: 2 false hits → disable the entry.
pub fn record_failure(&mut self, puzzle: &TemporalPuzzle) {
self.false_hits += 1;
let sig = Self::signature(puzzle);
if let Some(config) = self.signature_cache.get_mut(&sig) {
config.counterexample_count += 1;
// 2-failure quarantine: disable after 2 false hits
if config.counterexample_count >= 2 {
config.expected_correct = false;
}
}
}
/// Record a successful Strategy Zero hit.
/// Tracks steps saved vs estimated fallback cost.
pub fn record_success(&mut self, puzzle: &TemporalPuzzle, actual_steps: usize) {
let sig = Self::signature(puzzle);
if let Some(config) = self.signature_cache.get_mut(&sig) {
config.hit_count += 1;
// Estimate fallback cost as avg_steps * 2 (full scan is typically ~2x early-term)
let estimated_fallback = if config.avg_steps > 0.0 {
(config.avg_steps * 2.0) as i64
} else {
config.max_steps as i64
};
self.steps_saved += estimated_fallback - actual_steps as i64;
}
}
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn cache_size(&self) -> usize {
self.signature_cache.len()
}
/// Print diagnostic summary: per-signature stats, false hit distribution.
pub fn print_diagnostics(&self) {
println!();
println!(" Compiler Diagnostics (cache_size={})", self.cache_size());
println!(
" {:<40} {:>5} {:>5} {:>6} {:>8} {:>6}",
"Signature", "Obs", "Hits", "Fails", "AvgStep", "Conf"
);
println!(" {}", "-".repeat(72));
let mut entries: Vec<_> = self.signature_cache.iter().collect();
entries.sort_by(|a, b| b.1.counterexample_count.cmp(&a.1.counterexample_count));
for (sig, config) in entries.iter().take(15) {
let short_sig = if sig.len() > 38 { &sig[..38] } else { sig };
println!(
" {:<40} {:>5} {:>5} {:>6} {:>7.1} {:>.3}",
short_sig,
config.observations,
config.hit_count,
config.counterexample_count,
config.avg_steps,
config.confidence()
);
}
// Summary
let total_configs = self.signature_cache.len();
let disabled = self
.signature_cache
.values()
.filter(|c| !c.expected_correct)
.count();
let total_false_hits: usize = self
.signature_cache
.values()
.map(|c| c.counterexample_count)
.sum();
let false_hit_sigs = self
.signature_cache
.values()
.filter(|c| c.counterexample_count > 0)
.count();
println!();
println!(
" Total signatures: {}, disabled: {}",
total_configs, disabled
);
println!(
" False hits: {} across {} signatures ({:.1}% of sigs)",
total_false_hits,
false_hit_sigs,
if total_configs > 0 {
false_hit_sigs as f64 / total_configs as f64 * 100.0
} else {
0.0
}
);
println!(" Steps saved by compiler: {}", self.steps_saved);
}
}
// ═══════════════════════════════════════════════════════════════════════════
// StrategyRouter — contextual bandit for strategy selection
// ═══════════════════════════════════════════════════════════════════════════
/// Context bucket key for the bandit.
#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub struct RoutingContext {
/// Constraint family (sorted constraint types)
pub constraint_family: String,
/// Difficulty bucket (1-3=easy, 4-7=mid, 8-10=hard)
pub difficulty_bucket: u8,
/// Whether input is noisy
pub noisy: bool,
}
/// Per-arm stats in the contextual bandit.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ArmStats {
pub pulls: usize,
pub successes: usize,
pub total_steps: usize,
pub noise_successes: usize,
pub noise_pulls: usize,
}
impl ArmStats {
pub fn reward(&self) -> f64 {
if self.pulls == 0 {
return 0.5;
} // Optimistic prior
let success_rate = self.successes as f64 / self.pulls as f64;
let cost_bonus = if self.total_steps > 0 {
// Lower steps = higher reward. Normalize to ~0..0.3
0.3 * (1.0 - (self.total_steps as f64 / self.pulls as f64) / 100.0).max(0.0)
} else {
0.0
};
let robustness_bonus = if self.noise_pulls > 0 {
0.2 * (self.noise_successes as f64 / self.noise_pulls as f64)
} else {
0.0
};
success_rate * 0.5 + cost_bonus + robustness_bonus
}
}
/// Adaptive strategy router using contextual bandit.
/// Learns per-context ordering and budget allocation for strategies.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct StrategyRouter {
/// Per-context, per-strategy arm stats
pub arms: HashMap<RoutingContext, HashMap<String, ArmStats>>,
/// Exploration rate (epsilon-greedy)
pub epsilon: f64,
/// Minimum exploration observations before dropping a strategy
pub min_observations: usize,
/// RNG state for exploration
rng_state: u64,
}
impl StrategyRouter {
pub fn new() -> Self {
Self {
arms: HashMap::new(),
epsilon: 0.15,
min_observations: 10,
rng_state: 42,
}
}
/// Build routing context from puzzle features.
pub fn context(puzzle: &TemporalPuzzle, noisy: bool) -> RoutingContext {
let mut families: Vec<String> = puzzle
.constraints
.iter()
.map(|c| constraint_type_name(c))
.collect();
families.sort();
families.dedup();
let difficulty_bucket = match puzzle.difficulty {
1..=3 => 1,
4..=7 => 2,
_ => 3,
};
RoutingContext {
constraint_family: families.join(","),
difficulty_bucket,
noisy,
}
}
/// Select the best strategy for a context.
/// Returns ordered list of (strategy_name, budget_fraction).
pub fn select(&mut self, ctx: &RoutingContext, available: &[String]) -> Vec<(String, f64)> {
// Epsilon-greedy: explore with probability epsilon
let r = self.next_f64();
if r < self.epsilon {
// Explore: random permutation
let mut shuffled = available.to_vec();
for i in (1..shuffled.len()).rev() {
let j = (self.next_f64() * (i + 1) as f64) as usize;
shuffled.swap(i, j.min(i));
}
return shuffled
.into_iter()
.map(|s| (s, 1.0 / available.len() as f64))
.collect();
}
// Exploit: rank by reward, filter out strategies with zero success after min_observations
let arm_map = self.arms.entry(ctx.clone()).or_default();
let mut ranked: Vec<(String, f64)> = available
.iter()
.map(|s| {
let stats = arm_map.get(s).cloned().unwrap_or_default();
let should_drop = stats.pulls >= self.min_observations && stats.successes == 0;
let reward = if should_drop { -1.0 } else { stats.reward() };
(s.clone(), reward)
})
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// Filter out dropped strategies (reward < 0), keep at least one
let mut result: Vec<(String, f64)> =
ranked.into_iter().filter(|(_, r)| *r >= 0.0).collect();
if result.is_empty() {
result = vec![(available[0].clone(), 1.0)];
}
// Allocate budget: best gets 60%, rest split remainder
let n = result.len();
result.iter_mut().enumerate().for_each(|(i, (_, budget))| {
*budget = if i == 0 {
0.6
} else {
0.4 / (n - 1).max(1) as f64
};
});
result
}
/// Update arm stats after a solve attempt.
pub fn update(
&mut self,
ctx: &RoutingContext,
strategy: &str,
correct: bool,
steps: usize,
noisy: bool,
) {
let arm_map = self.arms.entry(ctx.clone()).or_default();
let stats = arm_map.entry(strategy.to_string()).or_default();
stats.pulls += 1;
stats.total_steps += steps;
if correct {
stats.successes += 1;
}
if noisy {
stats.noise_pulls += 1;
if correct {
stats.noise_successes += 1;
}
}
}
fn next_f64(&mut self) -> f64 {
let mut x = self.rng_state.max(1);
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng_state = x;
(x as f64) / (u64::MAX as f64)
}
}
// ═══════════════════════════════════════════════════════════════════════════
// AdaptiveSolver
// ═══════════════════════════════════════════════════════════════════════════
pub struct AdaptiveSolver {
/// Internal solver
solver: TemporalSolver,
/// ReasoningBank for learning
pub reasoning_bank: ReasoningBank,
/// Current strategy
current_strategy: Strategy,
/// Total episodes completed
pub episodes: usize,
/// When set, solve() uses this step limit instead of the strategy's
pub external_step_limit: Option<usize>,
/// KnowledgeCompiler for Strategy Zero (compiled solve configs)
pub compiler: KnowledgeCompiler,
/// Whether to use the compiler as Strategy Zero
pub compiler_enabled: bool,
/// Adaptive strategy router (contextual bandit)
pub router: StrategyRouter,
/// Whether to use the adaptive router instead of fixed strategy selection
pub router_enabled: bool,
/// PolicyKernel for skip-mode decisions (all modes use this)
pub policy_kernel: PolicyKernel,
/// Whether the current puzzle is noisy (set by caller before solve)
pub noisy_hint: bool,
}
impl Default for AdaptiveSolver {
fn default() -> Self {
Self::new()
}
}
impl AdaptiveSolver {
/// Create a new adaptive solver
pub fn new() -> Self {
Self {
solver: TemporalSolver::default(),
reasoning_bank: ReasoningBank::new(),
current_strategy: Strategy::default(),
episodes: 0,
external_step_limit: None,
compiler: KnowledgeCompiler::new(),
compiler_enabled: false,
router: StrategyRouter::new(),
router_enabled: false,
policy_kernel: PolicyKernel::new(),
noisy_hint: false,
}
}
/// Create with pre-trained ReasoningBank
pub fn with_reasoning_bank(reasoning_bank: ReasoningBank) -> Self {
Self {
solver: TemporalSolver::default(),
reasoning_bank,
current_strategy: Strategy::default(),
episodes: 0,
external_step_limit: None,
compiler: KnowledgeCompiler::new(),
compiler_enabled: false,
router: StrategyRouter::new(),
router_enabled: false,
policy_kernel: PolicyKernel::new(),
noisy_hint: false,
}
}
/// Recompile knowledge from the current ReasoningBank.
pub fn recompile(&mut self) {
self.compiler.compile_from_bank(&self.reasoning_bank);
}
/// Get mutable reference to the internal solver for configuration.
pub fn solver_mut(&mut self) -> &mut TemporalSolver {
&mut self.solver
}
/// Build a PolicyContext from puzzle features.
fn build_policy_context(&self, puzzle: &TemporalPuzzle) -> PolicyContext {
let has_dow = puzzle
.constraints
.iter()
.any(|c| matches!(c, TemporalConstraint::DayOfWeek(_)));
// Estimate posterior range from Between constraint
let posterior_range = puzzle
.constraints
.iter()
.find_map(|c| match c {
TemporalConstraint::Between(start, end) => {
Some((*end - *start).num_days().max(0) as usize)
}
_ => None,
})
.unwrap_or(365);
// Count distractors: redundant constraints that don't narrow the search
// (wider Between, redundant InYear, After well before range)
let distractor_count = count_distractors(puzzle);
let dv = puzzle
.difficulty_vector
.clone()
.unwrap_or_else(|| DifficultyVector::from_scalar(puzzle.difficulty));
PolicyContext {
posterior_range,
distractor_count,
has_day_of_week: has_dow,
noisy: self.noisy_hint,
difficulty: dv,
recent_false_hit_rate: self.policy_kernel.early_commit_rate(),
}
}
/// Solve a puzzle with adaptive learning.
///
/// All modes have access to the same solver capabilities (including skip_weekday).
/// What differs is the **policy** that decides how to use them:
/// - Mode A (baseline): fixed heuristic policy
/// - Mode B (compiler): compiler-suggested policy
/// - Mode C (full): learned PolicyKernel policy
pub fn solve(&mut self, puzzle: &TemporalPuzzle) -> Result<SolverResult> {
// Reset solver state
self.solver.skip_weekday = None;
// Get constraint types for pattern matching
let constraint_types: Vec<String> = puzzle
.constraints
.iter()
.map(|c| constraint_type_name(c))
.collect();
// Build policy context (same for all modes)
let policy_ctx = self.build_policy_context(puzzle);
// ─── PolicyKernel: decide skip_mode (all modes participate) ──────
let skip_mode = if self.router_enabled {
// Mode C: learned policy
self.policy_kernel.learned_policy(&policy_ctx)
} else if self.compiler_enabled {
// Mode B: compiler-suggested policy
let compiled_skip = self
.compiler
.lookup(puzzle)
.map(|config| config.compiled_skip_mode.clone());
PolicyKernel::compiled_policy(&policy_ctx, compiled_skip)
} else {
// Mode A: fixed baseline policy
PolicyKernel::fixed_policy(&policy_ctx)
};
// Apply skip_mode to solver
match &skip_mode {
SkipMode::None => {
self.solver.skip_weekday = None;
}
SkipMode::Weekday => {
self.solver.skip_weekday = puzzle.constraints.iter().find_map(|c| match c {
TemporalConstraint::DayOfWeek(w) => Some(*w),
_ => None,
});
}
SkipMode::Hybrid => {
// Hybrid: use weekday skip for initial scan (set here),
// then do a refinement pass below if needed.
// Force minimum evidence: never stop_after_first in Hybrid mode.
self.solver.skip_weekday = puzzle.constraints.iter().find_map(|c| match c {
TemporalConstraint::DayOfWeek(w) => Some(*w),
_ => None,
});
// Hybrid safety: disable early termination so solver checks
// all matching weekdays before committing
self.solver.stop_after_first = false;
}
}
// Accumulated steps across all attempts (Strategy Zero + fallback)
let mut extra_steps: usize = 0;
let mut extra_tool_calls: usize = 0;
// ─── Strategy Zero: KnowledgeCompiler (bounded trial) ────────────
if self.compiler_enabled {
let conf_threshold = self.compiler.confidence_threshold;
let compiled = self.compiler.lookup(puzzle).map(|config| {
(
config.expected_correct,
config.confidence(),
config.trial_budget(self.external_step_limit.unwrap_or(400)),
config.use_rewriting,
config.stop_after_first,
)
});
if let Some((expected_correct, confidence, trial_budget, use_rewriting, stop_first)) =
compiled
{
if expected_correct && confidence >= conf_threshold {
self.solver.calendar_tool = use_rewriting;
self.solver.stop_after_first = stop_first;
self.solver.max_steps = trial_budget;
let start = std::time::Instant::now();
let result = self.solver.solve(puzzle)?;
let latency = start.elapsed().as_millis() as u64;
self.solver.stop_after_first = false;
if result.correct {
self.compiler.record_success(puzzle, result.steps);
let mut trajectory = Trajectory::new(&puzzle.id, puzzle.difficulty);
trajectory.constraint_types = constraint_types;
trajectory.latency_ms = latency;
let sol_str = result
.solutions
.first()
.map(|d| d.to_string())
.unwrap_or_else(|| "none".to_string());
let bucket_key = PolicyKernel::context_bucket_static(&policy_ctx);
trajectory.record_attempt_witnessed(
sol_str,
0.95,
result.steps,
result.tool_calls,
"compiler",
&skip_mode.to_string(),
&bucket_key,
);
trajectory.set_verdict(
Verdict::Success,
puzzle.solutions.first().map(|d| d.to_string()),
);
self.reasoning_bank.record_trajectory(trajectory);
self.episodes += 1;
// Record successful skip outcome
let outcome = SkipOutcome {
mode: skip_mode,
correct: true,
steps: result.steps,
early_commit_wrong: false,
initial_candidates: policy_ctx.posterior_range,
remaining_at_commit: 0,
};
self.policy_kernel.record_outcome(&policy_ctx, &outcome);
if self.router_enabled {
let ctx = StrategyRouter::context(puzzle, false);
self.router
.update(&ctx, "compiler", true, result.steps, false);
}
return Ok(result);
} else {
extra_steps += result.steps;
extra_tool_calls += result.tool_calls;
self.compiler.record_failure(puzzle);
// Record early commit wrong if solver claimed solved but was wrong
if result.solved && !result.correct {
// Estimate remaining: initial minus steps scanned
let remaining = policy_ctx.posterior_range.saturating_sub(result.steps);
let outcome = SkipOutcome {
mode: skip_mode.clone(),
correct: false,
steps: result.steps,
early_commit_wrong: true,
initial_candidates: policy_ctx.posterior_range,
remaining_at_commit: remaining,
};
self.policy_kernel.record_outcome(&policy_ctx, &outcome);
}
}
}
}
}
// ─── Strategy Selection (fixed or router) ───────────────────────
if self.router_enabled {
let ctx = StrategyRouter::context(puzzle, false);
let available = vec![
"default".to_string(),
"aggressive".to_string(),
"conservative".to_string(),
"adaptive".to_string(),
];
let ranked = self.router.select(&ctx, &available);
if let Some((top_strategy, _)) = ranked.first() {
self.current_strategy = self
.reasoning_bank
.strategy_from_name(top_strategy, puzzle.difficulty);
}
} else {
self.current_strategy = self
.reasoning_bank
.get_strategy(puzzle.difficulty, &constraint_types);
}
// Configure solver based on strategy (external limit overrides strategy)
self.solver.calendar_tool = self.current_strategy.use_rewriting;
self.solver.max_steps = self
.external_step_limit
.unwrap_or(self.current_strategy.max_steps);
self.solver.stop_after_first = false;
// Wire prepass mode from PolicyKernel
self.solver.prepass_mode = self.policy_kernel.prepass.clone();
// Create trajectory for this puzzle
let mut trajectory = Trajectory::new(&puzzle.id, puzzle.difficulty);
trajectory.constraint_types = constraint_types;
// Solve the puzzle
let start = std::time::Instant::now();
let mut result = self.solver.solve(puzzle)?;
trajectory.latency_ms = start.elapsed().as_millis() as u64;
// Track prepass metrics if enabled
if self.policy_kernel.prepass != PrepassMode::Off {
self.policy_kernel.prepass_metrics.invocations += 1;
// Direct solve: steps < 15 and correct means propagation worked
if result.steps <= 15 && result.correct && result.solved {
self.policy_kernel.prepass_metrics.direct_solves += 1;
// Estimate scan steps saved
let would_have_scanned = policy_ctx.posterior_range;
self.policy_kernel.prepass_metrics.scan_steps_saved += would_have_scanned;
}
// Estimate pruned candidates
let actual_range = (result.steps as f64 * 7.0) as usize; // rough
let saved = policy_ctx.posterior_range.saturating_sub(actual_range);
self.policy_kernel.prepass_metrics.pruned_candidates += saved;
}
// ─── Hybrid refinement pass ──────────────────────────────────────
// If Hybrid mode was used and we found solutions via weekday skip,
// do a narrow linear scan around each candidate to catch near-misses.
if skip_mode == SkipMode::Hybrid && !result.solutions.is_empty() {
let mut refined_solutions = result.solutions.clone();
self.solver.skip_weekday = None; // Linear for refinement
let saved_max = self.solver.max_steps;
self.solver.max_steps = 14; // Check ±7 days around each candidate
for candidate in &result.solutions {
let refine_start = *candidate - chrono::Duration::days(7);
let refine_end = *candidate + chrono::Duration::days(7);
let refine_puzzle = TemporalPuzzle {
id: puzzle.id.clone(),
description: puzzle.description.clone(),
constraints: puzzle.constraints.clone(),
references: puzzle.references.clone(),
solutions: puzzle.solutions.clone(),
difficulty: puzzle.difficulty,
tags: puzzle.tags.clone(),
difficulty_vector: puzzle.difficulty_vector.clone(),
};
// Manually search the refinement window
let mut cur = refine_start;
while cur <= refine_end {
if let Ok(true) = refine_puzzle.check_date(cur) {
if !refined_solutions.contains(&cur) {
refined_solutions.push(cur);
}
}
cur = match cur.succ_opt() {
Some(d) => d,
None => break,
};
result.steps += 1;
}
}
self.solver.max_steps = saved_max;
result.solutions = refined_solutions;
// Re-check correctness after refinement
result.correct = if puzzle.solutions.is_empty() {
true
} else {
puzzle
.solutions
.iter()
.all(|s| result.solutions.contains(s))
};
}
// Accumulate overhead from failed Strategy Zero attempt
result.steps += extra_steps;
result.tool_calls += extra_tool_calls;
// Record attempt
let solution_str = result
.solutions
.first()
.map(|d| d.to_string())
.unwrap_or_else(|| "none".to_string());
let confidence = self.calculate_confidence(&result, puzzle);
let bucket_key = PolicyKernel::context_bucket_static(&policy_ctx);
trajectory.record_attempt_witnessed(
solution_str,
confidence,
result.steps,
result.tool_calls,
&self.current_strategy.name,
&skip_mode.to_string(),
&bucket_key,
);
// Determine verdict
let verdict = if result.correct {
if confidence >= 0.9 {
Verdict::Success
} else {
Verdict::Acceptable
}
} else if result.solved {
Verdict::Suboptimal {
reason: "Solution found but incorrect".to_string(),
delta: 1.0 - confidence,
}
} else if confidence < self.current_strategy.confidence_threshold {
Verdict::LowConfidence
} else {
Verdict::Failed
};
trajectory.set_verdict(verdict, puzzle.solutions.first().map(|d| d.to_string()));
// ─── Record PolicyKernel outcome ─────────────────────────────────
let early_commit_wrong = result.solved && !result.correct;
let remaining = policy_ctx.posterior_range.saturating_sub(result.steps);
let outcome = SkipOutcome {
mode: skip_mode,
correct: result.correct,
steps: result.steps,
early_commit_wrong,
initial_candidates: policy_ctx.posterior_range,
remaining_at_commit: remaining,
};
self.policy_kernel.record_outcome(&policy_ctx, &outcome);
// Update router stats
if self.router_enabled {
let ctx = StrategyRouter::context(puzzle, false);
self.router.update(
&ctx,
&self.current_strategy.name,
result.correct,
result.steps,
false,
);
}
// Record trajectory for learning
self.reasoning_bank.record_trajectory(trajectory);
self.episodes += 1;
Ok(result)
}
/// Calculate confidence in a result
fn calculate_confidence(&self, result: &SolverResult, puzzle: &TemporalPuzzle) -> f64 {
let mut confidence = 0.5;
// Higher confidence if solved quickly
if result.solved {
confidence += 0.2;
if result.steps < self.solver.max_steps / 2 {
confidence += 0.1;
}
}
// Higher confidence with tool use on complex puzzles
if result.tool_calls > 0 && puzzle.difficulty > 5 {
confidence += 0.1;
}
// Lower confidence if took many steps
if result.steps > self.solver.max_steps * 3 / 4 {
confidence -= 0.1;
}
// Adjust based on learned calibration
let calibrated_threshold = self
.reasoning_bank
.calibration
.get_threshold(puzzle.difficulty);
if confidence >= calibrated_threshold {
confidence += 0.05;
}
confidence.min(1.0).max(0.0)
}
/// Get learning progress
pub fn learning_progress(&self) -> crate::reasoning_bank::LearningProgress {
self.reasoning_bank.learning_progress()
}
/// Get hints for a puzzle
pub fn get_hints(&self, constraint_types: &[String]) -> Vec<String> {
self.reasoning_bank.get_hints(constraint_types)
}
}
/// Count distractor constraints in a puzzle.
/// A distractor is a constraint that is likely redundant (doesn't narrow the search much).
/// Public so the generator can tag puzzles with their distractor count.
pub fn count_distractors(puzzle: &TemporalPuzzle) -> usize {
let mut count = 0;
let mut seen_between = false;
let mut seen_inyear = false;
let mut seen_dow = false;
for c in &puzzle.constraints {
match c {
TemporalConstraint::Between(_, _) => {
if seen_between {
count += 1; // Redundant Between (wider or duplicate)
}
seen_between = true;
}
TemporalConstraint::InYear(_) => {
if seen_inyear {
count += 1; // Redundant InYear
}
seen_inyear = true;
}
TemporalConstraint::DayOfWeek(_) => {
if seen_dow {
count += 1; // Redundant DayOfWeek
}
seen_dow = true;
}
TemporalConstraint::After(d) => {
// After a date well before the Between range → distractor
if seen_between {
if let Some(between_start) = puzzle.constraints.iter().find_map(|c2| match c2 {
TemporalConstraint::Between(s, _) => Some(*s),
_ => None,
}) {
if *d < between_start - chrono::Duration::days(14) {
count += 1;
}
}
}
}
_ => {}
}
}
count
}
/// Get the type name of a constraint for pattern matching
fn constraint_type_name(constraint: &TemporalConstraint) -> String {
match constraint {
TemporalConstraint::Exact(_) => "Exact".to_string(),
TemporalConstraint::After(_) => "After".to_string(),
TemporalConstraint::Before(_) => "Before".to_string(),
TemporalConstraint::Between(_, _) => "Between".to_string(),
TemporalConstraint::DayOfWeek(_) => "DayOfWeek".to_string(),
TemporalConstraint::DaysAfter(_, _) => "DaysAfter".to_string(),
TemporalConstraint::DaysBefore(_, _) => "DaysBefore".to_string(),
TemporalConstraint::InMonth(_) => "InMonth".to_string(),
TemporalConstraint::InYear(_) => "InYear".to_string(),
TemporalConstraint::DayOfMonth(_) => "DayOfMonth".to_string(),
TemporalConstraint::RelativeToEvent(_, _) => "RelativeToEvent".to_string(),
}
}