wifi-densepose/vendor/ruvector/crates/ruvector-gnn/src/scheduler.rs

532 lines
16 KiB
Rust

//! Learning rate scheduling for Graph Neural Networks
//!
//! Provides various learning rate scheduling strategies to prevent catastrophic
//! forgetting and optimize training dynamics in continual learning scenarios.
use std::f32::consts::PI;
/// Learning rate scheduling strategies
#[derive(Debug, Clone)]
pub enum SchedulerType {
/// Constant learning rate throughout training
Constant,
/// Step decay: multiply learning rate by gamma every step_size epochs
/// Formula: lr = base_lr * gamma^(epoch / step_size)
StepDecay { step_size: usize, gamma: f32 },
/// Exponential decay: multiply learning rate by gamma each epoch
/// Formula: lr = base_lr * gamma^epoch
Exponential { gamma: f32 },
/// Cosine annealing with warm restarts
/// Formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * (epoch % t_max) / t_max))
CosineAnnealing { t_max: usize, eta_min: f32 },
/// Warmup phase followed by linear decay
/// Linearly increases lr from 0 to base_lr over warmup_steps,
/// then linearly decreases to 0 over remaining steps
WarmupLinear {
warmup_steps: usize,
total_steps: usize,
},
/// Reduce learning rate when a metric plateaus
/// Useful for online learning scenarios
ReduceOnPlateau {
factor: f32,
patience: usize,
min_lr: f32,
},
}
/// Learning rate scheduler for GNN training
///
/// Implements various scheduling strategies to control learning rate
/// during training, helping prevent catastrophic forgetting and
/// improve convergence.
#[derive(Debug, Clone)]
pub struct LearningRateScheduler {
scheduler_type: SchedulerType,
base_lr: f32,
current_lr: f32,
step_count: usize,
best_metric: f32,
patience_counter: usize,
}
impl LearningRateScheduler {
/// Creates a new learning rate scheduler
///
/// # Arguments
/// * `scheduler_type` - The scheduling strategy to use
/// * `base_lr` - The initial/base learning rate
///
/// # Example
/// ```
/// use ruvector_gnn::scheduler::{LearningRateScheduler, SchedulerType};
///
/// let scheduler = LearningRateScheduler::new(
/// SchedulerType::StepDecay { step_size: 10, gamma: 0.9 },
/// 0.001
/// );
/// ```
pub fn new(scheduler_type: SchedulerType, base_lr: f32) -> Self {
Self {
scheduler_type,
base_lr,
current_lr: base_lr,
step_count: 0,
best_metric: f32::INFINITY,
patience_counter: 0,
}
}
/// Advances the scheduler by one step and returns the new learning rate
///
/// For most schedulers, this should be called once per epoch.
/// For ReduceOnPlateau, use `step_with_metric` instead.
///
/// # Returns
/// The updated learning rate
pub fn step(&mut self) -> f32 {
self.step_count += 1;
self.current_lr = self.calculate_lr();
self.current_lr
}
/// Advances the scheduler with a metric value (for ReduceOnPlateau)
///
/// # Arguments
/// * `metric` - The metric value to monitor (e.g., validation loss)
///
/// # Returns
/// The updated learning rate
pub fn step_with_metric(&mut self, metric: f32) -> f32 {
self.step_count += 1;
match &self.scheduler_type {
SchedulerType::ReduceOnPlateau {
factor,
patience,
min_lr,
} => {
// Check if metric improved
if metric < self.best_metric - 1e-8 {
self.best_metric = metric;
self.patience_counter = 0;
} else {
self.patience_counter += 1;
// Reduce learning rate if patience exceeded
if self.patience_counter >= *patience {
self.current_lr = (self.current_lr * factor).max(*min_lr);
self.patience_counter = 0;
}
}
}
_ => {
// For non-plateau schedulers, just use step()
self.current_lr = self.calculate_lr();
}
}
self.current_lr
}
/// Gets the current learning rate without advancing the scheduler
pub fn get_lr(&self) -> f32 {
self.current_lr
}
/// Resets the scheduler to its initial state
pub fn reset(&mut self) {
self.current_lr = self.base_lr;
self.step_count = 0;
self.best_metric = f32::INFINITY;
self.patience_counter = 0;
}
/// Calculates the learning rate based on the current step and scheduler type
fn calculate_lr(&self) -> f32 {
match &self.scheduler_type {
SchedulerType::Constant => self.base_lr,
SchedulerType::StepDecay { step_size, gamma } => {
let decay_factor = (*gamma).powi((self.step_count / step_size) as i32);
self.base_lr * decay_factor
}
SchedulerType::Exponential { gamma } => {
let decay_factor = (*gamma).powi(self.step_count as i32);
self.base_lr * decay_factor
}
SchedulerType::CosineAnnealing { t_max, eta_min } => {
let cycle_step = self.step_count % t_max;
let cos_term = (PI * cycle_step as f32 / *t_max as f32).cos();
eta_min + 0.5 * (self.base_lr - eta_min) * (1.0 + cos_term)
}
SchedulerType::WarmupLinear {
warmup_steps,
total_steps,
} => {
if self.step_count < *warmup_steps {
// Warmup phase: linear increase
self.base_lr * (self.step_count as f32 / *warmup_steps as f32)
} else if self.step_count < *total_steps {
// Decay phase: linear decrease
let remaining_steps = *total_steps - self.step_count;
let total_decay_steps = *total_steps - *warmup_steps;
self.base_lr * (remaining_steps as f32 / total_decay_steps as f32)
} else {
// After total_steps, keep at 0
0.0
}
}
SchedulerType::ReduceOnPlateau { .. } => {
// For plateau scheduler, lr is updated in step_with_metric
self.current_lr
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-6;
fn assert_close(a: f32, b: f32, msg: &str) {
assert!((a - b).abs() < EPSILON, "{}: {} != {}", msg, a, b);
}
#[test]
fn test_constant_scheduler() {
let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.01);
assert_close(scheduler.get_lr(), 0.01, "Initial LR");
for i in 1..=10 {
let lr = scheduler.step();
assert_close(lr, 0.01, &format!("Step {} LR", i));
}
}
#[test]
fn test_step_decay() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::StepDecay {
step_size: 5,
gamma: 0.5,
},
0.1,
);
assert_close(scheduler.get_lr(), 0.1, "Initial LR");
// Steps 1-4: no decay
for i in 1..=4 {
let lr = scheduler.step();
assert_close(lr, 0.1, &format!("Step {} LR", i));
}
// Step 5: first decay (0.1 * 0.5)
let lr = scheduler.step();
assert_close(lr, 0.05, "Step 5 LR (first decay)");
// Steps 6-9: maintain decayed rate
for i in 6..=9 {
let lr = scheduler.step();
assert_close(lr, 0.05, &format!("Step {} LR", i));
}
// Step 10: second decay (0.1 * 0.5^2)
let lr = scheduler.step();
assert_close(lr, 0.025, "Step 10 LR (second decay)");
}
#[test]
fn test_exponential_decay() {
let mut scheduler =
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
assert_close(scheduler.get_lr(), 0.1, "Initial LR");
let expected_lrs = vec![
0.1 * 0.9, // Step 1
0.1 * 0.81, // Step 2 (0.9^2)
0.1 * 0.729, // Step 3 (0.9^3)
];
for (i, expected) in expected_lrs.iter().enumerate() {
let lr = scheduler.step();
assert_close(lr, *expected, &format!("Step {} LR", i + 1));
}
}
#[test]
fn test_cosine_annealing() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::CosineAnnealing {
t_max: 10,
eta_min: 0.0,
},
1.0,
);
assert_close(scheduler.get_lr(), 1.0, "Initial LR");
// Cosine annealing formula: lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * cycle_step / t_max))
// cycle_step = step_count % t_max
// At step 5: cycle_step = 5, cos(pi * 5/10) = cos(pi/2) = 0, lr = 0 + 0.5 * 1 * (1 + 0) = 0.5
// At step 10: cycle_step = 0 (wrapped), cos(0) = 1, lr = 0 + 0.5 * 1 * (1 + 1) = 1.0 (restart)
for _ in 1..=5 {
scheduler.step();
}
assert_close(scheduler.get_lr(), 0.5, "Mid-cycle LR (step 5)");
// At step 9: cycle_step = 9, cos(pi * 9/10) ≈ -0.951, lr ≈ 0.025
for _ in 6..=9 {
scheduler.step();
}
let lr_step9 = scheduler.get_lr();
assert!(
lr_step9 < 0.1,
"Near end of cycle LR (step 9) should be small: {}",
lr_step9
);
// At step 10: warm restart (cycle_step = 0), LR goes back to base
scheduler.step();
assert_close(
scheduler.get_lr(),
1.0,
"Restart at step 10 (cycle_step = 0)",
);
// Continue new cycle
scheduler.step();
assert!(
scheduler.get_lr() < 1.0,
"Step 11 should be less than base LR"
);
}
#[test]
fn test_warmup_linear() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::WarmupLinear {
warmup_steps: 5,
total_steps: 10,
},
1.0,
);
assert_close(scheduler.get_lr(), 1.0, "Initial LR");
// Warmup phase: linear increase
scheduler.step();
assert_close(scheduler.get_lr(), 0.2, "Step 1 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.4, "Step 2 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.6, "Step 3 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.8, "Step 4 (warmup)");
scheduler.step();
assert_close(scheduler.get_lr(), 1.0, "Step 5 (warmup end)");
// Decay phase: linear decrease
scheduler.step();
assert_close(scheduler.get_lr(), 0.8, "Step 6 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.6, "Step 7 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.4, "Step 8 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.2, "Step 9 (decay)");
scheduler.step();
assert_close(scheduler.get_lr(), 0.0, "Step 10 (decay end)");
// After total_steps
scheduler.step();
assert_close(scheduler.get_lr(), 0.0, "Step 11 (after total)");
}
#[test]
fn test_reduce_on_plateau() {
let mut scheduler = LearningRateScheduler::new(
SchedulerType::ReduceOnPlateau {
factor: 0.5,
patience: 3,
min_lr: 0.0001,
},
0.01,
);
assert_close(scheduler.get_lr(), 0.01, "Initial LR");
// Improving metrics: no reduction (sets best_metric, resets patience)
scheduler.step_with_metric(1.0);
assert_close(
scheduler.get_lr(),
0.01,
"Step 1 (first metric, sets baseline)",
);
scheduler.step_with_metric(0.9);
assert_close(scheduler.get_lr(), 0.01, "Step 2 (improving)");
// Plateau: metric not improving (patience counter: 1, 2, 3)
scheduler.step_with_metric(0.91);
assert_close(scheduler.get_lr(), 0.01, "Step 3 (plateau 1)");
scheduler.step_with_metric(0.92);
assert_close(scheduler.get_lr(), 0.01, "Step 4 (plateau 2)");
// patience=3 means after 3 non-improvements, reduce LR
// Step 5 is the 3rd non-improvement, so LR gets reduced
scheduler.step_with_metric(0.93);
assert_close(
scheduler.get_lr(),
0.005,
"Step 5 (patience exceeded, reduced)",
);
// Counter is reset after reduction, so we need 3 more non-improvements
scheduler.step_with_metric(0.94); // plateau 1 after reset
assert_close(scheduler.get_lr(), 0.005, "Step 6 (plateau 1 after reset)");
scheduler.step_with_metric(0.95); // plateau 2
assert_close(scheduler.get_lr(), 0.005, "Step 7 (plateau 2)");
scheduler.step_with_metric(0.96); // plateau 3 - triggers reduction
assert_close(scheduler.get_lr(), 0.0025, "Step 8 (reduced again)");
// Test min_lr floor
for _ in 0..20 {
scheduler.step_with_metric(1.0);
}
assert!(
scheduler.get_lr() >= 0.0001,
"LR should not go below min_lr"
);
}
#[test]
fn test_scheduler_reset() {
let mut scheduler =
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.9 }, 0.1);
// Run for several steps
for _ in 0..5 {
scheduler.step();
}
assert!(scheduler.get_lr() < 0.1, "LR should have decayed");
// Reset and verify
scheduler.reset();
assert_close(scheduler.get_lr(), 0.1, "Reset LR");
assert_eq!(scheduler.step_count, 0, "Reset step count");
}
#[test]
fn test_scheduler_cloning() {
let scheduler1 = LearningRateScheduler::new(
SchedulerType::StepDecay {
step_size: 10,
gamma: 0.5,
},
0.01,
);
let mut scheduler2 = scheduler1.clone();
// Advance clone
scheduler2.step();
// Original should be unchanged
assert_close(scheduler1.get_lr(), 0.01, "Original LR");
assert_close(scheduler2.get_lr(), 0.01, "Clone LR after step");
}
#[test]
fn test_multiple_scheduler_types() {
let schedulers = vec![
(SchedulerType::Constant, 0.01),
(
SchedulerType::StepDecay {
step_size: 5,
gamma: 0.9,
},
0.01,
),
(SchedulerType::Exponential { gamma: 0.95 }, 0.01),
(
SchedulerType::CosineAnnealing {
t_max: 10,
eta_min: 0.001,
},
0.01,
),
(
SchedulerType::WarmupLinear {
warmup_steps: 5,
total_steps: 20,
},
0.01,
),
(
SchedulerType::ReduceOnPlateau {
factor: 0.5,
patience: 5,
min_lr: 0.0001,
},
0.01,
),
];
for (sched_type, base_lr) in schedulers {
let mut scheduler = LearningRateScheduler::new(sched_type, base_lr);
// All schedulers should start at base_lr
assert_close(scheduler.get_lr(), base_lr, "Initial LR for scheduler type");
// All schedulers should be able to step
let _ = scheduler.step();
assert!(scheduler.get_lr() >= 0.0, "LR should be non-negative");
}
}
#[test]
fn test_edge_cases() {
// Zero learning rate
let mut scheduler = LearningRateScheduler::new(SchedulerType::Constant, 0.0);
assert_close(scheduler.get_lr(), 0.0, "Zero LR");
scheduler.step();
assert_close(scheduler.get_lr(), 0.0, "Zero LR after step");
// Very small gamma
let mut scheduler =
LearningRateScheduler::new(SchedulerType::Exponential { gamma: 0.1 }, 1.0);
for _ in 0..10 {
scheduler.step();
}
assert!(scheduler.get_lr() > 0.0, "LR should remain positive");
assert!(scheduler.get_lr() < 1e-8, "LR should be very small");
}
}