245 lines
12 KiB
Python
245 lines
12 KiB
Python
from __future__ import annotations
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import importlib.util as _ilu
|
|
import sys as _sys
|
|
from pathlib import Path as _Path
|
|
_UNIFIED_NAME = 'unified_cfm_model'
|
|
if _UNIFIED_NAME not in _sys.modules:
|
|
_unified_spec = _ilu.spec_from_file_location(_UNIFIED_NAME, _Path(__file__).resolve().parents[1] / 'Unified_CFM' / 'model.py')
|
|
_unified = _ilu.module_from_spec(_unified_spec)
|
|
_sys.modules[_UNIFIED_NAME] = _unified
|
|
_unified_spec.loader.exec_module(_unified)
|
|
else:
|
|
_unified = _sys.modules[_UNIFIED_NAME]
|
|
AdaLNBlock = _unified.AdaLNBlock
|
|
SinusoidalTimeEmb = _unified.SinusoidalTimeEmb
|
|
_sinkhorn_coupling = _unified._sinkhorn_coupling
|
|
|
|
@dataclass
|
|
class MixedCFMConfig:
|
|
T: int = 64
|
|
flow_dim: int = 20
|
|
n_cont_pkt: int = 3
|
|
n_disc_pkt: int = 6
|
|
cont_pkt_idx: tuple[int, ...] = (0, 1, 8)
|
|
disc_pkt_idx: tuple[int, ...] = (2, 3, 4, 5, 6, 7)
|
|
n_disc_classes: int = 2
|
|
token_dim: int | None = None
|
|
d_model: int = 128
|
|
n_layers: int = 4
|
|
n_heads: int = 4
|
|
mlp_ratio: float = 4.0
|
|
time_dim: int = 64
|
|
sigma: float = 0.1
|
|
use_ot: bool = False
|
|
reference_mode: str | None = None
|
|
lambda_disc: float = 1.0
|
|
disc_path: str = 'uniform'
|
|
disc_embed_scale: float = 1.0
|
|
|
|
def __post_init__(self) -> None:
|
|
if len(self.cont_pkt_idx) != self.n_cont_pkt:
|
|
raise ValueError('cont_pkt_idx length mismatch n_cont_pkt')
|
|
if len(self.disc_pkt_idx) != self.n_disc_pkt:
|
|
raise ValueError('disc_pkt_idx length mismatch n_disc_pkt')
|
|
if self.disc_path != 'uniform':
|
|
raise NotImplementedError(f'disc_path={self.disc_path}')
|
|
|
|
class MixedVelocity(nn.Module):
|
|
|
|
def __init__(self, token_dim: int, seq_len: int, n_disc: int, n_classes: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None:
|
|
super().__init__()
|
|
if reference_mode not in (None, 'causal_packets', 'causal_all'):
|
|
raise ValueError(f'reference_mode={reference_mode!r}')
|
|
self.token_dim = token_dim
|
|
self.seq_len = seq_len
|
|
self.n_disc = n_disc
|
|
self.n_classes = n_classes
|
|
self.reference_mode = reference_mode
|
|
self.input_proj = nn.Linear(token_dim, d_model)
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
|
|
self.type_emb = nn.Embedding(2, d_model)
|
|
nn.init.trunc_normal_(self.pos_emb, std=0.02)
|
|
nn.init.normal_(self.type_emb.weight, std=0.02)
|
|
self.time_emb = SinusoidalTimeEmb(time_dim)
|
|
self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model))
|
|
self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)])
|
|
self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
|
self.head_v = nn.Linear(d_model, token_dim)
|
|
self.head_disc = nn.Linear(d_model, n_disc * n_classes)
|
|
for layer in (self.head_v, self.head_disc):
|
|
nn.init.zeros_(layer.weight)
|
|
nn.init.zeros_(layer.bias)
|
|
type_ids = torch.ones(seq_len, dtype=torch.long)
|
|
type_ids[0] = 0
|
|
self.register_buffer('type_ids', type_ids, persistent=False)
|
|
|
|
def _attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None:
|
|
if self.reference_mode is None:
|
|
return None
|
|
if self.reference_mode == 'causal_packets':
|
|
mask = torch.zeros((L, L), dtype=torch.bool, device=device)
|
|
if L > 1:
|
|
mask[1:, 1:] = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1)
|
|
return mask
|
|
return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None) -> tuple[torch.Tensor, torch.Tensor]:
|
|
(B, L, _) = x.shape
|
|
if t.dim() == 0:
|
|
t = t.expand(B)
|
|
h = self.input_proj(x)
|
|
h = h + self.pos_emb[:, :L, :] + self.type_emb(self.type_ids[:L])[None, :, :]
|
|
cond = self.cond_mlp(self.time_emb(t))
|
|
attn_mask = self._attn_mask(L, x.device)
|
|
for block in self.blocks:
|
|
h = block(h, cond, key_padding_mask, attn_mask=attn_mask)
|
|
h = self.out_norm(h)
|
|
v = self.head_v(h)
|
|
d = self.head_disc(h).view(B, L, self.n_disc, self.n_classes)
|
|
return (v, d)
|
|
|
|
class MixedTokenCFM(nn.Module):
|
|
|
|
def __init__(self, cfg: MixedCFMConfig) -> None:
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
cont_size = cfg.n_cont_pkt + cfg.n_disc_pkt
|
|
self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cont_size)
|
|
if self.token_dim < 1 + max(cfg.flow_dim, cont_size):
|
|
raise ValueError('token_dim too small')
|
|
self.seq_len = cfg.T + 1
|
|
self.velocity = MixedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, n_disc=cfg.n_disc_pkt, n_classes=cfg.n_disc_classes, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode)
|
|
|
|
def _embed_disc(self, x_disc_int: torch.Tensor) -> torch.Tensor:
|
|
s = self.cfg.disc_embed_scale
|
|
return (x_disc_int.float() - 0.5) * s
|
|
|
|
def build_tokens(self, flow: torch.Tensor, packets_cont: torch.Tensor, x_disc_t_int: torch.Tensor) -> torch.Tensor:
|
|
(B, T, Cp) = packets_cont.shape
|
|
assert T == self.cfg.T and Cp == self.cfg.n_cont_pkt
|
|
z = packets_cont.new_zeros((B, T + 1, self.token_dim))
|
|
z[:, 0, 0] = -1.0
|
|
z[:, 0, 1:1 + self.cfg.flow_dim] = flow
|
|
z[:, 1:, 0] = 1.0
|
|
z[:, 1:, 1:1 + self.cfg.n_cont_pkt] = packets_cont
|
|
z[:, 1:, 1 + self.cfg.n_cont_pkt:1 + self.cfg.n_cont_pkt + self.cfg.n_disc_pkt] = self._embed_disc(x_disc_t_int)
|
|
return z
|
|
|
|
def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
|
B = lens.shape[0]
|
|
idx = torch.arange(self.cfg.T, device=lens.device)[None, :]
|
|
packet_real = idx < lens[:, None]
|
|
real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1)
|
|
return ~real
|
|
|
|
def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
|
return (~self.key_padding_mask(lens)).float()
|
|
|
|
def compute_loss(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, *, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
(B, T, _) = packets_cont.shape
|
|
device = packets_cont.device
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
x_1_cont = self.build_tokens(flow, packets_cont, torch.zeros_like(packets_disc))
|
|
x_0_cont = torch.randn_like(x_1_cont)
|
|
if self.cfg.use_ot:
|
|
flat0 = (x_0_cont * mask[:, :, None]).reshape(B, -1)
|
|
flat1 = (x_1_cont * mask[:, :, None]).reshape(B, -1)
|
|
col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float()))
|
|
x_1_cont = x_1_cont[col]
|
|
packets_cont = packets_cont[col]
|
|
packets_disc = packets_disc[col]
|
|
flow = flow[col]
|
|
lens = lens[col]
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
t = torch.rand(B, device=device)
|
|
x_t_cont = (1.0 - t[:, None, None]) * x_0_cont + t[:, None, None] * x_1_cont
|
|
if self.cfg.sigma > 0:
|
|
std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
|
|
x_t_cont = x_t_cont + std * torch.randn_like(x_t_cont)
|
|
target_cont = x_1_cont - x_0_cont
|
|
u = torch.rand(B, T, self.cfg.n_disc_pkt, device=device)
|
|
keep = u < t[:, None, None]
|
|
rand_disc = torch.randint(0, self.cfg.n_disc_classes, packets_disc.shape, device=device)
|
|
x_disc_t = torch.where(keep, packets_disc, rand_disc)
|
|
disc_start = 1 + self.cfg.n_cont_pkt
|
|
x_t_full = x_t_cont.clone()
|
|
x_t_full[:, 1:, disc_start:disc_start + self.cfg.n_disc_pkt] = self._embed_disc(x_disc_t)
|
|
(v_pred, d_logits) = self.velocity(x_t_full, t, key_padding_mask=kpm)
|
|
v_err = (v_pred - target_cont).square()
|
|
v_err[:, :, disc_start:disc_start + self.cfg.n_disc_pkt] = 0.0
|
|
v_per_token = v_err.mean(dim=-1)
|
|
per_sample = (v_per_token * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
|
L_cont = per_sample.mean()
|
|
pkt_logits = d_logits[:, 1:]
|
|
pkt_real = mask[:, 1:].bool()
|
|
corrupt = ~keep & pkt_real[:, :, None]
|
|
flat_logits = pkt_logits.reshape(-1, self.cfg.n_disc_classes)
|
|
flat_targets = packets_disc.reshape(-1).long()
|
|
flat_ce = F.cross_entropy(flat_logits, flat_targets, reduction='none')
|
|
flat_ce = flat_ce.view(B, T, self.cfg.n_disc_pkt)
|
|
flat_ce = flat_ce * corrupt.float()
|
|
denom = corrupt.float().sum().clamp_min(1.0)
|
|
L_disc = flat_ce.sum() / denom
|
|
total = L_cont + self.cfg.lambda_disc * L_disc
|
|
if return_components:
|
|
return {'total': total, 'main': L_cont.detach(), 'aux_disc': L_disc.detach(), 'aux_flow': L_cont.new_zeros(()), 'aux_packet': L_cont.new_zeros(())}
|
|
return total
|
|
|
|
@torch.no_grad()
|
|
def trajectory_metrics(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]:
|
|
z = self.build_tokens(flow, packets_cont, packets_disc)
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
B = z.shape[0]
|
|
dt = 1.0 / n_steps
|
|
disc_start = 1 + self.cfg.n_cont_pkt
|
|
disc_end = disc_start + self.cfg.n_disc_pkt
|
|
disc_embed = z[:, 1:, disc_start:disc_end].clone()
|
|
for k in range(n_steps):
|
|
t_val = 1.0 - k * dt
|
|
t = torch.full((B,), t_val, device=z.device)
|
|
(v, _) = self.velocity(z, t, key_padding_mask=kpm)
|
|
v[:, :, disc_start:disc_end] = 0.0
|
|
z = z - v * dt
|
|
z[:, 1:, disc_start:disc_end] = disc_embed
|
|
z_real = z * mask[:, :, None]
|
|
z_cont = z_real.clone()
|
|
z_cont[:, 1:, disc_start:disc_end] = 0.0
|
|
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
|
terminal = z_cont.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
|
|
terminal_flow = z_cont[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
|
|
terminal_packet = (z_cont[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
|
|
return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet}
|
|
|
|
@torch.no_grad()
|
|
def disc_nll_score(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
|
(B, T, _) = packets_cont.shape
|
|
device = packets_cont.device
|
|
mask = self._loss_mask(lens)
|
|
kpm = mask == 0
|
|
z = self.build_tokens(flow, packets_cont, packets_disc)
|
|
t = torch.full((B,), float(t_eval), device=device)
|
|
(_, d_logits) = self.velocity(z, t, key_padding_mask=kpm)
|
|
pkt_logits = d_logits[:, 1:]
|
|
flat_logits = pkt_logits.reshape(-1, self.cfg.n_disc_classes)
|
|
flat_targets = packets_disc.reshape(-1).long()
|
|
ce = F.cross_entropy(flat_logits, flat_targets, reduction='none')
|
|
ce = ce.view(B, T, self.cfg.n_disc_pkt)
|
|
pkt_real = mask[:, 1:].bool().float()
|
|
per_sample = (ce.sum(dim=-1) * pkt_real).sum(dim=-1) / pkt_real.sum(dim=-1).clamp_min(1.0)
|
|
per_ch = (ce * pkt_real[:, :, None]).sum(dim=1) / pkt_real.sum(dim=1).clamp_min(1.0)[:, None]
|
|
out = {'disc_nll_total': per_sample}
|
|
for (c, idx) in enumerate(self.cfg.disc_pkt_idx):
|
|
out[f'disc_nll_ch{idx}'] = per_ch[:, c]
|
|
return out
|
|
|
|
def param_count(self) -> int:
|
|
return sum((p.numel() for p in self.parameters()))
|