227 lines
7.5 KiB
Rust
227 lines
7.5 KiB
Rust
//! Pattern matching for threat detection
|
|
|
|
use aimds_core::{DetectionResult, Result, ThreatSeverity, ThreatType};
|
|
use aho_corasick::AhoCorasick;
|
|
use chrono::Utc;
|
|
use dashmap::DashMap;
|
|
use regex::RegexSet;
|
|
use std::sync::Arc;
|
|
use midstreamer_temporal_compare::{TemporalComparator, Sequence, ComparisonAlgorithm};
|
|
use uuid::Uuid;
|
|
|
|
/// Pattern matcher using multiple detection strategies
|
|
pub struct PatternMatcher {
|
|
/// Fast string matching for known patterns
|
|
aho_corasick: Arc<AhoCorasick>,
|
|
/// Regex patterns for complex matching
|
|
regex_set: Arc<RegexSet>,
|
|
/// Temporal comparison for behavioral patterns (using i32 for character codes)
|
|
temporal_comparator: TemporalComparator<i32>,
|
|
/// Pattern cache for performance
|
|
cache: Arc<DashMap<String, DetectionResult>>,
|
|
}
|
|
|
|
impl PatternMatcher {
|
|
/// Create a new pattern matcher with default patterns
|
|
pub fn new() -> Result<Self> {
|
|
let patterns = Self::default_patterns();
|
|
let regexes = Self::default_regexes();
|
|
|
|
let aho_corasick = AhoCorasick::new(patterns)
|
|
.map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
|
|
|
|
let regex_set = RegexSet::new(regexes)
|
|
.map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
aho_corasick: Arc::new(aho_corasick),
|
|
regex_set: Arc::new(regex_set),
|
|
temporal_comparator: TemporalComparator::new(1000, 1000), // cache_size, max_length
|
|
cache: Arc::new(DashMap::new()),
|
|
})
|
|
}
|
|
|
|
/// Match patterns in the input text
|
|
pub async fn match_patterns(&self, input: &str) -> Result<DetectionResult> {
|
|
// Check cache first
|
|
let hash = blake3::hash(input.as_bytes());
|
|
let input_hash = hash.to_hex().to_string();
|
|
if let Some(cached) = self.cache.get(&input_hash) {
|
|
return Ok(cached.clone());
|
|
}
|
|
|
|
// Perform pattern matching
|
|
let mut matched_patterns = Vec::new();
|
|
let mut max_severity = ThreatSeverity::Low;
|
|
let mut threat_type = ThreatType::Unknown;
|
|
|
|
// Fast string matching
|
|
for mat in self.aho_corasick.find_iter(input) {
|
|
let pattern_id = mat.pattern().as_usize();
|
|
matched_patterns.push(format!("pattern_{}", pattern_id));
|
|
|
|
// Update severity based on pattern
|
|
if pattern_id < 10 {
|
|
max_severity = ThreatSeverity::Critical;
|
|
threat_type = ThreatType::PromptInjection;
|
|
}
|
|
}
|
|
|
|
// Regex matching
|
|
let regex_matches = self.regex_set.matches(input);
|
|
for pattern_id in regex_matches.iter() {
|
|
matched_patterns.push(format!("regex_{}", pattern_id));
|
|
|
|
if pattern_id < 5 {
|
|
max_severity = std::cmp::max(max_severity, ThreatSeverity::High);
|
|
threat_type = ThreatType::JailbreakAttempt;
|
|
}
|
|
}
|
|
|
|
// Temporal analysis for behavioral patterns
|
|
let temporal_score = self.analyze_temporal_patterns(input).await?;
|
|
|
|
// Calculate confidence based on matches
|
|
let confidence = self.calculate_confidence(&matched_patterns, temporal_score);
|
|
|
|
let result = DetectionResult {
|
|
id: Uuid::new_v4(),
|
|
timestamp: Utc::now(),
|
|
severity: max_severity,
|
|
threat_type,
|
|
confidence,
|
|
input_hash: input_hash.clone(),
|
|
matched_patterns,
|
|
context: serde_json::json!({
|
|
"temporal_score": temporal_score,
|
|
"input_length": input.len(),
|
|
}),
|
|
};
|
|
|
|
// Cache the result
|
|
self.cache.insert(input_hash, result.clone());
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Analyze temporal patterns using Midstream's temporal comparator
|
|
async fn analyze_temporal_patterns(&self, input: &str) -> Result<f64> {
|
|
// Convert input to temporal sequence for DTW analysis (using i32 for char codes)
|
|
let mut input_sequence = Sequence::new();
|
|
for (idx, ch) in input.chars().take(1000).enumerate() {
|
|
input_sequence.push(ch as i32, idx as u64);
|
|
}
|
|
|
|
// Use temporal-compare DTW (validated: 7.8ms performance)
|
|
// Compare against known malicious temporal patterns
|
|
let threat_sequences = Self::threat_temporal_sequences();
|
|
|
|
let mut max_similarity: f64 = 0.0;
|
|
for threat_seq in threat_sequences {
|
|
match self.temporal_comparator.compare(
|
|
&input_sequence,
|
|
&threat_seq,
|
|
ComparisonAlgorithm::DTW,
|
|
) {
|
|
Ok(result) => {
|
|
// Convert distance to similarity (lower distance = higher similarity)
|
|
let similarity = 1.0 / (1.0 + result.distance);
|
|
max_similarity = max_similarity.max(similarity);
|
|
}
|
|
Err(_) => continue,
|
|
}
|
|
}
|
|
|
|
Ok(max_similarity)
|
|
}
|
|
|
|
/// Known threat temporal sequences for DTW comparison
|
|
fn threat_temporal_sequences() -> Vec<Sequence<i32>> {
|
|
vec![
|
|
// Prompt injection temporal pattern
|
|
Self::str_to_sequence("ignore previous instructions"),
|
|
// Jailbreak attempt pattern
|
|
Self::str_to_sequence("you are no longer bound by"),
|
|
// System prompt override pattern
|
|
Self::str_to_sequence("system: you must now"),
|
|
]
|
|
}
|
|
|
|
/// Helper to convert string to Sequence
|
|
fn str_to_sequence(s: &str) -> Sequence<i32> {
|
|
let mut seq = Sequence::new();
|
|
for (idx, ch) in s.chars().enumerate() {
|
|
seq.push(ch as i32, idx as u64);
|
|
}
|
|
seq
|
|
}
|
|
|
|
/// Calculate confidence score
|
|
fn calculate_confidence(&self, patterns: &[String], temporal_score: f64) -> f64 {
|
|
let pattern_score = (patterns.len() as f64 * 0.1).min(0.7);
|
|
let combined = (pattern_score * 0.6) + (temporal_score * 0.4);
|
|
combined.min(1.0)
|
|
}
|
|
|
|
/// Default threat patterns
|
|
fn default_patterns() -> Vec<&'static str> {
|
|
vec![
|
|
"ignore previous instructions",
|
|
"disregard all prior",
|
|
"forget everything",
|
|
"system prompt",
|
|
"admin mode",
|
|
"developer mode",
|
|
"jailbreak",
|
|
"unrestricted mode",
|
|
"bypass filter",
|
|
"override safety",
|
|
]
|
|
}
|
|
|
|
/// Default regex patterns
|
|
fn default_regexes() -> Vec<&'static str> {
|
|
vec![
|
|
r"(?i)ignore\s+(all|previous|prior)\s+instructions",
|
|
r"(?i)system\s*:\s*you\s+are",
|
|
r"(?i)act\s+as\s+(an?\s+)?unrestricted",
|
|
r"(?i)pretend\s+you\s+are\s+(not\s+)?bound",
|
|
r"(?i)disregard\s+your\s+(programming|rules)",
|
|
]
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_pattern_matcher_creation() {
|
|
let matcher = PatternMatcher::new();
|
|
assert!(matcher.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_simple_pattern_match() {
|
|
let matcher = PatternMatcher::new().unwrap();
|
|
let result = matcher
|
|
.match_patterns("Please ignore previous instructions")
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.matched_patterns.is_empty());
|
|
assert!(result.confidence > 0.0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_safe_input() {
|
|
let matcher = PatternMatcher::new().unwrap();
|
|
let result = matcher
|
|
.match_patterns("What is the weather today?")
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(result.matched_patterns.is_empty());
|
|
}
|
|
}
|