from __future__ import annotations import argparse import json import time from dataclasses import asdict from pathlib import Path from typing import Any import numpy as np import torch import yaml from sklearn.metrics import roc_auc_score from torch.utils.data import DataLoader, TensorDataset from data import UnifiedData, load_unified_data, subsample_train from model import UnifiedCFMConfig, UnifiedTokenCFM def _device(dev_arg: str) -> torch.device: if dev_arg == 'auto': return torch.device('cuda' if torch.cuda.is_available() else 'cpu') return torch.device(dev_arg) def _batch_score(model: UnifiedTokenCFM, flow_np: np.ndarray, packet_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int) -> dict[str, np.ndarray]: out: dict[str, list[np.ndarray]] = {} model.eval() for start in range(0, len(flow_np), batch_size): sl = slice(start, start + batch_size) flow = torch.from_numpy(flow_np[sl]).float().to(device) packets = torch.from_numpy(packet_np[sl]).float().to(device) lens = torch.from_numpy(len_np[sl]).long().to(device) metrics = model.trajectory_metrics(flow, packets, lens, n_steps=n_steps) vel = model.velocity_score(flow, packets, lens) metrics.update(vel) for (k, v) in metrics.items(): out.setdefault(k, []).append(v.detach().cpu().numpy()) return {k: np.concatenate(v, axis=0) for (k, v) in out.items()} def _quick_eval(model: UnifiedTokenCFM, data: UnifiedData, device: torch.device, cfg: dict[str, Any]) -> dict[str, float]: n_eval = int(cfg.get('eval_n', 2000)) rng = np.random.default_rng(0) def pick(n: int) -> np.ndarray: m = min(n_eval, n) return rng.choice(n, m, replace=False) vi = pick(len(data.val_flow)) ai = pick(len(data.attack_flow)) v = _batch_score(model, data.val_flow[vi], data.val_packets[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8))) a = _batch_score(model, data.attack_flow[ai], data.attack_packets[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8))) y = np.concatenate([np.zeros(len(vi)), np.ones(len(ai))]) result: dict[str, float] = {} for key in sorted(v.keys()): s = np.concatenate([v[key], a[key]]) s = np.nan_to_num(s, nan=0.0, posinf=1000000000000.0, neginf=-1000000000000.0) result[f'auroc_{key}'] = float(roc_auc_score(y, s)) return result def train(cfg: dict[str, Any]) -> Path: device = _device(str(cfg.get('device', 'auto'))) save_dir = Path(cfg['save_dir']) save_dir.mkdir(parents=True, exist_ok=True) with open(save_dir / 'config.yaml', 'w') as f: yaml.safe_dump(cfg, f) seed = int(cfg.get('seed', 42)) data_seed = int(cfg.get('data_seed', seed)) torch.manual_seed(seed) np.random.seed(seed) print(f'Device: {device}') print(f'[seed] model={seed} data={data_seed}') feature_columns = cfg.get('flow_feature_columns') data = load_unified_data(packets_npz=Path(cfg['packets_npz']) if cfg.get('packets_npz') else None, source_store=Path(cfg['source_store']) if cfg.get('source_store') else None, flows_parquet=Path(cfg['flows_parquet']), flow_features_path=Path(cfg['flow_features_path']) if cfg.get('flow_features_path') else None, flow_feature_columns=feature_columns, flow_features_align=str(cfg.get('flow_features_align', 'auto')), T=int(cfg['T']), split_seed=data_seed, train_ratio=float(cfg.get('train_ratio', 0.8)), benign_label=str(cfg.get('benign_label', 'normal')), min_len=int(cfg.get('min_len', 2)), packet_preprocess=str(cfg.get('packet_preprocess', 'mixed_dequant')), attack_cap=int(cfg['attack_cap']) if cfg.get('attack_cap') else None, val_cap=int(cfg['val_cap']) if cfg.get('val_cap') else None) print(f'[data] T={data.T} packet_D={data.packet_dim} flow_D={data.flow_dim} train={len(data.train_flow):,} val={len(data.val_flow):,} attack={len(data.attack_flow):,}') (tr_f, tr_p, tr_l) = subsample_train(data, int(cfg.get('n_train', 0)), data_seed) ds = TensorDataset(torch.from_numpy(tr_f).float(), torch.from_numpy(tr_p).float(), torch.from_numpy(tr_l).long()) loader = DataLoader(ds, batch_size=int(cfg['batch_size']), shuffle=True, drop_last=True, num_workers=int(cfg.get('num_workers', 0)), pin_memory=device.type == 'cuda') print(f'[data] using {len(ds):,} benign training flows') model_cfg = UnifiedCFMConfig(T=data.T, packet_dim=data.packet_dim, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'), d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']), mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)), sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)), reference_mode=cfg.get('reference_mode')) model = UnifiedTokenCFM(model_cfg).to(device) print(f'[model] params={model.param_count():,} token_dim={model.token_dim} seq_len={model.seq_len} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} reference_mode={model_cfg.reference_mode}') opt = torch.optim.AdamW(model.parameters(), lr=float(cfg['lr']), weight_decay=float(cfg.get('weight_decay', 0.01))) total_steps = max(1, int(cfg['epochs']) * len(loader)) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps) history: dict[str, list[Any]] = {'epoch': [], 'loss': [], 'eval': []} lambda_flow = float(cfg.get('lambda_flow', 0.0)) lambda_packet = float(cfg.get('lambda_packet', 0.0)) packet_mask_ratio = float(cfg.get('packet_mask_ratio', 0.5)) aux_enabled = lambda_flow > 0.0 or lambda_packet > 0.0 if aux_enabled: print(f'[loss] λ_flow={lambda_flow} λ_packet={lambda_packet} packet_mask_ratio={packet_mask_ratio}') for epoch in range(1, int(cfg['epochs']) + 1): model.train() losses: list[float] = [] aux_flow_sum = 0.0 aux_packet_sum = 0.0 n_steps_this_epoch = 0 t0 = time.time() for (flow, packets, lens) in loader: flow = flow.to(device, non_blocking=True) packets = packets.to(device, non_blocking=True) lens = lens.to(device, non_blocking=True) if aux_enabled: comp = model.compute_loss(flow, packets, lens, lambda_flow=lambda_flow, lambda_packet=lambda_packet, packet_mask_ratio=packet_mask_ratio, return_components=True) loss = comp['total'] aux_flow_sum += float(comp['aux_flow'].item()) aux_packet_sum += float(comp['aux_packet'].item()) else: loss = model.compute_loss(flow, packets, lens) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get('grad_clip', 1.0))) opt.step() sched.step() losses.append(float(loss.item())) n_steps_this_epoch += 1 mean_loss = float(np.mean(losses)) if losses else float('nan') eval_metrics: dict[str, float] | None = None if epoch % int(cfg.get('eval_every', 5)) == 0 or epoch == int(cfg['epochs']): eval_metrics = _quick_eval(model, data, device, cfg) history['epoch'].append(epoch) history['loss'].append(mean_loss) history['eval'].append(eval_metrics) elapsed = time.time() - t0 terminal = '' if eval_metrics: terminal = f" auroc_terminal={eval_metrics['auroc_terminal_norm']:.3f}" if aux_enabled and n_steps_this_epoch: terminal += f' aux_flow={aux_flow_sum / n_steps_this_epoch:.4f} aux_pkt={aux_packet_sum / n_steps_this_epoch:.4f}' print(f"[epoch {epoch:>3d}/{cfg['epochs']:<3d}] ({elapsed:.1f}s) loss={mean_loss:.4f}{terminal}") if not np.isfinite(mean_loss): raise RuntimeError(f'non-finite loss at epoch {epoch}') payload = {'model_state_dict': model.state_dict(), 'model_cfg': asdict(model_cfg), 'packet_mean': data.packet_mean, 'packet_std': data.packet_std, 'flow_mean': data.flow_mean, 'flow_std': data.flow_std, 'packet_preprocess': data.packet_preprocess, 'flow_feature_names': np.asarray(data.flow_feature_names), 'packet_feature_names': np.asarray(data.packet_feature_names)} torch.save(payload, save_dir / 'model.pt') with open(save_dir / 'history.json', 'w') as f: json.dump(history, f, indent=2, default=str) print(f"[saved] {save_dir / 'model.pt'}") return save_dir def main() -> None: p = argparse.ArgumentParser(description=__doc__) p.add_argument('--config', type=Path, required=True) p.add_argument('--override', type=str, nargs='*', default=[]) args = p.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) for override in args.override: (key, value) = override.split('=', 1) cfg[key] = yaml.safe_load(value) train(cfg) if __name__ == '__main__': main()