ablation: add Group A (aggregator) + Group B (architecture) infrastructure

Extends MixedCFMConfig with 5 backwards-compatible flags (use_flow_token,
n_packet_tokens, disc_as_cont, cont_as_disc + cont_n_bins) so existing
JANUS-full checkpoints load with 0 missing/unexpected keys.

Adds:
- 60 ablation training configs (5 variants × 4 datasets × 3 seeds)
- scripts/ablation/{generate_configs.py, run_groupB.sh, run_cross_groupB.sh,
  smoke_test.sh} — config generation + GPU drivers
- scripts/aggregate/aggregate_ablation{,_cross,_cross_B}.py — produces
  within-dataset and cross-dataset (3×3) ablation tables with 3-seed mean
  ± 95% t-CI plus optional paired DeLong p-values

README updated with ablation section pointing at
artifacts/ablation/ABLATION_SUMMARY.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-08 23:59:27 +08:00
parent 1d8862fbeb
commit a6bcbbd299
72 changed files with 3642 additions and 96 deletions

View File

@@ -0,0 +1,533 @@
"""JANUS ablation aggregator (Groups A + B).
Reads phase1_scores.npz from:
artifacts/route_comparison/janus_<ds>_seed<S>/ (A + JANUS-full anchor)
artifacts/ablation/janus_<ds>_seed<S>_<gid>/ (B variants)
Produces:
artifacts/ablation/ABLATION_TABLE.md final markdown table
artifacts/ablation/ABLATION_TABLE_RAW.json per-cell mean / std / CI / per-seed
artifacts/ablation/ABLATION_DELONG.md paired DeLong p-values vs JANUS-full
Group A operates entirely on existing route_comparison npz files (no GPU).
Group B requires the 60 B-variant runs to have completed.
"""
from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import numpy as np
from sklearn.covariance import OAS
from sklearn.metrics import roc_auc_score
ROOT = Path(__file__).resolve().parents[2]
ROUTE = ROOT / "artifacts" / "route_comparison"
ABL = ROOT / "artifacts" / "ablation"
DATASETS = ["iscxtor2016", "cicids2017", "cicddos2019", "ciciot2023"]
PRETTY = {
"iscxtor2016": "ISCXTor16",
"cicids2017": "CICIDS17",
"cicddos2019": "CICDDoS19",
"ciciot2023": "CICIoT23",
}
SEEDS = [42, 43, 44]
T_975_N3 = 4.302653 # 95% t-CI factor for n=3 (df=2)
CONT_KEYS = ["terminal_norm", "terminal_flow", "terminal_packet"]
DISC_KEYS = ["disc_nll_total", "disc_nll_ch2", "disc_nll_ch3",
"disc_nll_ch4", "disc_nll_ch5", "disc_nll_ch6", "disc_nll_ch7"]
ALL_KEYS = CONT_KEYS + DISC_KEYS # 10-d
# --------------------------------------------------------------------------- #
# I/O #
# --------------------------------------------------------------------------- #
def _load_npz(npz_path: Path):
z = np.load(npz_path, allow_pickle=True)
val = {}
atk = {}
for k in z.files:
if k.startswith("val_") and k != "val_labels":
val[k[4:]] = z[k]
elif k.startswith("atk_") and k != "atk_labels":
atk[k[4:]] = z[k]
return val, atk
def _load_cross_npz(npz_path: Path):
"""Cross npz schema: b_<key> = target benign, a_<key> = target attacks."""
z = np.load(npz_path, allow_pickle=True)
val = {}
atk = {}
for k in z.files:
if k.startswith("b_") and k != "b_labels":
val[k[2:]] = z[k]
elif k.startswith("a_") and k != "a_labels":
atk[k[2:]] = z[k]
return val, atk
def _stack(d: dict, keys: list[str]) -> np.ndarray:
arrs = []
for k in keys:
if k in d:
arrs.append(d[k])
else:
# variant doesn't produce this score (e.g. B2 has no disc, B5 disc untrained)
return None
out = np.stack(arrs, axis=1).astype(np.float64)
return np.nan_to_num(out, nan=0.0, posinf=1e6, neginf=-1e6)
# --------------------------------------------------------------------------- #
# Score functions (Group A definitions) #
# --------------------------------------------------------------------------- #
def _mahal(S, mu, inv_cov):
d = S - mu
return np.einsum("ni,ij,nj->n", d, inv_cov, d)
def _oas_mahal(val_S, atk_S):
mu = val_S.mean(axis=0)
cov = OAS().fit(val_S).covariance_
inv = np.linalg.inv(cov + 1e-9 * np.eye(cov.shape[0]))
return _mahal(val_S, mu, inv), _mahal(atk_S, mu, inv)
def _zscore_agg(val_S, atk_S, mode="mean"):
mu = val_S.mean(axis=0)
sd = val_S.std(axis=0) + 1e-9
zv = (val_S - mu) / sd
za = (atk_S - mu) / sd
if mode == "mean":
return zv.mean(axis=1), za.mean(axis=1)
if mode == "max":
return zv.max(axis=1), za.max(axis=1)
raise ValueError(mode)
def score_a1_terminal_norm(val, atk):
return val["terminal_norm"], atk["terminal_norm"]
def score_a2_disc_total(val, atk):
if "disc_nll_total" not in val:
return None
return val["disc_nll_total"], atk["disc_nll_total"]
def score_a3_oas_term3(val, atk):
Sv = _stack(val, CONT_KEYS)
Sa = _stack(atk, CONT_KEYS)
if Sv is None or Sa is None:
return None
return _oas_mahal(Sv, Sa)
def score_a4_oas_disc7(val, atk):
Sv = _stack(val, DISC_KEYS)
Sa = _stack(atk, DISC_KEYS)
if Sv is None or Sa is None:
return None
return _oas_mahal(Sv, Sa)
def score_a5_oas_all10(val, atk):
Sv = _stack(val, ALL_KEYS)
Sa = _stack(atk, ALL_KEYS)
if Sv is None or Sa is None:
return None
return _oas_mahal(Sv, Sa)
def score_a6_zmean(val, atk):
Sv = _stack(val, ALL_KEYS)
Sa = _stack(atk, ALL_KEYS)
if Sv is None or Sa is None:
return None
return _zscore_agg(Sv, Sa, "mean")
def score_a7_zmax(val, atk):
Sv = _stack(val, ALL_KEYS)
Sa = _stack(atk, ALL_KEYS)
if Sv is None or Sa is None:
return None
return _zscore_agg(Sv, Sa, "max")
def score_oas_disc_all(val, atk):
"""Auto-discover all `disc_nll_*` keys; OAS-Mahal over them. Used by B4."""
keys = sorted(k for k in val.keys() if k.startswith("disc_nll_"))
if not keys:
return None
Sv = _stack(val, keys)
Sa = _stack(atk, keys)
if Sv is None or Sa is None:
return None
return _oas_mahal(Sv, Sa)
def score_oas_all_available(val, atk):
"""OAS-Mahal over all `terminal_*` `disc_nll_*` keys present in the npz.
Used by B1 (no terminal_flow). Handles arbitrary subset of the 10 standard keys.
"""
keys = sorted([k for k in val.keys() if k.startswith("terminal_") or k.startswith("disc_nll_")])
if not keys:
return None
if len(keys) == 1:
return val[keys[0]], atk[keys[0]]
Sv = _stack(val, keys)
Sa = _stack(atk, keys)
if Sv is None or Sa is None:
return None
return _oas_mahal(Sv, Sa)
def score_oas_term_all(val, atk):
"""Auto-discover all `terminal_*` keys; OAS-Mahal. Used by B3 (3 keys) / B1 (2 keys)."""
keys = sorted(k for k in val.keys() if k.startswith("terminal_"))
if not keys:
return None
if len(keys) == 1:
# single scalar: just return raw
return val[keys[0]], atk[keys[0]]
Sv = _stack(val, keys)
Sa = _stack(atk, keys)
if Sv is None or Sa is None:
return None
return _oas_mahal(Sv, Sa)
SCORE_FNS = {
"A1_terminal_norm": score_a1_terminal_norm,
"A2_disc_nll_total": score_a2_disc_total,
"A3_OAS_term3": score_a3_oas_term3,
"A4_OAS_disc7": score_a4_oas_disc7,
"A5_OAS_all10": score_a5_oas_all10,
"A6_zmean_all10": score_a6_zmean,
"A7_zmax_all10": score_a7_zmax,
"OAS_disc_all": score_oas_disc_all,
"OAS_term_all": score_oas_term_all,
"OAS_all_available": score_oas_all_available,
}
# --------------------------------------------------------------------------- #
# Stats #
# --------------------------------------------------------------------------- #
def _auroc(s_v, s_a):
y = np.r_[np.zeros(len(s_v)), np.ones(len(s_a))]
s = np.r_[s_v, s_a]
return float(roc_auc_score(y, s))
def _mean_ci(values: list[float]):
"""3-seed mean ± 95% t-CI (n=3, df=2)."""
a = np.asarray([v for v in values if v is not None and not np.isnan(v)], dtype=float)
if a.size == 0:
return None
if a.size == 1:
return {"mean": float(a[0]), "std": 0.0, "ci": 0.0, "n": 1, "vals": a.tolist()}
se = a.std(ddof=1) / np.sqrt(a.size)
return {
"mean": float(a.mean()),
"std": float(a.std(ddof=1)),
"ci": float(T_975_N3 * se) if a.size == 3 else float(1.96 * se),
"n": int(a.size),
"vals": a.tolist(),
}
def _delong_var(s_v, s_a):
"""Compute DeLong AUROC variance (Sun & Xu 2014, fast O(n log n))."""
n0, n1 = len(s_v), len(s_a)
s = np.concatenate([s_a, s_v]) # positives first
order = np.argsort(s, kind="mergesort")
L = np.empty_like(s)
s_sorted = s[order]
# midrank
i = 0
while i < len(s_sorted):
j = i
while j < len(s_sorted) and s_sorted[j] == s_sorted[i]:
j += 1
L[order[i:j]] = (i + j - 1) / 2.0 + 1
i = j
# ranks split
L_a = L[:n1]
L_v = L[n1:]
# midrank within each class
s_a_order = np.argsort(s_a, kind="mergesort")
L_aa = np.empty(n1)
sa_sorted = s_a[s_a_order]
i = 0
while i < n1:
j = i
while j < n1 and sa_sorted[j] == sa_sorted[i]:
j += 1
L_aa[s_a_order[i:j]] = (i + j - 1) / 2.0 + 1
i = j
s_v_order = np.argsort(s_v, kind="mergesort")
L_vv = np.empty(n0)
sv_sorted = s_v[s_v_order]
i = 0
while i < n0:
j = i
while j < n0 and sv_sorted[j] == sv_sorted[i]:
j += 1
L_vv[s_v_order[i:j]] = (i + j - 1) / 2.0 + 1
i = j
auc = (L_a.sum() / n1 - (n1 + 1) / 2) / n0
V10 = (L_a - L_aa) / n0 # length n1
V01 = 1 - (L_v - L_vv) / n1 # length n0
s10 = V10.var(ddof=1)
s01 = V01.var(ddof=1)
var = s10 / n1 + s01 / n0
return float(auc), float(var), V10, V01
def _delong_paired_p(s_v, s_a, t_v, t_a):
"""Paired DeLong test for two AUROCs on the same data.
Returns (auc1 - auc2, p_value_two_sided).
s_*: candidate scores; t_*: reference (JANUS-full) scores.
Both arrays must align flow-by-flow.
"""
auc1, var1, V10_1, V01_1 = _delong_var(s_v, s_a)
auc2, var2, V10_2, V01_2 = _delong_var(t_v, t_a)
n1, n0 = len(s_a), len(s_v)
cov10 = np.cov(np.stack([V10_1, V10_2]), ddof=1)[0, 1]
cov01 = np.cov(np.stack([V01_1, V01_2]), ddof=1)[0, 1]
cov12 = cov10 / n1 + cov01 / n0
var_diff = var1 + var2 - 2 * cov12
if var_diff <= 0:
return auc1 - auc2, 1.0
z = (auc1 - auc2) / np.sqrt(var_diff)
# two-sided
from scipy.stats import norm
p = 2 * (1 - norm.cdf(abs(z)))
return auc1 - auc2, float(p)
# --------------------------------------------------------------------------- #
# Aggregation entry points #
# --------------------------------------------------------------------------- #
@dataclass
class VariantSpec:
vid: str
label: str
what_removed: str
npz_dir_pattern: str # e.g. "route_comparison/janus_{ds}_seed{seed}" or "ablation/janus_{ds}_seed{seed}_{gid}"
score_fn_id: str # which Group A score to apply on the npz (usually "A5_OAS_all10")
gid: str = "" # for B variants
def _expand_path(spec: VariantSpec, ds: str, seed: int) -> Path:
return ROOT / "artifacts" / spec.npz_dir_pattern.format(ds=ds, seed=seed, gid=spec.gid) / "phase1_scores.npz"
def collect_variant(spec: VariantSpec) -> dict:
rows: dict[str, list[float]] = {ds: [] for ds in DATASETS}
per_seed: dict[str, dict[int, float]] = {ds: {} for ds in DATASETS}
for ds in DATASETS:
for seed in SEEDS:
npz = _expand_path(spec, ds, seed)
if not npz.exists():
continue
val, atk = _load_npz(npz)
fn = SCORE_FNS[spec.score_fn_id]
res = fn(val, atk)
if res is None:
continue
sv, sa = res
auc = _auroc(sv, sa)
rows[ds].append(auc)
per_seed[ds][seed] = auc
summary = {ds: _mean_ci(rows[ds]) for ds in DATASETS}
return {
"vid": spec.vid,
"label": spec.label,
"what_removed": spec.what_removed,
"score_fn_id": spec.score_fn_id,
"gid": spec.gid,
"per_dataset": summary,
"per_seed": per_seed,
}
def collect_delong_pvals(spec: VariantSpec, ref_spec: VariantSpec) -> dict:
"""Paired DeLong test: spec vs ref_spec, on each (ds, seed)."""
out: dict[str, list[dict]] = {ds: [] for ds in DATASETS}
for ds in DATASETS:
for seed in SEEDS:
npz_s = _expand_path(spec, ds, seed)
npz_r = _expand_path(ref_spec, ds, seed)
if not (npz_s.exists() and npz_r.exists()):
continue
val_s, atk_s = _load_npz(npz_s)
val_r, atk_r = _load_npz(npz_r)
fn_s = SCORE_FNS[spec.score_fn_id]
fn_r = SCORE_FNS[ref_spec.score_fn_id]
res_s = fn_s(val_s, atk_s)
res_r = fn_r(val_r, atk_r)
if res_s is None or res_r is None:
continue
sv_s, sa_s = res_s
sv_r, sa_r = res_r
# if shapes differ (e.g. variant evaluated on subset), align by index — they should match seed-for-seed
# in practice for B variants the npz is from the SAME data as JANUS-full at that (ds, seed)
if len(sv_s) != len(sv_r) or len(sa_s) != len(sa_r):
continue
d, p = _delong_paired_p(sv_s, sa_s, sv_r, sa_r)
out[ds].append({"seed": seed, "delta": d, "p": p})
return out
# --------------------------------------------------------------------------- #
# Variant registry #
# --------------------------------------------------------------------------- #
ROUTE_DIR = "route_comparison/janus_{ds}_seed{seed}"
ABL_DIR = "ablation/janus_{ds}_seed{seed}_{gid}"
def _group_a_specs() -> list[VariantSpec]:
base = ROUTE_DIR
return [
VariantSpec("JANUS-full", "JANUS-full (A5)", "", base, "A5_OAS_all10"),
VariantSpec("A1", "A1 terminal_norm", "OAS aggregator + disc head", base, "A1_terminal_norm"),
VariantSpec("A2", "A2 disc_nll_total", "OAS aggregator + CFM head", base, "A2_disc_nll_total"),
VariantSpec("A3", "A3 OAS-Mahal term3", "disc head", base, "A3_OAS_term3"),
VariantSpec("A4", "A4 OAS-Mahal disc7", "CFM head", base, "A4_OAS_disc7"),
VariantSpec("A6", "A6 z-score mean (10-d)", "covariance structure", base, "A6_zmean_all10"),
VariantSpec("A7", "A7 z-score max (10-d)", "weighted aggregation", base, "A7_zmax_all10"),
]
def _group_b_specs() -> list[VariantSpec]:
return [
# B1 has 2 terminal keys (no terminal_flow) + full disc7 → use auto-key OAS (9-d in this case)
VariantSpec("B1", "B1 no FLOW token", "global context", ABL_DIR, "OAS_all_available", gid="b1_noflow"),
# B2 has only terminal_flow (= terminal_norm); single scalar
VariantSpec("B2", "B2 flow-only", "packet sequence", ABL_DIR, "A1_terminal_norm", gid="b2_flowonly"),
# B3 has terminal_norm/flow/packet covering all 9 dims (cont + disc-as-cont); OAS on 3-tuple
VariantSpec("B3", "B3 all-cont", "cont/disc split", ABL_DIR, "A3_OAS_term3", gid="b3_allcont"),
# B4 has 9 disc channels + total; auto-discover keys
VariantSpec("B4", "B4 all-disc", "cont/disc split (rev)", ABL_DIR, "OAS_disc_all", gid="b4_alldisc"),
# B5 has full schema but disc head is untrained noise; use term3 only
VariantSpec("B5", "B5 λ_disc=0", "joint training", ABL_DIR, "A3_OAS_term3", gid="b5_nodisc"),
]
# --------------------------------------------------------------------------- #
# Markdown writer #
# --------------------------------------------------------------------------- #
def _fmt_cell(c: dict | None) -> str:
if c is None:
return ""
if c["n"] == 1:
return f"{100 * c['mean']:.2f}"
return f"{100 * c['mean']:.2f} ± {100 * c['ci']:.2f}"
def write_table(rows: list[dict], path: Path, *, title: str = "JANUS ablation"):
lines = [f"# {title}", ""]
lines.append(f"3-seed mean ± 95% t-CI AUROC (%). Seeds = {SEEDS}.")
lines.append("")
header = ["Variant", "What removed"] + [PRETTY[ds] for ds in DATASETS] + ["Mean"]
lines.append("| " + " | ".join(header) + " |")
lines.append("|" + "|".join("---" for _ in header) + "|")
for r in rows:
cells = [r["label"], r["what_removed"]]
ds_means = []
for ds in DATASETS:
c = r["per_dataset"].get(ds)
cells.append(_fmt_cell(c))
if c is not None:
ds_means.append(c["mean"])
cells.append(f"{100 * np.mean(ds_means):.2f}" if ds_means else "")
lines.append("| " + " | ".join(cells) + " |")
lines.append("")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(lines))
def write_delong(records: list[dict], path: Path):
lines = ["# Paired DeLong p-values vs JANUS-full",
"",
f"Seeds = {SEEDS}. p reported per (variant, dataset, seed). "
"Holm-Bonferroni-correctable; raw p shown.",
""]
for rec in records:
lines.append(f"## {rec['label']} ({rec['vid']})")
lines.append("")
header = ["Seed"] + [PRETTY[ds] for ds in DATASETS]
lines.append("| " + " | ".join(header) + " |")
lines.append("|" + "|".join("---" for _ in header) + "|")
for seed in SEEDS:
row = [str(seed)]
for ds in DATASETS:
hits = [x for x in rec["delong"][ds] if x["seed"] == seed]
if hits:
h = hits[0]
sign = "+" if h["delta"] >= 0 else ""
row.append(f"Δ={sign}{abs(h['delta']):.4f}, p={h['p']:.3g}")
else:
row.append("")
lines.append("| " + " | ".join(row) + " |")
lines.append("")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(lines))
# --------------------------------------------------------------------------- #
# Main #
# --------------------------------------------------------------------------- #
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--group", choices=["A", "B", "all"], default="A")
ap.add_argument("--delong", action="store_true",
help="Compute paired DeLong p-values vs JANUS-full (CPU heavy on big eval sets).")
args = ap.parse_args()
ABL.mkdir(parents=True, exist_ok=True)
specs: list[VariantSpec] = []
if args.group in ("A", "all"):
specs.extend(_group_a_specs())
if args.group in ("B", "all"):
specs.extend(_group_b_specs())
rows = []
for spec in specs:
r = collect_variant(spec)
rows.append(r)
n_ok = sum(1 for ds in DATASETS if r["per_dataset"][ds] is not None)
print(f"[ok] {spec.vid:14s} datasets_with_data={n_ok}/{len(DATASETS)}", flush=True)
out_md = ABL / f"ABLATION_TABLE_{args.group}.md"
write_table(rows, out_md, title=f"JANUS ablation (group {args.group})")
out_json = ABL / f"ABLATION_TABLE_{args.group}.json"
out_json.write_text(json.dumps(rows, indent=2, default=lambda o: None))
print(f"[wrote] {out_md}")
print(f"[wrote] {out_json}")
if args.delong:
ref = next(s for s in _group_a_specs() if s.vid == "JANUS-full")
recs = []
for spec in specs:
if spec.vid == "JANUS-full":
continue
d = collect_delong_pvals(spec, ref)
recs.append({"vid": spec.vid, "label": spec.label, "delong": d})
print(f"[delong] {spec.vid}", flush=True)
write_delong(recs, ABL / f"ABLATION_DELONG_{args.group}.md")
print(f"[wrote] {ABL / f'ABLATION_DELONG_{args.group}.md'}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,218 @@
"""Cross-dataset version of the Group-A score-aggregator ablation.
For each (src, tgt, seed) cell we have a phase1-style npz with:
b_<key> target benign val (aggregator fit on this)
a_<key> target attacks
Within-dataset (src == tgt) cells reuse the standard
artifacts/route_comparison/janus_<ds>_seed<S>/phase1_scores.npz
(val_/atk_ prefixes — handled via the same _load_npz path).
We score 7 aggregators (A1..A7) + JANUS-full's deployed A5 across all
3×3 cells × 3 seeds, then summarize with two complementary views:
ABLATION_TABLE_CROSS_summary.md
| Aggregator | Within mean | Cross mean | Cross min (worst cell) |
Shows whether OAS's value lives in cross-dataset robustness.
ABLATION_TABLE_CROSS_full.md
Per-aggregator full 3×3 matrix (each cell = 3-seed mean ± 95% t-CI).
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
from aggregate_ablation import (
SCORE_FNS, T_975_N3, _auroc, _load_npz, _load_cross_npz,
)
ROOT = Path(__file__).resolve().parents[2]
ROUTE = ROOT / "artifacts" / "route_comparison"
CROSS = ROUTE / "cross"
ABL = ROOT / "artifacts" / "ablation"
# 3x3 cross matrix datasets (no ISCXTor16 — different feature space)
CROSS_DATASETS = ["cicids2017", "cicddos2019", "ciciot2023"]
PRETTY = {
"cicids2017": "CICIDS17",
"cicddos2019": "CICDDoS19",
"ciciot2023": "CICIoT23",
}
SEEDS = [42, 43, 44]
AGGREGATORS = [
("JANUS-full (A5)", "A5_OAS_all10", "deployed JANUS"),
("A1 terminal_norm","A1_terminal_norm", "raw scalar (CFM head)"),
("A2 disc_total", "A2_disc_nll_total","raw scalar (disc head)"),
("A3 OAS term3", "A3_OAS_term3", "OAS on 3 cont sub-scores"),
("A4 OAS disc7", "A4_OAS_disc7", "OAS on 7 disc sub-scores"),
("A6 z-score mean", "A6_zmean_all10", "equal-weight z-score sum"),
("A7 z-score max", "A7_zmax_all10", "equal-weight z-score max"),
]
# --------------------------------------------------------------------------- #
def _cell_path(src: str, tgt: str, seed: int) -> Path | None:
"""Return npz path for (src, tgt, seed) cell, or None if missing."""
if src == tgt:
p = ROUTE / f"janus_{src}_seed{seed}" / "phase1_scores.npz"
return p if p.exists() else None
p = CROSS / f"janus_seed{seed}_{src}_to_{tgt}.npz"
return p if p.exists() else None
def _load_cell(src: str, tgt: str, seed: int):
p = _cell_path(src, tgt, seed)
if p is None:
return None, None
if src == tgt:
return _load_npz(p)
return _load_cross_npz(p)
def _score_cell(src: str, tgt: str, seed: int, score_fn_id: str) -> float | None:
val, atk = _load_cell(src, tgt, seed)
if val is None:
return None
fn = SCORE_FNS[score_fn_id]
res = fn(val, atk)
if res is None:
return None
sv, sa = res
return _auroc(sv, sa)
def _seed_means(src: str, tgt: str, score_fn_id: str) -> dict | None:
"""3-seed AUROC for cell (src,tgt). Returns dict with mean/std/ci, or None."""
vals = []
for seed in SEEDS:
v = _score_cell(src, tgt, seed, score_fn_id)
if v is not None and not np.isnan(v):
vals.append(v)
if not vals:
return None
a = np.asarray(vals)
if a.size == 1:
return {"mean": float(a[0]), "std": 0.0, "ci": 0.0, "n": 1, "vals": a.tolist()}
se = a.std(ddof=1) / np.sqrt(a.size)
return {
"mean": float(a.mean()),
"std": float(a.std(ddof=1)),
"ci": float(T_975_N3 * se) if a.size == 3 else float(1.96 * se),
"n": int(a.size),
"vals": a.tolist(),
}
# --------------------------------------------------------------------------- #
def _fmt_cell(c):
if c is None:
return ""
if c["n"] == 1:
return f"{100 * c['mean']:.2f}"
return f"{100 * c['mean']:.2f} ± {100 * c['ci']:.2f}"
def _summary_row(rows_3x3: dict[tuple[str, str], dict | None]) -> tuple[float, float, float, dict | None]:
"""Return (within_mean, cross_mean, cross_worst, worst_cell_summary)."""
within = []
cross = []
worst_v = None
worst_cell = None
for (src, tgt), cell in rows_3x3.items():
if cell is None:
continue
if src == tgt:
within.append(cell["mean"])
else:
cross.append(cell["mean"])
if worst_v is None or cell["mean"] < worst_v:
worst_v = cell["mean"]
worst_cell = (src, tgt, cell)
w = float(np.mean(within)) if within else float("nan")
c = float(np.mean(cross)) if cross else float("nan")
cw = worst_v if worst_v is not None else float("nan")
return w, c, cw, worst_cell
# --------------------------------------------------------------------------- #
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--out-dir", type=Path, default=ABL)
args = ap.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
full = {} # aggregator label -> {(src, tgt) -> cell summary}
for label, fn_id, _why in AGGREGATORS:
rows = {}
for src in CROSS_DATASETS:
for tgt in CROSS_DATASETS:
rows[(src, tgt)] = _seed_means(src, tgt, fn_id)
full[label] = rows
n_ok = sum(1 for v in rows.values() if v is not None)
print(f"[ok] {label:20s} cells={n_ok}/{len(rows)}", flush=True)
# Summary table: within mean, cross mean, cross worst
summary_lines = ["# Cross-dataset Group-A summary",
"",
f"3-seed mean ± 95% t-CI AUROC. Datasets = {CROSS_DATASETS}.",
"Aggregator fit on **target** benign val only.",
"",
"| Aggregator | Within (3 cells, mean) | Cross (6 cells, mean) | Cross worst cell | Within Cross |",
"|---|---|---|---|---|"]
summary_data = {}
for label, fn_id, _why in AGGREGATORS:
rows = full[label]
w, c, cw, worst_cell = _summary_row(rows)
gap = (w - c) * 100 if not np.isnan(w) and not np.isnan(c) else float("nan")
worst_str = ""
if worst_cell is not None:
src, tgt, cell = worst_cell
worst_str = f"{PRETTY[src]}{PRETTY[tgt]}: {_fmt_cell(cell)}"
summary_lines.append(
f"| {label} | {100 * w:.2f} | {100 * c:.2f} | {worst_str} | {gap:+.2f} |"
)
summary_data[label] = {"within_mean": w, "cross_mean": c, "cross_worst": cw, "worst_cell": worst_cell}
summary_path = args.out_dir / "ABLATION_TABLE_CROSS_summary.md"
summary_path.write_text("\n".join(summary_lines) + "\n")
print(f"[wrote] {summary_path}")
# Full per-aggregator 3x3 matrices
full_lines = ["# Cross-dataset Group-A full matrices",
"",
"Per aggregator: 3×3 matrix (rows = source / training, columns = target / test).",
"Each cell = 3-seed mean ± 95% t-CI AUROC (%). Diagonal italic = within-dataset.",
""]
for label, fn_id, why in AGGREGATORS:
full_lines.append(f"## {label} ({why})")
full_lines.append("")
header = ["Source ↓ / Target →"] + [PRETTY[d] for d in CROSS_DATASETS]
full_lines.append("| " + " | ".join(header) + " |")
full_lines.append("|" + "|".join("---" for _ in header) + "|")
for src in CROSS_DATASETS:
row = [f"**{PRETTY[src]}**"]
for tgt in CROSS_DATASETS:
cell = full[label][(src, tgt)]
txt = _fmt_cell(cell)
if src == tgt:
txt = f"_{txt}_"
row.append(txt)
full_lines.append("| " + " | ".join(row) + " |")
full_lines.append("")
full_path = args.out_dir / "ABLATION_TABLE_CROSS_full.md"
full_path.write_text("\n".join(full_lines))
print(f"[wrote] {full_path}")
json_path = args.out_dir / "ABLATION_TABLE_CROSS.json"
json_path.write_text(json.dumps({
"summary": summary_data,
"full": {label: {f"{src}->{tgt}": cell for (src, tgt), cell in rows.items()}
for label, rows in full.items()},
}, indent=2, default=lambda o: None))
print(f"[wrote] {json_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,180 @@
"""B-variant cross-dataset aggregation.
Reads:
artifacts/ablation/janus_<ds>_seed<S>_<gid>/phase1_scores.npz (within-dataset)
artifacts/ablation/cross/<gid>__seed<S>_<src>_to_<tgt>.npz (cross-dataset)
For each B-variant we apply the variant-appropriate aggregator (auto-key OAS
fits whatever sub-scores the variant produces). JANUS-full anchor is read from
the production route_comparison/ paths.
Outputs:
ABLATION_CROSS_B_summary.md within mean / cross mean / cross worst per gid
ABLATION_CROSS_B_full.md per-gid 3×3 matrices
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
from aggregate_ablation import (
SCORE_FNS, T_975_N3, _auroc, _load_npz, _load_cross_npz,
)
ROOT = Path(__file__).resolve().parents[2]
ROUTE = ROOT / "artifacts" / "route_comparison"
ROUTE_CROSS = ROUTE / "cross"
ABL = ROOT / "artifacts" / "ablation"
ABL_CROSS = ABL / "cross"
CROSS_DATASETS = ["cicids2017", "cicddos2019", "ciciot2023"]
PRETTY = {
"cicids2017": "CICIDS17",
"cicddos2019": "CICDDoS19",
"ciciot2023": "CICIoT23",
}
SEEDS = [42, 43, 44]
# (gid, label, what_removed, score_fn_id)
B_VARIANTS = [
("janus_full", "JANUS-full", "", "OAS_all_available"),
("b1_noflow", "B1 no FLOW token","global context", "OAS_all_available"),
("b2_flowonly", "B2 flow-only", "packet sequence", "A1_terminal_norm"),
("b3_allcont", "B3 all-cont", "cont/disc split", "OAS_term_all"),
("b4_alldisc", "B4 all-disc", "cont/disc split (rev)", "OAS_disc_all"),
("b5_nodisc", "B5 λ_disc=0", "joint training", "OAS_term_all"),
]
def _within_path(gid: str, ds: str, seed: int) -> Path:
if gid == "janus_full":
return ROUTE / f"janus_{ds}_seed{seed}" / "phase1_scores.npz"
return ABL / f"janus_{ds}_seed{seed}_{gid}" / "phase1_scores.npz"
def _cross_path(gid: str, src: str, tgt: str, seed: int) -> Path:
if gid == "janus_full":
return ROUTE_CROSS / f"janus_seed{seed}_{src}_to_{tgt}.npz"
return ABL_CROSS / f"{gid}__seed{seed}_{src}_to_{tgt}.npz"
def _cell_score(gid: str, src: str, tgt: str, seed: int, fn_id: str):
if src == tgt:
p = _within_path(gid, src, seed)
if not p.exists():
return None
val, atk = _load_npz(p)
else:
p = _cross_path(gid, src, tgt, seed)
if not p.exists():
return None
val, atk = _load_cross_npz(p)
fn = SCORE_FNS[fn_id]
res = fn(val, atk)
if res is None:
return None
sv, sa = res
return _auroc(sv, sa)
def _seed_summary(vals: list[float]):
a = np.asarray([v for v in vals if v is not None and not np.isnan(v)])
if a.size == 0:
return None
if a.size == 1:
return {"mean": float(a[0]), "ci": 0.0, "n": 1}
se = a.std(ddof=1) / np.sqrt(a.size)
return {"mean": float(a.mean()),
"ci": float(T_975_N3 * se) if a.size == 3 else float(1.96 * se),
"n": int(a.size)}
def _fmt(c):
if c is None:
return ""
if c["n"] == 1:
return f"{100 * c['mean']:.2f}"
return f"{100 * c['mean']:.2f} ± {100 * c['ci']:.2f}"
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--out-dir", type=Path, default=ABL)
args = ap.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
full = {}
for gid, label, _why, fn_id in B_VARIANTS:
rows = {}
for src in CROSS_DATASETS:
for tgt in CROSS_DATASETS:
vals = [_cell_score(gid, src, tgt, s, fn_id) for s in SEEDS]
rows[(src, tgt)] = _seed_summary(vals)
full[gid] = (label, rows)
n_ok = sum(1 for v in rows.values() if v is not None)
print(f"[ok] {label:20s} cells={n_ok}/{len(rows)}", flush=True)
# Summary
lines = ["# B-variant cross-dataset summary",
"",
f"3-seed mean ± 95% t-CI AUROC. Datasets = {CROSS_DATASETS}.",
"All B variants share the same aggregator-fit-on-target-benign protocol as JANUS-full.",
"",
"| Variant | What removed | Within (3 cells) | Cross (6 cells) | Cross worst | Within Cross |",
"|---|---|---|---|---|---|"]
for gid, label, why, fn_id in B_VARIANTS:
_, rows = full[gid]
within = [v["mean"] for (s, t), v in rows.items() if s == t and v is not None]
cross = [v["mean"] for (s, t), v in rows.items() if s != t and v is not None]
cross_pairs = [((s, t), v) for (s, t), v in rows.items() if s != t and v is not None]
worst = min(cross_pairs, key=lambda x: x[1]["mean"], default=None)
w = float(np.mean(within)) if within else float("nan")
c = float(np.mean(cross)) if cross else float("nan")
worst_str = ""
if worst is not None:
(s, t), v = worst
worst_str = f"{PRETTY[s]}{PRETTY[t]}: {_fmt(v)}"
gap = (w - c) * 100 if not np.isnan(w) and not np.isnan(c) else float("nan")
lines.append(f"| {label} | {why} | {100 * w:.2f} | {100 * c:.2f} | {worst_str} | {gap:+.2f} |")
summary_path = args.out_dir / "ABLATION_CROSS_B_summary.md"
summary_path.write_text("\n".join(lines) + "\n")
print(f"[wrote] {summary_path}")
# Full per-variant 3x3 matrices
flines = ["# B-variant cross-dataset full matrices",
"",
"Per variant: 3×3 matrix (rows = source, columns = target). Diagonal italic.",
"Each cell = 3-seed mean ± 95% t-CI AUROC (%).",
""]
for gid, label, why, fn_id in B_VARIANTS:
_, rows = full[gid]
flines.append(f"## {label} ({why})")
flines.append("")
header = ["Source ↓ / Target →"] + [PRETTY[d] for d in CROSS_DATASETS]
flines.append("| " + " | ".join(header) + " |")
flines.append("|" + "|".join("---" for _ in header) + "|")
for src in CROSS_DATASETS:
row = [f"**{PRETTY[src]}**"]
for tgt in CROSS_DATASETS:
cell = rows[(src, tgt)]
txt = _fmt(cell)
if src == tgt:
txt = f"_{txt}_"
row.append(txt)
flines.append("| " + " | ".join(row) + " |")
flines.append("")
full_path = args.out_dir / "ABLATION_CROSS_B_full.md"
full_path.write_text("\n".join(flines))
print(f"[wrote] {full_path}")
json_path = args.out_dir / "ABLATION_CROSS_B.json"
json_path.write_text(json.dumps({
gid: {"label": label, "rows": {f"{s}->{t}": v for (s, t), v in rows.items()}}
for gid, (label, rows) in full.items()
}, indent=2, default=lambda o: None))
print(f"[wrote] {json_path}")
if __name__ == "__main__":
main()