589 lines
30 KiB
Python
589 lines
30 KiB
Python
from __future__ import annotations
|
|
import math
|
|
from dataclasses import dataclass
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchdiffeq import odeint
|
|
|
|
@torch.no_grad()
|
|
def _sinkhorn_coupling(C: torch.Tensor, reg: float=0.05, n_iter: int=20) -> torch.Tensor:
|
|
C = C.float()
|
|
log_k = -C / reg
|
|
B = C.shape[0]
|
|
log_u = torch.zeros(B, device=C.device)
|
|
log_v = torch.zeros(B, device=C.device)
|
|
for _ in range(n_iter):
|
|
log_v = -torch.logsumexp(log_k + log_u.unsqueeze(1), dim=0)
|
|
log_u = -torch.logsumexp(log_k + log_v.unsqueeze(0), dim=1)
|
|
log_p = log_u.unsqueeze(1) + log_k + log_v.unsqueeze(0)
|
|
return log_p.argmax(dim=1)
|
|
|
|
class SinusoidalTimeEmb(nn.Module):
|
|
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
if dim % 2 != 0:
|
|
raise ValueError('time embedding dimension must be even')
|
|
self.dim = dim
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
half = self.dim // 2
|
|
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half - 1, 1))
|
|
args = t[:, None] * freqs[None, :]
|
|
return torch.cat([args.sin(), args.cos()], dim=-1)
|
|
|
|
class AdaLNBlock(nn.Module):
|
|
|
|
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float, cond_dim: int) -> None:
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
|
|
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
|
|
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
|
|
hidden = int(d_model * mlp_ratio)
|
|
self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, d_model))
|
|
self.cond_proj = nn.Linear(cond_dim, 6 * d_model)
|
|
nn.init.zeros_(self.cond_proj.weight)
|
|
nn.init.zeros_(self.cond_proj.bias)
|
|
|
|
@staticmethod
|
|
def _modulate(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
|
|
return x * (1.0 + gamma[:, None, :]) + beta[:, None, :]
|
|
|
|
def forward(self, x: torch.Tensor, cond: torch.Tensor, key_padding_mask: torch.Tensor | None, attn_mask: torch.Tensor | None=None) -> torch.Tensor:
|
|
(g1, b1, a1, g2, b2, a2) = self.cond_proj(cond).chunk(6, dim=-1)
|
|
h = self._modulate(self.norm1(x), g1, b1)
|
|
(attn_out, _) = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
|
|
x = x + a1[:, None, :] * attn_out
|
|
h = self._modulate(self.norm2(x), g2, b2)
|
|
return x + a2[:, None, :] * self.mlp(h)
|
|
|
|
class UnifiedVelocity(nn.Module):
|
|
|
|
def __init__(self, token_dim: int, seq_len: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None:
|
|
super().__init__()
|
|
if reference_mode not in (None, 'independent_token', 'block_diagonal', 'causal_packets', 'causal_all'):
|
|
raise ValueError(f'unknown reference_mode={reference_mode!r}')
|
|
self.token_dim = token_dim
|
|
self.seq_len = seq_len
|
|
self.reference_mode = reference_mode
|
|
self.input_proj = nn.Linear(token_dim, d_model)
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
|
|
self.type_emb = nn.Embedding(2, d_model)
|
|
nn.init.trunc_normal_(self.pos_emb, std=0.02)
|
|
nn.init.normal_(self.type_emb.weight, std=0.02)
|
|
self.time_emb = SinusoidalTimeEmb(time_dim)
|
|
self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model))
|
|
self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)])
|
|
self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
|
self.out = nn.Linear(d_model, token_dim)
|
|
nn.init.zeros_(self.out.weight)
|
|
nn.init.zeros_(self.out.bias)
|
|
type_ids = torch.ones(seq_len, dtype=torch.long)
|
|
type_ids[0] = 0
|
|
self.register_buffer('type_ids', type_ids, persistent=False)
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None, attn_mask_override: torch.Tensor | None=None) -> torch.Tensor:
|
|
(B, L, _) = x.shape
|
|
if L > self.seq_len:
|
|
raise ValueError(f'sequence length {L} exceeds configured {self.seq_len}')
|
|
if t.dim() == 0:
|
|
t = t.expand(B)
|
|
h = self.input_proj(x)
|
|
h = h + self.pos_emb[:, :L, :]
|
|
h = h + self.type_emb(self.type_ids[:L])[None, :, :]
|
|
cond = self.cond_mlp(self.time_emb(t))
|
|
if attn_mask_override is not None:
|
|
attn_mask = attn_mask_override
|
|
else:
|
|
attn_mask = self._reference_attn_mask(L, x.device)
|
|
for block in self.blocks:
|
|
h = block(h, cond, key_padding_mask, attn_mask=attn_mask)
|
|
return self.out(self.out_norm(h))
|
|
|
|
def _reference_attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None:
|
|
if self.reference_mode is None:
|
|
return None
|
|
if self.reference_mode == 'independent_token':
|
|
return ~torch.eye(L, dtype=torch.bool, device=device)
|
|
if self.reference_mode == 'block_diagonal':
|
|
mask = torch.ones((L, L), dtype=torch.bool, device=device)
|
|
mask[0, 0] = False
|
|
if L > 1:
|
|
mask[1:, 1:] = False
|
|
return mask
|
|
if self.reference_mode == 'causal_packets':
|
|
mask = torch.zeros((L, L), dtype=torch.bool, device=device)
|
|
if L > 1:
|
|
packet_causal = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1)
|
|
mask[1:, 1:] = packet_causal
|
|
return mask
|
|
if self.reference_mode == 'causal_all':
|
|
return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
|
|
raise AssertionError(self.reference_mode)
|
|
|
|
@dataclass
|
|
class UnifiedCFMConfig:
|
|
T: int = 128
|
|
packet_dim: int = 9
|
|
flow_dim: int = 16
|
|
token_dim: int | None = None
|
|
d_model: int = 128
|
|
n_layers: int = 4
|
|
n_heads: int = 4
|
|
mlp_ratio: float = 4.0
|
|
time_dim: int = 64
|
|
sigma: float = 0.1
|
|
use_ot: bool = False
|
|
reference_mode: str | None = None
|
|
|
|
class UnifiedTokenCFM(nn.Module):
|
|
|
|
def __init__(self, cfg: UnifiedCFMConfig) -> None:
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cfg.packet_dim)
|
|
if self.token_dim < 1 + max(cfg.flow_dim, cfg.packet_dim):
|
|
raise ValueError('token_dim is too small for flow_dim/packet_dim')
|
|
self.seq_len = cfg.T + 1
|
|
self.velocity = UnifiedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode)
|
|
|
|
def build_tokens(self, flow: torch.Tensor, packets: torch.Tensor) -> torch.Tensor:
|
|
(B, T, Dp) = packets.shape
|
|
if T != self.cfg.T:
|
|
raise ValueError(f'packet T={T} but config T={self.cfg.T}')
|
|
if Dp != self.cfg.packet_dim:
|
|
raise ValueError(f'packet_dim={Dp} but config packet_dim={self.cfg.packet_dim}')
|
|
if flow.shape[-1] != self.cfg.flow_dim:
|
|
raise ValueError(f'flow_dim={flow.shape[-1]} but config flow_dim={self.cfg.flow_dim}')
|
|
z = packets.new_zeros((B, T + 1, self.token_dim))
|
|
z[:, 0, 0] = -1.0
|
|
z[:, 0, 1:1 + self.cfg.flow_dim] = flow
|
|
z[:, 1:, 0] = 1.0
|
|
z[:, 1:, 1:1 + self.cfg.packet_dim] = packets
|
|
return z
|
|
|
|
def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
|
B = lens.shape[0]
|
|
idx = torch.arange(self.cfg.T, device=lens.device)[None, :]
|
|
packet_real = idx < lens[:, None]
|
|
real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1)
|
|
return ~real
|
|
|
|
def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
|
return (~self.key_padding_mask(lens)).float()
|
|
|
|
@staticmethod
|
|
def _masked_trimmed_mean(values: torch.Tensor, mask: torch.Tensor, trim_frac: float=0.1) -> torch.Tensor:
|
|
out = values.new_zeros(values.shape[0])
|
|
for i in range(values.shape[0]):
|
|
v = values[i][mask[i] > 0]
|
|
if v.numel() == 0:
|
|
continue
|
|
if v.numel() < 5:
|
|
out[i] = v.mean()
|
|
continue
|
|
v_sorted = torch.sort(v).values
|
|
lo = int(trim_frac * v_sorted.numel())
|
|
hi = int((1.0 - trim_frac) * v_sorted.numel())
|
|
if hi <= lo:
|
|
out[i] = v_sorted.mean()
|
|
else:
|
|
out[i] = v_sorted[lo:hi].mean()
|
|
return out
|
|
|
|
@staticmethod
|
|
def _masked_median(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
out = values.new_zeros(values.shape[0])
|
|
for i in range(values.shape[0]):
|
|
v = values[i][mask[i] > 0]
|
|
if v.numel() == 0:
|
|
continue
|
|
v_sorted = torch.sort(v).values
|
|
mid = v_sorted.numel() // 2
|
|
if v_sorted.numel() % 2:
|
|
out[i] = v_sorted[mid]
|
|
else:
|
|
out[i] = 0.5 * (v_sorted[mid - 1] + v_sorted[mid])
|
|
return out
|
|
|
|
def compute_loss(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, *, lambda_flow: float=0.0, lambda_packet: float=0.0, packet_mask_ratio: float=0.5, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
x1 = self.build_tokens(flow, packets)
|
|
B = x1.shape[0]
|
|
x0 = torch.randn_like(x1)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
if self.cfg.use_ot:
|
|
flat0 = (x0 * mask[:, :, None]).reshape(B, -1)
|
|
flat1 = (x1 * mask[:, :, None]).reshape(B, -1)
|
|
col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float()))
|
|
x1 = x1[col]
|
|
flow = flow[col]
|
|
packets = packets[col]
|
|
lens = lens[col]
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
t = torch.rand(B, device=x1.device)
|
|
x_t = (1.0 - t[:, None, None]) * x0 + t[:, None, None] * x1
|
|
if self.cfg.sigma > 0:
|
|
std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
|
|
x_t = x_t + std * torch.randn_like(x_t)
|
|
target = x1 - x0
|
|
pred = self.velocity(x_t, t, key_padding_mask=kpm)
|
|
sq = (pred - target).square().mean(dim=-1)
|
|
per_sample = (sq * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
main_loss = per_sample.mean()
|
|
aux_flow_loss = x1.new_zeros(())
|
|
aux_packet_loss = x1.new_zeros(())
|
|
if lambda_flow > 0.0:
|
|
x_t_mf = x_t.clone()
|
|
x_t_mf[:, 0, :] = 0.0
|
|
pred_mf = self.velocity(x_t_mf, t, key_padding_mask=kpm)
|
|
err = (pred_mf[:, 0] - target[:, 0]).square().mean(dim=-1)
|
|
aux_flow_loss = err.mean()
|
|
if lambda_packet > 0.0:
|
|
packet_real = mask[:, 1:] > 0
|
|
rand_draw = torch.rand(packet_real.shape, device=x1.device)
|
|
mask_pkt = (rand_draw < packet_mask_ratio) & packet_real
|
|
pkt_mask_full = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x1.device), mask_pkt], dim=1)
|
|
x_t_mp = x_t.clone()
|
|
x_t_mp[pkt_mask_full] = 0.0
|
|
pred_mp = self.velocity(x_t_mp, t, key_padding_mask=kpm)
|
|
sq_mp = (pred_mp - target).square().mean(dim=-1)
|
|
mask_f = pkt_mask_full.float()
|
|
denom = mask_f.sum(dim=-1).clamp_min(1.0)
|
|
aux_packet_loss = ((sq_mp * mask_f).sum(dim=-1) / denom).mean()
|
|
total = main_loss + lambda_flow * aux_flow_loss + lambda_packet * aux_packet_loss
|
|
if return_components:
|
|
return {'total': total, 'main': main_loss.detach(), 'aux_flow': aux_flow_loss.detach(), 'aux_packet': aux_packet_loss.detach()}
|
|
return total
|
|
|
|
@torch.no_grad()
|
|
def velocity_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5, 0.75, 1.0)) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
total = torch.zeros(x.shape[0], device=x.device)
|
|
flow_s = torch.zeros_like(total)
|
|
packet_s = torch.zeros_like(total)
|
|
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
|
for t_val in t_eval:
|
|
t = torch.full((x.shape[0],), float(t_val), device=x.device)
|
|
v = self.velocity(x, t, key_padding_mask=kpm)
|
|
e = v.square().mean(dim=-1)
|
|
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
flow_s = flow_s + e[:, 0]
|
|
packet_s = packet_s + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
|
denom = float(len(t_eval))
|
|
return {'velocity_total': total / denom, 'velocity_flow': flow_s / denom, 'velocity_packet': packet_s / denom}
|
|
|
|
@torch.no_grad()
|
|
def trajectory_metrics(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]:
|
|
z = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
B = z.shape[0]
|
|
dt = 1.0 / n_steps
|
|
total_arc = torch.zeros(B, device=z.device)
|
|
total_ke = torch.zeros(B, device=z.device)
|
|
flow_ke = torch.zeros(B, device=z.device)
|
|
packet_ke = torch.zeros(B, device=z.device)
|
|
total_curv = torch.zeros(B, device=z.device)
|
|
flow_curv = torch.zeros(B, device=z.device)
|
|
packet_curv = torch.zeros(B, device=z.device)
|
|
packet_kappa2_speed2 = torch.zeros(B, max(z.shape[1] - 1, 0), device=z.device)
|
|
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
|
v_prev = None
|
|
v_prev_norm = None
|
|
for k in range(n_steps):
|
|
t_val = 1.0 - k * dt
|
|
t = torch.full((B,), t_val, device=z.device)
|
|
v = self.velocity(z, t, key_padding_mask=kpm)
|
|
e = v.square().mean(dim=-1)
|
|
v_norm = v.square().sum(dim=-1).clamp_min(1e-12).sqrt()
|
|
total_ke = total_ke + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) * dt
|
|
flow_ke = flow_ke + e[:, 0] * dt
|
|
packet_ke = packet_ke + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count * dt
|
|
if v_prev is not None:
|
|
dv = v - v_prev
|
|
dve = dv.square().mean(dim=-1)
|
|
total_curv = total_curv + (dve * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
flow_curv = flow_curv + dve[:, 0]
|
|
packet_curv = packet_curv + (dve[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
|
dv2_sum = dv[:, 1:].square().sum(dim=-1)
|
|
assert v_prev_norm is not None
|
|
v_avg = 0.5 * (v_norm[:, 1:] + v_prev_norm[:, 1:])
|
|
packet_kappa2_speed2 = packet_kappa2_speed2 + dv2_sum / v_avg.square().clamp_min(1e-06)
|
|
v_prev = v
|
|
v_prev_norm = v_norm
|
|
z_new = z - v * dt
|
|
dz = (z_new - z) * mask[:, :, None]
|
|
total_arc = total_arc + dz.reshape(B, -1).norm(dim=-1) / mask.sum(dim=-1).sqrt()
|
|
z = z_new
|
|
z_masked = z * mask[:, :, None]
|
|
terminal = z_masked.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
|
|
terminal_flow = z[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
|
|
terminal_packet = (z[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
|
|
packet_mask = mask[:, 1:]
|
|
kappa2_speed2_mean = (packet_kappa2_speed2 * packet_mask).sum(dim=-1) / packet_count
|
|
kappa2_speed2_median = self._masked_median(packet_kappa2_speed2, packet_mask)
|
|
kappa2_speed2_trimmed = self._masked_trimmed_mean(packet_kappa2_speed2, packet_mask)
|
|
return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet, 'arc_length': total_arc, 'kinetic_energy': total_ke, 'kinetic_flow': flow_ke, 'kinetic_packet': packet_ke, 'curvature_total': total_curv, 'curvature_flow': flow_curv, 'curvature_packet': packet_curv, 'kappa2_speed2norm_packet_mean': kappa2_speed2_mean, 'kappa2_speed2norm_packet_median': kappa2_speed2_median, 'kappa2_speed2norm_packet_trimmed10_mean': kappa2_speed2_trimmed}
|
|
|
|
@torch.no_grad()
|
|
def score_profile_vt(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0)) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
|
out: dict[str, torch.Tensor] = {}
|
|
for t_val in t_eval:
|
|
t = torch.full((x.shape[0],), float(t_val), device=x.device)
|
|
v = self.velocity(x, t, key_padding_mask=kpm)
|
|
e = v.square().mean(dim=-1)
|
|
tag = f't{int(round(t_val * 10)):02d}'
|
|
out[f'velocity_total_{tag}'] = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
out[f'velocity_flow_{tag}'] = e[:, 0]
|
|
out[f'velocity_packet_{tag}'] = (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
|
return out
|
|
|
|
@torch.no_grad()
|
|
def consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
B = x.shape[0]
|
|
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
|
t = torch.full((B,), float(t_eval), device=x.device)
|
|
v_full = self.velocity(x, t, key_padding_mask=kpm)
|
|
x_mf = x.clone()
|
|
x_mf[:, 0, :] = 0.0
|
|
v_mf = self.velocity(x_mf, t, key_padding_mask=kpm)
|
|
flow_cons = (v_full[:, 0] - v_mf[:, 0]).square().mean(dim=-1)
|
|
x_mp = x.clone()
|
|
pkt_mask_full = mask[:, 1:] > 0
|
|
idx_pkt_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x.device), pkt_mask_full], dim=1)
|
|
x_mp[idx_pkt_mask] = 0.0
|
|
v_mp = self.velocity(x_mp, t, key_padding_mask=kpm)
|
|
diff = (v_full - v_mp).square().mean(dim=-1)
|
|
packet_cons = (diff[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
|
return {'flow_consistency': flow_cons, 'packet_consistency': packet_cons, 'consistency_total': flow_cons + packet_cons}
|
|
|
|
def jacobian_hutchinson(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5,), n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
B = x.shape[0]
|
|
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
|
total = torch.zeros(B, device=x.device)
|
|
flow_j = torch.zeros(B, device=x.device)
|
|
packet_j = torch.zeros(B, device=x.device)
|
|
n_draws = n_eps * len(t_eval)
|
|
for t_val in t_eval:
|
|
t_current = torch.full((B,), float(t_val), device=x.device)
|
|
for _ in range(n_eps):
|
|
x_req = x.detach().clone().requires_grad_(True)
|
|
v = self.velocity(x_req, t_current, key_padding_mask=kpm)
|
|
eps = torch.randn(v.shape, device=v.device, generator=generator)
|
|
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
|
|
e = g.square().mean(dim=-1)
|
|
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
flow_j = flow_j + e[:, 0]
|
|
packet_j = packet_j + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
|
return {'jacobian_total': (total / n_draws).detach(), 'jacobian_flow': (flow_j / n_draws).detach(), 'jacobian_packet': (packet_j / n_draws).detach()}
|
|
|
|
@torch.no_grad()
|
|
def pna_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, flow_masked: bool=False) -> dict[str, torch.Tensor]:
|
|
eps_v2 = 1e-06
|
|
dt = 1.0 / n_steps
|
|
z = self.build_tokens(flow, packets)
|
|
if flow_masked:
|
|
z = z.clone()
|
|
z[:, 0, :] = 0.0
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
(B, L, _) = z.shape
|
|
pna = torch.zeros(B, L, device=z.device)
|
|
v_prev: torch.Tensor | None = None
|
|
v_norm_prev: torch.Tensor | None = None
|
|
for k in range(n_steps):
|
|
t_val = 1.0 - k * dt
|
|
t = torch.full((B,), t_val, device=z.device)
|
|
v = self.velocity(z, t, key_padding_mask=kpm)
|
|
v_norm = (v.square().sum(dim=-1) + 1e-12).sqrt()
|
|
if v_prev is not None:
|
|
dv2 = (v - v_prev).square().sum(dim=-1)
|
|
v_avg2 = (0.5 * (v_norm + v_norm_prev)).square().clamp_min(eps_v2)
|
|
pna = pna + dv2 / v_avg2
|
|
v_prev = v
|
|
v_norm_prev = v_norm
|
|
z = z - v * dt
|
|
if flow_masked:
|
|
z[:, 0, :] = 0.0
|
|
flow_pna = pna[:, 0]
|
|
packet_pna = pna[:, 1:]
|
|
packet_mask = mask[:, 1:]
|
|
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
|
pna_median = self._masked_median(packet_pna, packet_mask)
|
|
pna_mean = (packet_pna * packet_mask).sum(dim=-1) / packet_count
|
|
masked_for_max = packet_pna.masked_fill(packet_mask == 0, float('-inf'))
|
|
pna_max = masked_for_max.max(dim=-1).values
|
|
pna_trimmed = self._masked_trimmed_mean(packet_pna, packet_mask)
|
|
return {'pna_packet_median': pna_median, 'pna_packet_mean': pna_mean, 'pna_packet_max': pna_max, 'pna_packet_trimmed10_mean': pna_trimmed, 'pna_flow': flow_pna}
|
|
|
|
@torch.no_grad()
|
|
def causal_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
(B, L, _) = x.shape
|
|
t = torch.full((B,), float(t_eval), device=x.device)
|
|
v_full = self.velocity(x, t, key_padding_mask=kpm)
|
|
causal = torch.triu(torch.ones(L, L, dtype=torch.bool, device=x.device), diagonal=1)
|
|
v_causal = self.velocity(x, t, key_padding_mask=kpm, attn_mask_override=causal)
|
|
diff = (v_full - v_causal).square().mean(dim=-1)
|
|
flow_surprisal = diff[:, 0]
|
|
packet_diff = diff[:, 1:]
|
|
packet_mask = mask[:, 1:]
|
|
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
|
packet_mean = (packet_diff * packet_mask).sum(dim=-1) / packet_count
|
|
packet_median = self._masked_median(packet_diff, packet_mask)
|
|
masked_for_max = packet_diff.masked_fill(packet_mask == 0, float('-inf'))
|
|
packet_max = masked_for_max.max(dim=-1).values
|
|
packet_trimmed = self._masked_trimmed_mean(packet_diff, packet_mask)
|
|
total = (diff * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
return {'causal_surprisal_total': total, 'causal_surprisal_flow': flow_surprisal, 'causal_surprisal_packet_mean': packet_mean, 'causal_surprisal_packet_median': packet_median, 'causal_surprisal_packet_max': packet_max, 'causal_surprisal_packet_trimmed10_mean': packet_trimmed}
|
|
|
|
@torch.no_grad()
|
|
def direction_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.2, 0.4, 0.6, 0.8, 1.0)) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
(B, L, _) = x.shape
|
|
t_eval = tuple(t_eval)
|
|
if len(t_eval) < 2:
|
|
raise ValueError('direction_consistency_score needs >=2 t values')
|
|
prev_v: torch.Tensor | None = None
|
|
drift = x.new_zeros(B, L)
|
|
n_pairs = len(t_eval) - 1
|
|
for t_val in t_eval:
|
|
t = torch.full((B,), float(t_val), device=x.device)
|
|
v = self.velocity(x, t, key_padding_mask=kpm)
|
|
if prev_v is not None:
|
|
num = (prev_v * v).sum(dim=-1)
|
|
denom = prev_v.norm(dim=-1).clamp_min(1e-08) * v.norm(dim=-1).clamp_min(1e-08)
|
|
cos = num / denom
|
|
drift = drift + (1.0 - cos)
|
|
prev_v = v
|
|
drift = drift / max(n_pairs, 1)
|
|
flow_drift = drift[:, 0]
|
|
packet_drift = drift[:, 1:]
|
|
packet_mask = mask[:, 1:]
|
|
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
|
packet_mean = (packet_drift * packet_mask).sum(dim=-1) / packet_count
|
|
packet_median = self._masked_median(packet_drift, packet_mask)
|
|
masked_for_max = packet_drift.masked_fill(packet_mask == 0, float('-inf'))
|
|
packet_max = masked_for_max.max(dim=-1).values
|
|
packet_trimmed = self._masked_trimmed_mean(packet_drift, packet_mask)
|
|
total = (drift * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
return {'direction_drift_total': total, 'direction_drift_flow': flow_drift, 'direction_drift_packet_mean': packet_mean, 'direction_drift_packet_median': packet_median, 'direction_drift_packet_max': packet_max, 'direction_drift_packet_trimmed10_mean': packet_trimmed}
|
|
|
|
def inverse_flow_nll_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, n_eps: int=4, compute_divergence: bool=True, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
|
z = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
(B, L, D) = z.shape
|
|
dt = 1.0 / n_steps
|
|
accum_div = torch.zeros(B, device=z.device)
|
|
if compute_divergence:
|
|
for k in range(n_steps):
|
|
t_val = 1.0 - k * dt
|
|
t = torch.full((B,), t_val, device=z.device)
|
|
z_req = z.detach().clone().requires_grad_(True)
|
|
v = self.velocity(z_req, t, key_padding_mask=kpm)
|
|
div_step = torch.zeros(B, device=z.device)
|
|
for j in range(n_eps):
|
|
eps = torch.randn_like(v)
|
|
eps_masked = eps * mask[:, :, None]
|
|
retain = j < n_eps - 1
|
|
(g,) = torch.autograd.grad(outputs=v, inputs=z_req, grad_outputs=eps_masked, retain_graph=retain, create_graph=False)
|
|
div_step = div_step + (eps_masked * g).sum(dim=(1, 2))
|
|
div_step = div_step / float(n_eps)
|
|
accum_div = accum_div + div_step * dt
|
|
with torch.no_grad():
|
|
z = (z_req - v * dt).detach()
|
|
else:
|
|
with torch.no_grad():
|
|
for k in range(n_steps):
|
|
t_val = 1.0 - k * dt
|
|
t = torch.full((B,), t_val, device=z.device)
|
|
v = self.velocity(z, t, key_padding_mask=kpm)
|
|
z = z - v * dt
|
|
with torch.no_grad():
|
|
z_masked = z * mask[:, :, None]
|
|
n_real = mask.sum(dim=-1).clamp_min(1.0)
|
|
x0_quadratic = z_masked.reshape(B, -1).square().sum(dim=-1) / (n_real * float(D))
|
|
nll_x0_only = x0_quadratic
|
|
nll_div_only = accum_div / (n_real * float(D))
|
|
nll_full = nll_x0_only + nll_div_only
|
|
return {'nll_x0_only': nll_x0_only.detach(), 'nll_div_only': nll_div_only.detach(), 'nll_full': nll_full.detach()}
|
|
|
|
def jacobian_spectral_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
|
x = self.build_tokens(flow, packets)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
(B, L, D) = x.shape
|
|
t = torch.full((B,), float(t_eval), device=x.device)
|
|
packet_mask = mask[:, 1:]
|
|
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
|
norms_total: list[torch.Tensor] = []
|
|
norms_flow: list[torch.Tensor] = []
|
|
norms_packet: list[torch.Tensor] = []
|
|
for _ in range(n_eps):
|
|
x_req = x.detach().clone().requires_grad_(True)
|
|
v = self.velocity(x_req, t, key_padding_mask=kpm)
|
|
eps = torch.randn(v.shape, device=v.device, generator=generator)
|
|
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
|
|
e = g.square().mean(dim=-1)
|
|
n_total = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
n_flow = e[:, 0]
|
|
n_packet = (e[:, 1:] * packet_mask).sum(dim=-1) / packet_count
|
|
norms_total.append(n_total.detach())
|
|
norms_flow.append(n_flow.detach())
|
|
norms_packet.append(n_packet.detach())
|
|
|
|
def _spectral_summary(samples: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
stack = torch.stack(samples, dim=1)
|
|
mean = stack.mean(dim=1).clamp_min(1e-12)
|
|
mx = stack.max(dim=1).values
|
|
mn = stack.min(dim=1).values
|
|
logfro = torch.log(mean)
|
|
aniso = mx / mean
|
|
min_over_max = mn / mx.clamp_min(1e-12)
|
|
p = stack / stack.sum(dim=1, keepdim=True).clamp_min(1e-12)
|
|
entropy = -(p * p.clamp_min(1e-12).log()).sum(dim=1)
|
|
eff_rank = torch.exp(entropy)
|
|
return {'logfro': logfro, 'anisotropy': aniso, 'min_over_max': min_over_max, 'eff_rank': eff_rank}
|
|
out: dict[str, torch.Tensor] = {}
|
|
for (tag, samples) in (('total', norms_total), ('flow', norms_flow), ('packet', norms_packet)):
|
|
summ = _spectral_summary(samples)
|
|
for (stat_name, val) in summ.items():
|
|
out[f'jac_{stat_name}_{tag}'] = val
|
|
return out
|
|
|
|
@torch.no_grad()
|
|
def sample(self, n: int, lens: torch.Tensor, device: torch.device, n_steps: int=50, method: str='euler') -> torch.Tensor:
|
|
z = torch.randn(n, self.seq_len, self.token_dim, device=device)
|
|
ts = torch.linspace(0.0, 1.0, n_steps + 1, device=device)
|
|
kpm = self.key_padding_mask(lens.to(device))
|
|
|
|
def f(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
return self.velocity(x, t.expand(x.shape[0]), key_padding_mask=kpm)
|
|
if method == 'euler':
|
|
for i in range(n_steps):
|
|
z = z + f(ts[i], z) * (ts[i + 1] - ts[i])
|
|
return z
|
|
return odeint(f, z, ts, method=method)[-1]
|
|
|
|
def param_count(self) -> int:
|
|
return sum((p.numel() for p in self.parameters()))
|