from __future__ import annotations import math from dataclasses import dataclass import torch import torch.nn as nn from torchdiffeq import odeint @torch.no_grad() def _sinkhorn_coupling(C: torch.Tensor, reg: float=0.05, n_iter: int=20) -> torch.Tensor: C = C.float() log_k = -C / reg B = C.shape[0] log_u = torch.zeros(B, device=C.device) log_v = torch.zeros(B, device=C.device) for _ in range(n_iter): log_v = -torch.logsumexp(log_k + log_u.unsqueeze(1), dim=0) log_u = -torch.logsumexp(log_k + log_v.unsqueeze(0), dim=1) log_p = log_u.unsqueeze(1) + log_k + log_v.unsqueeze(0) return log_p.argmax(dim=1) class SinusoidalTimeEmb(nn.Module): def __init__(self, dim: int) -> None: super().__init__() if dim % 2 != 0: raise ValueError('time embedding dimension must be even') self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.dim // 2 freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half - 1, 1)) args = t[:, None] * freqs[None, :] return torch.cat([args.sin(), args.cos()], dim=-1) class AdaLNBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, mlp_ratio: float, cond_dim: int) -> None: super().__init__() self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False) self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False) hidden = int(d_model * mlp_ratio) self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, d_model)) self.cond_proj = nn.Linear(cond_dim, 6 * d_model) nn.init.zeros_(self.cond_proj.weight) nn.init.zeros_(self.cond_proj.bias) @staticmethod def _modulate(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: return x * (1.0 + gamma[:, None, :]) + beta[:, None, :] def forward(self, x: torch.Tensor, cond: torch.Tensor, key_padding_mask: torch.Tensor | None, attn_mask: torch.Tensor | None=None) -> torch.Tensor: (g1, b1, a1, g2, b2, a2) = self.cond_proj(cond).chunk(6, dim=-1) h = self._modulate(self.norm1(x), g1, b1) (attn_out, _) = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False) x = x + a1[:, None, :] * attn_out h = self._modulate(self.norm2(x), g2, b2) return x + a2[:, None, :] * self.mlp(h) class UnifiedVelocity(nn.Module): def __init__(self, token_dim: int, seq_len: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None: super().__init__() if reference_mode not in (None, 'independent_token', 'block_diagonal', 'causal_packets', 'causal_all'): raise ValueError(f'unknown reference_mode={reference_mode!r}') self.token_dim = token_dim self.seq_len = seq_len self.reference_mode = reference_mode self.input_proj = nn.Linear(token_dim, d_model) self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model)) self.type_emb = nn.Embedding(2, d_model) nn.init.trunc_normal_(self.pos_emb, std=0.02) nn.init.normal_(self.type_emb.weight, std=0.02) self.time_emb = SinusoidalTimeEmb(time_dim) self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model)) self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)]) self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False) self.out = nn.Linear(d_model, token_dim) nn.init.zeros_(self.out.weight) nn.init.zeros_(self.out.bias) type_ids = torch.ones(seq_len, dtype=torch.long) type_ids[0] = 0 self.register_buffer('type_ids', type_ids, persistent=False) def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None, attn_mask_override: torch.Tensor | None=None) -> torch.Tensor: (B, L, _) = x.shape if L > self.seq_len: raise ValueError(f'sequence length {L} exceeds configured {self.seq_len}') if t.dim() == 0: t = t.expand(B) h = self.input_proj(x) h = h + self.pos_emb[:, :L, :] h = h + self.type_emb(self.type_ids[:L])[None, :, :] cond = self.cond_mlp(self.time_emb(t)) if attn_mask_override is not None: attn_mask = attn_mask_override else: attn_mask = self._reference_attn_mask(L, x.device) for block in self.blocks: h = block(h, cond, key_padding_mask, attn_mask=attn_mask) return self.out(self.out_norm(h)) def _reference_attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None: if self.reference_mode is None: return None if self.reference_mode == 'independent_token': return ~torch.eye(L, dtype=torch.bool, device=device) if self.reference_mode == 'block_diagonal': mask = torch.ones((L, L), dtype=torch.bool, device=device) mask[0, 0] = False if L > 1: mask[1:, 1:] = False return mask if self.reference_mode == 'causal_packets': mask = torch.zeros((L, L), dtype=torch.bool, device=device) if L > 1: packet_causal = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1) mask[1:, 1:] = packet_causal return mask if self.reference_mode == 'causal_all': return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1) raise AssertionError(self.reference_mode) @dataclass class UnifiedCFMConfig: T: int = 128 packet_dim: int = 9 flow_dim: int = 16 token_dim: int | None = None d_model: int = 128 n_layers: int = 4 n_heads: int = 4 mlp_ratio: float = 4.0 time_dim: int = 64 sigma: float = 0.1 use_ot: bool = False reference_mode: str | None = None class UnifiedTokenCFM(nn.Module): def __init__(self, cfg: UnifiedCFMConfig) -> None: super().__init__() self.cfg = cfg self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cfg.packet_dim) if self.token_dim < 1 + max(cfg.flow_dim, cfg.packet_dim): raise ValueError('token_dim is too small for flow_dim/packet_dim') self.seq_len = cfg.T + 1 self.velocity = UnifiedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode) def build_tokens(self, flow: torch.Tensor, packets: torch.Tensor) -> torch.Tensor: (B, T, Dp) = packets.shape if T != self.cfg.T: raise ValueError(f'packet T={T} but config T={self.cfg.T}') if Dp != self.cfg.packet_dim: raise ValueError(f'packet_dim={Dp} but config packet_dim={self.cfg.packet_dim}') if flow.shape[-1] != self.cfg.flow_dim: raise ValueError(f'flow_dim={flow.shape[-1]} but config flow_dim={self.cfg.flow_dim}') z = packets.new_zeros((B, T + 1, self.token_dim)) z[:, 0, 0] = -1.0 z[:, 0, 1:1 + self.cfg.flow_dim] = flow z[:, 1:, 0] = 1.0 z[:, 1:, 1:1 + self.cfg.packet_dim] = packets return z def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor: B = lens.shape[0] idx = torch.arange(self.cfg.T, device=lens.device)[None, :] packet_real = idx < lens[:, None] real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1) return ~real def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor: return (~self.key_padding_mask(lens)).float() @staticmethod def _masked_trimmed_mean(values: torch.Tensor, mask: torch.Tensor, trim_frac: float=0.1) -> torch.Tensor: out = values.new_zeros(values.shape[0]) for i in range(values.shape[0]): v = values[i][mask[i] > 0] if v.numel() == 0: continue if v.numel() < 5: out[i] = v.mean() continue v_sorted = torch.sort(v).values lo = int(trim_frac * v_sorted.numel()) hi = int((1.0 - trim_frac) * v_sorted.numel()) if hi <= lo: out[i] = v_sorted.mean() else: out[i] = v_sorted[lo:hi].mean() return out @staticmethod def _masked_median(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: out = values.new_zeros(values.shape[0]) for i in range(values.shape[0]): v = values[i][mask[i] > 0] if v.numel() == 0: continue v_sorted = torch.sort(v).values mid = v_sorted.numel() // 2 if v_sorted.numel() % 2: out[i] = v_sorted[mid] else: out[i] = 0.5 * (v_sorted[mid - 1] + v_sorted[mid]) return out def compute_loss(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, *, lambda_flow: float=0.0, lambda_packet: float=0.0, packet_mask_ratio: float=0.5, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]: x1 = self.build_tokens(flow, packets) B = x1.shape[0] x0 = torch.randn_like(x1) mask = self._loss_mask(lens) kpm = mask == 0 if self.cfg.use_ot: flat0 = (x0 * mask[:, :, None]).reshape(B, -1) flat1 = (x1 * mask[:, :, None]).reshape(B, -1) col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float())) x1 = x1[col] flow = flow[col] packets = packets[col] lens = lens[col] mask = self._loss_mask(lens) kpm = mask == 0 t = torch.rand(B, device=x1.device) x_t = (1.0 - t[:, None, None]) * x0 + t[:, None, None] * x1 if self.cfg.sigma > 0: std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None] x_t = x_t + std * torch.randn_like(x_t) target = x1 - x0 pred = self.velocity(x_t, t, key_padding_mask=kpm) sq = (pred - target).square().mean(dim=-1) per_sample = (sq * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) main_loss = per_sample.mean() aux_flow_loss = x1.new_zeros(()) aux_packet_loss = x1.new_zeros(()) if lambda_flow > 0.0: x_t_mf = x_t.clone() x_t_mf[:, 0, :] = 0.0 pred_mf = self.velocity(x_t_mf, t, key_padding_mask=kpm) err = (pred_mf[:, 0] - target[:, 0]).square().mean(dim=-1) aux_flow_loss = err.mean() if lambda_packet > 0.0: packet_real = mask[:, 1:] > 0 rand_draw = torch.rand(packet_real.shape, device=x1.device) mask_pkt = (rand_draw < packet_mask_ratio) & packet_real pkt_mask_full = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x1.device), mask_pkt], dim=1) x_t_mp = x_t.clone() x_t_mp[pkt_mask_full] = 0.0 pred_mp = self.velocity(x_t_mp, t, key_padding_mask=kpm) sq_mp = (pred_mp - target).square().mean(dim=-1) mask_f = pkt_mask_full.float() denom = mask_f.sum(dim=-1).clamp_min(1.0) aux_packet_loss = ((sq_mp * mask_f).sum(dim=-1) / denom).mean() total = main_loss + lambda_flow * aux_flow_loss + lambda_packet * aux_packet_loss if return_components: return {'total': total, 'main': main_loss.detach(), 'aux_flow': aux_flow_loss.detach(), 'aux_packet': aux_packet_loss.detach()} return total @torch.no_grad() def velocity_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5, 0.75, 1.0)) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 total = torch.zeros(x.shape[0], device=x.device) flow_s = torch.zeros_like(total) packet_s = torch.zeros_like(total) packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0) for t_val in t_eval: t = torch.full((x.shape[0],), float(t_val), device=x.device) v = self.velocity(x, t, key_padding_mask=kpm) e = v.square().mean(dim=-1) total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) flow_s = flow_s + e[:, 0] packet_s = packet_s + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count denom = float(len(t_eval)) return {'velocity_total': total / denom, 'velocity_flow': flow_s / denom, 'velocity_packet': packet_s / denom} @torch.no_grad() def trajectory_metrics(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]: z = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 B = z.shape[0] dt = 1.0 / n_steps total_arc = torch.zeros(B, device=z.device) total_ke = torch.zeros(B, device=z.device) flow_ke = torch.zeros(B, device=z.device) packet_ke = torch.zeros(B, device=z.device) total_curv = torch.zeros(B, device=z.device) flow_curv = torch.zeros(B, device=z.device) packet_curv = torch.zeros(B, device=z.device) packet_kappa2_speed2 = torch.zeros(B, max(z.shape[1] - 1, 0), device=z.device) packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0) v_prev = None v_prev_norm = None for k in range(n_steps): t_val = 1.0 - k * dt t = torch.full((B,), t_val, device=z.device) v = self.velocity(z, t, key_padding_mask=kpm) e = v.square().mean(dim=-1) v_norm = v.square().sum(dim=-1).clamp_min(1e-12).sqrt() total_ke = total_ke + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) * dt flow_ke = flow_ke + e[:, 0] * dt packet_ke = packet_ke + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count * dt if v_prev is not None: dv = v - v_prev dve = dv.square().mean(dim=-1) total_curv = total_curv + (dve * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) flow_curv = flow_curv + dve[:, 0] packet_curv = packet_curv + (dve[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count dv2_sum = dv[:, 1:].square().sum(dim=-1) assert v_prev_norm is not None v_avg = 0.5 * (v_norm[:, 1:] + v_prev_norm[:, 1:]) packet_kappa2_speed2 = packet_kappa2_speed2 + dv2_sum / v_avg.square().clamp_min(1e-06) v_prev = v v_prev_norm = v_norm z_new = z - v * dt dz = (z_new - z) * mask[:, :, None] total_arc = total_arc + dz.reshape(B, -1).norm(dim=-1) / mask.sum(dim=-1).sqrt() z = z_new z_masked = z * mask[:, :, None] terminal = z_masked.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt() terminal_flow = z[:, 0].norm(dim=-1) / math.sqrt(self.token_dim) terminal_packet = (z[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt() packet_mask = mask[:, 1:] kappa2_speed2_mean = (packet_kappa2_speed2 * packet_mask).sum(dim=-1) / packet_count kappa2_speed2_median = self._masked_median(packet_kappa2_speed2, packet_mask) kappa2_speed2_trimmed = self._masked_trimmed_mean(packet_kappa2_speed2, packet_mask) return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet, 'arc_length': total_arc, 'kinetic_energy': total_ke, 'kinetic_flow': flow_ke, 'kinetic_packet': packet_ke, 'curvature_total': total_curv, 'curvature_flow': flow_curv, 'curvature_packet': packet_curv, 'kappa2_speed2norm_packet_mean': kappa2_speed2_mean, 'kappa2_speed2norm_packet_median': kappa2_speed2_median, 'kappa2_speed2norm_packet_trimmed10_mean': kappa2_speed2_trimmed} @torch.no_grad() def score_profile_vt(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0)) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0) out: dict[str, torch.Tensor] = {} for t_val in t_eval: t = torch.full((x.shape[0],), float(t_val), device=x.device) v = self.velocity(x, t, key_padding_mask=kpm) e = v.square().mean(dim=-1) tag = f't{int(round(t_val * 10)):02d}' out[f'velocity_total_{tag}'] = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) out[f'velocity_flow_{tag}'] = e[:, 0] out[f'velocity_packet_{tag}'] = (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count return out @torch.no_grad() def consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 B = x.shape[0] packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0) t = torch.full((B,), float(t_eval), device=x.device) v_full = self.velocity(x, t, key_padding_mask=kpm) x_mf = x.clone() x_mf[:, 0, :] = 0.0 v_mf = self.velocity(x_mf, t, key_padding_mask=kpm) flow_cons = (v_full[:, 0] - v_mf[:, 0]).square().mean(dim=-1) x_mp = x.clone() pkt_mask_full = mask[:, 1:] > 0 idx_pkt_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x.device), pkt_mask_full], dim=1) x_mp[idx_pkt_mask] = 0.0 v_mp = self.velocity(x_mp, t, key_padding_mask=kpm) diff = (v_full - v_mp).square().mean(dim=-1) packet_cons = (diff[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count return {'flow_consistency': flow_cons, 'packet_consistency': packet_cons, 'consistency_total': flow_cons + packet_cons} def jacobian_hutchinson(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5,), n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 B = x.shape[0] packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0) total = torch.zeros(B, device=x.device) flow_j = torch.zeros(B, device=x.device) packet_j = torch.zeros(B, device=x.device) n_draws = n_eps * len(t_eval) for t_val in t_eval: t_current = torch.full((B,), float(t_val), device=x.device) for _ in range(n_eps): x_req = x.detach().clone().requires_grad_(True) v = self.velocity(x_req, t_current, key_padding_mask=kpm) eps = torch.randn(v.shape, device=v.device, generator=generator) (g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False) e = g.square().mean(dim=-1) total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) flow_j = flow_j + e[:, 0] packet_j = packet_j + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count return {'jacobian_total': (total / n_draws).detach(), 'jacobian_flow': (flow_j / n_draws).detach(), 'jacobian_packet': (packet_j / n_draws).detach()} @torch.no_grad() def pna_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, flow_masked: bool=False) -> dict[str, torch.Tensor]: eps_v2 = 1e-06 dt = 1.0 / n_steps z = self.build_tokens(flow, packets) if flow_masked: z = z.clone() z[:, 0, :] = 0.0 mask = self._loss_mask(lens) kpm = mask == 0 (B, L, _) = z.shape pna = torch.zeros(B, L, device=z.device) v_prev: torch.Tensor | None = None v_norm_prev: torch.Tensor | None = None for k in range(n_steps): t_val = 1.0 - k * dt t = torch.full((B,), t_val, device=z.device) v = self.velocity(z, t, key_padding_mask=kpm) v_norm = (v.square().sum(dim=-1) + 1e-12).sqrt() if v_prev is not None: dv2 = (v - v_prev).square().sum(dim=-1) v_avg2 = (0.5 * (v_norm + v_norm_prev)).square().clamp_min(eps_v2) pna = pna + dv2 / v_avg2 v_prev = v v_norm_prev = v_norm z = z - v * dt if flow_masked: z[:, 0, :] = 0.0 flow_pna = pna[:, 0] packet_pna = pna[:, 1:] packet_mask = mask[:, 1:] packet_count = packet_mask.sum(dim=-1).clamp_min(1.0) pna_median = self._masked_median(packet_pna, packet_mask) pna_mean = (packet_pna * packet_mask).sum(dim=-1) / packet_count masked_for_max = packet_pna.masked_fill(packet_mask == 0, float('-inf')) pna_max = masked_for_max.max(dim=-1).values pna_trimmed = self._masked_trimmed_mean(packet_pna, packet_mask) return {'pna_packet_median': pna_median, 'pna_packet_mean': pna_mean, 'pna_packet_max': pna_max, 'pna_packet_trimmed10_mean': pna_trimmed, 'pna_flow': flow_pna} @torch.no_grad() def causal_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 (B, L, _) = x.shape t = torch.full((B,), float(t_eval), device=x.device) v_full = self.velocity(x, t, key_padding_mask=kpm) causal = torch.triu(torch.ones(L, L, dtype=torch.bool, device=x.device), diagonal=1) v_causal = self.velocity(x, t, key_padding_mask=kpm, attn_mask_override=causal) diff = (v_full - v_causal).square().mean(dim=-1) flow_surprisal = diff[:, 0] packet_diff = diff[:, 1:] packet_mask = mask[:, 1:] packet_count = packet_mask.sum(dim=-1).clamp_min(1.0) packet_mean = (packet_diff * packet_mask).sum(dim=-1) / packet_count packet_median = self._masked_median(packet_diff, packet_mask) masked_for_max = packet_diff.masked_fill(packet_mask == 0, float('-inf')) packet_max = masked_for_max.max(dim=-1).values packet_trimmed = self._masked_trimmed_mean(packet_diff, packet_mask) total = (diff * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) return {'causal_surprisal_total': total, 'causal_surprisal_flow': flow_surprisal, 'causal_surprisal_packet_mean': packet_mean, 'causal_surprisal_packet_median': packet_median, 'causal_surprisal_packet_max': packet_max, 'causal_surprisal_packet_trimmed10_mean': packet_trimmed} @torch.no_grad() def direction_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.2, 0.4, 0.6, 0.8, 1.0)) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 (B, L, _) = x.shape t_eval = tuple(t_eval) if len(t_eval) < 2: raise ValueError('direction_consistency_score needs >=2 t values') prev_v: torch.Tensor | None = None drift = x.new_zeros(B, L) n_pairs = len(t_eval) - 1 for t_val in t_eval: t = torch.full((B,), float(t_val), device=x.device) v = self.velocity(x, t, key_padding_mask=kpm) if prev_v is not None: num = (prev_v * v).sum(dim=-1) denom = prev_v.norm(dim=-1).clamp_min(1e-08) * v.norm(dim=-1).clamp_min(1e-08) cos = num / denom drift = drift + (1.0 - cos) prev_v = v drift = drift / max(n_pairs, 1) flow_drift = drift[:, 0] packet_drift = drift[:, 1:] packet_mask = mask[:, 1:] packet_count = packet_mask.sum(dim=-1).clamp_min(1.0) packet_mean = (packet_drift * packet_mask).sum(dim=-1) / packet_count packet_median = self._masked_median(packet_drift, packet_mask) masked_for_max = packet_drift.masked_fill(packet_mask == 0, float('-inf')) packet_max = masked_for_max.max(dim=-1).values packet_trimmed = self._masked_trimmed_mean(packet_drift, packet_mask) total = (drift * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) return {'direction_drift_total': total, 'direction_drift_flow': flow_drift, 'direction_drift_packet_mean': packet_mean, 'direction_drift_packet_median': packet_median, 'direction_drift_packet_max': packet_max, 'direction_drift_packet_trimmed10_mean': packet_trimmed} def inverse_flow_nll_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, n_eps: int=4, compute_divergence: bool=True, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]: z = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 (B, L, D) = z.shape dt = 1.0 / n_steps accum_div = torch.zeros(B, device=z.device) if compute_divergence: for k in range(n_steps): t_val = 1.0 - k * dt t = torch.full((B,), t_val, device=z.device) z_req = z.detach().clone().requires_grad_(True) v = self.velocity(z_req, t, key_padding_mask=kpm) div_step = torch.zeros(B, device=z.device) for j in range(n_eps): eps = torch.randn_like(v) eps_masked = eps * mask[:, :, None] retain = j < n_eps - 1 (g,) = torch.autograd.grad(outputs=v, inputs=z_req, grad_outputs=eps_masked, retain_graph=retain, create_graph=False) div_step = div_step + (eps_masked * g).sum(dim=(1, 2)) div_step = div_step / float(n_eps) accum_div = accum_div + div_step * dt with torch.no_grad(): z = (z_req - v * dt).detach() else: with torch.no_grad(): for k in range(n_steps): t_val = 1.0 - k * dt t = torch.full((B,), t_val, device=z.device) v = self.velocity(z, t, key_padding_mask=kpm) z = z - v * dt with torch.no_grad(): z_masked = z * mask[:, :, None] n_real = mask.sum(dim=-1).clamp_min(1.0) x0_quadratic = z_masked.reshape(B, -1).square().sum(dim=-1) / (n_real * float(D)) nll_x0_only = x0_quadratic nll_div_only = accum_div / (n_real * float(D)) nll_full = nll_x0_only + nll_div_only return {'nll_x0_only': nll_x0_only.detach(), 'nll_div_only': nll_div_only.detach(), 'nll_full': nll_full.detach()} def jacobian_spectral_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]: x = self.build_tokens(flow, packets) mask = self._loss_mask(lens) kpm = mask == 0 (B, L, D) = x.shape t = torch.full((B,), float(t_eval), device=x.device) packet_mask = mask[:, 1:] packet_count = packet_mask.sum(dim=-1).clamp_min(1.0) norms_total: list[torch.Tensor] = [] norms_flow: list[torch.Tensor] = [] norms_packet: list[torch.Tensor] = [] for _ in range(n_eps): x_req = x.detach().clone().requires_grad_(True) v = self.velocity(x_req, t, key_padding_mask=kpm) eps = torch.randn(v.shape, device=v.device, generator=generator) (g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False) e = g.square().mean(dim=-1) n_total = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) n_flow = e[:, 0] n_packet = (e[:, 1:] * packet_mask).sum(dim=-1) / packet_count norms_total.append(n_total.detach()) norms_flow.append(n_flow.detach()) norms_packet.append(n_packet.detach()) def _spectral_summary(samples: list[torch.Tensor]) -> dict[str, torch.Tensor]: stack = torch.stack(samples, dim=1) mean = stack.mean(dim=1).clamp_min(1e-12) mx = stack.max(dim=1).values mn = stack.min(dim=1).values logfro = torch.log(mean) aniso = mx / mean min_over_max = mn / mx.clamp_min(1e-12) p = stack / stack.sum(dim=1, keepdim=True).clamp_min(1e-12) entropy = -(p * p.clamp_min(1e-12).log()).sum(dim=1) eff_rank = torch.exp(entropy) return {'logfro': logfro, 'anisotropy': aniso, 'min_over_max': min_over_max, 'eff_rank': eff_rank} out: dict[str, torch.Tensor] = {} for (tag, samples) in (('total', norms_total), ('flow', norms_flow), ('packet', norms_packet)): summ = _spectral_summary(samples) for (stat_name, val) in summ.items(): out[f'jac_{stat_name}_{tag}'] = val return out @torch.no_grad() def sample(self, n: int, lens: torch.Tensor, device: torch.device, n_steps: int=50, method: str='euler') -> torch.Tensor: z = torch.randn(n, self.seq_len, self.token_dim, device=device) ts = torch.linspace(0.0, 1.0, n_steps + 1, device=device) kpm = self.key_padding_mask(lens.to(device)) def f(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: return self.velocity(x, t.expand(x.shape[0]), key_padding_mask=kpm) if method == 'euler': for i in range(n_steps): z = z + f(ts[i], z) * (ts[i + 1] - ts[i]) return z return odeint(f, z, ts, method=method)[-1] def param_count(self) -> int: return sum((p.numel() for p in self.parameters()))