wifi-densepose/vendor/ruvector/crates/ruvllm/src/bitnet/trace.rs

558 lines
18 KiB
Rust

//! Structured JSONL Trace Output for BitNet Inference
//!
//! Provides structured tracing of inference decisions including MoE expert
//! routing, citation verification, refusal calibration, and coherence scoring.
//! All trace entries are serialized as JSONL (one JSON object per line) using
//! manual serialization (no serde dependency).
//!
//! ## Trace Fields
//!
//! Each `TraceEntry` captures per-token, per-layer diagnostics:
//! - **Routing**: Which experts were selected and whether they agree with a teacher
//! - **Citations**: Whether generated spans match source chunks
//! - **Refusal**: Whether the model correctly refused harmful prompts
//! - **Coherence**: Token-level coherence score
//! - **Stop Reason**: Why generation terminated
//!
//! ## Example
//!
//! ```rust,ignore
//! use ruvllm::bitnet::trace::{TraceWriter, TraceEntry, StopReason};
//!
//! let mut writer = TraceWriter::new(None);
//! writer.record(entry);
//! let jsonl = writer.to_jsonl();
//! ```
use std::collections::HashSet;
use std::path::PathBuf;
use crate::error::{Result, RuvLLMError};
// ============================================================================
// Trace Data Structures
// ============================================================================
/// Routing trace for a single token at a single layer.
///
/// Records which experts the model selected (top-K) and optionally
/// which experts a teacher model would have selected, enabling
/// routing agreement evaluation.
pub struct RoutingTrace {
/// Expert indices selected by the student model (top-K)
pub topk_expert_ids: Vec<usize>,
/// Corresponding softmax weights for selected experts
pub topk_weights: Vec<f32>,
/// Expert indices from teacher model (if available)
pub teacher_expert_ids: Option<Vec<usize>>,
/// Corresponding teacher weights (if available)
pub teacher_weights: Option<Vec<f32>>,
/// Whether student and teacher selected the same expert set
pub agreement: bool,
}
/// Citation trace for a single generated span.
///
/// Records whether a generated text span can be traced back to a
/// source chunk, with Jaccard similarity as a quality metric.
pub struct CitationTrace {
/// Source chunk identifier
pub chunk_id: String,
/// Generated text span
pub span: String,
/// Whether the citation was validated
pub valid: bool,
/// Word-level Jaccard similarity between span and source
pub jaccard_score: f32,
}
/// Refusal calibration trace.
///
/// Records whether the model should have refused a prompt,
/// whether it actually did, and whether the decision was correct.
pub struct RefusalTrace {
/// Ground truth: should the model refuse this prompt?
pub should_refuse: bool,
/// Model behavior: did the model actually refuse?
pub did_refuse: bool,
/// Whether the model's refusal decision matched ground truth
pub correct: bool,
}
/// Reason why generation stopped.
pub enum StopReason {
/// End-of-sequence token generated
Eos,
/// Maximum generation length reached
MaxLength,
/// Model refused to generate (safety)
Refusal,
/// Coherence score dropped below threshold
LowCoherence,
/// An error occurred during generation
Error(String),
}
/// A single trace entry capturing per-token, per-layer diagnostics.
pub struct TraceEntry {
/// Unique identifier for the prompt being traced
pub prompt_id: String,
/// Token position in the generated sequence
pub token_idx: usize,
/// Transformer layer index
pub layer_idx: usize,
/// Expert routing diagnostics
pub routing: RoutingTrace,
/// Citation verification results
pub citations: Vec<CitationTrace>,
/// Refusal calibration result
pub refusal: RefusalTrace,
/// Token-level coherence score (0.0 to 1.0)
pub coherence_score: f32,
/// Why generation stopped at this token (if applicable)
pub stop_reason: StopReason,
/// Timestamp in milliseconds since epoch
pub timestamp_ms: u64,
}
// ============================================================================
// Manual JSON Serialization
// ============================================================================
/// Escape a string for JSON output.
fn json_escape(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 2);
for ch in s.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c if (c as u32) < 0x20 => {
out.push_str(&format!("\\u{:04x}", c as u32));
}
c => out.push(c),
}
}
out
}
/// Format a Vec<usize> as a JSON array string.
fn json_usize_array(v: &[usize]) -> String {
let parts: Vec<String> = v.iter().map(|x| x.to_string()).collect();
format!("[{}]", parts.join(","))
}
/// Format a Vec<f32> as a JSON array string.
fn json_f32_array(v: &[f32]) -> String {
let parts: Vec<String> = v.iter().map(|x| format!("{:.6}", x)).collect();
format!("[{}]", parts.join(","))
}
impl RoutingTrace {
/// Serialize to a JSON object string.
pub fn to_json(&self) -> String {
let teacher_ids = match &self.teacher_expert_ids {
Some(ids) => json_usize_array(ids),
None => "null".to_string(),
};
let teacher_wts = match &self.teacher_weights {
Some(wts) => json_f32_array(wts),
None => "null".to_string(),
};
format!(
"{{\"topk_expert_ids\":{},\"topk_weights\":{},\"teacher_expert_ids\":{},\"teacher_weights\":{},\"agreement\":{}}}",
json_usize_array(&self.topk_expert_ids),
json_f32_array(&self.topk_weights),
teacher_ids,
teacher_wts,
self.agreement,
)
}
}
impl CitationTrace {
/// Serialize to a JSON object string.
pub fn to_json(&self) -> String {
format!(
"{{\"chunk_id\":\"{}\",\"span\":\"{}\",\"valid\":{},\"jaccard_score\":{:.6}}}",
json_escape(&self.chunk_id),
json_escape(&self.span),
self.valid,
self.jaccard_score,
)
}
}
impl RefusalTrace {
/// Serialize to a JSON object string.
pub fn to_json(&self) -> String {
format!(
"{{\"should_refuse\":{},\"did_refuse\":{},\"correct\":{}}}",
self.should_refuse, self.did_refuse, self.correct,
)
}
}
impl StopReason {
/// Serialize to a JSON string value.
pub fn to_json(&self) -> String {
match self {
StopReason::Eos => "\"eos\"".to_string(),
StopReason::MaxLength => "\"max_length\"".to_string(),
StopReason::Refusal => "\"refusal\"".to_string(),
StopReason::LowCoherence => "\"low_coherence\"".to_string(),
StopReason::Error(msg) => format!("\"error:{}\"", json_escape(msg)),
}
}
}
impl TraceEntry {
/// Serialize to a JSON object string.
pub fn to_json(&self) -> String {
let citations_json: Vec<String> = self.citations.iter().map(|c| c.to_json()).collect();
format!(
"{{\"prompt_id\":\"{}\",\"token_idx\":{},\"layer_idx\":{},\"routing\":{},\"citations\":[{}],\"refusal\":{},\"coherence_score\":{:.6},\"stop_reason\":{},\"timestamp_ms\":{}}}",
json_escape(&self.prompt_id),
self.token_idx,
self.layer_idx,
self.routing.to_json(),
citations_json.join(","),
self.refusal.to_json(),
self.coherence_score,
self.stop_reason.to_json(),
self.timestamp_ms,
)
}
}
// ============================================================================
// Trace Writer
// ============================================================================
/// Collects trace entries and writes them as JSONL.
///
/// Entries can be accumulated via `record()` and then flushed to a file
/// or retrieved as a JSONL string.
pub struct TraceWriter {
entries: Vec<TraceEntry>,
output_path: Option<PathBuf>,
}
impl TraceWriter {
/// Create a new trace writer.
///
/// If `output_path` is `Some`, `flush()` will write to that file.
/// If `None`, entries are only available via `to_jsonl()`.
pub fn new(output_path: Option<PathBuf>) -> Self {
Self {
entries: Vec::new(),
output_path,
}
}
/// Record a trace entry.
pub fn record(&mut self, entry: TraceEntry) {
self.entries.push(entry);
}
/// Flush all recorded entries to the output file (if configured).
///
/// Each entry is written as a single JSON line. The file is
/// overwritten on each flush.
pub fn flush(&mut self) -> Result<()> {
let path = match &self.output_path {
Some(p) => p.clone(),
None => {
return Err(RuvLLMError::Config(
"No output path configured for trace writer".to_string(),
));
}
};
let jsonl = self.to_jsonl();
std::fs::write(&path, jsonl.as_bytes())
.map_err(|e| RuvLLMError::Model(format!("Failed to write trace file: {}", e)))?;
Ok(())
}
/// Convert all recorded entries to a JSONL string.
///
/// Each entry is one line of valid JSON, separated by newlines.
pub fn to_jsonl(&self) -> String {
let lines: Vec<String> = self.entries.iter().map(|e| e.to_json()).collect();
if lines.is_empty() {
return String::new();
}
let mut result = lines.join("\n");
result.push('\n');
result
}
/// Get a reference to the recorded entries.
pub fn entries(&self) -> &[TraceEntry] {
&self.entries
}
/// Clear all recorded entries.
pub fn clear(&mut self) {
self.entries.clear();
}
}
// ============================================================================
// Utility Functions
// ============================================================================
/// Compute word-level Jaccard similarity between two strings.
///
/// Splits both strings on whitespace, computes the Jaccard index:
/// `|A intersect B| / |A union B|`
///
/// # Arguments
///
/// * `a` - First string
/// * `b` - Second string
///
/// # Returns
///
/// Jaccard similarity in [0.0, 1.0]. Returns 1.0 if both strings are empty.
pub fn jaccard_similarity(a: &str, b: &str) -> f32 {
let set_a: HashSet<&str> = a.split_whitespace().collect();
let set_b: HashSet<&str> = b.split_whitespace().collect();
if set_a.is_empty() && set_b.is_empty() {
return 1.0;
}
let intersection = set_a.intersection(&set_b).count();
let union = set_a.union(&set_b).count();
if union == 0 {
return 1.0;
}
intersection as f32 / union as f32
}
/// Check whether model and teacher routing agree (same set of expert IDs).
///
/// Returns true if both slices contain the same set of expert indices,
/// regardless of order.
///
/// # Arguments
///
/// * `model` - Expert indices selected by the student model
/// * `teacher` - Expert indices selected by the teacher model
pub fn check_routing_agreement(model: &[usize], teacher: &[usize]) -> bool {
let model_set: HashSet<usize> = model.iter().copied().collect();
let teacher_set: HashSet<usize> = teacher.iter().copied().collect();
model_set == teacher_set
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
/// Helper to create a minimal trace entry for testing.
fn make_entry(prompt_id: &str, token_idx: usize, layer_idx: usize) -> TraceEntry {
TraceEntry {
prompt_id: prompt_id.to_string(),
token_idx,
layer_idx,
routing: RoutingTrace {
topk_expert_ids: vec![0, 2],
topk_weights: vec![0.6, 0.4],
teacher_expert_ids: Some(vec![0, 2]),
teacher_weights: Some(vec![0.55, 0.45]),
agreement: true,
},
citations: vec![CitationTrace {
chunk_id: "doc-1".to_string(),
span: "the quick fox".to_string(),
valid: true,
jaccard_score: 0.85,
}],
refusal: RefusalTrace {
should_refuse: false,
did_refuse: false,
correct: true,
},
coherence_score: 0.92,
stop_reason: StopReason::Eos,
timestamp_ms: 1700000000000,
}
}
#[test]
fn test_json_serialization_valid() {
let entry = make_entry("prompt-1", 0, 0);
let json = entry.to_json();
// Should start with { and end with }
assert!(json.starts_with('{'), "JSON should start with {{");
assert!(json.ends_with('}'), "JSON should end with }}");
// Should contain key fields
assert!(json.contains("\"prompt_id\":\"prompt-1\""));
assert!(json.contains("\"token_idx\":0"));
assert!(json.contains("\"layer_idx\":0"));
assert!(json.contains("\"coherence_score\":"));
assert!(json.contains("\"stop_reason\":\"eos\""));
}
#[test]
fn test_jsonl_one_per_line() {
let mut writer = TraceWriter::new(None);
writer.record(make_entry("p1", 0, 0));
writer.record(make_entry("p1", 1, 0));
writer.record(make_entry("p2", 0, 0));
let jsonl = writer.to_jsonl();
let lines: Vec<&str> = jsonl.trim_end().split('\n').collect();
assert_eq!(lines.len(), 3, "JSONL should have 3 lines for 3 entries");
// Each line should be valid JSON (starts with {, ends with })
for (i, line) in lines.iter().enumerate() {
assert!(
line.starts_with('{') && line.ends_with('}'),
"Line {} is not valid JSON: {}",
i,
line
);
}
}
#[test]
fn test_jaccard_identical() {
let score = jaccard_similarity("the quick brown fox", "the quick brown fox");
assert!(
(score - 1.0).abs() < 1e-6,
"Identical strings should have Jaccard = 1.0, got {}",
score
);
}
#[test]
fn test_jaccard_disjoint() {
let score = jaccard_similarity("alpha beta gamma", "delta epsilon zeta");
assert!(
score.abs() < 1e-6,
"Disjoint strings should have Jaccard = 0.0, got {}",
score
);
}
#[test]
fn test_jaccard_partial() {
// "the quick" and "the slow" share "the" out of {"the", "quick", "slow"}
let score = jaccard_similarity("the quick", "the slow");
let expected = 1.0 / 3.0; // intersection=1, union=3
assert!(
(score - expected).abs() < 1e-6,
"Partial overlap: expected {}, got {}",
expected,
score
);
}
#[test]
fn test_routing_agreement_same() {
assert!(
check_routing_agreement(&[0, 2, 5], &[5, 0, 2]),
"Same expert set (different order) should agree"
);
}
#[test]
fn test_routing_agreement_different() {
assert!(
!check_routing_agreement(&[0, 2], &[0, 3]),
"Different expert sets should not agree"
);
}
#[test]
fn test_flush_and_readback() {
let dir = std::env::temp_dir();
let path = dir.join("bitnet_trace_test.jsonl");
let mut writer = TraceWriter::new(Some(path.clone()));
writer.record(make_entry("flush-test", 0, 0));
writer.record(make_entry("flush-test", 1, 1));
writer.flush().unwrap();
let contents = std::fs::read_to_string(&path).unwrap();
let lines: Vec<&str> = contents.trim_end().split('\n').collect();
assert_eq!(lines.len(), 2, "Flushed file should have 2 lines");
for line in &lines {
assert!(line.starts_with('{') && line.ends_with('}'));
}
// Cleanup
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_stop_reason_serialization() {
assert_eq!(StopReason::Eos.to_json(), "\"eos\"");
assert_eq!(StopReason::MaxLength.to_json(), "\"max_length\"");
assert_eq!(StopReason::Refusal.to_json(), "\"refusal\"");
assert_eq!(StopReason::LowCoherence.to_json(), "\"low_coherence\"");
let error_json = StopReason::Error("timeout".to_string()).to_json();
assert_eq!(error_json, "\"error:timeout\"");
}
#[test]
fn test_clear_entries() {
let mut writer = TraceWriter::new(None);
writer.record(make_entry("p1", 0, 0));
assert_eq!(writer.entries().len(), 1);
writer.clear();
assert_eq!(writer.entries().len(), 0);
assert_eq!(writer.to_jsonl(), "");
}
#[test]
fn test_json_escape_special_chars() {
let entry = TraceEntry {
prompt_id: "test\"with\\special\nnewline".to_string(),
token_idx: 0,
layer_idx: 0,
routing: RoutingTrace {
topk_expert_ids: vec![],
topk_weights: vec![],
teacher_expert_ids: None,
teacher_weights: None,
agreement: false,
},
citations: vec![],
refusal: RefusalTrace {
should_refuse: false,
did_refuse: false,
correct: true,
},
coherence_score: 0.0,
stop_reason: StopReason::Eos,
timestamp_ms: 0,
};
let json = entry.to_json();
// The escaped prompt_id should not contain raw quotes or newlines
assert!(!json.contains("test\"with"), "Raw quote should be escaped");
assert!(
json.contains("test\\\"with"),
"Quote should be escaped as \\\""
);
assert!(json.contains("\\n"), "Newline should be escaped as \\n");
}
}