Add baseline reproduction: Shafir NF 2-NF ensemble (17/18 cells), ConMD Table I citation, JANUS thresholded F1 across 4 datasets

This commit is contained in:
2026-05-08 11:47:27 +08:00
parent dc22e20616
commit c33efc290a
4 changed files with 498 additions and 27 deletions

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
from sklearn.covariance import OAS
from sklearn.metrics import roc_auc_score
ROOT = Path(__file__).resolve().parents[2] / "artifacts" / "route_comparison"
DATASETS = ["cicids2017", "cicddos2019", "ciciot2023", "iscxtor2016"]
SEEDS = (42, 43, 44)
RNG_SPLIT = 12345
def metrics_at_tau(d2_b, d2_a, tau):
tp = int((d2_a >= tau).sum())
fn = int((d2_a < tau).sum())
fp = int((d2_b >= tau).sum())
tn = int((d2_b < tau).sum())
prec = tp / max(tp + fp, 1)
rec = tp / max(tp + fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
fpr = fp / max(fp + tn, 1)
return {"f1": f1, "prec": prec, "rec": rec, "fpr": fpr}
def evaluate_seed(npz_path: Path) -> dict:
z = np.load(npz_path, allow_pickle=True)
keys = sorted(k.replace("val_", "") for k in z.files if k.startswith("val_") and not k.endswith("labels"))
val_S = np.stack([z[f"val_{k}"] for k in keys], axis=1)
atk_S = np.stack([z[f"atk_{k}"] for k in keys], axis=1)
val_S = np.nan_to_num(val_S, nan=0.0, posinf=1e6, neginf=-1e6)
atk_S = np.nan_to_num(atk_S, nan=0.0, posinf=1e6, neginf=-1e6)
K = val_S.shape[1]
rng = np.random.default_rng(RNG_SPLIT)
idx = rng.permutation(len(val_S))
half = len(idx) // 2
val_A = val_S[idx[:half]]
val_B = val_S[idx[half:]]
mu = val_A.mean(axis=0)
oas = OAS().fit(val_A)
inv_cov = np.linalg.inv(oas.covariance_ + 1e-9 * np.eye(K))
def d2(S):
d = S - mu
return np.einsum("ni,ij,nj->n", d, inv_cov, d)
d2_A = d2(val_A)
d2_B = d2(val_B)
d2_atk = d2(atk_S)
auroc = float(roc_auc_score(np.r_[np.zeros(len(d2_B)), np.ones(len(d2_atk))], np.r_[d2_B, d2_atk]))
out = {"AUROC": auroc, "n_val": len(val_S), "n_atk": len(atk_S)}
for pct, name in [(95, "P95"), (99, "P99")]:
tau = float(np.percentile(d2_A, pct))
m = metrics_at_tau(d2_B, d2_atk, tau)
out[f"F1@{name}"] = m["f1"]
out[f"Prec@{name}"] = m["prec"]
out[f"Recall@{name}"] = m["rec"]
out[f"FPR@{name}"] = m["fpr"]
return out
def aggregate(dataset: str) -> dict:
rows = []
for s in SEEDS:
npz = ROOT / f"janus_{dataset}_seed{s}/phase1_scores.npz"
if not npz.exists():
print(f"[skip] {dataset} seed{s} — npz missing")
continue
rows.append(evaluate_seed(npz))
if not rows:
return {}
metrics = [k for k in rows[0] if k not in ("n_val", "n_atk")]
out = {"n_val": rows[0]["n_val"], "n_atk": rows[0]["n_atk"], "n_seeds": len(rows)}
for m in metrics:
a = np.array([r[m] for r in rows])
out[m] = {"mean": float(a.mean()), "std": float(a.std()), "per_seed": [float(x) for x in a]}
return out
SUPERVISED_REF = {
"cicddos2019": {"method": "TIPSO-GAN (supervised, single seed)", "AUROC": 0.9999, "F1": 0.9989, "source": "baselines/TIPSO-GAN/ndss_tipso_artifact/artifacts/perf_summary_cicddos2019.json"},
}
def render_md(by_ds: dict) -> str:
lines = []
lines.append("# JANUS Thresholded Metrics — Mahalanobis-OAS, 3-seed mean ± std")
lines.append("")
lines.append("Computed post-hoc from `janus_<ds>_seed{42,43,44}/phase1_scores.npz` — no retraining.")
lines.append("")
lines.append("## Protocol")
lines.append("")
lines.append("- Aggregator: **Mahalanobis-OAS** distance over the 10-d JANUS raw score vector")
lines.append(f"- (μ, Σ) fit on **benign val half A** (random split seed={RNG_SPLIT}); F1/Precision/Recall/FPR measured on **benign val half B + ALL attacks**")
lines.append("- AUROC measured on (half B + attacks)")
lines.append("- Thresholds: τ95 = 95th percentile of d² on half A; τ99 = 99th percentile")
lines.append("")
lines.append("## Headline (4 datasets × 3 seeds)")
lines.append("")
lines.append("| Dataset | n_val | n_atk | AUROC | F1@P95 | Prec@P95 | Recall@P95 | FPR@P95 | F1@P99 | TPR@P99 |")
lines.append("|---|---|---|---|---|---|---|---|---|---|")
for ds in DATASETS:
if ds not in by_ds or not by_ds[ds]:
lines.append(f"| {ds} | — | — | — | — | — | — | — | — | — |")
continue
d = by_ds[ds]
def cell(k):
v = d[k]
return f"{v['mean']:.4f} ± {v['std']:.4f}"
lines.append(
f"| {ds} | {d['n_val']} | {d['n_atk']} | {cell('AUROC')} | "
f"{cell('F1@P95')} | {cell('Prec@P95')} | {cell('Recall@P95')} | {cell('FPR@P95')} | "
f"{cell('F1@P99')} | {cell('Recall@P99')} |"
)
lines.append("")
if any(ds in SUPERVISED_REF and ds in by_ds and by_ds[ds] for ds in DATASETS):
lines.append("## Supervised SOTA reference (cell-by-cell)")
lines.append("")
lines.append("Single-seed published numbers from supervised methods, where available, for context. The protocols are not directly comparable (supervised uses attack labels at training); this is meant to show the ceiling, not for head-to-head SOTA claim.")
lines.append("")
lines.append("| Dataset | Supervised method | Sup AUROC | Sup F1 | JANUS AUROC | JANUS F1@P95 | Δ AUROC | Δ F1 |")
lines.append("|---|---|---|---|---|---|---|---|")
for ds in DATASETS:
if ds not in SUPERVISED_REF or ds not in by_ds or not by_ds[ds]:
continue
ref = SUPERVISED_REF[ds]
d = by_ds[ds]
lines.append(
f"| {ds} | {ref['method']} | {ref['AUROC']:.4f} | {ref['F1']:.4f} | "
f"{d['AUROC']['mean']:.4f} ± {d['AUROC']['std']:.4f} | "
f"{d['F1@P95']['mean']:.4f} ± {d['F1@P95']['std']:.4f} | "
f"{d['AUROC']['mean'] - ref['AUROC']:+.4f} | "
f"{d['F1@P95']['mean'] - ref['F1']:+.4f} |"
)
lines.append("")
for ds in DATASETS:
if ds not in by_ds or not by_ds[ds]:
continue
d = by_ds[ds]
lines.append(f"## {ds}")
lines.append("")
lines.append(f"n_val={d['n_val']}, n_atk={d['n_atk']}, n_seeds={d['n_seeds']}")
lines.append("")
lines.append("| Metric | seed42 | seed43 | seed44 | mean ± std |")
lines.append("|---|---|---|---|---|")
for m in ["AUROC", "F1@P95", "Prec@P95", "Recall@P95", "FPR@P95", "F1@P99", "Prec@P99", "Recall@P99", "FPR@P99"]:
v = d[m]
ps = v["per_seed"]
lines.append(f"| {m} | {ps[0]:.4f} | {ps[1]:.4f} | {ps[2]:.4f} | {v['mean']:.4f} ± {v['std']:.4f} |")
lines.append("")
return "\n".join(lines) + "\n"
def main():
p = argparse.ArgumentParser()
p.add_argument("--datasets", nargs="*", default=DATASETS)
p.add_argument("--out", type=Path, default=ROOT / "THRESHOLDED.md")
args = p.parse_args()
by_ds = {ds: aggregate(ds) for ds in args.datasets}
md = render_md(by_ds)
args.out.write_text(md)
print(md)
print(f"\n[wrote] {args.out}")
if __name__ == "__main__":
main()