from __future__ import annotations import argparse import json import sys as _sys import time from dataclasses import asdict from pathlib import Path from pathlib import Path as _Path from typing import Any import numpy as np import torch import yaml from sklearn.metrics import roc_auc_score from torch.utils.data import DataLoader, TensorDataset _sys.path.insert(0, str(_Path(__file__).resolve().parent)) from data import MixedData, load_mixed_data, subsample_train from model import MixedCFMConfig, MixedTokenCFM def _device(arg: str) -> torch.device: if arg == 'auto': return torch.device('cuda' if torch.cuda.is_available() else 'cpu') return torch.device(arg) def _batch_score(model: MixedTokenCFM, flow_np: np.ndarray, cont_np: np.ndarray, disc_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int, cont_bin_edges: torch.Tensor | None = None) -> dict[str, np.ndarray]: out: dict[str, list[np.ndarray]] = {} model.eval() for start in range(0, len(flow_np), batch_size): sl = slice(start, start + batch_size) flow = torch.from_numpy(flow_np[sl]).float().to(device) cont = torch.from_numpy(cont_np[sl]).float().to(device) disc = torch.from_numpy(disc_np[sl]).long().to(device) lens = torch.from_numpy(len_np[sl]).long().to(device) m = model.trajectory_metrics(flow, cont, disc, lens, n_steps=n_steps, cont_bin_edges=cont_bin_edges) d = model.disc_nll_score(flow, cont, disc, lens, cont_bin_edges=cont_bin_edges) for src in (m, d): for (k, v) in src.items(): out.setdefault(k, []).append(v.detach().cpu().numpy()) return {k: np.concatenate(v, axis=0) for (k, v) in out.items()} def _quick_eval(model: MixedTokenCFM, data: MixedData, device: torch.device, cfg: dict[str, Any], cont_bin_edges: torch.Tensor | None = None) -> dict[str, float]: n_eval = int(cfg.get('eval_n', 2000)) rng = np.random.default_rng(0) def pick(n: int) -> np.ndarray: m = min(n_eval, n) return rng.choice(n, m, replace=False) vi = pick(len(data.val_flow)) ai = pick(len(data.attack_flow)) v = _batch_score(model, data.val_flow[vi], data.val_cont[vi], data.val_disc[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)), cont_bin_edges=cont_bin_edges) a = _batch_score(model, data.attack_flow[ai], data.attack_cont[ai], data.attack_disc[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)), cont_bin_edges=cont_bin_edges) y = np.concatenate([np.zeros(len(vi)), np.ones(len(ai))]) out: dict[str, float] = {} for k in sorted(v.keys()): s = np.concatenate([v[k], a[k]]) s = np.nan_to_num(s, nan=0.0, posinf=1000000000000.0, neginf=-1000000000000.0) out[f'auroc_{k}'] = float(roc_auc_score(y, s)) return out def train(cfg: dict[str, Any]) -> Path: device = _device(str(cfg.get('device', 'auto'))) save_dir = Path(cfg['save_dir']) save_dir.mkdir(parents=True, exist_ok=True) with open(save_dir / 'config.yaml', 'w') as f: yaml.safe_dump(cfg, f) seed = int(cfg.get('seed', 42)) data_seed = int(cfg.get('data_seed', seed)) torch.manual_seed(seed) np.random.seed(seed) print(f'Device: {device} seed=model:{seed}/data:{data_seed}') data = load_mixed_data(packets_npz=Path(cfg['packets_npz']) if cfg.get('packets_npz') else None, source_store=Path(cfg['source_store']) if cfg.get('source_store') else None, flows_parquet=Path(cfg['flows_parquet']), flow_features_path=Path(cfg['flow_features_path']), flow_feature_columns=cfg.get('flow_feature_columns'), flow_features_align=str(cfg.get('flow_features_align', 'auto')), T=int(cfg['T']), split_seed=data_seed, train_ratio=float(cfg.get('train_ratio', 0.8)), benign_label=str(cfg.get('benign_label', 'normal')), min_len=int(cfg.get('min_len', 2)), attack_cap=int(cfg['attack_cap']) if cfg.get('attack_cap') else None, val_cap=int(cfg['val_cap']) if cfg.get('val_cap') else None) print(f'[data] T={data.T} cont={data.n_cont} disc={data.n_disc} flow={data.flow_dim} train={len(data.train_flow):,} val={len(data.val_flow):,} attack={len(data.attack_flow):,}') (tr_f, tr_c, tr_d, tr_l) = subsample_train(data, int(cfg.get('n_train', 0)), data_seed) ds = TensorDataset(torch.from_numpy(tr_f).float(), torch.from_numpy(tr_c).float(), torch.from_numpy(tr_d).long(), torch.from_numpy(tr_l).long()) loader = DataLoader(ds, batch_size=int(cfg['batch_size']), shuffle=True, drop_last=True, num_workers=int(cfg.get('num_workers', 0)), pin_memory=device.type == 'cuda') print(f'[data] training on {len(ds):,} flows') n_disc_classes = int(cfg.get('n_disc_classes', 2)) model_cfg = MixedCFMConfig( T=data.T, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'), d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']), mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)), sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)), reference_mode=cfg.get('reference_mode'), lambda_disc=float(cfg.get('lambda_disc', 1.0)), n_disc_classes=n_disc_classes, # B-group ablation flags use_flow_token=bool(cfg.get('use_flow_token', True)), n_packet_tokens=int(cfg.get('n_packet_tokens', -1)), disc_as_cont=bool(cfg.get('disc_as_cont', False)), cont_as_disc=bool(cfg.get('cont_as_disc', False)), ) model = MixedTokenCFM(model_cfg).to(device) # B4: compute bin edges from benign train cont (z-scored, masked) for cont_as_disc quantization cont_bin_edges = None if model_cfg.cont_as_disc: n_bins = n_disc_classes n_cont_orig = model_cfg.n_cont_pkt # gather real cont samples per channel (mask padding) masks = np.arange(data.train_cont.shape[1])[None, :] < data.train_len[:, None] edges = np.zeros((n_cont_orig, n_bins - 1), dtype=np.float32) for c in range(n_cont_orig): vals = data.train_cont[..., c][masks] qs = np.linspace(0, 1, n_bins + 1)[1:-1] # interior quantiles edges[c] = np.quantile(vals, qs).astype(np.float32) cont_bin_edges = torch.from_numpy(edges).to(device) print(f'[B4] cont_bin_edges shape={tuple(edges.shape)} (n_bins={n_bins})') print(f'[model] params={model.param_count():,} token_dim={model.token_dim} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} lambda_disc={model_cfg.lambda_disc} use_flow_token={model_cfg.use_flow_token} n_packet_tokens={model_cfg.n_packet_tokens} disc_as_cont={model_cfg.disc_as_cont} cont_as_disc={model_cfg.cont_as_disc}') opt = torch.optim.AdamW(model.parameters(), lr=float(cfg['lr']), weight_decay=float(cfg.get('weight_decay', 0.01))) total_steps = max(1, int(cfg['epochs']) * len(loader)) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps) history: dict[str, list[Any]] = {'epoch': [], 'loss': [], 'eval': []} for epoch in range(1, int(cfg['epochs']) + 1): model.train() losses: list[float] = [] ldisc_sum = 0.0 n_batches = 0 t0 = time.time() for (flow, cont, disc, lens) in loader: flow = flow.to(device, non_blocking=True) cont = cont.to(device, non_blocking=True) disc = disc.to(device, non_blocking=True) lens = lens.to(device, non_blocking=True) comp = model.compute_loss(flow, cont, disc, lens, return_components=True, cont_bin_edges=cont_bin_edges) loss = comp['total'] ldisc_sum += float(comp['aux_disc'].item()) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get('grad_clip', 1.0))) opt.step() sched.step() losses.append(float(loss.item())) n_batches += 1 mean_loss = float(np.mean(losses)) if losses else float('nan') eval_metrics: dict[str, float] | None = None if epoch % int(cfg.get('eval_every', 5)) == 0 or epoch == int(cfg['epochs']): eval_metrics = _quick_eval(model, data, device, cfg, cont_bin_edges=cont_bin_edges) history['epoch'].append(epoch) history['loss'].append(mean_loss) history['eval'].append(eval_metrics) elapsed = time.time() - t0 tail = '' if eval_metrics: t = eval_metrics.get('auroc_terminal_norm', float('nan')) n = eval_metrics.get('auroc_disc_nll_total', float('nan')) tail = f' auroc_term={t:.3f} auroc_disc={n:.3f}' if n_batches: tail += f' L_disc={ldisc_sum / n_batches:.4f}' print(f"[epoch {epoch:>3d}/{cfg['epochs']:<3d}] ({elapsed:.1f}s) loss={mean_loss:.4f}{tail}") if not np.isfinite(mean_loss): raise RuntimeError(f'non-finite loss at epoch {epoch}') payload = {'model_state_dict': model.state_dict(), 'model_cfg': asdict(model_cfg), 'cont_mean': data.cont_mean, 'cont_std': data.cont_std, 'flow_mean': data.flow_mean, 'flow_std': data.flow_std, 'flow_feature_names': np.asarray(data.flow_feature_names), 'packet_feature_names': np.asarray(data.packet_feature_names)} if cont_bin_edges is not None: payload['cont_bin_edges'] = cont_bin_edges.detach().cpu().numpy() torch.save(payload, save_dir / 'model.pt') with open(save_dir / 'history.json', 'w') as f: json.dump(history, f, indent=2, default=str) print(f"[saved] {save_dir / 'model.pt'}") return save_dir def main() -> None: p = argparse.ArgumentParser(description=__doc__) p.add_argument('--config', type=Path, required=True) p.add_argument('--override', type=str, nargs='*', default=[]) args = p.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) for ov in args.override: (k, v) = ov.split('=', 1) cfg[k] = yaml.safe_load(v) train(cfg) if __name__ == '__main__': main()