wifi-densepose/benchmarks/wiflow-std/remote/sweep/model_compact.py

333 lines
14 KiB
Python

"""Configurable compact variants of the WiFlow-STD pose model (ADR-152 efficiency sweep).
This is a parameterized copy of upstream models/{pose_model,tcn,convnet,attention}.py
(DY2434/WiFlow @ 06899d29, Apache-2.0). upstream/ is NOT modified. Deviations from
upstream, all forced by shrinking channels and documented per variant in run_sweep.py:
1. TCN grouped-conv groups: upstream hardcodes groups=20, which does not divide
the compact channel counts (e.g. 270, 135, 85). Rule here:
- groups_mode='gcd20': per-conv groups = gcd(channels, 20) (== 20 wherever
upstream's choice is valid, incl. the 540-ch input conv; falls back to the
largest common divisor with 20 otherwise).
- groups_mode='depthwise': groups = channels (tiny variant only).
2. Conv2d downsampling strides: upstream uses 4 stride-(1,2) blocks because
240/2^4 = 15 == n_keypoints. With smaller TCN output widths that would leave
<15 rows and AdaptiveAvgPool2d((15,1)) would duplicate rows across keypoints.
Rule: halve the width only while the result stays >= 15 (stride-2 blocks
first, stride-1 after). Full model: 240 -> 4 halvings = upstream exactly.
3. input_pw_groups (tiny only): the dense 540->c pointwise + residual downsample
in TCN block 1 cost 2*540*c params (a ~117k floor that alone exceeds the
tiny <100k budget). tiny groups these two convs (groups=4; 4 | gcd(540, 68)).
4. Decoder mid-channels: upstream 64->32; here c_last -> max(c_last // 2, 4).
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def tcn_groups(channels: int, mode: str) -> int:
if mode == 'depthwise':
return channels
if mode == 'gcd20':
return math.gcd(channels, 20)
raise ValueError(mode)
# ---------------------------------------------------------------- TCN (copy of tcn.py)
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super().__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :-self.chomp_size].contiguous()
class CompactGroupedTemporalBlock(nn.Module):
"""Upstream InnerGroupedTemporalBlock with parameterized groups."""
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding,
dropout=0.2, groups_mode='gcd20', pw_groups=1):
super().__init__()
g_in = tcn_groups(n_inputs, groups_mode)
g_out = tcn_groups(n_outputs, groups_mode)
self.groups = (g_in, g_out)
self.pw_groups = pw_groups
self.conv1_group = nn.Conv1d(n_inputs, n_inputs, kernel_size, stride=stride,
padding=padding, dilation=dilation,
groups=g_in, bias=False)
self.chomp1 = Chomp1d(padding) if padding > 0 else nn.Identity()
self.bn1_group = nn.BatchNorm1d(n_inputs)
self.relu1_group = nn.SiLU(inplace=True)
self.conv1_pw = nn.Conv1d(n_inputs, n_outputs, 1, groups=pw_groups, bias=False)
self.bn1_pw = nn.BatchNorm1d(n_outputs)
self.relu1_pw = nn.SiLU(inplace=True)
self.dropout1 = nn.Dropout(dropout)
self.conv2_group = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=1,
padding=padding, dilation=dilation,
groups=g_out, bias=False)
self.chomp2 = Chomp1d(padding) if padding > 0 else nn.Identity()
self.bn2_group = nn.BatchNorm1d(n_outputs)
self.relu2_group = nn.SiLU(inplace=True)
self.conv2_pw = nn.Conv1d(n_outputs, n_outputs, 1, bias=False)
self.bn2_pw = nn.BatchNorm1d(n_outputs)
self.relu2_pw = nn.SiLU(inplace=True)
self.dropout2 = nn.Dropout(dropout)
self.downsample = nn.Sequential(
nn.Conv1d(n_inputs, n_outputs, 1, groups=pw_groups, bias=False),
nn.BatchNorm1d(n_outputs)
) if n_inputs != n_outputs else nn.Identity()
def forward(self, x):
res = self.downsample(x)
out = self.conv1_group(x)
out = self.chomp1(out)
out = self.bn1_group(out)
out = self.relu1_group(out)
out = self.conv1_pw(out)
out = self.bn1_pw(out)
out = self.relu1_pw(out)
out = self.dropout1(out)
out = self.conv2_group(out)
out = self.chomp2(out)
out = self.bn2_group(out)
out = self.relu2_group(out)
out = self.conv2_pw(out)
out = self.bn2_pw(out)
out = self.relu2_pw(out)
out = self.dropout2(out)
return F.silu(out + res)
class CompactTemporalBlock(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.2,
groups_mode='gcd20', input_pw_groups=1):
super().__init__()
layers = []
for i, out_channels in enumerate(num_channels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
layers.append(CompactGroupedTemporalBlock(
in_channels, out_channels, kernel_size, stride=1,
dilation=dilation_size, padding=(kernel_size - 1) * dilation_size,
dropout=dropout, groups_mode=groups_mode,
pw_groups=input_pw_groups if i == 0 else 1))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# ------------------------------------------------------- Conv2d path (copy of convnet.py)
class AsymmetricConvBlock(nn.Module):
"""Upstream block with parameterized width stride (upstream: always (1,2))."""
def __init__(self, in_channels, out_channels, dropout=0.3, stride_w=2):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3),
stride=(1, stride_w), padding=(0, 1)),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
nn.BatchNorm2d(out_channels)
)
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=(1, stride_w), bias=False),
nn.BatchNorm2d(out_channels)
)
self.activation = nn.SiLU(inplace=True)
def forward(self, x):
return self.activation(self.block(x) + self.downsample(x))
class ConvBlock1(nn.Module):
def __init__(self, in_channels, out_channels, dropout=0.3):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
nn.BatchNorm2d(out_channels)
)
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)
)
self.activation = nn.SiLU(inplace=True)
def forward(self, x):
return self.activation(self.block(x) + self.downsample(x))
# ----------------------------------------------------- attention (verbatim attention.py)
class AxialAttention(nn.Module):
def __init__(self, in_planes, out_planes, groups=8, stride=1, bias=False, width=False):
assert (in_planes % groups == 0) and (out_planes % groups == 0)
super().__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.groups = groups
self.group_planes = out_planes // groups
self.stride = stride
self.bias = bias
self.width = width
self.qkv_transform = nn.Conv1d(in_planes, out_planes * 3, kernel_size=1,
stride=1, padding=0, bias=False)
self.bn_qkv = nn.BatchNorm1d(out_planes * 3)
self.bn_similarity = nn.BatchNorm2d(groups)
self.bn_output = nn.BatchNorm1d(out_planes)
if stride > 1:
self.pooling = nn.AvgPool2d(stride, stride=stride)
nn.init.normal_(self.qkv_transform.weight.data, 0, math.sqrt(1. / self.in_planes))
def forward(self, x):
if self.width:
x = x.permute(0, 2, 1, 3)
else:
x = x.permute(0, 3, 1, 2)
N, W, C, H = x.shape
x = x.contiguous().view(N * W, C, H)
qkv = self.bn_qkv(self.qkv_transform(x))
qkv = qkv.reshape(N * W, 3, self.out_planes, H).permute(1, 0, 2, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q.reshape(N * W, self.groups, self.group_planes, H)
k = k.reshape(N * W, self.groups, self.group_planes, H)
v = v.reshape(N * W, self.groups, self.group_planes, H)
qk = torch.einsum('bgci, bgcj->bgij', q, k)
qk = self.bn_similarity(qk)
similarity = F.softmax(qk, dim=-1)
sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
sv = sv.reshape(N * W, self.out_planes, H)
out = self.bn_output(sv)
out = out.view(N, W, self.out_planes, H)
if self.width:
out = out.permute(0, 2, 1, 3)
else:
out = out.permute(0, 2, 3, 1)
if self.stride > 1:
out = self.pooling(out)
return out
class DualAxialAttention(nn.Module):
def __init__(self, in_planes, out_planes, groups=8, stride=1, bias=False):
super().__init__()
self.width_axis = AxialAttention(in_planes, out_planes, groups, stride, bias, width=True)
self.height_axis = AxialAttention(out_planes, out_planes, groups, stride, bias, width=False)
def forward(self, x):
return self.height_axis(self.width_axis(x))
# --------------------------------------------------------------- full model
def compute_strides(width: int, n_blocks: int, target: int = 15):
"""Halve width while result stays >= target (upstream: 240 -> 4 halvings -> 15)."""
strides = []
for _ in range(n_blocks):
nxt = (width + 1) // 2 # conv k=3 s=2 p=1: out = ceil(in/2)
if nxt >= target:
strides.append(2)
width = nxt
else:
strides.append(1)
return strides, width
class CompactWiFlowPoseModel(nn.Module):
"""Parameterized upstream WiFlowPoseModel.
Upstream config == tcn_channels=[540,440,340,240], conv_channels=[8,16,32,64],
attn_groups=8, groups_mode='gcd20' (gcd(c,20)==20 for all upstream channels),
input_pw_groups=1 -> identical architecture, 2,225,042 params.
"""
def __init__(self, tcn_channels, conv_channels, attn_groups,
groups_mode='gcd20', input_pw_groups=1, dropout=0.3,
num_subcarriers=540, num_keypoints=15):
super().__init__()
self.tcn = CompactTemporalBlock(
num_inputs=num_subcarriers, num_channels=tcn_channels, kernel_size=3,
dropout=dropout, groups_mode=groups_mode, input_pw_groups=input_pw_groups)
self.up = ConvBlock1(1, conv_channels[0])
strides, self.final_width = compute_strides(
tcn_channels[-1], len(conv_channels), target=num_keypoints)
self.conv_strides = strides
self.residual_blocks = nn.ModuleList()
in_channels = conv_channels[0]
for out_channels, s in zip(conv_channels, strides):
self.residual_blocks.append(
AsymmetricConvBlock(in_channels, out_channels, stride_w=s))
in_channels = out_channels
c_last = conv_channels[-1]
self.attention = DualAxialAttention(c_last, c_last, groups=attn_groups)
c_mid = max(c_last // 2, 4)
self.decoder = nn.Sequential(
nn.Conv2d(c_last, c_mid, kernel_size=3, padding=1),
nn.BatchNorm2d(c_mid),
nn.SiLU(inplace=True),
nn.Conv2d(c_mid, 2, kernel_size=1),
nn.BatchNorm2d(2),
nn.SiLU(inplace=True)
)
self.avg_pool = nn.AdaptiveAvgPool2d((num_keypoints, 1))
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# [B, 540, 20]
x = self.tcn(x) # [B, C_tcn, 20]
x = x.transpose(1, 2).unsqueeze(1) # [B, 1, 20, C_tcn]
x = self.up(x)
for block in self.residual_blocks:
x = block(x) # [B, C_conv, 20, W']
x = x.permute(0, 1, 3, 2) # [B, C_conv, W', 20]
x = self.attention(x)
x = self.decoder(x) # [B, 2, W', 20]
x = self.avg_pool(x).squeeze(-1) # [B, 2, 15]
return x.transpose(1, 2) # [B, 15, 2]
def describe(model: 'CompactWiFlowPoseModel'):
params = sum(p.numel() for p in model.parameters())
tcn_g = [blk.groups for blk in model.tcn.network]
return {'params': params, 'tcn_groups_per_block': tcn_g,
'conv_strides': model.conv_strides, 'final_width': model.final_width}