wifi-densepose/vendor/ruvector/crates/cognitum-gate-tilezero/src/merge.rs

610 lines
19 KiB
Rust

//! Report merging from 255 worker tiles
//!
//! This module handles aggregating partial graph reports from worker tiles
//! into a unified view for supergraph construction.
//!
//! ## Performance Optimizations
//!
//! - Pre-allocated HashMaps with expected capacity (255 workers)
//! - Inline functions for merge strategies
//! - Iterator-based processing to avoid allocations
//! - Sorted slices with binary search for median calculation
//! - Capacity hints for all collections
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::TileId;
/// Expected number of worker tiles for capacity pre-allocation
const EXPECTED_WORKERS: usize = 255;
/// Expected nodes per worker for capacity hints
const EXPECTED_NODES_PER_WORKER: usize = 16;
/// Expected boundary edges per worker
const EXPECTED_EDGES_PER_WORKER: usize = 32;
/// Epoch identifier for report sequencing
pub type Epoch = u64;
/// Transaction identifier (32-byte hash)
pub type TxId = [u8; 32];
/// Errors during report merging
#[derive(Debug, Clone)]
pub enum MergeError {
/// Empty report set
EmptyReports,
/// Conflicting epochs in reports
ConflictingEpochs,
/// Invalid edge weight
InvalidWeight(String),
/// Node not found
NodeNotFound(String),
}
impl std::fmt::Display for MergeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MergeError::EmptyReports => write!(f, "Empty report set"),
MergeError::ConflictingEpochs => write!(f, "Conflicting epochs in reports"),
MergeError::InvalidWeight(msg) => write!(f, "Invalid edge weight: {}", msg),
MergeError::NodeNotFound(id) => write!(f, "Node not found: {}", id),
}
}
}
impl std::error::Error for MergeError {}
/// Strategy for merging overlapping data from multiple workers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MergeStrategy {
/// Simple average of all values
SimpleAverage,
/// Weighted average by tile confidence
WeightedAverage,
/// Take the median value
Median,
/// Take the maximum value (conservative)
Maximum,
/// Byzantine fault tolerant (2/3 agreement)
ByzantineFaultTolerant,
}
/// A node summary from a worker tile
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeSummary {
/// Node identifier
pub id: String,
/// Aggregated weight/importance
pub weight: f64,
/// Number of edges in worker's partition
pub edge_count: usize,
/// Local coherence score
pub coherence: f64,
}
/// An edge summary from a worker tile (for boundary edges)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeSummary {
/// Source node ID
pub source: String,
/// Target node ID
pub target: String,
/// Edge capacity/weight
pub capacity: f64,
/// Is this a boundary edge (crosses tile partitions)?
pub is_boundary: bool,
}
/// Report from a worker tile containing partition summary
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct WorkerReport {
/// Tile identifier (1-255)
pub tile_id: TileId,
/// Epoch this report belongs to
pub epoch: Epoch,
/// Timestamp when report was generated (unix millis)
pub timestamp_ms: u64,
/// Transactions processed in this partition
pub transactions: Vec<TxId>,
/// Node summaries for super-nodes
pub nodes: Vec<NodeSummary>,
/// Boundary edge summaries
pub boundary_edges: Vec<EdgeSummary>,
/// Local min-cut value (within partition)
pub local_mincut: f64,
/// Worker's confidence in this report (0.0-1.0)
pub confidence: f64,
/// Hash of the worker's local state
pub state_hash: [u8; 32],
}
impl WorkerReport {
/// Create a new worker report
pub fn new(tile_id: TileId, epoch: Epoch) -> Self {
Self {
tile_id,
epoch,
timestamp_ms: 0,
transactions: Vec::new(),
nodes: Vec::new(),
boundary_edges: Vec::new(),
local_mincut: 0.0,
confidence: 1.0,
state_hash: [0u8; 32],
}
}
/// Add a node summary
pub fn add_node(&mut self, node: NodeSummary) {
self.nodes.push(node);
}
/// Add a boundary edge
pub fn add_boundary_edge(&mut self, edge: EdgeSummary) {
self.boundary_edges.push(edge);
}
/// Compute state hash using blake3
pub fn compute_state_hash(&mut self) {
let mut hasher = blake3::Hasher::new();
hasher.update(&self.tile_id.to_le_bytes());
hasher.update(&self.epoch.to_le_bytes());
for node in &self.nodes {
hasher.update(node.id.as_bytes());
hasher.update(&node.weight.to_le_bytes());
}
for edge in &self.boundary_edges {
hasher.update(edge.source.as_bytes());
hasher.update(edge.target.as_bytes());
hasher.update(&edge.capacity.to_le_bytes());
}
self.state_hash = *hasher.finalize().as_bytes();
}
}
/// Merged report combining data from multiple workers
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergedReport {
/// Epoch of the merged report
pub epoch: Epoch,
/// Number of worker reports merged
pub worker_count: usize,
/// Merged super-nodes (aggregated from all workers)
pub super_nodes: HashMap<String, MergedNode>,
/// Merged boundary edges
pub boundary_edges: Vec<MergedEdge>,
/// Global min-cut estimate
pub global_mincut_estimate: f64,
/// Overall confidence (aggregated)
pub confidence: f64,
/// Merge strategy used
pub strategy: MergeStrategy,
}
/// A merged super-node aggregated from multiple workers
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergedNode {
/// Node identifier
pub id: String,
/// Aggregated weight
pub weight: f64,
/// Total edge count across workers
pub total_edge_count: usize,
/// Average coherence
pub avg_coherence: f64,
/// Contributing worker tiles
pub contributors: Vec<TileId>,
}
/// A merged edge aggregated from boundary reports
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergedEdge {
/// Source node
pub source: String,
/// Target node
pub target: String,
/// Aggregated capacity
pub capacity: f64,
/// Number of workers reporting this edge
pub report_count: usize,
}
/// Report merger that combines worker reports
///
/// OPTIMIZATION: Uses capacity hints and inline functions for better performance
pub struct ReportMerger {
strategy: MergeStrategy,
/// Pre-allocated scratch buffer for weight calculations
/// OPTIMIZATION: Reuse allocation across merge operations
scratch_weights: Vec<f64>,
}
impl ReportMerger {
/// Create a new report merger with given strategy
#[inline]
pub fn new(strategy: MergeStrategy) -> Self {
Self {
strategy,
// Pre-allocate scratch buffer with expected capacity
scratch_weights: Vec::with_capacity(EXPECTED_WORKERS),
}
}
/// Merge multiple worker reports into a unified view
///
/// OPTIMIZATION: Pre-allocates all collections with expected capacity
pub fn merge(&self, reports: &[WorkerReport]) -> Result<MergedReport, MergeError> {
if reports.is_empty() {
return Err(MergeError::EmptyReports);
}
// Verify all reports are from the same epoch
// OPTIMIZATION: Use first() and fold for short-circuit evaluation
let epoch = reports[0].epoch;
for r in reports.iter().skip(1) {
if r.epoch != epoch {
return Err(MergeError::ConflictingEpochs);
}
}
// Merge nodes - pre-allocate based on expected size
let super_nodes = self.merge_nodes(reports)?;
// Merge boundary edges
let boundary_edges = self.merge_edges(reports)?;
// Compute global min-cut estimate
let global_mincut_estimate = self.estimate_global_mincut(reports);
// Compute aggregated confidence
let confidence = self.aggregate_confidence(reports);
Ok(MergedReport {
epoch,
worker_count: reports.len(),
super_nodes,
boundary_edges,
global_mincut_estimate,
confidence,
strategy: self.strategy,
})
}
/// Merge node summaries from all workers
///
/// OPTIMIZATION: Pre-allocates HashMap with expected capacity
#[inline]
fn merge_nodes(
&self,
reports: &[WorkerReport],
) -> Result<HashMap<String, MergedNode>, MergeError> {
// OPTIMIZATION: Estimate total nodes across all reports
let estimated_nodes = reports.len() * EXPECTED_NODES_PER_WORKER;
let mut node_data: HashMap<String, Vec<(TileId, &NodeSummary)>> =
HashMap::with_capacity(estimated_nodes);
// Collect all node data
for report in reports {
for node in &report.nodes {
node_data
.entry(node.id.clone())
.or_insert_with(|| Vec::with_capacity(reports.len()))
.push((report.tile_id, node));
}
}
// Merge each node
// OPTIMIZATION: Pre-allocate result HashMap
let mut merged = HashMap::with_capacity(node_data.len());
for (id, data) in node_data {
let merged_node = self.merge_single_node(&id, &data)?;
merged.insert(id, merged_node);
}
Ok(merged)
}
/// Merge a single node's data from multiple workers
///
/// OPTIMIZATION: Uses inline strategy functions and avoids repeated allocations
#[inline]
fn merge_single_node(
&self,
id: &str,
data: &[(TileId, &NodeSummary)],
) -> Result<MergedNode, MergeError> {
// OPTIMIZATION: Pre-allocate with exact capacity
let mut contributors: Vec<TileId> = Vec::with_capacity(data.len());
contributors.extend(data.iter().map(|(tile, _)| *tile));
let total_edge_count: usize = data.iter().map(|(_, n)| n.edge_count).sum();
let len = data.len();
let len_f64 = len as f64;
let weight = match self.strategy {
MergeStrategy::SimpleAverage => {
// OPTIMIZATION: Single pass sum
let sum: f64 = data.iter().map(|(_, n)| n.weight).sum();
sum / len_f64
}
MergeStrategy::WeightedAverage => {
// OPTIMIZATION: Single pass for both sums
let (weighted_sum, coherence_sum) =
data.iter().fold((0.0, 0.0), |(ws, cs), (_, n)| {
(ws + n.weight * n.coherence, cs + n.coherence)
});
if coherence_sum > 0.0 {
weighted_sum / coherence_sum
} else {
0.0
}
}
MergeStrategy::Median => {
// OPTIMIZATION: Inline median calculation
Self::compute_median(data.iter().map(|(_, n)| n.weight))
}
MergeStrategy::Maximum => {
// OPTIMIZATION: Use fold without intermediate iterator
data.iter()
.map(|(_, n)| n.weight)
.fold(f64::NEG_INFINITY, f64::max)
}
MergeStrategy::ByzantineFaultTolerant => {
// OPTIMIZATION: BFT with inline median of 2/3
Self::compute_bft_weight(data.iter().map(|(_, n)| n.weight), len)
}
};
// OPTIMIZATION: Single pass for coherence average
let avg_coherence = data.iter().map(|(_, n)| n.coherence).sum::<f64>() / len_f64;
Ok(MergedNode {
id: id.to_string(),
weight,
total_edge_count,
avg_coherence,
contributors,
})
}
/// Compute median of an iterator of f64 values
///
/// OPTIMIZATION: Inline function to avoid heap allocation overhead
#[inline]
fn compute_median<I: Iterator<Item = f64>>(iter: I) -> f64 {
let mut weights: Vec<f64> = iter.collect();
let len = weights.len();
if len == 0 {
return 0.0;
}
// OPTIMIZATION: Use unstable sort for f64 (faster, no stability needed)
weights.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = len / 2;
if len % 2 == 0 {
// SAFETY: mid > 0 when len >= 2 and even
(weights[mid - 1] + weights[mid]) * 0.5
} else {
weights[mid]
}
}
/// Compute Byzantine Fault Tolerant weight (median of top 2/3)
///
/// OPTIMIZATION: Inline function with optimized threshold calculation
#[inline]
fn compute_bft_weight<I: Iterator<Item = f64>>(iter: I, len: usize) -> f64 {
let mut weights: Vec<f64> = iter.collect();
if weights.is_empty() {
return 0.0;
}
weights.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
// 2/3 threshold
let threshold = (len * 2) / 3;
if threshold > 0 {
let sum: f64 = weights.iter().take(threshold).sum();
sum / threshold as f64
} else {
weights[0]
}
}
/// Merge boundary edges from all workers
///
/// OPTIMIZATION: Pre-allocates collections, uses inline merge strategies
#[inline]
fn merge_edges(&self, reports: &[WorkerReport]) -> Result<Vec<MergedEdge>, MergeError> {
// OPTIMIZATION: Pre-allocate with expected capacity
let estimated_edges = reports.len() * EXPECTED_EDGES_PER_WORKER;
let mut edge_data: HashMap<(String, String), Vec<f64>> =
HashMap::with_capacity(estimated_edges);
// Collect all edge data
for report in reports {
for edge in &report.boundary_edges {
if edge.is_boundary {
// Normalize edge key (smaller first for undirected)
// OPTIMIZATION: Avoid unnecessary clones by checking order first
let key = if edge.source <= edge.target {
(edge.source.clone(), edge.target.clone())
} else {
(edge.target.clone(), edge.source.clone())
};
edge_data
.entry(key)
.or_insert_with(|| Vec::with_capacity(reports.len()))
.push(edge.capacity);
}
}
}
// Merge each edge
// OPTIMIZATION: Pre-allocate result vector
let mut merged = Vec::with_capacity(edge_data.len());
for ((source, target), capacities) in edge_data {
let len = capacities.len();
let capacity = self.merge_capacities(&capacities, len);
merged.push(MergedEdge {
source,
target,
capacity,
report_count: len,
});
}
Ok(merged)
}
/// Merge capacities according to strategy
///
/// OPTIMIZATION: Inline function to avoid match overhead in loop
#[inline(always)]
fn merge_capacities(&self, capacities: &[f64], len: usize) -> f64 {
match self.strategy {
MergeStrategy::SimpleAverage | MergeStrategy::WeightedAverage => {
capacities.iter().sum::<f64>() / len as f64
}
MergeStrategy::Median => Self::compute_median(capacities.iter().copied()),
MergeStrategy::Maximum => capacities.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
MergeStrategy::ByzantineFaultTolerant => {
Self::compute_bft_weight(capacities.iter().copied(), len)
}
}
}
/// Estimate global min-cut from local values
///
/// OPTIMIZATION: Single-pass computation
#[inline]
fn estimate_global_mincut(&self, reports: &[WorkerReport]) -> f64 {
// OPTIMIZATION: Single pass for both local_sum and boundary_count
let (local_sum, boundary_count) = reports.iter().fold((0.0, 0usize), |(sum, count), r| {
let bc = r.boundary_edges.iter().filter(|e| e.is_boundary).count();
(sum + r.local_mincut, count + bc)
});
// Simple estimate: local sum adjusted by boundary factor
// OPTIMIZATION: Pre-compute constant multiplier
let boundary_factor = 1.0 / (1.0 + (boundary_count as f64 * 0.01));
local_sum * boundary_factor
}
/// Aggregate confidence from all workers
///
/// OPTIMIZATION: Inline, uses fold for single-pass computation
#[inline]
fn aggregate_confidence(&self, reports: &[WorkerReport]) -> f64 {
let len = reports.len();
if len == 0 {
return 0.0;
}
match self.strategy {
MergeStrategy::ByzantineFaultTolerant => {
// Conservative: use minimum of top 2/3
let mut confidences: Vec<f64> = Vec::with_capacity(len);
confidences.extend(reports.iter().map(|r| r.confidence));
// Sort descending
confidences
.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let threshold = (len * 2) / 3;
confidences
.get(threshold.saturating_sub(1))
.copied()
.unwrap_or(0.0)
}
_ => {
// Geometric mean using log-sum for numerical stability
// OPTIMIZATION: Use log-sum-exp pattern to avoid overflow
let log_sum: f64 = reports.iter().map(|r| r.confidence.ln()).sum();
(log_sum / len as f64).exp()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_report(tile_id: TileId, epoch: Epoch) -> WorkerReport {
let mut report = WorkerReport::new(tile_id, epoch);
report.add_node(NodeSummary {
id: "node1".to_string(),
weight: tile_id as f64 * 0.1,
edge_count: 5,
coherence: 0.9,
});
report.confidence = 0.95;
report.local_mincut = 1.0;
report
}
#[test]
fn test_merge_simple_average() {
let merger = ReportMerger::new(MergeStrategy::SimpleAverage);
let reports = vec![
create_test_report(1, 0),
create_test_report(2, 0),
create_test_report(3, 0),
];
let merged = merger.merge(&reports).unwrap();
assert_eq!(merged.worker_count, 3);
assert_eq!(merged.epoch, 0);
let node = merged.super_nodes.get("node1").unwrap();
// Average of 0.1, 0.2, 0.3 = 0.2
assert!((node.weight - 0.2).abs() < 0.001);
}
#[test]
fn test_merge_empty_reports() {
let merger = ReportMerger::new(MergeStrategy::SimpleAverage);
let result = merger.merge(&[]);
assert!(matches!(result, Err(MergeError::EmptyReports)));
}
#[test]
fn test_merge_conflicting_epochs() {
let merger = ReportMerger::new(MergeStrategy::SimpleAverage);
let reports = vec![create_test_report(1, 0), create_test_report(2, 1)];
let result = merger.merge(&reports);
assert!(matches!(result, Err(MergeError::ConflictingEpochs)));
}
#[test]
fn test_state_hash_computation() {
let mut report = create_test_report(1, 0);
report.compute_state_hash();
assert_ne!(report.state_hash, [0u8; 32]);
}
}