ablation: add Group A (aggregator) + Group B (architecture) infrastructure
Extends MixedCFMConfig with 5 backwards-compatible flags (use_flow_token,
n_packet_tokens, disc_as_cont, cont_as_disc + cont_n_bins) so existing
JANUS-full checkpoints load with 0 missing/unexpected keys.
Adds:
- 60 ablation training configs (5 variants × 4 datasets × 3 seeds)
- scripts/ablation/{generate_configs.py, run_groupB.sh, run_cross_groupB.sh,
smoke_test.sh} — config generation + GPU drivers
- scripts/aggregate/aggregate_ablation{,_cross,_cross_B}.py — produces
within-dataset and cross-dataset (3×3) ablation tables with 3-seed mean
± 95% t-CI plus optional paired DeLong p-values
README updated with ablation section pointing at
artifacts/ablation/ABLATION_SUMMARY.md.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -19,6 +19,7 @@ AdaLNBlock = _unified.AdaLNBlock
|
||||
SinusoidalTimeEmb = _unified.SinusoidalTimeEmb
|
||||
_sinkhorn_coupling = _unified._sinkhorn_coupling
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixedCFMConfig:
|
||||
T: int = 64
|
||||
@@ -40,6 +41,11 @@ class MixedCFMConfig:
|
||||
lambda_disc: float = 1.0
|
||||
disc_path: str = 'uniform'
|
||||
disc_embed_scale: float = 1.0
|
||||
# ---- B-group ablation flags (defaults preserve JANUS-full behavior) ----
|
||||
use_flow_token: bool = True # B1: False removes the [FLOW] token
|
||||
n_packet_tokens: int = -1 # B2: 0 removes packet tokens entirely; -1 = use cfg.T
|
||||
disc_as_cont: bool = False # B3: feed 6 disc bits through CFM head as continuous values
|
||||
cont_as_disc: bool = False # B4: quantize 3 cont channels into n_disc_classes bins (mask-pred only)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if len(self.cont_pkt_idx) != self.n_cont_pkt:
|
||||
@@ -48,10 +54,13 @@ class MixedCFMConfig:
|
||||
raise ValueError('disc_pkt_idx length mismatch n_disc_pkt')
|
||||
if self.disc_path != 'uniform':
|
||||
raise NotImplementedError(f'disc_path={self.disc_path}')
|
||||
if self.disc_as_cont and self.cont_as_disc:
|
||||
raise ValueError('disc_as_cont and cont_as_disc are mutually exclusive')
|
||||
|
||||
|
||||
class MixedVelocity(nn.Module):
|
||||
|
||||
def __init__(self, token_dim: int, seq_len: int, n_disc: int, n_classes: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None:
|
||||
def __init__(self, token_dim: int, seq_len: int, n_disc: int, n_classes: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None, has_flow_token: bool=True) -> None:
|
||||
super().__init__()
|
||||
if reference_mode not in (None, 'causal_packets', 'causal_all'):
|
||||
raise ValueError(f'reference_mode={reference_mode!r}')
|
||||
@@ -60,6 +69,7 @@ class MixedVelocity(nn.Module):
|
||||
self.n_disc = n_disc
|
||||
self.n_classes = n_classes
|
||||
self.reference_mode = reference_mode
|
||||
self.has_flow_token = has_flow_token
|
||||
self.input_proj = nn.Linear(token_dim, d_model)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
|
||||
self.type_emb = nn.Embedding(2, d_model)
|
||||
@@ -70,12 +80,15 @@ class MixedVelocity(nn.Module):
|
||||
self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)])
|
||||
self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.head_v = nn.Linear(d_model, token_dim)
|
||||
self.head_disc = nn.Linear(d_model, n_disc * n_classes)
|
||||
# head_disc only meaningful when n_disc > 0
|
||||
out_disc = max(n_disc * n_classes, 1)
|
||||
self.head_disc = nn.Linear(d_model, out_disc)
|
||||
for layer in (self.head_v, self.head_disc):
|
||||
nn.init.zeros_(layer.weight)
|
||||
nn.init.zeros_(layer.bias)
|
||||
type_ids = torch.ones(seq_len, dtype=torch.long)
|
||||
type_ids[0] = 0
|
||||
if has_flow_token and seq_len >= 1:
|
||||
type_ids[0] = 0
|
||||
self.register_buffer('type_ids', type_ids, persistent=False)
|
||||
|
||||
def _attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None:
|
||||
@@ -83,8 +96,11 @@ class MixedVelocity(nn.Module):
|
||||
return None
|
||||
if self.reference_mode == 'causal_packets':
|
||||
mask = torch.zeros((L, L), dtype=torch.bool, device=device)
|
||||
if L > 1:
|
||||
mask[1:, 1:] = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1)
|
||||
offset = 1 if self.has_flow_token else 0
|
||||
if L > offset:
|
||||
M = L - offset
|
||||
if M > 1:
|
||||
mask[offset:, offset:] = torch.triu(torch.ones(M, M, dtype=torch.bool, device=device), diagonal=1)
|
||||
return mask
|
||||
return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
|
||||
|
||||
@@ -100,143 +116,339 @@ class MixedVelocity(nn.Module):
|
||||
h = block(h, cond, key_padding_mask, attn_mask=attn_mask)
|
||||
h = self.out_norm(h)
|
||||
v = self.head_v(h)
|
||||
d = self.head_disc(h).view(B, L, self.n_disc, self.n_classes)
|
||||
if self.n_disc > 0:
|
||||
d = self.head_disc(h).view(B, L, self.n_disc, self.n_classes)
|
||||
else:
|
||||
d = h.new_zeros((B, L, 0, self.n_classes))
|
||||
return (v, d)
|
||||
|
||||
|
||||
class MixedTokenCFM(nn.Module):
|
||||
|
||||
def __init__(self, cfg: MixedCFMConfig) -> None:
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
cont_size = cfg.n_cont_pkt + cfg.n_disc_pkt
|
||||
# Effective packet count (B2: n_packet_tokens=0 → no packets)
|
||||
self.eff_T = cfg.T if cfg.n_packet_tokens < 0 else int(cfg.n_packet_tokens)
|
||||
if not cfg.use_flow_token and self.eff_T == 0:
|
||||
raise ValueError('cannot disable both FLOW token and packet tokens')
|
||||
# Effective per-packet feature split
|
||||
if cfg.disc_as_cont:
|
||||
# B3: 9 cont, 0 disc (CFM head only)
|
||||
self.eff_n_cont = cfg.n_cont_pkt + cfg.n_disc_pkt
|
||||
self.eff_n_disc = 0
|
||||
elif cfg.cont_as_disc:
|
||||
# B4: 0 cont, 9 disc (mask-pred head only)
|
||||
self.eff_n_cont = 0
|
||||
self.eff_n_disc = cfg.n_cont_pkt + cfg.n_disc_pkt
|
||||
else:
|
||||
self.eff_n_cont = cfg.n_cont_pkt
|
||||
self.eff_n_disc = cfg.n_disc_pkt
|
||||
cont_size = self.eff_n_cont + self.eff_n_disc
|
||||
# Token layout: [type_flag(1) | flow_dim or cont_size]
|
||||
self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cont_size)
|
||||
if self.token_dim < 1 + max(cfg.flow_dim, cont_size):
|
||||
raise ValueError('token_dim too small')
|
||||
self.seq_len = cfg.T + 1
|
||||
self.velocity = MixedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, n_disc=cfg.n_disc_pkt, n_classes=cfg.n_disc_classes, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode)
|
||||
self.seq_len = (1 if cfg.use_flow_token else 0) + self.eff_T
|
||||
self.velocity = MixedVelocity(
|
||||
token_dim=self.token_dim, seq_len=self.seq_len,
|
||||
n_disc=self.eff_n_disc, n_classes=cfg.n_disc_classes,
|
||||
d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads,
|
||||
mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim,
|
||||
reference_mode=cfg.reference_mode, has_flow_token=cfg.use_flow_token,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# token assembly #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _embed_disc(self, x_disc_int: torch.Tensor) -> torch.Tensor:
|
||||
n = self.cfg.n_disc_classes
|
||||
s = self.cfg.disc_embed_scale
|
||||
return (x_disc_int.float() - 0.5) * s
|
||||
if n <= 1:
|
||||
return x_disc_int.float() * 0.0
|
||||
# Map integers in [0, n-1] to centered floats in [-s/2, +s/2].
|
||||
# Backwards-compatible with old (x - 0.5)*s formula when n=2.
|
||||
return (x_disc_int.float() / (n - 1) - 0.5) * s
|
||||
|
||||
def _flow_dim(self) -> int:
|
||||
return self.cfg.flow_dim
|
||||
|
||||
def build_tokens(self, flow: torch.Tensor, packets_cont: torch.Tensor, x_disc_t_int: torch.Tensor) -> torch.Tensor:
|
||||
(B, T, Cp) = packets_cont.shape
|
||||
assert T == self.cfg.T and Cp == self.cfg.n_cont_pkt
|
||||
z = packets_cont.new_zeros((B, T + 1, self.token_dim))
|
||||
z[:, 0, 0] = -1.0
|
||||
z[:, 0, 1:1 + self.cfg.flow_dim] = flow
|
||||
z[:, 1:, 0] = 1.0
|
||||
z[:, 1:, 1:1 + self.cfg.n_cont_pkt] = packets_cont
|
||||
z[:, 1:, 1 + self.cfg.n_cont_pkt:1 + self.cfg.n_cont_pkt + self.cfg.n_disc_pkt] = self._embed_disc(x_disc_t_int)
|
||||
"""Assemble [B, seq_len, token_dim].
|
||||
|
||||
packets_cont: [B, eff_T, eff_n_cont] (may be empty in last dim)
|
||||
x_disc_t_int: [B, eff_T, eff_n_disc] integer ids in [0, n_disc_classes-1]
|
||||
"""
|
||||
B = flow.shape[0]
|
||||
device = flow.device
|
||||
T = self.eff_T
|
||||
z = flow.new_zeros((B, self.seq_len, self.token_dim))
|
||||
cur = 0
|
||||
if self.cfg.use_flow_token:
|
||||
z[:, 0, 0] = -1.0 # type flag
|
||||
z[:, 0, 1:1 + self._flow_dim()] = flow
|
||||
cur = 1
|
||||
if T > 0:
|
||||
z[:, cur:cur + T, 0] = 1.0 # type flag
|
||||
base = 1
|
||||
if self.eff_n_cont > 0:
|
||||
z[:, cur:cur + T, base:base + self.eff_n_cont] = packets_cont
|
||||
base += self.eff_n_cont
|
||||
if self.eff_n_disc > 0:
|
||||
z[:, cur:cur + T, base:base + self.eff_n_disc] = self._embed_disc(x_disc_t_int)
|
||||
return z
|
||||
|
||||
def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
||||
B = lens.shape[0]
|
||||
idx = torch.arange(self.cfg.T, device=lens.device)[None, :]
|
||||
packet_real = idx < lens[:, None]
|
||||
real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1)
|
||||
device = lens.device
|
||||
T = self.eff_T
|
||||
pieces = []
|
||||
if self.cfg.use_flow_token:
|
||||
pieces.append(torch.ones(B, 1, dtype=torch.bool, device=device))
|
||||
if T > 0:
|
||||
idx = torch.arange(T, device=device)[None, :]
|
||||
pieces.append(idx < lens[:, None])
|
||||
real = torch.cat(pieces, dim=1) if pieces else torch.ones(B, 0, dtype=torch.bool, device=device)
|
||||
return ~real
|
||||
|
||||
def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
||||
return (~self.key_padding_mask(lens)).float()
|
||||
|
||||
def compute_loss(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, *, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]:
|
||||
(B, T, _) = packets_cont.shape
|
||||
device = packets_cont.device
|
||||
# ------------------------------------------------------------------ #
|
||||
# B4 helper: quantize cont -> integer bins #
|
||||
# ------------------------------------------------------------------ #
|
||||
def quantize_cont(self, packets_cont: torch.Tensor, bin_edges: torch.Tensor) -> torch.Tensor:
|
||||
"""packets_cont [B, T, n_cont_orig] (already z-scored); bin_edges [n_cont_orig, n_classes-1]
|
||||
returns int64 [B, T, n_cont_orig] in [0, n_classes-1]."""
|
||||
B, T, C = packets_cont.shape
|
||||
out = torch.zeros((B, T, C), dtype=torch.long, device=packets_cont.device)
|
||||
for c in range(C):
|
||||
edges = bin_edges[c] # [n_classes-1]
|
||||
# bucketize: returns 0..n for n edges
|
||||
out[:, :, c] = torch.bucketize(packets_cont[:, :, c].contiguous(), edges)
|
||||
out.clamp_(0, self.cfg.n_disc_classes - 1)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Loss #
|
||||
# ------------------------------------------------------------------ #
|
||||
def compute_loss(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, *, return_components: bool=False, cont_bin_edges: torch.Tensor | None=None) -> torch.Tensor | dict[str, torch.Tensor]:
|
||||
cfg = self.cfg
|
||||
B = flow.shape[0]
|
||||
T = self.eff_T
|
||||
device = flow.device
|
||||
|
||||
# Resolve effective cont/disc tensors per ablation mode
|
||||
if cfg.disc_as_cont:
|
||||
# 9 cont = original 3 cont + 6 disc-as-float
|
||||
disc_as_cont_float = self._embed_disc(packets_disc) if T > 0 else None
|
||||
if T > 0:
|
||||
eff_cont = torch.cat([packets_cont, disc_as_cont_float], dim=-1) if cfg.n_cont_pkt > 0 else disc_as_cont_float
|
||||
else:
|
||||
eff_cont = packets_cont.new_zeros((B, 0, 0))
|
||||
eff_disc_int = torch.zeros((B, T, 0), dtype=torch.long, device=device)
|
||||
elif cfg.cont_as_disc:
|
||||
# 0 cont, 9 disc: quantize cont via supplied bin_edges
|
||||
if T > 0:
|
||||
if cont_bin_edges is None:
|
||||
raise ValueError('cont_as_disc requires cont_bin_edges')
|
||||
cont_int = self.quantize_cont(packets_cont, cont_bin_edges)
|
||||
eff_disc_int = torch.cat([cont_int, packets_disc.long()], dim=-1)
|
||||
else:
|
||||
eff_disc_int = torch.zeros((B, 0, self.eff_n_disc), dtype=torch.long, device=device)
|
||||
eff_cont = flow.new_zeros((B, T, 0))
|
||||
else:
|
||||
eff_cont = packets_cont if T > 0 else packets_cont.new_zeros((B, 0, cfg.n_cont_pkt))
|
||||
eff_disc_int = packets_disc.long() if T > 0 else torch.zeros((B, 0, cfg.n_disc_pkt), dtype=torch.long, device=device)
|
||||
|
||||
# Build x_1 (data tokens; mask-pred path uses zero ids for disc at packet positions during CFM regression)
|
||||
zero_disc = torch.zeros_like(eff_disc_int)
|
||||
x_1_cont = self.build_tokens(flow, eff_cont, zero_disc)
|
||||
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
x_1_cont = self.build_tokens(flow, packets_cont, torch.zeros_like(packets_disc))
|
||||
|
||||
x_0_cont = torch.randn_like(x_1_cont)
|
||||
if self.cfg.use_ot:
|
||||
|
||||
if cfg.use_ot:
|
||||
flat0 = (x_0_cont * mask[:, :, None]).reshape(B, -1)
|
||||
flat1 = (x_1_cont * mask[:, :, None]).reshape(B, -1)
|
||||
col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float()))
|
||||
x_1_cont = x_1_cont[col]
|
||||
packets_cont = packets_cont[col]
|
||||
eff_cont = eff_cont[col] if eff_cont.numel() > 0 else eff_cont
|
||||
eff_disc_int = eff_disc_int[col] if eff_disc_int.numel() > 0 else eff_disc_int
|
||||
packets_disc = packets_disc[col]
|
||||
flow = flow[col]
|
||||
lens = lens[col]
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
|
||||
t = torch.rand(B, device=device)
|
||||
x_t_cont = (1.0 - t[:, None, None]) * x_0_cont + t[:, None, None] * x_1_cont
|
||||
if self.cfg.sigma > 0:
|
||||
std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
|
||||
if cfg.sigma > 0:
|
||||
std = cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
|
||||
x_t_cont = x_t_cont + std * torch.randn_like(x_t_cont)
|
||||
target_cont = x_1_cont - x_0_cont
|
||||
u = torch.rand(B, T, self.cfg.n_disc_pkt, device=device)
|
||||
keep = u < t[:, None, None]
|
||||
rand_disc = torch.randint(0, self.cfg.n_disc_classes, packets_disc.shape, device=device)
|
||||
x_disc_t = torch.where(keep, packets_disc, rand_disc)
|
||||
disc_start = 1 + self.cfg.n_cont_pkt
|
||||
x_t_full = x_t_cont.clone()
|
||||
x_t_full[:, 1:, disc_start:disc_start + self.cfg.n_disc_pkt] = self._embed_disc(x_disc_t)
|
||||
|
||||
# Disc corruption schedule (mask-pred): keep fraction t of true labels
|
||||
if T > 0 and self.eff_n_disc > 0:
|
||||
u = torch.rand(B, T, self.eff_n_disc, device=device)
|
||||
keep = u < t[:, None, None]
|
||||
rand_disc = torch.randint(0, cfg.n_disc_classes, eff_disc_int.shape, device=device)
|
||||
x_disc_t = torch.where(keep, eff_disc_int, rand_disc)
|
||||
disc_start = (1 if cfg.use_flow_token else 0) + 0 # placeholder; overwritten below
|
||||
# Where in x_t_full do disc embeds go?
|
||||
# Within each packet token: [type(1) | cont(eff_n_cont) | disc(eff_n_disc) | pad...]
|
||||
disc_start_in_token = 1 + self.eff_n_cont
|
||||
cur_offset = 1 if cfg.use_flow_token else 0
|
||||
x_t_full = x_t_cont.clone()
|
||||
x_t_full[:, cur_offset:cur_offset + T, disc_start_in_token:disc_start_in_token + self.eff_n_disc] = self._embed_disc(x_disc_t)
|
||||
else:
|
||||
x_t_full = x_t_cont
|
||||
x_disc_t = eff_disc_int # unused
|
||||
keep = None
|
||||
|
||||
(v_pred, d_logits) = self.velocity(x_t_full, t, key_padding_mask=kpm)
|
||||
|
||||
# CFM regression loss on cont slots (mask out disc slots)
|
||||
v_err = (v_pred - target_cont).square()
|
||||
v_err[:, :, disc_start:disc_start + self.cfg.n_disc_pkt] = 0.0
|
||||
if T > 0 and self.eff_n_disc > 0:
|
||||
disc_start_in_token = 1 + self.eff_n_cont
|
||||
cur_offset = 1 if cfg.use_flow_token else 0
|
||||
v_err[:, cur_offset:cur_offset + T, disc_start_in_token:disc_start_in_token + self.eff_n_disc] = 0.0
|
||||
v_per_token = v_err.mean(dim=-1)
|
||||
per_sample = (v_per_token * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
L_cont = per_sample.mean()
|
||||
pkt_logits = d_logits[:, 1:]
|
||||
pkt_real = mask[:, 1:].bool()
|
||||
corrupt = ~keep & pkt_real[:, :, None]
|
||||
flat_logits = pkt_logits.reshape(-1, self.cfg.n_disc_classes)
|
||||
flat_targets = packets_disc.reshape(-1).long()
|
||||
flat_ce = F.cross_entropy(flat_logits, flat_targets, reduction='none')
|
||||
flat_ce = flat_ce.view(B, T, self.cfg.n_disc_pkt)
|
||||
flat_ce = flat_ce * corrupt.float()
|
||||
denom = corrupt.float().sum().clamp_min(1.0)
|
||||
L_disc = flat_ce.sum() / denom
|
||||
total = L_cont + self.cfg.lambda_disc * L_disc
|
||||
|
||||
# Mask-pred CE on corrupted disc positions
|
||||
if T > 0 and self.eff_n_disc > 0 and keep is not None:
|
||||
cur_offset = 1 if cfg.use_flow_token else 0
|
||||
pkt_logits = d_logits[:, cur_offset:cur_offset + T]
|
||||
pkt_real = mask[:, cur_offset:cur_offset + T].bool()
|
||||
corrupt = ~keep & pkt_real[:, :, None]
|
||||
flat_logits = pkt_logits.reshape(-1, cfg.n_disc_classes)
|
||||
flat_targets = eff_disc_int.reshape(-1).long()
|
||||
flat_ce = F.cross_entropy(flat_logits, flat_targets, reduction='none')
|
||||
flat_ce = flat_ce.view(B, T, self.eff_n_disc)
|
||||
flat_ce = flat_ce * corrupt.float()
|
||||
denom = corrupt.float().sum().clamp_min(1.0)
|
||||
L_disc = flat_ce.sum() / denom
|
||||
else:
|
||||
L_disc = L_cont.new_zeros(())
|
||||
|
||||
total = L_cont + cfg.lambda_disc * L_disc
|
||||
if return_components:
|
||||
return {'total': total, 'main': L_cont.detach(), 'aux_disc': L_disc.detach(), 'aux_flow': L_cont.new_zeros(()), 'aux_packet': L_cont.new_zeros(())}
|
||||
return {'total': total, 'main': L_cont.detach(), 'aux_disc': L_disc.detach(),
|
||||
'aux_flow': L_cont.new_zeros(()), 'aux_packet': L_cont.new_zeros(())}
|
||||
return total
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Scoring #
|
||||
# ------------------------------------------------------------------ #
|
||||
@torch.no_grad()
|
||||
def trajectory_metrics(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]:
|
||||
z = self.build_tokens(flow, packets_cont, packets_disc)
|
||||
def trajectory_metrics(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, n_steps: int=16, cont_bin_edges: torch.Tensor | None=None) -> dict[str, torch.Tensor]:
|
||||
cfg = self.cfg
|
||||
B = flow.shape[0]
|
||||
T = self.eff_T
|
||||
|
||||
# Build effective cont / disc tensors per ablation mode
|
||||
if cfg.disc_as_cont:
|
||||
disc_float = self._embed_disc(packets_disc) if T > 0 else None
|
||||
if T > 0:
|
||||
eff_cont = torch.cat([packets_cont, disc_float], dim=-1) if cfg.n_cont_pkt > 0 else disc_float
|
||||
else:
|
||||
eff_cont = packets_cont.new_zeros((B, 0, 0))
|
||||
eff_disc_int = torch.zeros((B, T, 0), dtype=torch.long, device=flow.device)
|
||||
elif cfg.cont_as_disc:
|
||||
if T > 0:
|
||||
if cont_bin_edges is None:
|
||||
raise ValueError('cont_as_disc requires cont_bin_edges at scoring time')
|
||||
cont_int = self.quantize_cont(packets_cont, cont_bin_edges)
|
||||
eff_disc_int = torch.cat([cont_int, packets_disc.long()], dim=-1)
|
||||
else:
|
||||
eff_disc_int = torch.zeros((B, 0, 0), dtype=torch.long, device=flow.device)
|
||||
eff_cont = flow.new_zeros((B, T, 0))
|
||||
else:
|
||||
eff_cont = packets_cont if T > 0 else packets_cont.new_zeros((B, 0, cfg.n_cont_pkt))
|
||||
eff_disc_int = packets_disc.long() if T > 0 else torch.zeros((B, 0, cfg.n_disc_pkt), dtype=torch.long, device=flow.device)
|
||||
|
||||
z = self.build_tokens(flow, eff_cont, eff_disc_int)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
dt = 1.0 / n_steps
|
||||
disc_start = 1 + self.cfg.n_cont_pkt
|
||||
disc_end = disc_start + self.cfg.n_disc_pkt
|
||||
disc_embed = z[:, 1:, disc_start:disc_end].clone()
|
||||
|
||||
# Disc embed slot bounds (within token vector) for "freeze disc during ODE"
|
||||
cur_offset = 1 if cfg.use_flow_token else 0
|
||||
disc_start_in_token = 1 + self.eff_n_cont
|
||||
disc_end_in_token = disc_start_in_token + self.eff_n_disc
|
||||
if self.eff_n_disc > 0 and T > 0:
|
||||
disc_embed = z[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token].clone()
|
||||
else:
|
||||
disc_embed = None
|
||||
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
(v, _) = self.velocity(z, t, key_padding_mask=kpm)
|
||||
v[:, :, disc_start:disc_end] = 0.0
|
||||
if self.eff_n_disc > 0 and T > 0:
|
||||
v[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token] = 0.0
|
||||
z = z - v * dt
|
||||
z[:, 1:, disc_start:disc_end] = disc_embed
|
||||
if disc_embed is not None:
|
||||
z[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token] = disc_embed
|
||||
|
||||
# Compute terminal-norm scores. Zero out the discrete embed slots so they don't pollute.
|
||||
z_real = z * mask[:, :, None]
|
||||
z_cont = z_real.clone()
|
||||
z_cont[:, 1:, disc_start:disc_end] = 0.0
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
terminal = z_cont.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
|
||||
terminal_flow = z_cont[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
|
||||
terminal_packet = (z_cont[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
|
||||
return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet}
|
||||
if self.eff_n_disc > 0 and T > 0:
|
||||
z_cont[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token] = 0.0
|
||||
|
||||
full_norm = z_cont.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
|
||||
out = {'terminal_norm': full_norm}
|
||||
if cfg.use_flow_token:
|
||||
out['terminal_flow'] = z_cont[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
|
||||
if T > 0:
|
||||
packet_count = mask[:, cur_offset:cur_offset + T].sum(dim=-1).clamp_min(1.0)
|
||||
out['terminal_packet'] = (z_cont[:, cur_offset:cur_offset + T] * mask[:, cur_offset:cur_offset + T, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def disc_nll_score(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
||||
(B, T, _) = packets_cont.shape
|
||||
device = packets_cont.device
|
||||
def disc_nll_score(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, cont_bin_edges: torch.Tensor | None=None) -> dict[str, torch.Tensor]:
|
||||
cfg = self.cfg
|
||||
B = flow.shape[0]
|
||||
T = self.eff_T
|
||||
device = flow.device
|
||||
if T == 0 or self.eff_n_disc == 0:
|
||||
return {} # no disc head to score
|
||||
|
||||
# Build effective disc int per mode
|
||||
if cfg.cont_as_disc:
|
||||
if cont_bin_edges is None:
|
||||
raise ValueError('cont_as_disc requires cont_bin_edges at scoring time')
|
||||
cont_int = self.quantize_cont(packets_cont, cont_bin_edges)
|
||||
eff_disc_int = torch.cat([cont_int, packets_disc.long()], dim=-1)
|
||||
eff_cont = flow.new_zeros((B, T, 0))
|
||||
ch_idx_list = list(cfg.cont_pkt_idx) + list(cfg.disc_pkt_idx)
|
||||
else:
|
||||
eff_disc_int = packets_disc.long()
|
||||
eff_cont = packets_cont
|
||||
ch_idx_list = list(cfg.disc_pkt_idx)
|
||||
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
z = self.build_tokens(flow, packets_cont, packets_disc)
|
||||
z = self.build_tokens(flow, eff_cont, eff_disc_int)
|
||||
t = torch.full((B,), float(t_eval), device=device)
|
||||
(_, d_logits) = self.velocity(z, t, key_padding_mask=kpm)
|
||||
pkt_logits = d_logits[:, 1:]
|
||||
flat_logits = pkt_logits.reshape(-1, self.cfg.n_disc_classes)
|
||||
flat_targets = packets_disc.reshape(-1).long()
|
||||
cur_offset = 1 if cfg.use_flow_token else 0
|
||||
pkt_logits = d_logits[:, cur_offset:cur_offset + T]
|
||||
flat_logits = pkt_logits.reshape(-1, cfg.n_disc_classes)
|
||||
flat_targets = eff_disc_int.reshape(-1).long()
|
||||
ce = F.cross_entropy(flat_logits, flat_targets, reduction='none')
|
||||
ce = ce.view(B, T, self.cfg.n_disc_pkt)
|
||||
pkt_real = mask[:, 1:].bool().float()
|
||||
ce = ce.view(B, T, self.eff_n_disc)
|
||||
pkt_real = mask[:, cur_offset:cur_offset + T].bool().float()
|
||||
per_sample = (ce.sum(dim=-1) * pkt_real).sum(dim=-1) / pkt_real.sum(dim=-1).clamp_min(1.0)
|
||||
per_ch = (ce * pkt_real[:, :, None]).sum(dim=1) / pkt_real.sum(dim=1).clamp_min(1.0)[:, None]
|
||||
out = {'disc_nll_total': per_sample}
|
||||
for (c, idx) in enumerate(self.cfg.disc_pkt_idx):
|
||||
for c, idx in enumerate(ch_idx_list):
|
||||
out[f'disc_nll_ch{idx}'] = per_ch[:, c]
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user