Initial commit: code, paper, small artifacts
This commit is contained in:
588
Unified_CFM/model.py
Normal file
588
Unified_CFM/model.py
Normal file
@@ -0,0 +1,588 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchdiffeq import odeint
|
||||
|
||||
@torch.no_grad()
|
||||
def _sinkhorn_coupling(C: torch.Tensor, reg: float=0.05, n_iter: int=20) -> torch.Tensor:
|
||||
C = C.float()
|
||||
log_k = -C / reg
|
||||
B = C.shape[0]
|
||||
log_u = torch.zeros(B, device=C.device)
|
||||
log_v = torch.zeros(B, device=C.device)
|
||||
for _ in range(n_iter):
|
||||
log_v = -torch.logsumexp(log_k + log_u.unsqueeze(1), dim=0)
|
||||
log_u = -torch.logsumexp(log_k + log_v.unsqueeze(0), dim=1)
|
||||
log_p = log_u.unsqueeze(1) + log_k + log_v.unsqueeze(0)
|
||||
return log_p.argmax(dim=1)
|
||||
|
||||
class SinusoidalTimeEmb(nn.Module):
|
||||
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError('time embedding dimension must be even')
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
half = self.dim // 2
|
||||
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half - 1, 1))
|
||||
args = t[:, None] * freqs[None, :]
|
||||
return torch.cat([args.sin(), args.cos()], dim=-1)
|
||||
|
||||
class AdaLNBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float, cond_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
|
||||
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
hidden = int(d_model * mlp_ratio)
|
||||
self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, d_model))
|
||||
self.cond_proj = nn.Linear(cond_dim, 6 * d_model)
|
||||
nn.init.zeros_(self.cond_proj.weight)
|
||||
nn.init.zeros_(self.cond_proj.bias)
|
||||
|
||||
@staticmethod
|
||||
def _modulate(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
|
||||
return x * (1.0 + gamma[:, None, :]) + beta[:, None, :]
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor, key_padding_mask: torch.Tensor | None, attn_mask: torch.Tensor | None=None) -> torch.Tensor:
|
||||
(g1, b1, a1, g2, b2, a2) = self.cond_proj(cond).chunk(6, dim=-1)
|
||||
h = self._modulate(self.norm1(x), g1, b1)
|
||||
(attn_out, _) = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
|
||||
x = x + a1[:, None, :] * attn_out
|
||||
h = self._modulate(self.norm2(x), g2, b2)
|
||||
return x + a2[:, None, :] * self.mlp(h)
|
||||
|
||||
class UnifiedVelocity(nn.Module):
|
||||
|
||||
def __init__(self, token_dim: int, seq_len: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None:
|
||||
super().__init__()
|
||||
if reference_mode not in (None, 'independent_token', 'block_diagonal', 'causal_packets', 'causal_all'):
|
||||
raise ValueError(f'unknown reference_mode={reference_mode!r}')
|
||||
self.token_dim = token_dim
|
||||
self.seq_len = seq_len
|
||||
self.reference_mode = reference_mode
|
||||
self.input_proj = nn.Linear(token_dim, d_model)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
|
||||
self.type_emb = nn.Embedding(2, d_model)
|
||||
nn.init.trunc_normal_(self.pos_emb, std=0.02)
|
||||
nn.init.normal_(self.type_emb.weight, std=0.02)
|
||||
self.time_emb = SinusoidalTimeEmb(time_dim)
|
||||
self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model))
|
||||
self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)])
|
||||
self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.out = nn.Linear(d_model, token_dim)
|
||||
nn.init.zeros_(self.out.weight)
|
||||
nn.init.zeros_(self.out.bias)
|
||||
type_ids = torch.ones(seq_len, dtype=torch.long)
|
||||
type_ids[0] = 0
|
||||
self.register_buffer('type_ids', type_ids, persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None, attn_mask_override: torch.Tensor | None=None) -> torch.Tensor:
|
||||
(B, L, _) = x.shape
|
||||
if L > self.seq_len:
|
||||
raise ValueError(f'sequence length {L} exceeds configured {self.seq_len}')
|
||||
if t.dim() == 0:
|
||||
t = t.expand(B)
|
||||
h = self.input_proj(x)
|
||||
h = h + self.pos_emb[:, :L, :]
|
||||
h = h + self.type_emb(self.type_ids[:L])[None, :, :]
|
||||
cond = self.cond_mlp(self.time_emb(t))
|
||||
if attn_mask_override is not None:
|
||||
attn_mask = attn_mask_override
|
||||
else:
|
||||
attn_mask = self._reference_attn_mask(L, x.device)
|
||||
for block in self.blocks:
|
||||
h = block(h, cond, key_padding_mask, attn_mask=attn_mask)
|
||||
return self.out(self.out_norm(h))
|
||||
|
||||
def _reference_attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None:
|
||||
if self.reference_mode is None:
|
||||
return None
|
||||
if self.reference_mode == 'independent_token':
|
||||
return ~torch.eye(L, dtype=torch.bool, device=device)
|
||||
if self.reference_mode == 'block_diagonal':
|
||||
mask = torch.ones((L, L), dtype=torch.bool, device=device)
|
||||
mask[0, 0] = False
|
||||
if L > 1:
|
||||
mask[1:, 1:] = False
|
||||
return mask
|
||||
if self.reference_mode == 'causal_packets':
|
||||
mask = torch.zeros((L, L), dtype=torch.bool, device=device)
|
||||
if L > 1:
|
||||
packet_causal = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1)
|
||||
mask[1:, 1:] = packet_causal
|
||||
return mask
|
||||
if self.reference_mode == 'causal_all':
|
||||
return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
|
||||
raise AssertionError(self.reference_mode)
|
||||
|
||||
@dataclass
|
||||
class UnifiedCFMConfig:
|
||||
T: int = 128
|
||||
packet_dim: int = 9
|
||||
flow_dim: int = 16
|
||||
token_dim: int | None = None
|
||||
d_model: int = 128
|
||||
n_layers: int = 4
|
||||
n_heads: int = 4
|
||||
mlp_ratio: float = 4.0
|
||||
time_dim: int = 64
|
||||
sigma: float = 0.1
|
||||
use_ot: bool = False
|
||||
reference_mode: str | None = None
|
||||
|
||||
class UnifiedTokenCFM(nn.Module):
|
||||
|
||||
def __init__(self, cfg: UnifiedCFMConfig) -> None:
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cfg.packet_dim)
|
||||
if self.token_dim < 1 + max(cfg.flow_dim, cfg.packet_dim):
|
||||
raise ValueError('token_dim is too small for flow_dim/packet_dim')
|
||||
self.seq_len = cfg.T + 1
|
||||
self.velocity = UnifiedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode)
|
||||
|
||||
def build_tokens(self, flow: torch.Tensor, packets: torch.Tensor) -> torch.Tensor:
|
||||
(B, T, Dp) = packets.shape
|
||||
if T != self.cfg.T:
|
||||
raise ValueError(f'packet T={T} but config T={self.cfg.T}')
|
||||
if Dp != self.cfg.packet_dim:
|
||||
raise ValueError(f'packet_dim={Dp} but config packet_dim={self.cfg.packet_dim}')
|
||||
if flow.shape[-1] != self.cfg.flow_dim:
|
||||
raise ValueError(f'flow_dim={flow.shape[-1]} but config flow_dim={self.cfg.flow_dim}')
|
||||
z = packets.new_zeros((B, T + 1, self.token_dim))
|
||||
z[:, 0, 0] = -1.0
|
||||
z[:, 0, 1:1 + self.cfg.flow_dim] = flow
|
||||
z[:, 1:, 0] = 1.0
|
||||
z[:, 1:, 1:1 + self.cfg.packet_dim] = packets
|
||||
return z
|
||||
|
||||
def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
||||
B = lens.shape[0]
|
||||
idx = torch.arange(self.cfg.T, device=lens.device)[None, :]
|
||||
packet_real = idx < lens[:, None]
|
||||
real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1)
|
||||
return ~real
|
||||
|
||||
def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
||||
return (~self.key_padding_mask(lens)).float()
|
||||
|
||||
@staticmethod
|
||||
def _masked_trimmed_mean(values: torch.Tensor, mask: torch.Tensor, trim_frac: float=0.1) -> torch.Tensor:
|
||||
out = values.new_zeros(values.shape[0])
|
||||
for i in range(values.shape[0]):
|
||||
v = values[i][mask[i] > 0]
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
if v.numel() < 5:
|
||||
out[i] = v.mean()
|
||||
continue
|
||||
v_sorted = torch.sort(v).values
|
||||
lo = int(trim_frac * v_sorted.numel())
|
||||
hi = int((1.0 - trim_frac) * v_sorted.numel())
|
||||
if hi <= lo:
|
||||
out[i] = v_sorted.mean()
|
||||
else:
|
||||
out[i] = v_sorted[lo:hi].mean()
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _masked_median(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = values.new_zeros(values.shape[0])
|
||||
for i in range(values.shape[0]):
|
||||
v = values[i][mask[i] > 0]
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
v_sorted = torch.sort(v).values
|
||||
mid = v_sorted.numel() // 2
|
||||
if v_sorted.numel() % 2:
|
||||
out[i] = v_sorted[mid]
|
||||
else:
|
||||
out[i] = 0.5 * (v_sorted[mid - 1] + v_sorted[mid])
|
||||
return out
|
||||
|
||||
def compute_loss(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, *, lambda_flow: float=0.0, lambda_packet: float=0.0, packet_mask_ratio: float=0.5, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]:
|
||||
x1 = self.build_tokens(flow, packets)
|
||||
B = x1.shape[0]
|
||||
x0 = torch.randn_like(x1)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
if self.cfg.use_ot:
|
||||
flat0 = (x0 * mask[:, :, None]).reshape(B, -1)
|
||||
flat1 = (x1 * mask[:, :, None]).reshape(B, -1)
|
||||
col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float()))
|
||||
x1 = x1[col]
|
||||
flow = flow[col]
|
||||
packets = packets[col]
|
||||
lens = lens[col]
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
t = torch.rand(B, device=x1.device)
|
||||
x_t = (1.0 - t[:, None, None]) * x0 + t[:, None, None] * x1
|
||||
if self.cfg.sigma > 0:
|
||||
std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
|
||||
x_t = x_t + std * torch.randn_like(x_t)
|
||||
target = x1 - x0
|
||||
pred = self.velocity(x_t, t, key_padding_mask=kpm)
|
||||
sq = (pred - target).square().mean(dim=-1)
|
||||
per_sample = (sq * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
main_loss = per_sample.mean()
|
||||
aux_flow_loss = x1.new_zeros(())
|
||||
aux_packet_loss = x1.new_zeros(())
|
||||
if lambda_flow > 0.0:
|
||||
x_t_mf = x_t.clone()
|
||||
x_t_mf[:, 0, :] = 0.0
|
||||
pred_mf = self.velocity(x_t_mf, t, key_padding_mask=kpm)
|
||||
err = (pred_mf[:, 0] - target[:, 0]).square().mean(dim=-1)
|
||||
aux_flow_loss = err.mean()
|
||||
if lambda_packet > 0.0:
|
||||
packet_real = mask[:, 1:] > 0
|
||||
rand_draw = torch.rand(packet_real.shape, device=x1.device)
|
||||
mask_pkt = (rand_draw < packet_mask_ratio) & packet_real
|
||||
pkt_mask_full = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x1.device), mask_pkt], dim=1)
|
||||
x_t_mp = x_t.clone()
|
||||
x_t_mp[pkt_mask_full] = 0.0
|
||||
pred_mp = self.velocity(x_t_mp, t, key_padding_mask=kpm)
|
||||
sq_mp = (pred_mp - target).square().mean(dim=-1)
|
||||
mask_f = pkt_mask_full.float()
|
||||
denom = mask_f.sum(dim=-1).clamp_min(1.0)
|
||||
aux_packet_loss = ((sq_mp * mask_f).sum(dim=-1) / denom).mean()
|
||||
total = main_loss + lambda_flow * aux_flow_loss + lambda_packet * aux_packet_loss
|
||||
if return_components:
|
||||
return {'total': total, 'main': main_loss.detach(), 'aux_flow': aux_flow_loss.detach(), 'aux_packet': aux_packet_loss.detach()}
|
||||
return total
|
||||
|
||||
@torch.no_grad()
|
||||
def velocity_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5, 0.75, 1.0)) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
total = torch.zeros(x.shape[0], device=x.device)
|
||||
flow_s = torch.zeros_like(total)
|
||||
packet_s = torch.zeros_like(total)
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
for t_val in t_eval:
|
||||
t = torch.full((x.shape[0],), float(t_val), device=x.device)
|
||||
v = self.velocity(x, t, key_padding_mask=kpm)
|
||||
e = v.square().mean(dim=-1)
|
||||
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
flow_s = flow_s + e[:, 0]
|
||||
packet_s = packet_s + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
denom = float(len(t_eval))
|
||||
return {'velocity_total': total / denom, 'velocity_flow': flow_s / denom, 'velocity_packet': packet_s / denom}
|
||||
|
||||
@torch.no_grad()
|
||||
def trajectory_metrics(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]:
|
||||
z = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
dt = 1.0 / n_steps
|
||||
total_arc = torch.zeros(B, device=z.device)
|
||||
total_ke = torch.zeros(B, device=z.device)
|
||||
flow_ke = torch.zeros(B, device=z.device)
|
||||
packet_ke = torch.zeros(B, device=z.device)
|
||||
total_curv = torch.zeros(B, device=z.device)
|
||||
flow_curv = torch.zeros(B, device=z.device)
|
||||
packet_curv = torch.zeros(B, device=z.device)
|
||||
packet_kappa2_speed2 = torch.zeros(B, max(z.shape[1] - 1, 0), device=z.device)
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
v_prev = None
|
||||
v_prev_norm = None
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v = self.velocity(z, t, key_padding_mask=kpm)
|
||||
e = v.square().mean(dim=-1)
|
||||
v_norm = v.square().sum(dim=-1).clamp_min(1e-12).sqrt()
|
||||
total_ke = total_ke + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) * dt
|
||||
flow_ke = flow_ke + e[:, 0] * dt
|
||||
packet_ke = packet_ke + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count * dt
|
||||
if v_prev is not None:
|
||||
dv = v - v_prev
|
||||
dve = dv.square().mean(dim=-1)
|
||||
total_curv = total_curv + (dve * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
flow_curv = flow_curv + dve[:, 0]
|
||||
packet_curv = packet_curv + (dve[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
dv2_sum = dv[:, 1:].square().sum(dim=-1)
|
||||
assert v_prev_norm is not None
|
||||
v_avg = 0.5 * (v_norm[:, 1:] + v_prev_norm[:, 1:])
|
||||
packet_kappa2_speed2 = packet_kappa2_speed2 + dv2_sum / v_avg.square().clamp_min(1e-06)
|
||||
v_prev = v
|
||||
v_prev_norm = v_norm
|
||||
z_new = z - v * dt
|
||||
dz = (z_new - z) * mask[:, :, None]
|
||||
total_arc = total_arc + dz.reshape(B, -1).norm(dim=-1) / mask.sum(dim=-1).sqrt()
|
||||
z = z_new
|
||||
z_masked = z * mask[:, :, None]
|
||||
terminal = z_masked.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
|
||||
terminal_flow = z[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
|
||||
terminal_packet = (z[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
|
||||
packet_mask = mask[:, 1:]
|
||||
kappa2_speed2_mean = (packet_kappa2_speed2 * packet_mask).sum(dim=-1) / packet_count
|
||||
kappa2_speed2_median = self._masked_median(packet_kappa2_speed2, packet_mask)
|
||||
kappa2_speed2_trimmed = self._masked_trimmed_mean(packet_kappa2_speed2, packet_mask)
|
||||
return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet, 'arc_length': total_arc, 'kinetic_energy': total_ke, 'kinetic_flow': flow_ke, 'kinetic_packet': packet_ke, 'curvature_total': total_curv, 'curvature_flow': flow_curv, 'curvature_packet': packet_curv, 'kappa2_speed2norm_packet_mean': kappa2_speed2_mean, 'kappa2_speed2norm_packet_median': kappa2_speed2_median, 'kappa2_speed2norm_packet_trimmed10_mean': kappa2_speed2_trimmed}
|
||||
|
||||
@torch.no_grad()
|
||||
def score_profile_vt(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0)) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
out: dict[str, torch.Tensor] = {}
|
||||
for t_val in t_eval:
|
||||
t = torch.full((x.shape[0],), float(t_val), device=x.device)
|
||||
v = self.velocity(x, t, key_padding_mask=kpm)
|
||||
e = v.square().mean(dim=-1)
|
||||
tag = f't{int(round(t_val * 10)):02d}'
|
||||
out[f'velocity_total_{tag}'] = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
out[f'velocity_flow_{tag}'] = e[:, 0]
|
||||
out[f'velocity_packet_{tag}'] = (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = x.shape[0]
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
t = torch.full((B,), float(t_eval), device=x.device)
|
||||
v_full = self.velocity(x, t, key_padding_mask=kpm)
|
||||
x_mf = x.clone()
|
||||
x_mf[:, 0, :] = 0.0
|
||||
v_mf = self.velocity(x_mf, t, key_padding_mask=kpm)
|
||||
flow_cons = (v_full[:, 0] - v_mf[:, 0]).square().mean(dim=-1)
|
||||
x_mp = x.clone()
|
||||
pkt_mask_full = mask[:, 1:] > 0
|
||||
idx_pkt_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x.device), pkt_mask_full], dim=1)
|
||||
x_mp[idx_pkt_mask] = 0.0
|
||||
v_mp = self.velocity(x_mp, t, key_padding_mask=kpm)
|
||||
diff = (v_full - v_mp).square().mean(dim=-1)
|
||||
packet_cons = (diff[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
return {'flow_consistency': flow_cons, 'packet_consistency': packet_cons, 'consistency_total': flow_cons + packet_cons}
|
||||
|
||||
def jacobian_hutchinson(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5,), n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = x.shape[0]
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
total = torch.zeros(B, device=x.device)
|
||||
flow_j = torch.zeros(B, device=x.device)
|
||||
packet_j = torch.zeros(B, device=x.device)
|
||||
n_draws = n_eps * len(t_eval)
|
||||
for t_val in t_eval:
|
||||
t_current = torch.full((B,), float(t_val), device=x.device)
|
||||
for _ in range(n_eps):
|
||||
x_req = x.detach().clone().requires_grad_(True)
|
||||
v = self.velocity(x_req, t_current, key_padding_mask=kpm)
|
||||
eps = torch.randn(v.shape, device=v.device, generator=generator)
|
||||
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
|
||||
e = g.square().mean(dim=-1)
|
||||
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
flow_j = flow_j + e[:, 0]
|
||||
packet_j = packet_j + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
return {'jacobian_total': (total / n_draws).detach(), 'jacobian_flow': (flow_j / n_draws).detach(), 'jacobian_packet': (packet_j / n_draws).detach()}
|
||||
|
||||
@torch.no_grad()
|
||||
def pna_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, flow_masked: bool=False) -> dict[str, torch.Tensor]:
|
||||
eps_v2 = 1e-06
|
||||
dt = 1.0 / n_steps
|
||||
z = self.build_tokens(flow, packets)
|
||||
if flow_masked:
|
||||
z = z.clone()
|
||||
z[:, 0, :] = 0.0
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, _) = z.shape
|
||||
pna = torch.zeros(B, L, device=z.device)
|
||||
v_prev: torch.Tensor | None = None
|
||||
v_norm_prev: torch.Tensor | None = None
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v = self.velocity(z, t, key_padding_mask=kpm)
|
||||
v_norm = (v.square().sum(dim=-1) + 1e-12).sqrt()
|
||||
if v_prev is not None:
|
||||
dv2 = (v - v_prev).square().sum(dim=-1)
|
||||
v_avg2 = (0.5 * (v_norm + v_norm_prev)).square().clamp_min(eps_v2)
|
||||
pna = pna + dv2 / v_avg2
|
||||
v_prev = v
|
||||
v_norm_prev = v_norm
|
||||
z = z - v * dt
|
||||
if flow_masked:
|
||||
z[:, 0, :] = 0.0
|
||||
flow_pna = pna[:, 0]
|
||||
packet_pna = pna[:, 1:]
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
pna_median = self._masked_median(packet_pna, packet_mask)
|
||||
pna_mean = (packet_pna * packet_mask).sum(dim=-1) / packet_count
|
||||
masked_for_max = packet_pna.masked_fill(packet_mask == 0, float('-inf'))
|
||||
pna_max = masked_for_max.max(dim=-1).values
|
||||
pna_trimmed = self._masked_trimmed_mean(packet_pna, packet_mask)
|
||||
return {'pna_packet_median': pna_median, 'pna_packet_mean': pna_mean, 'pna_packet_max': pna_max, 'pna_packet_trimmed10_mean': pna_trimmed, 'pna_flow': flow_pna}
|
||||
|
||||
@torch.no_grad()
|
||||
def causal_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, _) = x.shape
|
||||
t = torch.full((B,), float(t_eval), device=x.device)
|
||||
v_full = self.velocity(x, t, key_padding_mask=kpm)
|
||||
causal = torch.triu(torch.ones(L, L, dtype=torch.bool, device=x.device), diagonal=1)
|
||||
v_causal = self.velocity(x, t, key_padding_mask=kpm, attn_mask_override=causal)
|
||||
diff = (v_full - v_causal).square().mean(dim=-1)
|
||||
flow_surprisal = diff[:, 0]
|
||||
packet_diff = diff[:, 1:]
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
packet_mean = (packet_diff * packet_mask).sum(dim=-1) / packet_count
|
||||
packet_median = self._masked_median(packet_diff, packet_mask)
|
||||
masked_for_max = packet_diff.masked_fill(packet_mask == 0, float('-inf'))
|
||||
packet_max = masked_for_max.max(dim=-1).values
|
||||
packet_trimmed = self._masked_trimmed_mean(packet_diff, packet_mask)
|
||||
total = (diff * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
return {'causal_surprisal_total': total, 'causal_surprisal_flow': flow_surprisal, 'causal_surprisal_packet_mean': packet_mean, 'causal_surprisal_packet_median': packet_median, 'causal_surprisal_packet_max': packet_max, 'causal_surprisal_packet_trimmed10_mean': packet_trimmed}
|
||||
|
||||
@torch.no_grad()
|
||||
def direction_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.2, 0.4, 0.6, 0.8, 1.0)) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, _) = x.shape
|
||||
t_eval = tuple(t_eval)
|
||||
if len(t_eval) < 2:
|
||||
raise ValueError('direction_consistency_score needs >=2 t values')
|
||||
prev_v: torch.Tensor | None = None
|
||||
drift = x.new_zeros(B, L)
|
||||
n_pairs = len(t_eval) - 1
|
||||
for t_val in t_eval:
|
||||
t = torch.full((B,), float(t_val), device=x.device)
|
||||
v = self.velocity(x, t, key_padding_mask=kpm)
|
||||
if prev_v is not None:
|
||||
num = (prev_v * v).sum(dim=-1)
|
||||
denom = prev_v.norm(dim=-1).clamp_min(1e-08) * v.norm(dim=-1).clamp_min(1e-08)
|
||||
cos = num / denom
|
||||
drift = drift + (1.0 - cos)
|
||||
prev_v = v
|
||||
drift = drift / max(n_pairs, 1)
|
||||
flow_drift = drift[:, 0]
|
||||
packet_drift = drift[:, 1:]
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
packet_mean = (packet_drift * packet_mask).sum(dim=-1) / packet_count
|
||||
packet_median = self._masked_median(packet_drift, packet_mask)
|
||||
masked_for_max = packet_drift.masked_fill(packet_mask == 0, float('-inf'))
|
||||
packet_max = masked_for_max.max(dim=-1).values
|
||||
packet_trimmed = self._masked_trimmed_mean(packet_drift, packet_mask)
|
||||
total = (drift * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
return {'direction_drift_total': total, 'direction_drift_flow': flow_drift, 'direction_drift_packet_mean': packet_mean, 'direction_drift_packet_median': packet_median, 'direction_drift_packet_max': packet_max, 'direction_drift_packet_trimmed10_mean': packet_trimmed}
|
||||
|
||||
def inverse_flow_nll_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, n_eps: int=4, compute_divergence: bool=True, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
||||
z = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, D) = z.shape
|
||||
dt = 1.0 / n_steps
|
||||
accum_div = torch.zeros(B, device=z.device)
|
||||
if compute_divergence:
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
z_req = z.detach().clone().requires_grad_(True)
|
||||
v = self.velocity(z_req, t, key_padding_mask=kpm)
|
||||
div_step = torch.zeros(B, device=z.device)
|
||||
for j in range(n_eps):
|
||||
eps = torch.randn_like(v)
|
||||
eps_masked = eps * mask[:, :, None]
|
||||
retain = j < n_eps - 1
|
||||
(g,) = torch.autograd.grad(outputs=v, inputs=z_req, grad_outputs=eps_masked, retain_graph=retain, create_graph=False)
|
||||
div_step = div_step + (eps_masked * g).sum(dim=(1, 2))
|
||||
div_step = div_step / float(n_eps)
|
||||
accum_div = accum_div + div_step * dt
|
||||
with torch.no_grad():
|
||||
z = (z_req - v * dt).detach()
|
||||
else:
|
||||
with torch.no_grad():
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v = self.velocity(z, t, key_padding_mask=kpm)
|
||||
z = z - v * dt
|
||||
with torch.no_grad():
|
||||
z_masked = z * mask[:, :, None]
|
||||
n_real = mask.sum(dim=-1).clamp_min(1.0)
|
||||
x0_quadratic = z_masked.reshape(B, -1).square().sum(dim=-1) / (n_real * float(D))
|
||||
nll_x0_only = x0_quadratic
|
||||
nll_div_only = accum_div / (n_real * float(D))
|
||||
nll_full = nll_x0_only + nll_div_only
|
||||
return {'nll_x0_only': nll_x0_only.detach(), 'nll_div_only': nll_div_only.detach(), 'nll_full': nll_full.detach()}
|
||||
|
||||
def jacobian_spectral_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, D) = x.shape
|
||||
t = torch.full((B,), float(t_eval), device=x.device)
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
norms_total: list[torch.Tensor] = []
|
||||
norms_flow: list[torch.Tensor] = []
|
||||
norms_packet: list[torch.Tensor] = []
|
||||
for _ in range(n_eps):
|
||||
x_req = x.detach().clone().requires_grad_(True)
|
||||
v = self.velocity(x_req, t, key_padding_mask=kpm)
|
||||
eps = torch.randn(v.shape, device=v.device, generator=generator)
|
||||
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
|
||||
e = g.square().mean(dim=-1)
|
||||
n_total = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
n_flow = e[:, 0]
|
||||
n_packet = (e[:, 1:] * packet_mask).sum(dim=-1) / packet_count
|
||||
norms_total.append(n_total.detach())
|
||||
norms_flow.append(n_flow.detach())
|
||||
norms_packet.append(n_packet.detach())
|
||||
|
||||
def _spectral_summary(samples: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
stack = torch.stack(samples, dim=1)
|
||||
mean = stack.mean(dim=1).clamp_min(1e-12)
|
||||
mx = stack.max(dim=1).values
|
||||
mn = stack.min(dim=1).values
|
||||
logfro = torch.log(mean)
|
||||
aniso = mx / mean
|
||||
min_over_max = mn / mx.clamp_min(1e-12)
|
||||
p = stack / stack.sum(dim=1, keepdim=True).clamp_min(1e-12)
|
||||
entropy = -(p * p.clamp_min(1e-12).log()).sum(dim=1)
|
||||
eff_rank = torch.exp(entropy)
|
||||
return {'logfro': logfro, 'anisotropy': aniso, 'min_over_max': min_over_max, 'eff_rank': eff_rank}
|
||||
out: dict[str, torch.Tensor] = {}
|
||||
for (tag, samples) in (('total', norms_total), ('flow', norms_flow), ('packet', norms_packet)):
|
||||
summ = _spectral_summary(samples)
|
||||
for (stat_name, val) in summ.items():
|
||||
out[f'jac_{stat_name}_{tag}'] = val
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, n: int, lens: torch.Tensor, device: torch.device, n_steps: int=50, method: str='euler') -> torch.Tensor:
|
||||
z = torch.randn(n, self.seq_len, self.token_dim, device=device)
|
||||
ts = torch.linspace(0.0, 1.0, n_steps + 1, device=device)
|
||||
kpm = self.key_padding_mask(lens.to(device))
|
||||
|
||||
def f(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.velocity(x, t.expand(x.shape[0]), key_padding_mask=kpm)
|
||||
if method == 'euler':
|
||||
for i in range(n_steps):
|
||||
z = z + f(ts[i], z) * (ts[i + 1] - ts[i])
|
||||
return z
|
||||
return odeint(f, z, ts, method=method)[-1]
|
||||
|
||||
def param_count(self) -> int:
|
||||
return sum((p.numel() for p in self.parameters()))
|
||||
Reference in New Issue
Block a user