ablation: add Group A (aggregator) + Group B (architecture) infrastructure
Extends MixedCFMConfig with 5 backwards-compatible flags (use_flow_token,
n_packet_tokens, disc_as_cont, cont_as_disc + cont_n_bins) so existing
JANUS-full checkpoints load with 0 missing/unexpected keys.
Adds:
- 60 ablation training configs (5 variants × 4 datasets × 3 seeds)
- scripts/ablation/{generate_configs.py, run_groupB.sh, run_cross_groupB.sh,
smoke_test.sh} — config generation + GPU drivers
- scripts/aggregate/aggregate_ablation{,_cross,_cross_B}.py — produces
within-dataset and cross-dataset (3×3) ablation tables with 3-seed mean
± 95% t-CI plus optional paired DeLong p-values
README updated with ablation section pointing at
artifacts/ablation/ABLATION_SUMMARY.md.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,7 @@ def _device(arg: str) -> torch.device:
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return torch.device(arg)
|
||||
|
||||
def _batch_score(model: MixedTokenCFM, flow_np: np.ndarray, cont_np: np.ndarray, disc_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int) -> dict[str, np.ndarray]:
|
||||
def _batch_score(model: MixedTokenCFM, flow_np: np.ndarray, cont_np: np.ndarray, disc_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int, cont_bin_edges: torch.Tensor | None = None) -> dict[str, np.ndarray]:
|
||||
out: dict[str, list[np.ndarray]] = {}
|
||||
model.eval()
|
||||
for start in range(0, len(flow_np), batch_size):
|
||||
@@ -30,14 +30,14 @@ def _batch_score(model: MixedTokenCFM, flow_np: np.ndarray, cont_np: np.ndarray,
|
||||
cont = torch.from_numpy(cont_np[sl]).float().to(device)
|
||||
disc = torch.from_numpy(disc_np[sl]).long().to(device)
|
||||
lens = torch.from_numpy(len_np[sl]).long().to(device)
|
||||
m = model.trajectory_metrics(flow, cont, disc, lens, n_steps=n_steps)
|
||||
d = model.disc_nll_score(flow, cont, disc, lens)
|
||||
m = model.trajectory_metrics(flow, cont, disc, lens, n_steps=n_steps, cont_bin_edges=cont_bin_edges)
|
||||
d = model.disc_nll_score(flow, cont, disc, lens, cont_bin_edges=cont_bin_edges)
|
||||
for src in (m, d):
|
||||
for (k, v) in src.items():
|
||||
out.setdefault(k, []).append(v.detach().cpu().numpy())
|
||||
return {k: np.concatenate(v, axis=0) for (k, v) in out.items()}
|
||||
|
||||
def _quick_eval(model: MixedTokenCFM, data: MixedData, device: torch.device, cfg: dict[str, Any]) -> dict[str, float]:
|
||||
def _quick_eval(model: MixedTokenCFM, data: MixedData, device: torch.device, cfg: dict[str, Any], cont_bin_edges: torch.Tensor | None = None) -> dict[str, float]:
|
||||
n_eval = int(cfg.get('eval_n', 2000))
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
@@ -46,8 +46,8 @@ def _quick_eval(model: MixedTokenCFM, data: MixedData, device: torch.device, cfg
|
||||
return rng.choice(n, m, replace=False)
|
||||
vi = pick(len(data.val_flow))
|
||||
ai = pick(len(data.attack_flow))
|
||||
v = _batch_score(model, data.val_flow[vi], data.val_cont[vi], data.val_disc[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
|
||||
a = _batch_score(model, data.attack_flow[ai], data.attack_cont[ai], data.attack_disc[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
|
||||
v = _batch_score(model, data.val_flow[vi], data.val_cont[vi], data.val_disc[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)), cont_bin_edges=cont_bin_edges)
|
||||
a = _batch_score(model, data.attack_flow[ai], data.attack_cont[ai], data.attack_disc[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)), cont_bin_edges=cont_bin_edges)
|
||||
y = np.concatenate([np.zeros(len(vi)), np.ones(len(ai))])
|
||||
out: dict[str, float] = {}
|
||||
for k in sorted(v.keys()):
|
||||
@@ -73,9 +73,36 @@ def train(cfg: dict[str, Any]) -> Path:
|
||||
ds = TensorDataset(torch.from_numpy(tr_f).float(), torch.from_numpy(tr_c).float(), torch.from_numpy(tr_d).long(), torch.from_numpy(tr_l).long())
|
||||
loader = DataLoader(ds, batch_size=int(cfg['batch_size']), shuffle=True, drop_last=True, num_workers=int(cfg.get('num_workers', 0)), pin_memory=device.type == 'cuda')
|
||||
print(f'[data] training on {len(ds):,} flows')
|
||||
model_cfg = MixedCFMConfig(T=data.T, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'), d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']), mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)), sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)), reference_mode=cfg.get('reference_mode'), lambda_disc=float(cfg.get('lambda_disc', 1.0)))
|
||||
n_disc_classes = int(cfg.get('n_disc_classes', 2))
|
||||
model_cfg = MixedCFMConfig(
|
||||
T=data.T, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'),
|
||||
d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']),
|
||||
mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)),
|
||||
sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)),
|
||||
reference_mode=cfg.get('reference_mode'), lambda_disc=float(cfg.get('lambda_disc', 1.0)),
|
||||
n_disc_classes=n_disc_classes,
|
||||
# B-group ablation flags
|
||||
use_flow_token=bool(cfg.get('use_flow_token', True)),
|
||||
n_packet_tokens=int(cfg.get('n_packet_tokens', -1)),
|
||||
disc_as_cont=bool(cfg.get('disc_as_cont', False)),
|
||||
cont_as_disc=bool(cfg.get('cont_as_disc', False)),
|
||||
)
|
||||
model = MixedTokenCFM(model_cfg).to(device)
|
||||
print(f'[model] params={model.param_count():,} token_dim={model.token_dim} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} lambda_disc={model_cfg.lambda_disc}')
|
||||
# B4: compute bin edges from benign train cont (z-scored, masked) for cont_as_disc quantization
|
||||
cont_bin_edges = None
|
||||
if model_cfg.cont_as_disc:
|
||||
n_bins = n_disc_classes
|
||||
n_cont_orig = model_cfg.n_cont_pkt
|
||||
# gather real cont samples per channel (mask padding)
|
||||
masks = np.arange(data.train_cont.shape[1])[None, :] < data.train_len[:, None]
|
||||
edges = np.zeros((n_cont_orig, n_bins - 1), dtype=np.float32)
|
||||
for c in range(n_cont_orig):
|
||||
vals = data.train_cont[..., c][masks]
|
||||
qs = np.linspace(0, 1, n_bins + 1)[1:-1] # interior quantiles
|
||||
edges[c] = np.quantile(vals, qs).astype(np.float32)
|
||||
cont_bin_edges = torch.from_numpy(edges).to(device)
|
||||
print(f'[B4] cont_bin_edges shape={tuple(edges.shape)} (n_bins={n_bins})')
|
||||
print(f'[model] params={model.param_count():,} token_dim={model.token_dim} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} lambda_disc={model_cfg.lambda_disc} use_flow_token={model_cfg.use_flow_token} n_packet_tokens={model_cfg.n_packet_tokens} disc_as_cont={model_cfg.disc_as_cont} cont_as_disc={model_cfg.cont_as_disc}')
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=float(cfg['lr']), weight_decay=float(cfg.get('weight_decay', 0.01)))
|
||||
total_steps = max(1, int(cfg['epochs']) * len(loader))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
|
||||
@@ -91,7 +118,7 @@ def train(cfg: dict[str, Any]) -> Path:
|
||||
cont = cont.to(device, non_blocking=True)
|
||||
disc = disc.to(device, non_blocking=True)
|
||||
lens = lens.to(device, non_blocking=True)
|
||||
comp = model.compute_loss(flow, cont, disc, lens, return_components=True)
|
||||
comp = model.compute_loss(flow, cont, disc, lens, return_components=True, cont_bin_edges=cont_bin_edges)
|
||||
loss = comp['total']
|
||||
ldisc_sum += float(comp['aux_disc'].item())
|
||||
opt.zero_grad(set_to_none=True)
|
||||
@@ -104,7 +131,7 @@ def train(cfg: dict[str, Any]) -> Path:
|
||||
mean_loss = float(np.mean(losses)) if losses else float('nan')
|
||||
eval_metrics: dict[str, float] | None = None
|
||||
if epoch % int(cfg.get('eval_every', 5)) == 0 or epoch == int(cfg['epochs']):
|
||||
eval_metrics = _quick_eval(model, data, device, cfg)
|
||||
eval_metrics = _quick_eval(model, data, device, cfg, cont_bin_edges=cont_bin_edges)
|
||||
history['epoch'].append(epoch)
|
||||
history['loss'].append(mean_loss)
|
||||
history['eval'].append(eval_metrics)
|
||||
@@ -120,6 +147,8 @@ def train(cfg: dict[str, Any]) -> Path:
|
||||
if not np.isfinite(mean_loss):
|
||||
raise RuntimeError(f'non-finite loss at epoch {epoch}')
|
||||
payload = {'model_state_dict': model.state_dict(), 'model_cfg': asdict(model_cfg), 'cont_mean': data.cont_mean, 'cont_std': data.cont_std, 'flow_mean': data.flow_mean, 'flow_std': data.flow_std, 'flow_feature_names': np.asarray(data.flow_feature_names), 'packet_feature_names': np.asarray(data.packet_feature_names)}
|
||||
if cont_bin_edges is not None:
|
||||
payload['cont_bin_edges'] = cont_bin_edges.detach().cpu().numpy()
|
||||
torch.save(payload, save_dir / 'model.pt')
|
||||
with open(save_dir / 'history.json', 'w') as f:
|
||||
json.dump(history, f, indent=2, default=str)
|
||||
|
||||
Reference in New Issue
Block a user