ablation: add Group A (aggregator) + Group B (architecture) infrastructure
Extends MixedCFMConfig with 5 backwards-compatible flags (use_flow_token,
n_packet_tokens, disc_as_cont, cont_as_disc + cont_n_bins) so existing
JANUS-full checkpoints load with 0 missing/unexpected keys.
Adds:
- 60 ablation training configs (5 variants × 4 datasets × 3 seeds)
- scripts/ablation/{generate_configs.py, run_groupB.sh, run_cross_groupB.sh,
smoke_test.sh} — config generation + GPU drivers
- scripts/aggregate/aggregate_ablation{,_cross,_cross_B}.py — produces
within-dataset and cross-dataset (3×3) ablation tables with 3-seed mean
± 95% t-CI plus optional paired DeLong p-values
README updated with ablation section pointing at
artifacts/ablation/ABLATION_SUMMARY.md.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
56
scripts/ablation/generate_configs.py
Normal file
56
scripts/ablation/generate_configs.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Generate 60 B-group ablation configs from existing 12 base configs.
|
||||
|
||||
Reads:
|
||||
Mixed_CFM/configs/<ds>_seed<S>.yaml (4 datasets × 3 seeds = 12 base)
|
||||
|
||||
Writes:
|
||||
Mixed_CFM/configs/ablation/<gid>/<ds>_seed<S>.yaml (5 variants × 12 = 60)
|
||||
|
||||
Each variant overrides save_dir → artifacts/ablation/janus_<ds>_seed<S>_<gid>/
|
||||
plus the variant-specific flags. CICIoT2023 base is `ciciot2023_seed42.yaml`
|
||||
(NOT `ciciot2023_route_c_seed42.yaml`, which is a different score-router config).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
BASE_DIR = ROOT / "Mixed_CFM" / "configs"
|
||||
OUT_DIR = ROOT / "Mixed_CFM" / "configs" / "ablation"
|
||||
|
||||
DATASETS = ["iscxtor2016", "cicids2017", "cicddos2019", "ciciot2023"]
|
||||
SEEDS = [42, 43, 44]
|
||||
|
||||
VARIANTS = {
|
||||
"b1_noflow": {"use_flow_token": False},
|
||||
"b2_flowonly": {"n_packet_tokens": 0, "lambda_disc": 0.0},
|
||||
"b3_allcont": {"disc_as_cont": True, "lambda_disc": 0.0},
|
||||
"b4_alldisc": {"cont_as_disc": True, "n_disc_classes": 8},
|
||||
"b5_nodisc": {"lambda_disc": 0.0},
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
for gid, overrides in VARIANTS.items():
|
||||
(OUT_DIR / gid).mkdir(parents=True, exist_ok=True)
|
||||
n_written = 0
|
||||
for ds in DATASETS:
|
||||
for seed in SEEDS:
|
||||
base_path = BASE_DIR / f"{ds}_seed{seed}.yaml"
|
||||
if not base_path.exists():
|
||||
print(f"[miss] {base_path}")
|
||||
continue
|
||||
base_cfg = yaml.safe_load(base_path.read_text())
|
||||
for gid, overrides in VARIANTS.items():
|
||||
cfg = dict(base_cfg)
|
||||
cfg["save_dir"] = str(ROOT / "artifacts" / "ablation" / f"janus_{ds}_seed{seed}_{gid}")
|
||||
cfg.update(overrides)
|
||||
out = OUT_DIR / gid / f"{ds}_seed{seed}.yaml"
|
||||
out.write_text(yaml.safe_dump(cfg, sort_keys=False))
|
||||
n_written += 1
|
||||
print(f"[wrote] {n_written} config files under {OUT_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
66
scripts/ablation/run_cross_groupB.sh
Executable file
66
scripts/ablation/run_cross_groupB.sh
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env bash
|
||||
# Cross-dataset evaluation for B-group ablation models.
|
||||
# 5 variants × 6 off-diagonal directions × 3 seeds = 90 cross evals.
|
||||
#
|
||||
# Each B-variant model dir is artifacts/ablation/janus_<ds>_seed<S>_<gid>/.
|
||||
# We only cross within the 3-dataset matrix (cicids2017, cicddos2019, ciciot2023);
|
||||
# ISCXTor16 has different feature space for cross.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/ablation/run_cross_groupB.sh # all 90
|
||||
# bash scripts/ablation/run_cross_groupB.sh b1_noflow b3_allcont
|
||||
set -euo pipefail
|
||||
ROOT=/home/chy/JANUS
|
||||
EVAL=${ROOT}/Mixed_CFM/eval_cross.py
|
||||
OUT_DIR=${ROOT}/artifacts/ablation/cross
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
declare -A STORE FLOWS FEATS
|
||||
STORE[cicids2017]=${ROOT}/datasets/cicids2017/processed/full_store
|
||||
FLOWS[cicids2017]=${ROOT}/datasets/cicids2017/processed/flows.parquet
|
||||
FEATS[cicids2017]=${ROOT}/datasets/cicids2017/processed/flow_features.parquet
|
||||
STORE[cicddos2019]=${ROOT}/datasets/cicddos2019/processed/full_store
|
||||
FLOWS[cicddos2019]=${ROOT}/datasets/cicddos2019/processed/flows.parquet
|
||||
FEATS[cicddos2019]=${ROOT}/datasets/cicddos2019/processed/flow_features.parquet
|
||||
STORE[ciciot2023]=${ROOT}/datasets/ciciot2023/processed/full_store
|
||||
FLOWS[ciciot2023]=${ROOT}/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
FEATS[ciciot2023]=${ROOT}/datasets/ciciot2023/processed/flow_features.parquet
|
||||
|
||||
ALL_GIDS=(b1_noflow b2_flowonly b3_allcont b4_alldisc b5_nodisc)
|
||||
DATASETS=(cicids2017 cicddos2019 ciciot2023)
|
||||
SEEDS=(42 43 44)
|
||||
GPU="${GPU:-0}"
|
||||
|
||||
if [[ $# -gt 0 ]]; then
|
||||
GIDS=("$@")
|
||||
else
|
||||
GIDS=("${ALL_GIDS[@]}")
|
||||
fi
|
||||
|
||||
run_one() {
|
||||
local gid=$1 src=$2 tgt=$3 seed=$4
|
||||
local md=${ROOT}/artifacts/ablation/janus_${src}_seed${seed}_${gid}
|
||||
local out=${OUT_DIR}/${gid}__seed${seed}_${src}_to_${tgt}.json
|
||||
if [[ -f "${out}" ]]; then echo "[skip] $gid ${src}→${tgt} seed${seed}"; return; fi
|
||||
if [[ ! -f "${md}/model.pt" ]]; then echo "[missing model] ${md}/model.pt"; return; fi
|
||||
echo "[gpu${GPU}] $(date +%H:%M:%S) $gid ${src} → ${tgt} seed${seed}"
|
||||
cd ${ROOT}/Mixed_CFM
|
||||
CUDA_VISIBLE_DEVICES=${GPU} uv run --no-sync python -u ${EVAL} \
|
||||
--model-dir ${md} \
|
||||
--target-store ${STORE[$tgt]} --target-flows ${FLOWS[$tgt]} --target-flow-features ${FEATS[$tgt]} \
|
||||
--benign-label normal --n-benign 10000 --n-attack 1000000 \
|
||||
--out ${out} --seed ${seed} --T 64 --batch-size 512 --n-steps 16 \
|
||||
> ${OUT_DIR}/${gid}__seed${seed}_${src}_to_${tgt}.log 2>&1
|
||||
}
|
||||
|
||||
for gid in "${GIDS[@]}"; do
|
||||
for src in "${DATASETS[@]}"; do
|
||||
for tgt in "${DATASETS[@]}"; do
|
||||
[[ "$src" == "$tgt" ]] && continue
|
||||
for seed in "${SEEDS[@]}"; do
|
||||
run_one "$gid" "$src" "$tgt" "$seed"
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
echo "[done] cross evals complete"
|
||||
76
scripts/ablation/run_groupB.sh
Executable file
76
scripts/ablation/run_groupB.sh
Executable file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env bash
|
||||
# Run all 60 B-group ablation training + phase1-eval runs.
|
||||
#
|
||||
# Splits work across two GPUs round-robin (set GPUS env to override).
|
||||
# Logs per-run go to artifacts/ablation/<save_dir>/{train,phase1}.log.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/ablation/run_groupB.sh # all 60 runs
|
||||
# bash scripts/ablation/run_groupB.sh b1_noflow b5_nodisc # subset of groups
|
||||
# GPUS=0 bash scripts/ablation/run_groupB.sh # single-GPU serial
|
||||
set -euo pipefail
|
||||
cd "$(dirname "$0")/../.."
|
||||
|
||||
ALL_GIDS=(b1_noflow b2_flowonly b3_allcont b4_alldisc b5_nodisc)
|
||||
DATASETS=(iscxtor2016 cicids2017 cicddos2019 ciciot2023)
|
||||
SEEDS=(42 43 44)
|
||||
GPUS="${GPUS:-0,1}"
|
||||
IFS=',' read -ra GPU_ARR <<< "$GPUS"
|
||||
N_GPU=${#GPU_ARR[@]}
|
||||
|
||||
if [[ $# -gt 0 ]]; then
|
||||
GIDS=("$@")
|
||||
else
|
||||
GIDS=("${ALL_GIDS[@]}")
|
||||
fi
|
||||
|
||||
# Build the full run list
|
||||
runs=()
|
||||
for gid in "${GIDS[@]}"; do
|
||||
for ds in "${DATASETS[@]}"; do
|
||||
for seed in "${SEEDS[@]}"; do
|
||||
runs+=("${gid}|${ds}|${seed}")
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
n_runs=${#runs[@]}
|
||||
echo "[plan] ${n_runs} runs across GPUs ${GPUS} (gids=${GIDS[*]})"
|
||||
|
||||
run_one() {
|
||||
local spec="$1" gpu_id="$2"
|
||||
IFS='|' read -r gid ds seed <<< "$spec"
|
||||
local cfg="Mixed_CFM/configs/ablation/${gid}/${ds}_seed${seed}.yaml"
|
||||
local save_dir
|
||||
save_dir=$(uv run --no-sync python -c "import yaml,sys; print(yaml.safe_load(open('$cfg'))['save_dir'])")
|
||||
mkdir -p "$save_dir"
|
||||
echo "[gpu${gpu_id}] $(date +%H:%M:%S) START $gid $ds seed${seed}"
|
||||
CUDA_VISIBLE_DEVICES="$gpu_id" uv run --no-sync python Mixed_CFM/train.py \
|
||||
--config "$cfg" >"$save_dir/train.log" 2>&1
|
||||
CUDA_VISIBLE_DEVICES="$gpu_id" uv run --no-sync python Mixed_CFM/eval_phase1.py \
|
||||
--model-dir "$save_dir" --out-dir "$save_dir" \
|
||||
--batch-size 256 --n-steps 16 \
|
||||
--n-val-cap 30000 --n-atk-cap 30000 >"$save_dir/phase1.log" 2>&1
|
||||
echo "[gpu${gpu_id}] $(date +%H:%M:%S) DONE $gid $ds seed${seed}"
|
||||
}
|
||||
|
||||
# Round-robin assignment
|
||||
pids=()
|
||||
for i in "${!runs[@]}"; do
|
||||
spec="${runs[$i]}"
|
||||
gpu_id="${GPU_ARR[$((i % N_GPU))]}"
|
||||
# If single GPU: serial; if multi-GPU: parallel up to N_GPU at a time
|
||||
if [[ $N_GPU -eq 1 ]]; then
|
||||
run_one "$spec" "$gpu_id"
|
||||
else
|
||||
run_one "$spec" "$gpu_id" &
|
||||
pids+=($!)
|
||||
# Cap concurrency at N_GPU
|
||||
if (( (i + 1) % N_GPU == 0 )); then
|
||||
for pid in "${pids[@]}"; do wait "$pid" || true; done
|
||||
pids=()
|
||||
fi
|
||||
fi
|
||||
done
|
||||
for pid in "${pids[@]}"; do wait "$pid" || true; done
|
||||
echo "[done] all ${n_runs} runs complete"
|
||||
39
scripts/ablation/smoke_test.sh
Executable file
39
scripts/ablation/smoke_test.sh
Executable file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env bash
|
||||
# Smoke-test all 5 B-group variants on cicids2017 seed42 with reduced epochs
|
||||
# and tiny train set, on CPU (so VLLM workers on the GPUs are not disturbed).
|
||||
#
|
||||
# After: each ablation/janus_cicids2017_seed42_<gid>/ should contain model.pt
|
||||
# + phase1_scores.npz with the variant-specific score keys.
|
||||
set -euo pipefail
|
||||
cd "$(dirname "$0")/../.."
|
||||
|
||||
GIDS=(b1_noflow b2_flowonly b3_allcont b4_alldisc b5_nodisc)
|
||||
DS=cicids2017
|
||||
SEED=42
|
||||
|
||||
for gid in "${GIDS[@]}"; do
|
||||
cfg="Mixed_CFM/configs/ablation/${gid}/${DS}_seed${SEED}.yaml"
|
||||
echo "=================================================="
|
||||
echo "[smoke] $gid"
|
||||
echo "=================================================="
|
||||
uv run --no-sync python Mixed_CFM/train.py \
|
||||
--config "$cfg" \
|
||||
--override "device=cpu" "epochs=2" "n_train=500" "eval_n=200" "eval_every=2" \
|
||||
"save_dir=/home/chy/JANUS/artifacts/ablation_smoke/${gid}" 2>&1 | tail -8
|
||||
uv run --no-sync python Mixed_CFM/eval_phase1.py \
|
||||
--model-dir "/home/chy/JANUS/artifacts/ablation_smoke/${gid}" \
|
||||
--out-dir "/home/chy/JANUS/artifacts/ablation_smoke/${gid}" \
|
||||
--device cpu --batch-size 64 --n-steps 4 \
|
||||
--n-val-cap 200 --n-atk-cap 200 2>&1 | tail -4
|
||||
echo
|
||||
done
|
||||
echo "=== Smoke summary ==="
|
||||
for gid in "${GIDS[@]}"; do
|
||||
npz="/home/chy/JANUS/artifacts/ablation_smoke/${gid}/phase1_scores.npz"
|
||||
if [[ -f "$npz" ]]; then
|
||||
keys=$(uv run --no-sync python -c "import numpy as np; z=np.load('$npz', allow_pickle=True); print(','.join(sorted(k for k in z.files if k.startswith(('val_terminal','val_disc')))))")
|
||||
echo "$gid: $keys"
|
||||
else
|
||||
echo "$gid: MISSING"
|
||||
fi
|
||||
done
|
||||
533
scripts/aggregate/aggregate_ablation.py
Normal file
533
scripts/aggregate/aggregate_ablation.py
Normal file
@@ -0,0 +1,533 @@
|
||||
"""JANUS ablation aggregator (Groups A + B).
|
||||
|
||||
Reads phase1_scores.npz from:
|
||||
artifacts/route_comparison/janus_<ds>_seed<S>/ (A + JANUS-full anchor)
|
||||
artifacts/ablation/janus_<ds>_seed<S>_<gid>/ (B variants)
|
||||
|
||||
Produces:
|
||||
artifacts/ablation/ABLATION_TABLE.md final markdown table
|
||||
artifacts/ablation/ABLATION_TABLE_RAW.json per-cell mean / std / CI / per-seed
|
||||
artifacts/ablation/ABLATION_DELONG.md paired DeLong p-values vs JANUS-full
|
||||
|
||||
Group A operates entirely on existing route_comparison npz files (no GPU).
|
||||
Group B requires the 60 B-variant runs to have completed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
from sklearn.covariance import OAS
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
ROUTE = ROOT / "artifacts" / "route_comparison"
|
||||
ABL = ROOT / "artifacts" / "ablation"
|
||||
|
||||
DATASETS = ["iscxtor2016", "cicids2017", "cicddos2019", "ciciot2023"]
|
||||
PRETTY = {
|
||||
"iscxtor2016": "ISCXTor16",
|
||||
"cicids2017": "CICIDS17",
|
||||
"cicddos2019": "CICDDoS19",
|
||||
"ciciot2023": "CICIoT23",
|
||||
}
|
||||
SEEDS = [42, 43, 44]
|
||||
T_975_N3 = 4.302653 # 95% t-CI factor for n=3 (df=2)
|
||||
|
||||
CONT_KEYS = ["terminal_norm", "terminal_flow", "terminal_packet"]
|
||||
DISC_KEYS = ["disc_nll_total", "disc_nll_ch2", "disc_nll_ch3",
|
||||
"disc_nll_ch4", "disc_nll_ch5", "disc_nll_ch6", "disc_nll_ch7"]
|
||||
ALL_KEYS = CONT_KEYS + DISC_KEYS # 10-d
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# I/O #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _load_npz(npz_path: Path):
|
||||
z = np.load(npz_path, allow_pickle=True)
|
||||
val = {}
|
||||
atk = {}
|
||||
for k in z.files:
|
||||
if k.startswith("val_") and k != "val_labels":
|
||||
val[k[4:]] = z[k]
|
||||
elif k.startswith("atk_") and k != "atk_labels":
|
||||
atk[k[4:]] = z[k]
|
||||
return val, atk
|
||||
|
||||
|
||||
def _load_cross_npz(npz_path: Path):
|
||||
"""Cross npz schema: b_<key> = target benign, a_<key> = target attacks."""
|
||||
z = np.load(npz_path, allow_pickle=True)
|
||||
val = {}
|
||||
atk = {}
|
||||
for k in z.files:
|
||||
if k.startswith("b_") and k != "b_labels":
|
||||
val[k[2:]] = z[k]
|
||||
elif k.startswith("a_") and k != "a_labels":
|
||||
atk[k[2:]] = z[k]
|
||||
return val, atk
|
||||
|
||||
|
||||
def _stack(d: dict, keys: list[str]) -> np.ndarray:
|
||||
arrs = []
|
||||
for k in keys:
|
||||
if k in d:
|
||||
arrs.append(d[k])
|
||||
else:
|
||||
# variant doesn't produce this score (e.g. B2 has no disc, B5 disc untrained)
|
||||
return None
|
||||
out = np.stack(arrs, axis=1).astype(np.float64)
|
||||
return np.nan_to_num(out, nan=0.0, posinf=1e6, neginf=-1e6)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Score functions (Group A definitions) #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _mahal(S, mu, inv_cov):
|
||||
d = S - mu
|
||||
return np.einsum("ni,ij,nj->n", d, inv_cov, d)
|
||||
|
||||
|
||||
def _oas_mahal(val_S, atk_S):
|
||||
mu = val_S.mean(axis=0)
|
||||
cov = OAS().fit(val_S).covariance_
|
||||
inv = np.linalg.inv(cov + 1e-9 * np.eye(cov.shape[0]))
|
||||
return _mahal(val_S, mu, inv), _mahal(atk_S, mu, inv)
|
||||
|
||||
|
||||
def _zscore_agg(val_S, atk_S, mode="mean"):
|
||||
mu = val_S.mean(axis=0)
|
||||
sd = val_S.std(axis=0) + 1e-9
|
||||
zv = (val_S - mu) / sd
|
||||
za = (atk_S - mu) / sd
|
||||
if mode == "mean":
|
||||
return zv.mean(axis=1), za.mean(axis=1)
|
||||
if mode == "max":
|
||||
return zv.max(axis=1), za.max(axis=1)
|
||||
raise ValueError(mode)
|
||||
|
||||
|
||||
def score_a1_terminal_norm(val, atk):
|
||||
return val["terminal_norm"], atk["terminal_norm"]
|
||||
|
||||
|
||||
def score_a2_disc_total(val, atk):
|
||||
if "disc_nll_total" not in val:
|
||||
return None
|
||||
return val["disc_nll_total"], atk["disc_nll_total"]
|
||||
|
||||
|
||||
def score_a3_oas_term3(val, atk):
|
||||
Sv = _stack(val, CONT_KEYS)
|
||||
Sa = _stack(atk, CONT_KEYS)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _oas_mahal(Sv, Sa)
|
||||
|
||||
|
||||
def score_a4_oas_disc7(val, atk):
|
||||
Sv = _stack(val, DISC_KEYS)
|
||||
Sa = _stack(atk, DISC_KEYS)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _oas_mahal(Sv, Sa)
|
||||
|
||||
|
||||
def score_a5_oas_all10(val, atk):
|
||||
Sv = _stack(val, ALL_KEYS)
|
||||
Sa = _stack(atk, ALL_KEYS)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _oas_mahal(Sv, Sa)
|
||||
|
||||
|
||||
def score_a6_zmean(val, atk):
|
||||
Sv = _stack(val, ALL_KEYS)
|
||||
Sa = _stack(atk, ALL_KEYS)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _zscore_agg(Sv, Sa, "mean")
|
||||
|
||||
|
||||
def score_a7_zmax(val, atk):
|
||||
Sv = _stack(val, ALL_KEYS)
|
||||
Sa = _stack(atk, ALL_KEYS)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _zscore_agg(Sv, Sa, "max")
|
||||
|
||||
|
||||
def score_oas_disc_all(val, atk):
|
||||
"""Auto-discover all `disc_nll_*` keys; OAS-Mahal over them. Used by B4."""
|
||||
keys = sorted(k for k in val.keys() if k.startswith("disc_nll_"))
|
||||
if not keys:
|
||||
return None
|
||||
Sv = _stack(val, keys)
|
||||
Sa = _stack(atk, keys)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _oas_mahal(Sv, Sa)
|
||||
|
||||
|
||||
def score_oas_all_available(val, atk):
|
||||
"""OAS-Mahal over all `terminal_*` ∪ `disc_nll_*` keys present in the npz.
|
||||
|
||||
Used by B1 (no terminal_flow). Handles arbitrary subset of the 10 standard keys.
|
||||
"""
|
||||
keys = sorted([k for k in val.keys() if k.startswith("terminal_") or k.startswith("disc_nll_")])
|
||||
if not keys:
|
||||
return None
|
||||
if len(keys) == 1:
|
||||
return val[keys[0]], atk[keys[0]]
|
||||
Sv = _stack(val, keys)
|
||||
Sa = _stack(atk, keys)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _oas_mahal(Sv, Sa)
|
||||
|
||||
|
||||
def score_oas_term_all(val, atk):
|
||||
"""Auto-discover all `terminal_*` keys; OAS-Mahal. Used by B3 (3 keys) / B1 (2 keys)."""
|
||||
keys = sorted(k for k in val.keys() if k.startswith("terminal_"))
|
||||
if not keys:
|
||||
return None
|
||||
if len(keys) == 1:
|
||||
# single scalar: just return raw
|
||||
return val[keys[0]], atk[keys[0]]
|
||||
Sv = _stack(val, keys)
|
||||
Sa = _stack(atk, keys)
|
||||
if Sv is None or Sa is None:
|
||||
return None
|
||||
return _oas_mahal(Sv, Sa)
|
||||
|
||||
|
||||
SCORE_FNS = {
|
||||
"A1_terminal_norm": score_a1_terminal_norm,
|
||||
"A2_disc_nll_total": score_a2_disc_total,
|
||||
"A3_OAS_term3": score_a3_oas_term3,
|
||||
"A4_OAS_disc7": score_a4_oas_disc7,
|
||||
"A5_OAS_all10": score_a5_oas_all10,
|
||||
"A6_zmean_all10": score_a6_zmean,
|
||||
"A7_zmax_all10": score_a7_zmax,
|
||||
"OAS_disc_all": score_oas_disc_all,
|
||||
"OAS_term_all": score_oas_term_all,
|
||||
"OAS_all_available": score_oas_all_available,
|
||||
}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Stats #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _auroc(s_v, s_a):
|
||||
y = np.r_[np.zeros(len(s_v)), np.ones(len(s_a))]
|
||||
s = np.r_[s_v, s_a]
|
||||
return float(roc_auc_score(y, s))
|
||||
|
||||
|
||||
def _mean_ci(values: list[float]):
|
||||
"""3-seed mean ± 95% t-CI (n=3, df=2)."""
|
||||
a = np.asarray([v for v in values if v is not None and not np.isnan(v)], dtype=float)
|
||||
if a.size == 0:
|
||||
return None
|
||||
if a.size == 1:
|
||||
return {"mean": float(a[0]), "std": 0.0, "ci": 0.0, "n": 1, "vals": a.tolist()}
|
||||
se = a.std(ddof=1) / np.sqrt(a.size)
|
||||
return {
|
||||
"mean": float(a.mean()),
|
||||
"std": float(a.std(ddof=1)),
|
||||
"ci": float(T_975_N3 * se) if a.size == 3 else float(1.96 * se),
|
||||
"n": int(a.size),
|
||||
"vals": a.tolist(),
|
||||
}
|
||||
|
||||
|
||||
def _delong_var(s_v, s_a):
|
||||
"""Compute DeLong AUROC variance (Sun & Xu 2014, fast O(n log n))."""
|
||||
n0, n1 = len(s_v), len(s_a)
|
||||
s = np.concatenate([s_a, s_v]) # positives first
|
||||
order = np.argsort(s, kind="mergesort")
|
||||
L = np.empty_like(s)
|
||||
s_sorted = s[order]
|
||||
# midrank
|
||||
i = 0
|
||||
while i < len(s_sorted):
|
||||
j = i
|
||||
while j < len(s_sorted) and s_sorted[j] == s_sorted[i]:
|
||||
j += 1
|
||||
L[order[i:j]] = (i + j - 1) / 2.0 + 1
|
||||
i = j
|
||||
# ranks split
|
||||
L_a = L[:n1]
|
||||
L_v = L[n1:]
|
||||
# midrank within each class
|
||||
s_a_order = np.argsort(s_a, kind="mergesort")
|
||||
L_aa = np.empty(n1)
|
||||
sa_sorted = s_a[s_a_order]
|
||||
i = 0
|
||||
while i < n1:
|
||||
j = i
|
||||
while j < n1 and sa_sorted[j] == sa_sorted[i]:
|
||||
j += 1
|
||||
L_aa[s_a_order[i:j]] = (i + j - 1) / 2.0 + 1
|
||||
i = j
|
||||
s_v_order = np.argsort(s_v, kind="mergesort")
|
||||
L_vv = np.empty(n0)
|
||||
sv_sorted = s_v[s_v_order]
|
||||
i = 0
|
||||
while i < n0:
|
||||
j = i
|
||||
while j < n0 and sv_sorted[j] == sv_sorted[i]:
|
||||
j += 1
|
||||
L_vv[s_v_order[i:j]] = (i + j - 1) / 2.0 + 1
|
||||
i = j
|
||||
auc = (L_a.sum() / n1 - (n1 + 1) / 2) / n0
|
||||
V10 = (L_a - L_aa) / n0 # length n1
|
||||
V01 = 1 - (L_v - L_vv) / n1 # length n0
|
||||
s10 = V10.var(ddof=1)
|
||||
s01 = V01.var(ddof=1)
|
||||
var = s10 / n1 + s01 / n0
|
||||
return float(auc), float(var), V10, V01
|
||||
|
||||
|
||||
def _delong_paired_p(s_v, s_a, t_v, t_a):
|
||||
"""Paired DeLong test for two AUROCs on the same data.
|
||||
|
||||
Returns (auc1 - auc2, p_value_two_sided).
|
||||
s_*: candidate scores; t_*: reference (JANUS-full) scores.
|
||||
Both arrays must align flow-by-flow.
|
||||
"""
|
||||
auc1, var1, V10_1, V01_1 = _delong_var(s_v, s_a)
|
||||
auc2, var2, V10_2, V01_2 = _delong_var(t_v, t_a)
|
||||
n1, n0 = len(s_a), len(s_v)
|
||||
cov10 = np.cov(np.stack([V10_1, V10_2]), ddof=1)[0, 1]
|
||||
cov01 = np.cov(np.stack([V01_1, V01_2]), ddof=1)[0, 1]
|
||||
cov12 = cov10 / n1 + cov01 / n0
|
||||
var_diff = var1 + var2 - 2 * cov12
|
||||
if var_diff <= 0:
|
||||
return auc1 - auc2, 1.0
|
||||
z = (auc1 - auc2) / np.sqrt(var_diff)
|
||||
# two-sided
|
||||
from scipy.stats import norm
|
||||
p = 2 * (1 - norm.cdf(abs(z)))
|
||||
return auc1 - auc2, float(p)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Aggregation entry points #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@dataclass
|
||||
class VariantSpec:
|
||||
vid: str
|
||||
label: str
|
||||
what_removed: str
|
||||
npz_dir_pattern: str # e.g. "route_comparison/janus_{ds}_seed{seed}" or "ablation/janus_{ds}_seed{seed}_{gid}"
|
||||
score_fn_id: str # which Group A score to apply on the npz (usually "A5_OAS_all10")
|
||||
gid: str = "" # for B variants
|
||||
|
||||
|
||||
def _expand_path(spec: VariantSpec, ds: str, seed: int) -> Path:
|
||||
return ROOT / "artifacts" / spec.npz_dir_pattern.format(ds=ds, seed=seed, gid=spec.gid) / "phase1_scores.npz"
|
||||
|
||||
|
||||
def collect_variant(spec: VariantSpec) -> dict:
|
||||
rows: dict[str, list[float]] = {ds: [] for ds in DATASETS}
|
||||
per_seed: dict[str, dict[int, float]] = {ds: {} for ds in DATASETS}
|
||||
for ds in DATASETS:
|
||||
for seed in SEEDS:
|
||||
npz = _expand_path(spec, ds, seed)
|
||||
if not npz.exists():
|
||||
continue
|
||||
val, atk = _load_npz(npz)
|
||||
fn = SCORE_FNS[spec.score_fn_id]
|
||||
res = fn(val, atk)
|
||||
if res is None:
|
||||
continue
|
||||
sv, sa = res
|
||||
auc = _auroc(sv, sa)
|
||||
rows[ds].append(auc)
|
||||
per_seed[ds][seed] = auc
|
||||
summary = {ds: _mean_ci(rows[ds]) for ds in DATASETS}
|
||||
return {
|
||||
"vid": spec.vid,
|
||||
"label": spec.label,
|
||||
"what_removed": spec.what_removed,
|
||||
"score_fn_id": spec.score_fn_id,
|
||||
"gid": spec.gid,
|
||||
"per_dataset": summary,
|
||||
"per_seed": per_seed,
|
||||
}
|
||||
|
||||
|
||||
def collect_delong_pvals(spec: VariantSpec, ref_spec: VariantSpec) -> dict:
|
||||
"""Paired DeLong test: spec vs ref_spec, on each (ds, seed)."""
|
||||
out: dict[str, list[dict]] = {ds: [] for ds in DATASETS}
|
||||
for ds in DATASETS:
|
||||
for seed in SEEDS:
|
||||
npz_s = _expand_path(spec, ds, seed)
|
||||
npz_r = _expand_path(ref_spec, ds, seed)
|
||||
if not (npz_s.exists() and npz_r.exists()):
|
||||
continue
|
||||
val_s, atk_s = _load_npz(npz_s)
|
||||
val_r, atk_r = _load_npz(npz_r)
|
||||
fn_s = SCORE_FNS[spec.score_fn_id]
|
||||
fn_r = SCORE_FNS[ref_spec.score_fn_id]
|
||||
res_s = fn_s(val_s, atk_s)
|
||||
res_r = fn_r(val_r, atk_r)
|
||||
if res_s is None or res_r is None:
|
||||
continue
|
||||
sv_s, sa_s = res_s
|
||||
sv_r, sa_r = res_r
|
||||
# if shapes differ (e.g. variant evaluated on subset), align by index — they should match seed-for-seed
|
||||
# in practice for B variants the npz is from the SAME data as JANUS-full at that (ds, seed)
|
||||
if len(sv_s) != len(sv_r) or len(sa_s) != len(sa_r):
|
||||
continue
|
||||
d, p = _delong_paired_p(sv_s, sa_s, sv_r, sa_r)
|
||||
out[ds].append({"seed": seed, "delta": d, "p": p})
|
||||
return out
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Variant registry #
|
||||
# --------------------------------------------------------------------------- #
|
||||
ROUTE_DIR = "route_comparison/janus_{ds}_seed{seed}"
|
||||
ABL_DIR = "ablation/janus_{ds}_seed{seed}_{gid}"
|
||||
|
||||
|
||||
def _group_a_specs() -> list[VariantSpec]:
|
||||
base = ROUTE_DIR
|
||||
return [
|
||||
VariantSpec("JANUS-full", "JANUS-full (A5)", "—", base, "A5_OAS_all10"),
|
||||
VariantSpec("A1", "A1 terminal_norm", "OAS aggregator + disc head", base, "A1_terminal_norm"),
|
||||
VariantSpec("A2", "A2 disc_nll_total", "OAS aggregator + CFM head", base, "A2_disc_nll_total"),
|
||||
VariantSpec("A3", "A3 OAS-Mahal term3", "disc head", base, "A3_OAS_term3"),
|
||||
VariantSpec("A4", "A4 OAS-Mahal disc7", "CFM head", base, "A4_OAS_disc7"),
|
||||
VariantSpec("A6", "A6 z-score mean (10-d)", "covariance structure", base, "A6_zmean_all10"),
|
||||
VariantSpec("A7", "A7 z-score max (10-d)", "weighted aggregation", base, "A7_zmax_all10"),
|
||||
]
|
||||
|
||||
|
||||
def _group_b_specs() -> list[VariantSpec]:
|
||||
return [
|
||||
# B1 has 2 terminal keys (no terminal_flow) + full disc7 → use auto-key OAS (9-d in this case)
|
||||
VariantSpec("B1", "B1 no FLOW token", "global context", ABL_DIR, "OAS_all_available", gid="b1_noflow"),
|
||||
# B2 has only terminal_flow (= terminal_norm); single scalar
|
||||
VariantSpec("B2", "B2 flow-only", "packet sequence", ABL_DIR, "A1_terminal_norm", gid="b2_flowonly"),
|
||||
# B3 has terminal_norm/flow/packet covering all 9 dims (cont + disc-as-cont); OAS on 3-tuple
|
||||
VariantSpec("B3", "B3 all-cont", "cont/disc split", ABL_DIR, "A3_OAS_term3", gid="b3_allcont"),
|
||||
# B4 has 9 disc channels + total; auto-discover keys
|
||||
VariantSpec("B4", "B4 all-disc", "cont/disc split (rev)", ABL_DIR, "OAS_disc_all", gid="b4_alldisc"),
|
||||
# B5 has full schema but disc head is untrained noise; use term3 only
|
||||
VariantSpec("B5", "B5 λ_disc=0", "joint training", ABL_DIR, "A3_OAS_term3", gid="b5_nodisc"),
|
||||
]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Markdown writer #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _fmt_cell(c: dict | None) -> str:
|
||||
if c is None:
|
||||
return "—"
|
||||
if c["n"] == 1:
|
||||
return f"{100 * c['mean']:.2f}"
|
||||
return f"{100 * c['mean']:.2f} ± {100 * c['ci']:.2f}"
|
||||
|
||||
|
||||
def write_table(rows: list[dict], path: Path, *, title: str = "JANUS ablation"):
|
||||
lines = [f"# {title}", ""]
|
||||
lines.append(f"3-seed mean ± 95% t-CI AUROC (%). Seeds = {SEEDS}.")
|
||||
lines.append("")
|
||||
header = ["Variant", "What removed"] + [PRETTY[ds] for ds in DATASETS] + ["Mean"]
|
||||
lines.append("| " + " | ".join(header) + " |")
|
||||
lines.append("|" + "|".join("---" for _ in header) + "|")
|
||||
for r in rows:
|
||||
cells = [r["label"], r["what_removed"]]
|
||||
ds_means = []
|
||||
for ds in DATASETS:
|
||||
c = r["per_dataset"].get(ds)
|
||||
cells.append(_fmt_cell(c))
|
||||
if c is not None:
|
||||
ds_means.append(c["mean"])
|
||||
cells.append(f"{100 * np.mean(ds_means):.2f}" if ds_means else "—")
|
||||
lines.append("| " + " | ".join(cells) + " |")
|
||||
lines.append("")
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("\n".join(lines))
|
||||
|
||||
|
||||
def write_delong(records: list[dict], path: Path):
|
||||
lines = ["# Paired DeLong p-values vs JANUS-full",
|
||||
"",
|
||||
f"Seeds = {SEEDS}. p reported per (variant, dataset, seed). "
|
||||
"Holm-Bonferroni-correctable; raw p shown.",
|
||||
""]
|
||||
for rec in records:
|
||||
lines.append(f"## {rec['label']} ({rec['vid']})")
|
||||
lines.append("")
|
||||
header = ["Seed"] + [PRETTY[ds] for ds in DATASETS]
|
||||
lines.append("| " + " | ".join(header) + " |")
|
||||
lines.append("|" + "|".join("---" for _ in header) + "|")
|
||||
for seed in SEEDS:
|
||||
row = [str(seed)]
|
||||
for ds in DATASETS:
|
||||
hits = [x for x in rec["delong"][ds] if x["seed"] == seed]
|
||||
if hits:
|
||||
h = hits[0]
|
||||
sign = "+" if h["delta"] >= 0 else "−"
|
||||
row.append(f"Δ={sign}{abs(h['delta']):.4f}, p={h['p']:.3g}")
|
||||
else:
|
||||
row.append("—")
|
||||
lines.append("| " + " | ".join(row) + " |")
|
||||
lines.append("")
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("\n".join(lines))
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Main #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--group", choices=["A", "B", "all"], default="A")
|
||||
ap.add_argument("--delong", action="store_true",
|
||||
help="Compute paired DeLong p-values vs JANUS-full (CPU heavy on big eval sets).")
|
||||
args = ap.parse_args()
|
||||
|
||||
ABL.mkdir(parents=True, exist_ok=True)
|
||||
specs: list[VariantSpec] = []
|
||||
if args.group in ("A", "all"):
|
||||
specs.extend(_group_a_specs())
|
||||
if args.group in ("B", "all"):
|
||||
specs.extend(_group_b_specs())
|
||||
|
||||
rows = []
|
||||
for spec in specs:
|
||||
r = collect_variant(spec)
|
||||
rows.append(r)
|
||||
n_ok = sum(1 for ds in DATASETS if r["per_dataset"][ds] is not None)
|
||||
print(f"[ok] {spec.vid:14s} datasets_with_data={n_ok}/{len(DATASETS)}", flush=True)
|
||||
|
||||
out_md = ABL / f"ABLATION_TABLE_{args.group}.md"
|
||||
write_table(rows, out_md, title=f"JANUS ablation (group {args.group})")
|
||||
out_json = ABL / f"ABLATION_TABLE_{args.group}.json"
|
||||
out_json.write_text(json.dumps(rows, indent=2, default=lambda o: None))
|
||||
print(f"[wrote] {out_md}")
|
||||
print(f"[wrote] {out_json}")
|
||||
|
||||
if args.delong:
|
||||
ref = next(s for s in _group_a_specs() if s.vid == "JANUS-full")
|
||||
recs = []
|
||||
for spec in specs:
|
||||
if spec.vid == "JANUS-full":
|
||||
continue
|
||||
d = collect_delong_pvals(spec, ref)
|
||||
recs.append({"vid": spec.vid, "label": spec.label, "delong": d})
|
||||
print(f"[delong] {spec.vid}", flush=True)
|
||||
write_delong(recs, ABL / f"ABLATION_DELONG_{args.group}.md")
|
||||
print(f"[wrote] {ABL / f'ABLATION_DELONG_{args.group}.md'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
218
scripts/aggregate/aggregate_ablation_cross.py
Normal file
218
scripts/aggregate/aggregate_ablation_cross.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Cross-dataset version of the Group-A score-aggregator ablation.
|
||||
|
||||
For each (src, tgt, seed) cell we have a phase1-style npz with:
|
||||
b_<key> target benign val (aggregator fit on this)
|
||||
a_<key> target attacks
|
||||
|
||||
Within-dataset (src == tgt) cells reuse the standard
|
||||
artifacts/route_comparison/janus_<ds>_seed<S>/phase1_scores.npz
|
||||
(val_/atk_ prefixes — handled via the same _load_npz path).
|
||||
|
||||
We score 7 aggregators (A1..A7) + JANUS-full's deployed A5 across all
|
||||
3×3 cells × 3 seeds, then summarize with two complementary views:
|
||||
|
||||
ABLATION_TABLE_CROSS_summary.md
|
||||
| Aggregator | Within mean | Cross mean | Cross min (worst cell) |
|
||||
Shows whether OAS's value lives in cross-dataset robustness.
|
||||
|
||||
ABLATION_TABLE_CROSS_full.md
|
||||
Per-aggregator full 3×3 matrix (each cell = 3-seed mean ± 95% t-CI).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from aggregate_ablation import (
|
||||
SCORE_FNS, T_975_N3, _auroc, _load_npz, _load_cross_npz,
|
||||
)
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
ROUTE = ROOT / "artifacts" / "route_comparison"
|
||||
CROSS = ROUTE / "cross"
|
||||
ABL = ROOT / "artifacts" / "ablation"
|
||||
|
||||
# 3x3 cross matrix datasets (no ISCXTor16 — different feature space)
|
||||
CROSS_DATASETS = ["cicids2017", "cicddos2019", "ciciot2023"]
|
||||
PRETTY = {
|
||||
"cicids2017": "CICIDS17",
|
||||
"cicddos2019": "CICDDoS19",
|
||||
"ciciot2023": "CICIoT23",
|
||||
}
|
||||
SEEDS = [42, 43, 44]
|
||||
|
||||
AGGREGATORS = [
|
||||
("JANUS-full (A5)", "A5_OAS_all10", "deployed JANUS"),
|
||||
("A1 terminal_norm","A1_terminal_norm", "raw scalar (CFM head)"),
|
||||
("A2 disc_total", "A2_disc_nll_total","raw scalar (disc head)"),
|
||||
("A3 OAS term3", "A3_OAS_term3", "OAS on 3 cont sub-scores"),
|
||||
("A4 OAS disc7", "A4_OAS_disc7", "OAS on 7 disc sub-scores"),
|
||||
("A6 z-score mean", "A6_zmean_all10", "equal-weight z-score sum"),
|
||||
("A7 z-score max", "A7_zmax_all10", "equal-weight z-score max"),
|
||||
]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _cell_path(src: str, tgt: str, seed: int) -> Path | None:
|
||||
"""Return npz path for (src, tgt, seed) cell, or None if missing."""
|
||||
if src == tgt:
|
||||
p = ROUTE / f"janus_{src}_seed{seed}" / "phase1_scores.npz"
|
||||
return p if p.exists() else None
|
||||
p = CROSS / f"janus_seed{seed}_{src}_to_{tgt}.npz"
|
||||
return p if p.exists() else None
|
||||
|
||||
|
||||
def _load_cell(src: str, tgt: str, seed: int):
|
||||
p = _cell_path(src, tgt, seed)
|
||||
if p is None:
|
||||
return None, None
|
||||
if src == tgt:
|
||||
return _load_npz(p)
|
||||
return _load_cross_npz(p)
|
||||
|
||||
|
||||
def _score_cell(src: str, tgt: str, seed: int, score_fn_id: str) -> float | None:
|
||||
val, atk = _load_cell(src, tgt, seed)
|
||||
if val is None:
|
||||
return None
|
||||
fn = SCORE_FNS[score_fn_id]
|
||||
res = fn(val, atk)
|
||||
if res is None:
|
||||
return None
|
||||
sv, sa = res
|
||||
return _auroc(sv, sa)
|
||||
|
||||
|
||||
def _seed_means(src: str, tgt: str, score_fn_id: str) -> dict | None:
|
||||
"""3-seed AUROC for cell (src,tgt). Returns dict with mean/std/ci, or None."""
|
||||
vals = []
|
||||
for seed in SEEDS:
|
||||
v = _score_cell(src, tgt, seed, score_fn_id)
|
||||
if v is not None and not np.isnan(v):
|
||||
vals.append(v)
|
||||
if not vals:
|
||||
return None
|
||||
a = np.asarray(vals)
|
||||
if a.size == 1:
|
||||
return {"mean": float(a[0]), "std": 0.0, "ci": 0.0, "n": 1, "vals": a.tolist()}
|
||||
se = a.std(ddof=1) / np.sqrt(a.size)
|
||||
return {
|
||||
"mean": float(a.mean()),
|
||||
"std": float(a.std(ddof=1)),
|
||||
"ci": float(T_975_N3 * se) if a.size == 3 else float(1.96 * se),
|
||||
"n": int(a.size),
|
||||
"vals": a.tolist(),
|
||||
}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _fmt_cell(c):
|
||||
if c is None:
|
||||
return "—"
|
||||
if c["n"] == 1:
|
||||
return f"{100 * c['mean']:.2f}"
|
||||
return f"{100 * c['mean']:.2f} ± {100 * c['ci']:.2f}"
|
||||
|
||||
|
||||
def _summary_row(rows_3x3: dict[tuple[str, str], dict | None]) -> tuple[float, float, float, dict | None]:
|
||||
"""Return (within_mean, cross_mean, cross_worst, worst_cell_summary)."""
|
||||
within = []
|
||||
cross = []
|
||||
worst_v = None
|
||||
worst_cell = None
|
||||
for (src, tgt), cell in rows_3x3.items():
|
||||
if cell is None:
|
||||
continue
|
||||
if src == tgt:
|
||||
within.append(cell["mean"])
|
||||
else:
|
||||
cross.append(cell["mean"])
|
||||
if worst_v is None or cell["mean"] < worst_v:
|
||||
worst_v = cell["mean"]
|
||||
worst_cell = (src, tgt, cell)
|
||||
w = float(np.mean(within)) if within else float("nan")
|
||||
c = float(np.mean(cross)) if cross else float("nan")
|
||||
cw = worst_v if worst_v is not None else float("nan")
|
||||
return w, c, cw, worst_cell
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--out-dir", type=Path, default=ABL)
|
||||
args = ap.parse_args()
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
full = {} # aggregator label -> {(src, tgt) -> cell summary}
|
||||
for label, fn_id, _why in AGGREGATORS:
|
||||
rows = {}
|
||||
for src in CROSS_DATASETS:
|
||||
for tgt in CROSS_DATASETS:
|
||||
rows[(src, tgt)] = _seed_means(src, tgt, fn_id)
|
||||
full[label] = rows
|
||||
n_ok = sum(1 for v in rows.values() if v is not None)
|
||||
print(f"[ok] {label:20s} cells={n_ok}/{len(rows)}", flush=True)
|
||||
|
||||
# Summary table: within mean, cross mean, cross worst
|
||||
summary_lines = ["# Cross-dataset Group-A summary",
|
||||
"",
|
||||
f"3-seed mean ± 95% t-CI AUROC. Datasets = {CROSS_DATASETS}.",
|
||||
"Aggregator fit on **target** benign val only.",
|
||||
"",
|
||||
"| Aggregator | Within (3 cells, mean) | Cross (6 cells, mean) | Cross worst cell | Within − Cross |",
|
||||
"|---|---|---|---|---|"]
|
||||
summary_data = {}
|
||||
for label, fn_id, _why in AGGREGATORS:
|
||||
rows = full[label]
|
||||
w, c, cw, worst_cell = _summary_row(rows)
|
||||
gap = (w - c) * 100 if not np.isnan(w) and not np.isnan(c) else float("nan")
|
||||
worst_str = "—"
|
||||
if worst_cell is not None:
|
||||
src, tgt, cell = worst_cell
|
||||
worst_str = f"{PRETTY[src]}→{PRETTY[tgt]}: {_fmt_cell(cell)}"
|
||||
summary_lines.append(
|
||||
f"| {label} | {100 * w:.2f} | {100 * c:.2f} | {worst_str} | {gap:+.2f} |"
|
||||
)
|
||||
summary_data[label] = {"within_mean": w, "cross_mean": c, "cross_worst": cw, "worst_cell": worst_cell}
|
||||
summary_path = args.out_dir / "ABLATION_TABLE_CROSS_summary.md"
|
||||
summary_path.write_text("\n".join(summary_lines) + "\n")
|
||||
print(f"[wrote] {summary_path}")
|
||||
|
||||
# Full per-aggregator 3x3 matrices
|
||||
full_lines = ["# Cross-dataset Group-A full matrices",
|
||||
"",
|
||||
"Per aggregator: 3×3 matrix (rows = source / training, columns = target / test).",
|
||||
"Each cell = 3-seed mean ± 95% t-CI AUROC (%). Diagonal italic = within-dataset.",
|
||||
""]
|
||||
for label, fn_id, why in AGGREGATORS:
|
||||
full_lines.append(f"## {label} ({why})")
|
||||
full_lines.append("")
|
||||
header = ["Source ↓ / Target →"] + [PRETTY[d] for d in CROSS_DATASETS]
|
||||
full_lines.append("| " + " | ".join(header) + " |")
|
||||
full_lines.append("|" + "|".join("---" for _ in header) + "|")
|
||||
for src in CROSS_DATASETS:
|
||||
row = [f"**{PRETTY[src]}**"]
|
||||
for tgt in CROSS_DATASETS:
|
||||
cell = full[label][(src, tgt)]
|
||||
txt = _fmt_cell(cell)
|
||||
if src == tgt:
|
||||
txt = f"_{txt}_"
|
||||
row.append(txt)
|
||||
full_lines.append("| " + " | ".join(row) + " |")
|
||||
full_lines.append("")
|
||||
full_path = args.out_dir / "ABLATION_TABLE_CROSS_full.md"
|
||||
full_path.write_text("\n".join(full_lines))
|
||||
print(f"[wrote] {full_path}")
|
||||
|
||||
json_path = args.out_dir / "ABLATION_TABLE_CROSS.json"
|
||||
json_path.write_text(json.dumps({
|
||||
"summary": summary_data,
|
||||
"full": {label: {f"{src}->{tgt}": cell for (src, tgt), cell in rows.items()}
|
||||
for label, rows in full.items()},
|
||||
}, indent=2, default=lambda o: None))
|
||||
print(f"[wrote] {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
180
scripts/aggregate/aggregate_ablation_cross_B.py
Normal file
180
scripts/aggregate/aggregate_ablation_cross_B.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""B-variant cross-dataset aggregation.
|
||||
|
||||
Reads:
|
||||
artifacts/ablation/janus_<ds>_seed<S>_<gid>/phase1_scores.npz (within-dataset)
|
||||
artifacts/ablation/cross/<gid>__seed<S>_<src>_to_<tgt>.npz (cross-dataset)
|
||||
|
||||
For each B-variant we apply the variant-appropriate aggregator (auto-key OAS
|
||||
fits whatever sub-scores the variant produces). JANUS-full anchor is read from
|
||||
the production route_comparison/ paths.
|
||||
|
||||
Outputs:
|
||||
ABLATION_CROSS_B_summary.md within mean / cross mean / cross worst per gid
|
||||
ABLATION_CROSS_B_full.md per-gid 3×3 matrices
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from aggregate_ablation import (
|
||||
SCORE_FNS, T_975_N3, _auroc, _load_npz, _load_cross_npz,
|
||||
)
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
ROUTE = ROOT / "artifacts" / "route_comparison"
|
||||
ROUTE_CROSS = ROUTE / "cross"
|
||||
ABL = ROOT / "artifacts" / "ablation"
|
||||
ABL_CROSS = ABL / "cross"
|
||||
|
||||
CROSS_DATASETS = ["cicids2017", "cicddos2019", "ciciot2023"]
|
||||
PRETTY = {
|
||||
"cicids2017": "CICIDS17",
|
||||
"cicddos2019": "CICDDoS19",
|
||||
"ciciot2023": "CICIoT23",
|
||||
}
|
||||
SEEDS = [42, 43, 44]
|
||||
|
||||
# (gid, label, what_removed, score_fn_id)
|
||||
B_VARIANTS = [
|
||||
("janus_full", "JANUS-full", "—", "OAS_all_available"),
|
||||
("b1_noflow", "B1 no FLOW token","global context", "OAS_all_available"),
|
||||
("b2_flowonly", "B2 flow-only", "packet sequence", "A1_terminal_norm"),
|
||||
("b3_allcont", "B3 all-cont", "cont/disc split", "OAS_term_all"),
|
||||
("b4_alldisc", "B4 all-disc", "cont/disc split (rev)", "OAS_disc_all"),
|
||||
("b5_nodisc", "B5 λ_disc=0", "joint training", "OAS_term_all"),
|
||||
]
|
||||
|
||||
|
||||
def _within_path(gid: str, ds: str, seed: int) -> Path:
|
||||
if gid == "janus_full":
|
||||
return ROUTE / f"janus_{ds}_seed{seed}" / "phase1_scores.npz"
|
||||
return ABL / f"janus_{ds}_seed{seed}_{gid}" / "phase1_scores.npz"
|
||||
|
||||
|
||||
def _cross_path(gid: str, src: str, tgt: str, seed: int) -> Path:
|
||||
if gid == "janus_full":
|
||||
return ROUTE_CROSS / f"janus_seed{seed}_{src}_to_{tgt}.npz"
|
||||
return ABL_CROSS / f"{gid}__seed{seed}_{src}_to_{tgt}.npz"
|
||||
|
||||
|
||||
def _cell_score(gid: str, src: str, tgt: str, seed: int, fn_id: str):
|
||||
if src == tgt:
|
||||
p = _within_path(gid, src, seed)
|
||||
if not p.exists():
|
||||
return None
|
||||
val, atk = _load_npz(p)
|
||||
else:
|
||||
p = _cross_path(gid, src, tgt, seed)
|
||||
if not p.exists():
|
||||
return None
|
||||
val, atk = _load_cross_npz(p)
|
||||
fn = SCORE_FNS[fn_id]
|
||||
res = fn(val, atk)
|
||||
if res is None:
|
||||
return None
|
||||
sv, sa = res
|
||||
return _auroc(sv, sa)
|
||||
|
||||
|
||||
def _seed_summary(vals: list[float]):
|
||||
a = np.asarray([v for v in vals if v is not None and not np.isnan(v)])
|
||||
if a.size == 0:
|
||||
return None
|
||||
if a.size == 1:
|
||||
return {"mean": float(a[0]), "ci": 0.0, "n": 1}
|
||||
se = a.std(ddof=1) / np.sqrt(a.size)
|
||||
return {"mean": float(a.mean()),
|
||||
"ci": float(T_975_N3 * se) if a.size == 3 else float(1.96 * se),
|
||||
"n": int(a.size)}
|
||||
|
||||
|
||||
def _fmt(c):
|
||||
if c is None:
|
||||
return "—"
|
||||
if c["n"] == 1:
|
||||
return f"{100 * c['mean']:.2f}"
|
||||
return f"{100 * c['mean']:.2f} ± {100 * c['ci']:.2f}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--out-dir", type=Path, default=ABL)
|
||||
args = ap.parse_args()
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
full = {}
|
||||
for gid, label, _why, fn_id in B_VARIANTS:
|
||||
rows = {}
|
||||
for src in CROSS_DATASETS:
|
||||
for tgt in CROSS_DATASETS:
|
||||
vals = [_cell_score(gid, src, tgt, s, fn_id) for s in SEEDS]
|
||||
rows[(src, tgt)] = _seed_summary(vals)
|
||||
full[gid] = (label, rows)
|
||||
n_ok = sum(1 for v in rows.values() if v is not None)
|
||||
print(f"[ok] {label:20s} cells={n_ok}/{len(rows)}", flush=True)
|
||||
|
||||
# Summary
|
||||
lines = ["# B-variant cross-dataset summary",
|
||||
"",
|
||||
f"3-seed mean ± 95% t-CI AUROC. Datasets = {CROSS_DATASETS}.",
|
||||
"All B variants share the same aggregator-fit-on-target-benign protocol as JANUS-full.",
|
||||
"",
|
||||
"| Variant | What removed | Within (3 cells) | Cross (6 cells) | Cross worst | Within − Cross |",
|
||||
"|---|---|---|---|---|---|"]
|
||||
for gid, label, why, fn_id in B_VARIANTS:
|
||||
_, rows = full[gid]
|
||||
within = [v["mean"] for (s, t), v in rows.items() if s == t and v is not None]
|
||||
cross = [v["mean"] for (s, t), v in rows.items() if s != t and v is not None]
|
||||
cross_pairs = [((s, t), v) for (s, t), v in rows.items() if s != t and v is not None]
|
||||
worst = min(cross_pairs, key=lambda x: x[1]["mean"], default=None)
|
||||
w = float(np.mean(within)) if within else float("nan")
|
||||
c = float(np.mean(cross)) if cross else float("nan")
|
||||
worst_str = "—"
|
||||
if worst is not None:
|
||||
(s, t), v = worst
|
||||
worst_str = f"{PRETTY[s]}→{PRETTY[t]}: {_fmt(v)}"
|
||||
gap = (w - c) * 100 if not np.isnan(w) and not np.isnan(c) else float("nan")
|
||||
lines.append(f"| {label} | {why} | {100 * w:.2f} | {100 * c:.2f} | {worst_str} | {gap:+.2f} |")
|
||||
summary_path = args.out_dir / "ABLATION_CROSS_B_summary.md"
|
||||
summary_path.write_text("\n".join(lines) + "\n")
|
||||
print(f"[wrote] {summary_path}")
|
||||
|
||||
# Full per-variant 3x3 matrices
|
||||
flines = ["# B-variant cross-dataset full matrices",
|
||||
"",
|
||||
"Per variant: 3×3 matrix (rows = source, columns = target). Diagonal italic.",
|
||||
"Each cell = 3-seed mean ± 95% t-CI AUROC (%).",
|
||||
""]
|
||||
for gid, label, why, fn_id in B_VARIANTS:
|
||||
_, rows = full[gid]
|
||||
flines.append(f"## {label} ({why})")
|
||||
flines.append("")
|
||||
header = ["Source ↓ / Target →"] + [PRETTY[d] for d in CROSS_DATASETS]
|
||||
flines.append("| " + " | ".join(header) + " |")
|
||||
flines.append("|" + "|".join("---" for _ in header) + "|")
|
||||
for src in CROSS_DATASETS:
|
||||
row = [f"**{PRETTY[src]}**"]
|
||||
for tgt in CROSS_DATASETS:
|
||||
cell = rows[(src, tgt)]
|
||||
txt = _fmt(cell)
|
||||
if src == tgt:
|
||||
txt = f"_{txt}_"
|
||||
row.append(txt)
|
||||
flines.append("| " + " | ".join(row) + " |")
|
||||
flines.append("")
|
||||
full_path = args.out_dir / "ABLATION_CROSS_B_full.md"
|
||||
full_path.write_text("\n".join(flines))
|
||||
print(f"[wrote] {full_path}")
|
||||
|
||||
json_path = args.out_dir / "ABLATION_CROSS_B.json"
|
||||
json_path.write_text(json.dumps({
|
||||
gid: {"label": label, "rows": {f"{s}->{t}": v for (s, t), v in rows.items()}}
|
||||
for gid, (label, rows) in full.items()
|
||||
}, indent=2, default=lambda o: None))
|
||||
print(f"[wrote] {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user