//! Pattern matching for threat detection use aimds_core::{DetectionResult, Result, ThreatSeverity, ThreatType}; use aho_corasick::AhoCorasick; use chrono::Utc; use dashmap::DashMap; use regex::RegexSet; use std::sync::Arc; use midstreamer_temporal_compare::{TemporalComparator, Sequence, ComparisonAlgorithm}; use uuid::Uuid; /// Pattern matcher using multiple detection strategies pub struct PatternMatcher { /// Fast string matching for known patterns aho_corasick: Arc, /// Regex patterns for complex matching regex_set: Arc, /// Temporal comparison for behavioral patterns (using i32 for character codes) temporal_comparator: TemporalComparator, /// Pattern cache for performance cache: Arc>, } impl PatternMatcher { /// Create a new pattern matcher with default patterns pub fn new() -> Result { let patterns = Self::default_patterns(); let regexes = Self::default_regexes(); let aho_corasick = AhoCorasick::new(patterns) .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?; let regex_set = RegexSet::new(regexes) .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?; Ok(Self { aho_corasick: Arc::new(aho_corasick), regex_set: Arc::new(regex_set), temporal_comparator: TemporalComparator::new(1000, 1000), // cache_size, max_length cache: Arc::new(DashMap::new()), }) } /// Match patterns in the input text pub async fn match_patterns(&self, input: &str) -> Result { // Check cache first let hash = blake3::hash(input.as_bytes()); let input_hash = hash.to_hex().to_string(); if let Some(cached) = self.cache.get(&input_hash) { return Ok(cached.clone()); } // Perform pattern matching let mut matched_patterns = Vec::new(); let mut max_severity = ThreatSeverity::Low; let mut threat_type = ThreatType::Unknown; // Fast string matching for mat in self.aho_corasick.find_iter(input) { let pattern_id = mat.pattern().as_usize(); matched_patterns.push(format!("pattern_{}", pattern_id)); // Update severity based on pattern if pattern_id < 10 { max_severity = ThreatSeverity::Critical; threat_type = ThreatType::PromptInjection; } } // Regex matching let regex_matches = self.regex_set.matches(input); for pattern_id in regex_matches.iter() { matched_patterns.push(format!("regex_{}", pattern_id)); if pattern_id < 5 { max_severity = std::cmp::max(max_severity, ThreatSeverity::High); threat_type = ThreatType::JailbreakAttempt; } } // Temporal analysis for behavioral patterns let temporal_score = self.analyze_temporal_patterns(input).await?; // Calculate confidence based on matches let confidence = self.calculate_confidence(&matched_patterns, temporal_score); let result = DetectionResult { id: Uuid::new_v4(), timestamp: Utc::now(), severity: max_severity, threat_type, confidence, input_hash: input_hash.clone(), matched_patterns, context: serde_json::json!({ "temporal_score": temporal_score, "input_length": input.len(), }), }; // Cache the result self.cache.insert(input_hash, result.clone()); Ok(result) } /// Analyze temporal patterns using Midstream's temporal comparator async fn analyze_temporal_patterns(&self, input: &str) -> Result { // Convert input to temporal sequence for DTW analysis (using i32 for char codes) let mut input_sequence = Sequence::new(); for (idx, ch) in input.chars().take(1000).enumerate() { input_sequence.push(ch as i32, idx as u64); } // Use temporal-compare DTW (validated: 7.8ms performance) // Compare against known malicious temporal patterns let threat_sequences = Self::threat_temporal_sequences(); let mut max_similarity: f64 = 0.0; for threat_seq in threat_sequences { match self.temporal_comparator.compare( &input_sequence, &threat_seq, ComparisonAlgorithm::DTW, ) { Ok(result) => { // Convert distance to similarity (lower distance = higher similarity) let similarity = 1.0 / (1.0 + result.distance); max_similarity = max_similarity.max(similarity); } Err(_) => continue, } } Ok(max_similarity) } /// Known threat temporal sequences for DTW comparison fn threat_temporal_sequences() -> Vec> { vec![ // Prompt injection temporal pattern Self::str_to_sequence("ignore previous instructions"), // Jailbreak attempt pattern Self::str_to_sequence("you are no longer bound by"), // System prompt override pattern Self::str_to_sequence("system: you must now"), ] } /// Helper to convert string to Sequence fn str_to_sequence(s: &str) -> Sequence { let mut seq = Sequence::new(); for (idx, ch) in s.chars().enumerate() { seq.push(ch as i32, idx as u64); } seq } /// Calculate confidence score fn calculate_confidence(&self, patterns: &[String], temporal_score: f64) -> f64 { let pattern_score = (patterns.len() as f64 * 0.1).min(0.7); let combined = (pattern_score * 0.6) + (temporal_score * 0.4); combined.min(1.0) } /// Default threat patterns fn default_patterns() -> Vec<&'static str> { vec![ "ignore previous instructions", "disregard all prior", "forget everything", "system prompt", "admin mode", "developer mode", "jailbreak", "unrestricted mode", "bypass filter", "override safety", ] } /// Default regex patterns fn default_regexes() -> Vec<&'static str> { vec![ r"(?i)ignore\s+(all|previous|prior)\s+instructions", r"(?i)system\s*:\s*you\s+are", r"(?i)act\s+as\s+(an?\s+)?unrestricted", r"(?i)pretend\s+you\s+are\s+(not\s+)?bound", r"(?i)disregard\s+your\s+(programming|rules)", ] } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_pattern_matcher_creation() { let matcher = PatternMatcher::new(); assert!(matcher.is_ok()); } #[tokio::test] async fn test_simple_pattern_match() { let matcher = PatternMatcher::new().unwrap(); let result = matcher .match_patterns("Please ignore previous instructions") .await .unwrap(); assert!(!result.matched_patterns.is_empty()); assert!(result.confidence > 0.0); } #[tokio::test] async fn test_safe_input() { let matcher = PatternMatcher::new().unwrap(); let result = matcher .match_patterns("What is the weather today?") .await .unwrap(); assert!(result.matched_patterns.is_empty()); } }