301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""Modality translation network for WiFi-DensePose system."""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Dict, Any, List
|
|
|
|
|
|
class ModalityTranslationError(Exception):
|
|
"""Exception raised for modality translation errors."""
|
|
pass
|
|
|
|
|
|
class ModalityTranslationNetwork(nn.Module):
|
|
"""Neural network for translating CSI data to visual feature space."""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
"""Initialize modality translation network.
|
|
|
|
Args:
|
|
config: Configuration dictionary with network parameters
|
|
"""
|
|
super().__init__()
|
|
|
|
self._validate_config(config)
|
|
self.config = config
|
|
|
|
self.input_channels = config['input_channels']
|
|
self.hidden_channels = config['hidden_channels']
|
|
self.output_channels = config['output_channels']
|
|
self.kernel_size = config.get('kernel_size', 3)
|
|
self.stride = config.get('stride', 1)
|
|
self.padding = config.get('padding', 1)
|
|
self.dropout_rate = config.get('dropout_rate', 0.1)
|
|
self.activation = config.get('activation', 'relu')
|
|
self.normalization = config.get('normalization', 'batch')
|
|
self.use_attention = config.get('use_attention', False)
|
|
self.attention_heads = config.get('attention_heads', 8)
|
|
|
|
# Encoder: CSI -> Feature space
|
|
self.encoder = self._build_encoder()
|
|
|
|
# Decoder: Feature space -> Visual-like features
|
|
self.decoder = self._build_decoder()
|
|
|
|
# Attention mechanism
|
|
if self.use_attention:
|
|
self.attention = self._build_attention()
|
|
|
|
# Initialize weights
|
|
self._initialize_weights()
|
|
|
|
def _validate_config(self, config: Dict[str, Any]):
|
|
"""Validate configuration parameters."""
|
|
required_fields = ['input_channels', 'hidden_channels', 'output_channels']
|
|
for field in required_fields:
|
|
if field not in config:
|
|
raise ValueError(f"Missing required field: {field}")
|
|
|
|
if config['input_channels'] <= 0:
|
|
raise ValueError("input_channels must be positive")
|
|
|
|
if not config['hidden_channels'] or len(config['hidden_channels']) == 0:
|
|
raise ValueError("hidden_channels must be a non-empty list")
|
|
|
|
if config['output_channels'] <= 0:
|
|
raise ValueError("output_channels must be positive")
|
|
|
|
def _build_encoder(self) -> nn.ModuleList:
|
|
"""Build encoder network."""
|
|
layers = nn.ModuleList()
|
|
|
|
# Initial convolution
|
|
in_channels = self.input_channels
|
|
|
|
for i, out_channels in enumerate(self.hidden_channels):
|
|
layer_block = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels,
|
|
kernel_size=self.kernel_size,
|
|
stride=self.stride if i == 0 else 2,
|
|
padding=self.padding),
|
|
self._get_normalization(out_channels),
|
|
self._get_activation(),
|
|
nn.Dropout2d(self.dropout_rate)
|
|
)
|
|
layers.append(layer_block)
|
|
in_channels = out_channels
|
|
|
|
return layers
|
|
|
|
def _build_decoder(self) -> nn.ModuleList:
|
|
"""Build decoder network."""
|
|
layers = nn.ModuleList()
|
|
|
|
# Start with the last hidden channel size
|
|
in_channels = self.hidden_channels[-1]
|
|
|
|
# Progressive upsampling (reverse of encoder)
|
|
for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])):
|
|
layer_block = nn.Sequential(
|
|
nn.ConvTranspose2d(in_channels, out_channels,
|
|
kernel_size=self.kernel_size,
|
|
stride=2,
|
|
padding=self.padding,
|
|
output_padding=1),
|
|
self._get_normalization(out_channels),
|
|
self._get_activation(),
|
|
nn.Dropout2d(self.dropout_rate)
|
|
)
|
|
layers.append(layer_block)
|
|
in_channels = out_channels
|
|
|
|
# Final output layer
|
|
final_layer = nn.Sequential(
|
|
nn.Conv2d(in_channels, self.output_channels,
|
|
kernel_size=self.kernel_size,
|
|
padding=self.padding),
|
|
nn.Tanh() # Normalize output
|
|
)
|
|
layers.append(final_layer)
|
|
|
|
return layers
|
|
|
|
def _get_normalization(self, channels: int) -> nn.Module:
|
|
"""Get normalization layer."""
|
|
if self.normalization == 'batch':
|
|
return nn.BatchNorm2d(channels)
|
|
elif self.normalization == 'instance':
|
|
return nn.InstanceNorm2d(channels)
|
|
elif self.normalization == 'layer':
|
|
return nn.GroupNorm(1, channels)
|
|
else:
|
|
return nn.Identity()
|
|
|
|
def _get_activation(self) -> nn.Module:
|
|
"""Get activation function."""
|
|
if self.activation == 'relu':
|
|
return nn.ReLU(inplace=True)
|
|
elif self.activation == 'leaky_relu':
|
|
return nn.LeakyReLU(0.2, inplace=True)
|
|
elif self.activation == 'gelu':
|
|
return nn.GELU()
|
|
else:
|
|
return nn.ReLU(inplace=True)
|
|
|
|
def _build_attention(self) -> nn.Module:
|
|
"""Build attention mechanism."""
|
|
return nn.MultiheadAttention(
|
|
embed_dim=self.hidden_channels[-1],
|
|
num_heads=self.attention_heads,
|
|
dropout=self.dropout_rate,
|
|
batch_first=True
|
|
)
|
|
|
|
def _initialize_weights(self):
|
|
"""Initialize network weights."""
|
|
for m in self.modules():
|
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Forward pass through the network.
|
|
|
|
Args:
|
|
x: Input CSI tensor of shape (batch_size, channels, height, width)
|
|
|
|
Returns:
|
|
Translated features tensor
|
|
"""
|
|
# Validate input shape
|
|
if x.shape[1] != self.input_channels:
|
|
raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
|
|
|
# Encode CSI data
|
|
encoded_features = self.encode(x)
|
|
|
|
# Decode to visual-like features
|
|
decoded = self.decode(encoded_features)
|
|
|
|
return decoded
|
|
|
|
def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
"""Encode input through encoder layers.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
|
|
Returns:
|
|
List of feature maps from each encoder layer
|
|
"""
|
|
features = []
|
|
current = x
|
|
|
|
for layer in self.encoder:
|
|
current = layer(current)
|
|
features.append(current)
|
|
|
|
return features
|
|
|
|
def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor:
|
|
"""Decode features through decoder layers.
|
|
|
|
Args:
|
|
encoded_features: List of encoded feature maps
|
|
|
|
Returns:
|
|
Decoded output tensor
|
|
"""
|
|
# Start with the last encoded feature
|
|
current = encoded_features[-1]
|
|
|
|
# Apply attention if enabled
|
|
if self.use_attention:
|
|
batch_size, channels, height, width = current.shape
|
|
# Reshape for attention: (batch, seq_len, embed_dim)
|
|
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
|
attended, _ = self.attention(current_flat, current_flat, current_flat)
|
|
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
|
|
|
# Apply decoder layers
|
|
for layer in self.decoder:
|
|
current = layer(current)
|
|
|
|
return current
|
|
|
|
def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor:
|
|
"""Compute translation loss between predicted and target features.
|
|
|
|
Args:
|
|
predicted: Predicted feature tensor
|
|
target: Target feature tensor
|
|
loss_type: Type of loss ('mse', 'l1', 'smooth_l1')
|
|
|
|
Returns:
|
|
Computed loss tensor
|
|
"""
|
|
if loss_type == 'mse':
|
|
return F.mse_loss(predicted, target)
|
|
elif loss_type == 'l1':
|
|
return F.l1_loss(predicted, target)
|
|
elif loss_type == 'smooth_l1':
|
|
return F.smooth_l1_loss(predicted, target)
|
|
else:
|
|
return F.mse_loss(predicted, target)
|
|
|
|
def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]:
|
|
"""Get statistics of feature tensor.
|
|
|
|
Args:
|
|
features: Feature tensor to analyze
|
|
|
|
Returns:
|
|
Dictionary of feature statistics
|
|
"""
|
|
with torch.no_grad():
|
|
return {
|
|
'mean': features.mean().item(),
|
|
'std': features.std().item(),
|
|
'min': features.min().item(),
|
|
'max': features.max().item(),
|
|
'sparsity': (features == 0).float().mean().item()
|
|
}
|
|
|
|
def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]:
|
|
"""Get intermediate features for visualization.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
|
|
Returns:
|
|
Dictionary containing intermediate features
|
|
"""
|
|
result = {}
|
|
|
|
# Get encoder features
|
|
encoder_features = self.encode(x)
|
|
result['encoder_features'] = encoder_features
|
|
|
|
# Get decoder features
|
|
decoder_features = []
|
|
current = encoder_features[-1]
|
|
|
|
if self.use_attention:
|
|
batch_size, channels, height, width = current.shape
|
|
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
|
attended, attention_weights = self.attention(current_flat, current_flat, current_flat)
|
|
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
|
result['attention_weights'] = attention_weights
|
|
|
|
for layer in self.decoder:
|
|
current = layer(current)
|
|
decoder_features.append(current)
|
|
|
|
result['decoder_features'] = decoder_features
|
|
|
|
return result |