from __future__ import annotations import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F import importlib.util as _ilu import sys as _sys from pathlib import Path as _Path _UNIFIED_NAME = 'unified_cfm_model' if _UNIFIED_NAME not in _sys.modules: _unified_spec = _ilu.spec_from_file_location(_UNIFIED_NAME, _Path(__file__).resolve().parents[1] / 'Unified_CFM' / 'model.py') _unified = _ilu.module_from_spec(_unified_spec) _sys.modules[_UNIFIED_NAME] = _unified _unified_spec.loader.exec_module(_unified) else: _unified = _sys.modules[_UNIFIED_NAME] AdaLNBlock = _unified.AdaLNBlock SinusoidalTimeEmb = _unified.SinusoidalTimeEmb _sinkhorn_coupling = _unified._sinkhorn_coupling @dataclass class MixedCFMConfig: T: int = 64 flow_dim: int = 20 n_cont_pkt: int = 3 n_disc_pkt: int = 6 cont_pkt_idx: tuple[int, ...] = (0, 1, 8) disc_pkt_idx: tuple[int, ...] = (2, 3, 4, 5, 6, 7) n_disc_classes: int = 2 token_dim: int | None = None d_model: int = 128 n_layers: int = 4 n_heads: int = 4 mlp_ratio: float = 4.0 time_dim: int = 64 sigma: float = 0.1 use_ot: bool = False reference_mode: str | None = None lambda_disc: float = 1.0 disc_path: str = 'uniform' disc_embed_scale: float = 1.0 # ---- B-group ablation flags (defaults preserve JANUS-full behavior) ---- use_flow_token: bool = True # B1: False removes the [FLOW] token n_packet_tokens: int = -1 # B2: 0 removes packet tokens entirely; -1 = use cfg.T disc_as_cont: bool = False # B3: feed 6 disc bits through CFM head as continuous values cont_as_disc: bool = False # B4: quantize 3 cont channels into n_disc_classes bins (mask-pred only) def __post_init__(self) -> None: if len(self.cont_pkt_idx) != self.n_cont_pkt: raise ValueError('cont_pkt_idx length mismatch n_cont_pkt') if len(self.disc_pkt_idx) != self.n_disc_pkt: raise ValueError('disc_pkt_idx length mismatch n_disc_pkt') if self.disc_path != 'uniform': raise NotImplementedError(f'disc_path={self.disc_path}') if self.disc_as_cont and self.cont_as_disc: raise ValueError('disc_as_cont and cont_as_disc are mutually exclusive') class MixedVelocity(nn.Module): def __init__(self, token_dim: int, seq_len: int, n_disc: int, n_classes: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None, has_flow_token: bool=True) -> None: super().__init__() if reference_mode not in (None, 'causal_packets', 'causal_all'): raise ValueError(f'reference_mode={reference_mode!r}') self.token_dim = token_dim self.seq_len = seq_len self.n_disc = n_disc self.n_classes = n_classes self.reference_mode = reference_mode self.has_flow_token = has_flow_token self.input_proj = nn.Linear(token_dim, d_model) self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model)) self.type_emb = nn.Embedding(2, d_model) nn.init.trunc_normal_(self.pos_emb, std=0.02) nn.init.normal_(self.type_emb.weight, std=0.02) self.time_emb = SinusoidalTimeEmb(time_dim) self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model)) self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)]) self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False) self.head_v = nn.Linear(d_model, token_dim) # head_disc only meaningful when n_disc > 0 out_disc = max(n_disc * n_classes, 1) self.head_disc = nn.Linear(d_model, out_disc) for layer in (self.head_v, self.head_disc): nn.init.zeros_(layer.weight) nn.init.zeros_(layer.bias) type_ids = torch.ones(seq_len, dtype=torch.long) if has_flow_token and seq_len >= 1: type_ids[0] = 0 self.register_buffer('type_ids', type_ids, persistent=False) def _attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None: if self.reference_mode is None: return None if self.reference_mode == 'causal_packets': mask = torch.zeros((L, L), dtype=torch.bool, device=device) offset = 1 if self.has_flow_token else 0 if L > offset: M = L - offset if M > 1: mask[offset:, offset:] = torch.triu(torch.ones(M, M, dtype=torch.bool, device=device), diagonal=1) return mask return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1) def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None) -> tuple[torch.Tensor, torch.Tensor]: (B, L, _) = x.shape if t.dim() == 0: t = t.expand(B) h = self.input_proj(x) h = h + self.pos_emb[:, :L, :] + self.type_emb(self.type_ids[:L])[None, :, :] cond = self.cond_mlp(self.time_emb(t)) attn_mask = self._attn_mask(L, x.device) for block in self.blocks: h = block(h, cond, key_padding_mask, attn_mask=attn_mask) h = self.out_norm(h) v = self.head_v(h) if self.n_disc > 0: d = self.head_disc(h).view(B, L, self.n_disc, self.n_classes) else: d = h.new_zeros((B, L, 0, self.n_classes)) return (v, d) class MixedTokenCFM(nn.Module): def __init__(self, cfg: MixedCFMConfig) -> None: super().__init__() self.cfg = cfg # Effective packet count (B2: n_packet_tokens=0 → no packets) self.eff_T = cfg.T if cfg.n_packet_tokens < 0 else int(cfg.n_packet_tokens) if not cfg.use_flow_token and self.eff_T == 0: raise ValueError('cannot disable both FLOW token and packet tokens') # Effective per-packet feature split if cfg.disc_as_cont: # B3: 9 cont, 0 disc (CFM head only) self.eff_n_cont = cfg.n_cont_pkt + cfg.n_disc_pkt self.eff_n_disc = 0 elif cfg.cont_as_disc: # B4: 0 cont, 9 disc (mask-pred head only) self.eff_n_cont = 0 self.eff_n_disc = cfg.n_cont_pkt + cfg.n_disc_pkt else: self.eff_n_cont = cfg.n_cont_pkt self.eff_n_disc = cfg.n_disc_pkt cont_size = self.eff_n_cont + self.eff_n_disc # Token layout: [type_flag(1) | flow_dim or cont_size] self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cont_size) if self.token_dim < 1 + max(cfg.flow_dim, cont_size): raise ValueError('token_dim too small') self.seq_len = (1 if cfg.use_flow_token else 0) + self.eff_T self.velocity = MixedVelocity( token_dim=self.token_dim, seq_len=self.seq_len, n_disc=self.eff_n_disc, n_classes=cfg.n_disc_classes, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode, has_flow_token=cfg.use_flow_token, ) # ------------------------------------------------------------------ # # token assembly # # ------------------------------------------------------------------ # def _embed_disc(self, x_disc_int: torch.Tensor) -> torch.Tensor: n = self.cfg.n_disc_classes s = self.cfg.disc_embed_scale if n <= 1: return x_disc_int.float() * 0.0 # Map integers in [0, n-1] to centered floats in [-s/2, +s/2]. # Backwards-compatible with old (x - 0.5)*s formula when n=2. return (x_disc_int.float() / (n - 1) - 0.5) * s def _flow_dim(self) -> int: return self.cfg.flow_dim def build_tokens(self, flow: torch.Tensor, packets_cont: torch.Tensor, x_disc_t_int: torch.Tensor) -> torch.Tensor: """Assemble [B, seq_len, token_dim]. packets_cont: [B, eff_T, eff_n_cont] (may be empty in last dim) x_disc_t_int: [B, eff_T, eff_n_disc] integer ids in [0, n_disc_classes-1] """ B = flow.shape[0] device = flow.device T = self.eff_T z = flow.new_zeros((B, self.seq_len, self.token_dim)) cur = 0 if self.cfg.use_flow_token: z[:, 0, 0] = -1.0 # type flag z[:, 0, 1:1 + self._flow_dim()] = flow cur = 1 if T > 0: z[:, cur:cur + T, 0] = 1.0 # type flag base = 1 if self.eff_n_cont > 0: z[:, cur:cur + T, base:base + self.eff_n_cont] = packets_cont base += self.eff_n_cont if self.eff_n_disc > 0: z[:, cur:cur + T, base:base + self.eff_n_disc] = self._embed_disc(x_disc_t_int) return z def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor: B = lens.shape[0] device = lens.device T = self.eff_T pieces = [] if self.cfg.use_flow_token: pieces.append(torch.ones(B, 1, dtype=torch.bool, device=device)) if T > 0: idx = torch.arange(T, device=device)[None, :] pieces.append(idx < lens[:, None]) real = torch.cat(pieces, dim=1) if pieces else torch.ones(B, 0, dtype=torch.bool, device=device) return ~real def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor: return (~self.key_padding_mask(lens)).float() # ------------------------------------------------------------------ # # B4 helper: quantize cont -> integer bins # # ------------------------------------------------------------------ # def quantize_cont(self, packets_cont: torch.Tensor, bin_edges: torch.Tensor) -> torch.Tensor: """packets_cont [B, T, n_cont_orig] (already z-scored); bin_edges [n_cont_orig, n_classes-1] returns int64 [B, T, n_cont_orig] in [0, n_classes-1].""" B, T, C = packets_cont.shape out = torch.zeros((B, T, C), dtype=torch.long, device=packets_cont.device) for c in range(C): edges = bin_edges[c] # [n_classes-1] # bucketize: returns 0..n for n edges out[:, :, c] = torch.bucketize(packets_cont[:, :, c].contiguous(), edges) out.clamp_(0, self.cfg.n_disc_classes - 1) return out # ------------------------------------------------------------------ # # Loss # # ------------------------------------------------------------------ # def compute_loss(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, *, return_components: bool=False, cont_bin_edges: torch.Tensor | None=None) -> torch.Tensor | dict[str, torch.Tensor]: cfg = self.cfg B = flow.shape[0] T = self.eff_T device = flow.device # Resolve effective cont/disc tensors per ablation mode if cfg.disc_as_cont: # 9 cont = original 3 cont + 6 disc-as-float disc_as_cont_float = self._embed_disc(packets_disc) if T > 0 else None if T > 0: eff_cont = torch.cat([packets_cont, disc_as_cont_float], dim=-1) if cfg.n_cont_pkt > 0 else disc_as_cont_float else: eff_cont = packets_cont.new_zeros((B, 0, 0)) eff_disc_int = torch.zeros((B, T, 0), dtype=torch.long, device=device) elif cfg.cont_as_disc: # 0 cont, 9 disc: quantize cont via supplied bin_edges if T > 0: if cont_bin_edges is None: raise ValueError('cont_as_disc requires cont_bin_edges') cont_int = self.quantize_cont(packets_cont, cont_bin_edges) eff_disc_int = torch.cat([cont_int, packets_disc.long()], dim=-1) else: eff_disc_int = torch.zeros((B, 0, self.eff_n_disc), dtype=torch.long, device=device) eff_cont = flow.new_zeros((B, T, 0)) else: eff_cont = packets_cont if T > 0 else packets_cont.new_zeros((B, 0, cfg.n_cont_pkt)) eff_disc_int = packets_disc.long() if T > 0 else torch.zeros((B, 0, cfg.n_disc_pkt), dtype=torch.long, device=device) # Build x_1 (data tokens; mask-pred path uses zero ids for disc at packet positions during CFM regression) zero_disc = torch.zeros_like(eff_disc_int) x_1_cont = self.build_tokens(flow, eff_cont, zero_disc) mask = self._loss_mask(lens) kpm = mask == 0 x_0_cont = torch.randn_like(x_1_cont) if cfg.use_ot: flat0 = (x_0_cont * mask[:, :, None]).reshape(B, -1) flat1 = (x_1_cont * mask[:, :, None]).reshape(B, -1) col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float())) x_1_cont = x_1_cont[col] eff_cont = eff_cont[col] if eff_cont.numel() > 0 else eff_cont eff_disc_int = eff_disc_int[col] if eff_disc_int.numel() > 0 else eff_disc_int packets_disc = packets_disc[col] flow = flow[col] lens = lens[col] mask = self._loss_mask(lens) kpm = mask == 0 t = torch.rand(B, device=device) x_t_cont = (1.0 - t[:, None, None]) * x_0_cont + t[:, None, None] * x_1_cont if cfg.sigma > 0: std = cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None] x_t_cont = x_t_cont + std * torch.randn_like(x_t_cont) target_cont = x_1_cont - x_0_cont # Disc corruption schedule (mask-pred): keep fraction t of true labels if T > 0 and self.eff_n_disc > 0: u = torch.rand(B, T, self.eff_n_disc, device=device) keep = u < t[:, None, None] rand_disc = torch.randint(0, cfg.n_disc_classes, eff_disc_int.shape, device=device) x_disc_t = torch.where(keep, eff_disc_int, rand_disc) disc_start = (1 if cfg.use_flow_token else 0) + 0 # placeholder; overwritten below # Where in x_t_full do disc embeds go? # Within each packet token: [type(1) | cont(eff_n_cont) | disc(eff_n_disc) | pad...] disc_start_in_token = 1 + self.eff_n_cont cur_offset = 1 if cfg.use_flow_token else 0 x_t_full = x_t_cont.clone() x_t_full[:, cur_offset:cur_offset + T, disc_start_in_token:disc_start_in_token + self.eff_n_disc] = self._embed_disc(x_disc_t) else: x_t_full = x_t_cont x_disc_t = eff_disc_int # unused keep = None (v_pred, d_logits) = self.velocity(x_t_full, t, key_padding_mask=kpm) # CFM regression loss on cont slots (mask out disc slots) v_err = (v_pred - target_cont).square() if T > 0 and self.eff_n_disc > 0: disc_start_in_token = 1 + self.eff_n_cont cur_offset = 1 if cfg.use_flow_token else 0 v_err[:, cur_offset:cur_offset + T, disc_start_in_token:disc_start_in_token + self.eff_n_disc] = 0.0 v_per_token = v_err.mean(dim=-1) per_sample = (v_per_token * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) L_cont = per_sample.mean() # Mask-pred CE on corrupted disc positions if T > 0 and self.eff_n_disc > 0 and keep is not None: cur_offset = 1 if cfg.use_flow_token else 0 pkt_logits = d_logits[:, cur_offset:cur_offset + T] pkt_real = mask[:, cur_offset:cur_offset + T].bool() corrupt = ~keep & pkt_real[:, :, None] flat_logits = pkt_logits.reshape(-1, cfg.n_disc_classes) flat_targets = eff_disc_int.reshape(-1).long() flat_ce = F.cross_entropy(flat_logits, flat_targets, reduction='none') flat_ce = flat_ce.view(B, T, self.eff_n_disc) flat_ce = flat_ce * corrupt.float() denom = corrupt.float().sum().clamp_min(1.0) L_disc = flat_ce.sum() / denom else: L_disc = L_cont.new_zeros(()) total = L_cont + cfg.lambda_disc * L_disc if return_components: return {'total': total, 'main': L_cont.detach(), 'aux_disc': L_disc.detach(), 'aux_flow': L_cont.new_zeros(()), 'aux_packet': L_cont.new_zeros(())} return total # ------------------------------------------------------------------ # # Scoring # # ------------------------------------------------------------------ # @torch.no_grad() def trajectory_metrics(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, n_steps: int=16, cont_bin_edges: torch.Tensor | None=None) -> dict[str, torch.Tensor]: cfg = self.cfg B = flow.shape[0] T = self.eff_T # Build effective cont / disc tensors per ablation mode if cfg.disc_as_cont: disc_float = self._embed_disc(packets_disc) if T > 0 else None if T > 0: eff_cont = torch.cat([packets_cont, disc_float], dim=-1) if cfg.n_cont_pkt > 0 else disc_float else: eff_cont = packets_cont.new_zeros((B, 0, 0)) eff_disc_int = torch.zeros((B, T, 0), dtype=torch.long, device=flow.device) elif cfg.cont_as_disc: if T > 0: if cont_bin_edges is None: raise ValueError('cont_as_disc requires cont_bin_edges at scoring time') cont_int = self.quantize_cont(packets_cont, cont_bin_edges) eff_disc_int = torch.cat([cont_int, packets_disc.long()], dim=-1) else: eff_disc_int = torch.zeros((B, 0, 0), dtype=torch.long, device=flow.device) eff_cont = flow.new_zeros((B, T, 0)) else: eff_cont = packets_cont if T > 0 else packets_cont.new_zeros((B, 0, cfg.n_cont_pkt)) eff_disc_int = packets_disc.long() if T > 0 else torch.zeros((B, 0, cfg.n_disc_pkt), dtype=torch.long, device=flow.device) z = self.build_tokens(flow, eff_cont, eff_disc_int) mask = self._loss_mask(lens) kpm = mask == 0 dt = 1.0 / n_steps # Disc embed slot bounds (within token vector) for "freeze disc during ODE" cur_offset = 1 if cfg.use_flow_token else 0 disc_start_in_token = 1 + self.eff_n_cont disc_end_in_token = disc_start_in_token + self.eff_n_disc if self.eff_n_disc > 0 and T > 0: disc_embed = z[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token].clone() else: disc_embed = None for k in range(n_steps): t_val = 1.0 - k * dt t = torch.full((B,), t_val, device=z.device) (v, _) = self.velocity(z, t, key_padding_mask=kpm) if self.eff_n_disc > 0 and T > 0: v[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token] = 0.0 z = z - v * dt if disc_embed is not None: z[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token] = disc_embed # Compute terminal-norm scores. Zero out the discrete embed slots so they don't pollute. z_real = z * mask[:, :, None] z_cont = z_real.clone() if self.eff_n_disc > 0 and T > 0: z_cont[:, cur_offset:cur_offset + T, disc_start_in_token:disc_end_in_token] = 0.0 full_norm = z_cont.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt() out = {'terminal_norm': full_norm} if cfg.use_flow_token: out['terminal_flow'] = z_cont[:, 0].norm(dim=-1) / math.sqrt(self.token_dim) if T > 0: packet_count = mask[:, cur_offset:cur_offset + T].sum(dim=-1).clamp_min(1.0) out['terminal_packet'] = (z_cont[:, cur_offset:cur_offset + T] * mask[:, cur_offset:cur_offset + T, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt() return out @torch.no_grad() def disc_nll_score(self, flow: torch.Tensor, packets_cont: torch.Tensor, packets_disc: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, cont_bin_edges: torch.Tensor | None=None) -> dict[str, torch.Tensor]: cfg = self.cfg B = flow.shape[0] T = self.eff_T device = flow.device if T == 0 or self.eff_n_disc == 0: return {} # no disc head to score # Build effective disc int per mode if cfg.cont_as_disc: if cont_bin_edges is None: raise ValueError('cont_as_disc requires cont_bin_edges at scoring time') cont_int = self.quantize_cont(packets_cont, cont_bin_edges) eff_disc_int = torch.cat([cont_int, packets_disc.long()], dim=-1) eff_cont = flow.new_zeros((B, T, 0)) ch_idx_list = list(cfg.cont_pkt_idx) + list(cfg.disc_pkt_idx) else: eff_disc_int = packets_disc.long() eff_cont = packets_cont ch_idx_list = list(cfg.disc_pkt_idx) mask = self._loss_mask(lens) kpm = mask == 0 z = self.build_tokens(flow, eff_cont, eff_disc_int) t = torch.full((B,), float(t_eval), device=device) (_, d_logits) = self.velocity(z, t, key_padding_mask=kpm) cur_offset = 1 if cfg.use_flow_token else 0 pkt_logits = d_logits[:, cur_offset:cur_offset + T] flat_logits = pkt_logits.reshape(-1, cfg.n_disc_classes) flat_targets = eff_disc_int.reshape(-1).long() ce = F.cross_entropy(flat_logits, flat_targets, reduction='none') ce = ce.view(B, T, self.eff_n_disc) pkt_real = mask[:, cur_offset:cur_offset + T].bool().float() per_sample = (ce.sum(dim=-1) * pkt_real).sum(dim=-1) / pkt_real.sum(dim=-1).clamp_min(1.0) per_ch = (ce * pkt_real[:, :, None]).sum(dim=1) / pkt_real.sum(dim=1).clamp_min(1.0)[:, None] out = {'disc_nll_total': per_sample} for c, idx in enumerate(ch_idx_list): out[f'disc_nll_ch{idx}'] = per_ch[:, c] return out def param_count(self) -> int: return sum((p.numel() for p in self.parameters()))