README: one-line descriptions of each baseline; figures: SVG export + label tweaks
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -32,4 +32,7 @@ Thumbs.db
|
||||
|
||||
*.tmp
|
||||
|
||||
CLAUDE.md
|
||||
CLAUDE.md
|
||||
.gitignore
|
||||
|
||||
drafts/
|
||||
|
||||
15
README.md
15
README.md
@@ -39,6 +39,21 @@ JANUS is fully unsupervised (benign-only training, no attack labels at any stage
|
||||
|
||||
Thresholded F1 metrics for JANUS across all four datasets are in `RESULTS.md` Section D. -->
|
||||
|
||||
### Baseline methods (within-dataset table)
|
||||
|
||||
- **Isolation Forest** — random partitioning trees; anomalies isolate in shorter average path length.
|
||||
- **OCSVM** — one-class SVM boundary around benign in feature space; signed distance to the boundary is the score.
|
||||
- **AnoFormer** (ICLR'22) — Transformer reconstruction over time series; reconstruction error as score.
|
||||
- **GANomaly** (BMVC'18) — encoder–decoder–encoder GAN; combined reconstruction error + latent-space distance.
|
||||
- **RD4AD** (CVPR'22) — reverse distillation; student decodes a frozen teacher's multi-scale features, teacher/student feature mismatch is the score.
|
||||
- **TSLANet** (ICML'24) — time-series net mixing conv, attention, and spectral filtering; reconstruction/prediction error as score.
|
||||
- **ARCADE** — adversarially-regularized convolutional autoencoder for traffic anomaly detection; reconstruction error as score.
|
||||
- **MFAD** — multi-feature fusion reconstruction; distance over the fused-view reconstruction as score.
|
||||
- **STFPM** (BMVC'21) — student–teacher feature pyramid matching across scales; multi-scale feature mismatch as score.
|
||||
- **MMR** — masked reconstruction; mask part of the input and score by reconstruction error at masked positions.
|
||||
- **Shafir NF + Shapley** (arXiv'26) — Normalizing Flow on CICFlowMeter flow statistics with SHAP-selected top-5 features; negative log-likelihood as score.
|
||||
- **ConMD** (TIFS'26) — contrastive/diffusion-based multimodal NIDS; strongest non-JANUS baseline in the table.
|
||||
|
||||
### 3×3 cross-dataset transfer matrix
|
||||
|
||||
Source (rows) trained on 10K benign of source dataset; target (columns) tested on full target benign + **all** target attacks. Aggregator fit on target benign val only — no attack labels at any stage. Diagonal italic = within-dataset.
|
||||
|
||||
@@ -105,11 +105,52 @@ def plot_one(npz: Path, dataset: str) -> Path:
|
||||
|
||||
out = OUT / f"velocity_field_view_{dataset.lower()}.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".svg"), bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_one_overview(npz: Path, dataset: str) -> Path:
|
||||
"""Render a clean single-panel velocity-field SVG for use as the overview-
|
||||
figure component 03 (CFM head). Training-phase visualization only:
|
||||
log-norm heatmap + white streamlines + benign t=0.5 cloud. No attacks,
|
||||
no axes / colorbar / title (the surrounding overview wrapper supplies
|
||||
those). Outputs both SVG and PDF for LaTeX flexibility.
|
||||
"""
|
||||
z = np.load(npz)
|
||||
GX = z["grid_x"]
|
||||
GY = z["grid_y"]
|
||||
field_log = z["field_log_norm"]
|
||||
field_v = z["field_v_2d"]
|
||||
benign_t05 = z["benign_t05_2d"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(3.0, 2.6), constrained_layout=True)
|
||||
vmin, vmax = np.percentile(field_log, [5, 95])
|
||||
ax.pcolormesh(GX, GY, field_log, cmap="viridis", shading="auto",
|
||||
vmin=vmin, vmax=vmax, rasterized=True)
|
||||
speed = np.linalg.norm(field_v, axis=-1)
|
||||
lw = 0.35 + 1.5 * (speed / (speed.max() + 1e-9))
|
||||
ax.streamplot(GX, GY, field_v[..., 0], field_v[..., 1],
|
||||
color="white", linewidth=lw, density=0.85, arrowsize=0.7)
|
||||
n_overlay = min(200, benign_t05.shape[0])
|
||||
rng = np.random.default_rng(0)
|
||||
idx_ov = rng.choice(benign_t05.shape[0], n_overlay, replace=False)
|
||||
ax.scatter(benign_t05[idx_ov, 0], benign_t05[idx_ov, 1],
|
||||
s=2.5, c="white", alpha=0.55, edgecolors="black",
|
||||
linewidths=0.12, rasterized=True, zorder=4)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
out = OUT / f"velocity_field_overview_{dataset.lower()}.svg"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--datasets", nargs="+",
|
||||
@@ -126,6 +167,8 @@ def main() -> None:
|
||||
continue
|
||||
p = plot_one(npz, pretty.get(ds, ds))
|
||||
print(f"[wrote] {p}")
|
||||
p_ov = plot_one_overview(npz, pretty.get(ds, ds))
|
||||
print(f"[wrote] {p_ov}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -97,6 +97,7 @@ def plot_corr_heatmap() -> Path:
|
||||
cbar.set_label("Pearson ρ on benign val", fontsize=10)
|
||||
out = OUT / "subscore_correlation_benign_val.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".svg"), bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
@@ -211,11 +212,227 @@ def plot_dual_head() -> Path:
|
||||
|
||||
out = OUT / "dual_head_oas_ellipses_top__whitened_pca_bottom.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".svg"), bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_dual_head_overview(dataset: str = "cicddos2019") -> Path:
|
||||
"""Render a clean single-panel OAS-ellipse SVG for use as overview-figure
|
||||
component 06 (Mahalanobis-OAS aggregator).
|
||||
|
||||
Visual: benign as a smooth 2D KDE density blob (blue, 6 filled
|
||||
contour levels), attacks as sparse bright red dots with white halos
|
||||
clearly outside the dense benign region, and 1/2/3-sigma OAS-Mahalanobis
|
||||
ellipses overlaid on top in bold black. The visual story: the
|
||||
aggregator (ellipses) is fit on the dense benign cloud; attacks at
|
||||
inference fall outside the ellipses, which is what makes $d^2$ a
|
||||
useful anomaly score.
|
||||
"""
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
val, atk = load_scores(dataset)
|
||||
rng = np.random.default_rng(0)
|
||||
fig, ax = plt.subplots(figsize=(3.0, 2.6), constrained_layout=True)
|
||||
|
||||
x_v, y_v = val[:, 0], val[:, 3]
|
||||
x_a, y_a = atk[:, 0], atk[:, 3]
|
||||
|
||||
nv = min(5000, len(x_v))
|
||||
iv = rng.choice(len(x_v), nv, replace=False)
|
||||
na = min(120, len(x_a)) # fewer, brighter attack dots for visibility
|
||||
ia = rng.choice(len(x_a), na, replace=False)
|
||||
|
||||
# View window: capture 99% of benign + 95% of attack
|
||||
x_lo = min(np.quantile(x_v, 0.005), np.quantile(x_a, 0.05))
|
||||
x_hi = max(np.quantile(x_v, 0.995), np.quantile(x_a, 0.95))
|
||||
y_lo = min(np.quantile(y_v, 0.005), np.quantile(y_a, 0.05))
|
||||
y_hi = max(np.quantile(y_v, 0.995), np.quantile(y_a, 0.95))
|
||||
pad_x = 0.05 * (x_hi - x_lo)
|
||||
pad_y = 0.05 * (y_hi - y_lo)
|
||||
xlim = (x_lo - pad_x, x_hi + pad_x)
|
||||
ylim = (y_lo - pad_y, y_hi + pad_y)
|
||||
|
||||
# Benign 2D KDE density blob
|
||||
kde = gaussian_kde(np.vstack([x_v[iv], y_v[iv]]))
|
||||
xx, yy = np.meshgrid(np.linspace(*xlim, 90), np.linspace(*ylim, 90))
|
||||
grid = np.vstack([xx.ravel(), yy.ravel()])
|
||||
density = kde(grid).reshape(xx.shape)
|
||||
# Drop the lowest-density floor (clip near-zero edge artefacts)
|
||||
floor = np.quantile(density, 0.55)
|
||||
levels = np.linspace(floor, density.max() * 0.97, 6)
|
||||
ax.contourf(xx, yy, density, levels=levels, cmap="Blues", alpha=0.92, zorder=1)
|
||||
|
||||
# Attack scatter (sparse, bright, white halo for crispness)
|
||||
ax.scatter(x_a[ia], y_a[ia], s=11, c="#d7191c",
|
||||
edgecolors="white", linewidth=0.5, alpha=0.95, zorder=3)
|
||||
|
||||
# OAS Mahalanobis ellipses on top: bold black
|
||||
XY_v = val[:, [0, 3]]
|
||||
oas2 = OAS().fit(XY_v)
|
||||
mu2 = XY_v.mean(axis=0)
|
||||
for ns, ls in [(1, "-"), (2, "--"), (3, ":")]:
|
||||
e = _ellipse_from_2x2(
|
||||
mu2, oas2.covariance_, ns,
|
||||
edgecolor="black", facecolor="none", lw=1.3, ls=ls, alpha=0.92,
|
||||
zorder=5,
|
||||
)
|
||||
ax.add_patch(e)
|
||||
|
||||
ax.set_xlim(*xlim)
|
||||
ax.set_ylim(*ylim)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
out = OUT / f"oas_ellipse_overview_{dataset.lower()}.svg"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def _benign_disc_p1(dataset: str) -> np.ndarray:
|
||||
"""Empirical per-channel P(c_i = 1) over all benign packets in the
|
||||
processed dataset. Returns shape [6] for (dir, SYN, FIN, RST, PSH, ACK).
|
||||
"""
|
||||
import pandas as pd
|
||||
data_root = ROOT / "datasets" / dataset / "processed"
|
||||
pkts = np.load(data_root / "packets.npz")
|
||||
packet_tokens = pkts["packet_tokens"] # [N, T_full, 9]
|
||||
packet_lengths = pkts["packet_lengths"] # [N]
|
||||
|
||||
flows_df = pd.read_parquet(data_root / "flows.parquet")
|
||||
label_norm = flows_df["label"].astype(str).str.strip().str.lower()
|
||||
benign_aliases = {"benign", "normal"}
|
||||
benign_mask = label_norm.isin(benign_aliases).values
|
||||
benign_idx = np.where(benign_mask)[0]
|
||||
|
||||
if benign_idx.size == 0:
|
||||
return np.zeros(6, dtype=np.float64)
|
||||
|
||||
T_full = packet_tokens.shape[1]
|
||||
valid = np.arange(T_full)[None, :] < packet_lengths[benign_idx, None] # [Nb, T]
|
||||
# Disc channels live at indices 2..7 of the canonical 9-d packet schema.
|
||||
disc = packet_tokens[benign_idx, :, 2:8].astype(np.float64) # [Nb, T, 6]
|
||||
masked_sum = (disc * valid[..., None]).sum(axis=(0, 1)) # [6]
|
||||
total = float(valid.sum())
|
||||
return masked_sum / max(total, 1.0)
|
||||
|
||||
|
||||
def plot_dfm_head_overview(dataset: str = "cicids2017") -> Path:
|
||||
"""Render a clean single-panel DFM head SVG for use as overview-figure
|
||||
component 04. Six paired bars (P(c=0) light / P(c=1) dark) show the
|
||||
empirical categorical distribution of the six binary packet channels
|
||||
(direction + five TCP flags) on benign packets — the distribution the
|
||||
DFM head is trained to model. Training-phase visualization: benign-only.
|
||||
|
||||
Default uses CICIDS2017 because (a) it ships a flat `packets.npz` that
|
||||
this helper reads directly, and (b) its 51/49 TCP/UDP split exercises
|
||||
the full range of flag distributions in a way UDP-heavy CICDDoS2019
|
||||
or TCP-heavy CICIoT2023 do not.
|
||||
"""
|
||||
p_c1 = _benign_disc_p1(dataset)
|
||||
p_c0 = 1.0 - p_c1
|
||||
channel_labels = ["dir", "SYN", "FIN", "RST", "PSH", "ACK"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(3.2, 2.4), constrained_layout=True)
|
||||
x = np.arange(6, dtype=float)
|
||||
bar_w = 0.36
|
||||
ax.bar(x - bar_w / 2 - 0.02, p_c0, bar_w,
|
||||
color="#F4C58A", edgecolor="#a85518", linewidth=0.5,
|
||||
label=r"$P(c_i{=}0)$")
|
||||
ax.bar(x + bar_w / 2 + 0.02, p_c1, bar_w,
|
||||
color="#A85518", edgecolor="#5a2f0e", linewidth=0.5,
|
||||
label=r"$P(c_i{=}1)$")
|
||||
|
||||
# Reference line at y=0.5
|
||||
ax.axhline(0.5, color="#888", lw=0.4, ls="--", alpha=0.55, zorder=0)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(channel_labels, fontsize=7.5)
|
||||
ax.set_ylim(0, 1.08)
|
||||
ax.set_yticks([])
|
||||
ax.tick_params(axis="x", length=0, pad=2)
|
||||
for side in ("top", "right", "left"):
|
||||
ax.spines[side].set_visible(False)
|
||||
ax.spines["bottom"].set_linewidth(0.45)
|
||||
|
||||
ax.legend(
|
||||
loc="upper right", fontsize=6.5, frameon=False,
|
||||
ncol=2, bbox_to_anchor=(1.00, 1.13),
|
||||
handletextpad=0.35, columnspacing=0.8,
|
||||
)
|
||||
|
||||
out = OUT / f"dfm_head_overview_{dataset.lower()}.svg"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_score_family_overview(dataset: str = "cicids2017") -> Path:
|
||||
"""Render a clean single-panel score-family SVG for use as overview-figure
|
||||
component 05 (the 10-d score vector $s(x)$ between the heads and the
|
||||
aggregator). Replaces the schematic 10-cell row with real data: each of
|
||||
the 10 sub-scores becomes one vertical bar whose height is the *attack
|
||||
median z-score* relative to benign val (i.e., how many benign-std units
|
||||
the typical attack is shifted on this score).
|
||||
|
||||
Layout:
|
||||
- 3 cBlue bars on the left (term3, CFM-head scores)
|
||||
- 7 cOrange bars on the right (disc7, DFM-head scores)
|
||||
- Group brackets and small labels above each group.
|
||||
- Score-name x-tick labels below each bar (rotated 30°).
|
||||
- Faint benign reference line at z=0.
|
||||
"""
|
||||
val, atk = load_scores(dataset)
|
||||
mu = val.mean(axis=0)
|
||||
sd = val.std(axis=0) + 1e-9
|
||||
# z-normalised attacks: median over the attack class per score.
|
||||
z_atk = np.median((atk - mu) / sd, axis=0)
|
||||
# Same for benign val (sanity: should be ~0).
|
||||
z_val = np.median((val - mu) / sd, axis=0)
|
||||
|
||||
# CFM head fill = #FFF2CC (drawio yellow), DFM head fill = #D5E8D4 (drawio green).
|
||||
# Use the matching darker shade for the edge so bars are still visible.
|
||||
cfm_fill, cfm_edge = "#FFF2CC", "#D6B656"
|
||||
dfm_fill, dfm_edge = "#D5E8D4", "#82B366"
|
||||
fills = [cfm_fill] * 3 + [dfm_fill] * 7
|
||||
edges = [cfm_edge] * 3 + [dfm_edge] * 7
|
||||
|
||||
fig, ax = plt.subplots(figsize=(3.6, 2.0), constrained_layout=True)
|
||||
x = np.arange(10, dtype=float)
|
||||
bar_w = 0.72
|
||||
ax.bar(x, z_atk, bar_w, color=fills,
|
||||
edgecolor=edges, linewidth=0.9, zorder=3)
|
||||
|
||||
# Faint benign reference line at z=0.
|
||||
ax.axhline(0.0, color="#888", lw=0.5, ls="--", alpha=0.7, zorder=1)
|
||||
|
||||
# No x-tick labels, no top bracket annotations: clean bars only.
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.tick_params(axis="x", length=0, pad=2)
|
||||
ax.set_xlim(-0.6, 9.6)
|
||||
for side in ("top", "right", "left"):
|
||||
ax.spines[side].set_visible(False)
|
||||
ax.spines["bottom"].set_linewidth(0.45)
|
||||
|
||||
# Y-limits: keep small headroom but no need for bracket clearance now.
|
||||
hi = float(max(z_atk.max() * 1.10, 1.0))
|
||||
lo = float(min(z_atk.min() * 1.10, -0.05))
|
||||
ax.set_ylim(lo, hi)
|
||||
|
||||
out = OUT / f"score_family_overview_{dataset.lower()}.svg"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_score_hist() -> Path:
|
||||
fig, axes = plt.subplots(4, 4, figsize=(16, 12), constrained_layout=True)
|
||||
for col, ds in enumerate(DATASETS):
|
||||
@@ -265,6 +482,7 @@ def plot_score_hist() -> Path:
|
||||
axes[0, 3].legend(loc="upper right", fontsize=8, framealpha=0.85)
|
||||
out = OUT / "score_distributions_raw__termOAS__discOAS__allOAS.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".svg"), bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
@@ -302,9 +520,131 @@ def _hist_panel(ax, sv, sa, log_x: bool = False):
|
||||
ax.set_yticks([])
|
||||
|
||||
|
||||
def _load_cross(src: str, tgt: str, seeds=(42, 43, 44)) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Load 10-d score vectors for the (src→tgt) cross-domain pair, pooled
|
||||
across seeds. b_* are benign-val from the source training domain;
|
||||
a_* are attacks from the target test domain.
|
||||
"""
|
||||
val_pool, atk_pool = [], []
|
||||
for s in seeds:
|
||||
npz = ROOT / "artifacts" / "route_comparison" / "cross" / f"janus_seed{s}_{src}_to_{tgt}.npz"
|
||||
z = np.load(npz, allow_pickle=True)
|
||||
bv = np.stack([z[f"b_{k}"] for k in SCORE_KEYS], axis=1)
|
||||
av = np.stack([z[f"a_{k}"] for k in SCORE_KEYS], axis=1)
|
||||
val_pool.append(bv)
|
||||
atk_pool.append(av)
|
||||
val = np.concatenate(val_pool, axis=0)
|
||||
atk = np.concatenate(atk_pool, axis=0)
|
||||
val = np.nan_to_num(val, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64)
|
||||
atk = np.nan_to_num(atk, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64)
|
||||
return val, atk
|
||||
|
||||
|
||||
def plot_collapse_diagnosis_overview(src: str = "cicddos2019",
|
||||
tgt: str = "cicids2017") -> Path:
|
||||
"""Render a compact two-panel SVG that visualises the source-likeness
|
||||
collapse (left: raw 1D terminal_norm under cross-dataset shift) and the
|
||||
Mahalanobis-OAS cure (right: aggregated d^2). Real cross-domain scores
|
||||
pooled across seeds 42-44.
|
||||
|
||||
Used as the overview figure component that bridges Stage 4 (score family)
|
||||
and Stage 5 (Mahal-OAS aggregator), supplying the diagnostic backbone of
|
||||
contribution C2 directly inside the architecture sketch.
|
||||
"""
|
||||
val, atk = _load_cross(src, tgt)
|
||||
y = np.r_[np.zeros(len(val)), np.ones(len(atk))]
|
||||
|
||||
# --- left panel: raw terminal_norm (1D NLL-style score) ---
|
||||
sv_raw = val[:, SCORE_KEYS.index("terminal_norm")]
|
||||
sa_raw = atk[:, SCORE_KEYS.index("terminal_norm")]
|
||||
auc_raw = roc_auc_score(y, np.r_[sv_raw, sa_raw])
|
||||
|
||||
# --- right panel: Mahal-OAS d^2 over the full 10-d score family ---
|
||||
mu, inv_cov, *_ = fit_oas(val)
|
||||
sv_mah = mahal(val, mu, inv_cov)
|
||||
sa_mah = mahal(atk, mu, inv_cov)
|
||||
auc_mah = roc_auc_score(y, np.r_[sv_mah, sa_mah])
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(5.8, 1.95), constrained_layout=False,
|
||||
gridspec_kw=dict(wspace=0.22))
|
||||
# No bottom-legend reservation; legend moves into the upper-left panel.
|
||||
fig.subplots_adjust(left=0.05, right=0.99, top=0.80, bottom=0.14)
|
||||
|
||||
def _kde_panel(ax, sv, sa, auc, log_x: bool, label_top: str):
|
||||
s = np.r_[sv, sa]
|
||||
if log_x:
|
||||
eps = max(1e-3, np.quantile(s[s > 0], 0.005) * 0.5) if (s > 0).any() else 1e-3
|
||||
sv_p = np.maximum(sv, eps)
|
||||
sa_p = np.maximum(sa, eps)
|
||||
lo = np.quantile(np.r_[sv_p, sa_p], 0.005)
|
||||
hi = np.quantile(np.r_[sv_p, sa_p], 0.999)
|
||||
bins = np.geomspace(max(lo, eps), hi, 60)
|
||||
mask_v = (sv_p >= lo) & (sv_p <= hi)
|
||||
mask_a = (sa_p >= lo) & (sa_p <= hi)
|
||||
sv_p, sa_p = sv_p[mask_v], sa_p[mask_a]
|
||||
ax.set_xscale("log")
|
||||
else:
|
||||
lo, hi = np.quantile(s, [0.005, 0.995])
|
||||
bins = np.linspace(lo, hi, 60)
|
||||
mask_v = (sv >= lo) & (sv <= hi)
|
||||
mask_a = (sa >= lo) & (sa <= hi)
|
||||
sv_p, sa_p = sv[mask_v], sa[mask_a]
|
||||
# Use raw count weighting so each class integrates to 1
|
||||
# (avoids leftover-mass spikes at clip edges).
|
||||
w_v = np.full_like(sv_p, 1.0 / max(len(sv_p), 1))
|
||||
w_a = np.full_like(sa_p, 1.0 / max(len(sa_p), 1))
|
||||
ax.hist(sv_p, bins=bins, color="#2c7fb8", alpha=0.65,
|
||||
weights=w_v, edgecolor="none")
|
||||
ax.hist(sa_p, bins=bins, color="#d7191c", alpha=0.65,
|
||||
weights=w_a, edgecolor="none")
|
||||
ax.set_yticks([])
|
||||
ax.tick_params(axis="x", labelsize=6.5, length=2, pad=1.5)
|
||||
for side in ("top", "right", "left"):
|
||||
ax.spines[side].set_visible(False)
|
||||
ax.spines["bottom"].set_linewidth(0.45)
|
||||
# State word top-center (one or two words describing the panel state).
|
||||
ax.text(
|
||||
0.5, 1.08, label_top,
|
||||
transform=ax.transAxes, ha="center", va="bottom", fontsize=9.0,
|
||||
color=("#a02a2a" if auc < 0.75 else "#1f6f3a"),
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
_kde_panel(axes[0], sv_raw, sa_raw, auc_raw, log_x=False,
|
||||
label_top="collapse")
|
||||
_kde_panel(axes[1], sv_mah, sa_mah, auc_mah, log_x=True,
|
||||
label_top="separated")
|
||||
|
||||
# Centred arrow between panels — fig-coordinates so it sits between axes.
|
||||
fig.text(0.5125, 0.48, r"$\Rightarrow$", ha="center", va="center",
|
||||
fontsize=18, color="#444")
|
||||
# Compact vertical legend in the upper-right of the LEFT panel.
|
||||
from matplotlib.patches import Patch
|
||||
legend_handles = [
|
||||
Patch(facecolor="#2c7fb8", alpha=0.65, label="benign-val"),
|
||||
Patch(facecolor="#d7191c", alpha=0.65, label="attack"),
|
||||
]
|
||||
axes[0].legend(
|
||||
handles=legend_handles,
|
||||
loc="upper right", ncol=1,
|
||||
fontsize=6.5, frameon=False,
|
||||
handlelength=0.9, handleheight=0.7,
|
||||
handletextpad=0.35, labelspacing=0.30,
|
||||
borderaxespad=0.4,
|
||||
)
|
||||
|
||||
out = OUT / f"collapse_diagnosis_overview_{src}_to_{tgt}.svg"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--which", choices=["all", "corr", "dual", "hist"], default="all")
|
||||
parser.add_argument("--which",
|
||||
choices=["all", "corr", "dual", "hist", "diag", "score"],
|
||||
default="all")
|
||||
args = parser.parse_args()
|
||||
OUT.mkdir(parents=True, exist_ok=True)
|
||||
mpl.rcParams.update({
|
||||
@@ -320,9 +660,19 @@ def main() -> None:
|
||||
if args.which in ("all", "dual"):
|
||||
p = plot_dual_head()
|
||||
print(f"[wrote] {p}")
|
||||
p_ov = plot_dual_head_overview()
|
||||
print(f"[wrote] {p_ov}")
|
||||
p_dfm = plot_dfm_head_overview()
|
||||
print(f"[wrote] {p_dfm}")
|
||||
if args.which in ("all", "hist"):
|
||||
p = plot_score_hist()
|
||||
print(f"[wrote] {p}")
|
||||
if args.which in ("all", "diag"):
|
||||
p = plot_collapse_diagnosis_overview()
|
||||
print(f"[wrote] {p}")
|
||||
if args.which in ("all", "score"):
|
||||
p = plot_score_family_overview()
|
||||
print(f"[wrote] {p}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -82,16 +82,17 @@ def plot_trajectory(npz_paths: dict[str, Path]) -> Path:
|
||||
)
|
||||
out = OUT / "fig4_trajectory_pca.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".svg"), bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_velocity_norm(npz_paths: dict[str, Path]) -> Path:
|
||||
fig, axes = plt.subplots(1, len(npz_paths), figsize=(6.5 * len(npz_paths), 5.6), constrained_layout=True)
|
||||
fig, axes = plt.subplots(1, len(npz_paths), figsize=(6.5 * len(npz_paths), 5.6 * 2 / 3), constrained_layout=True)
|
||||
if len(npz_paths) == 1:
|
||||
axes = [axes]
|
||||
for ax, (ds, npz) in zip(axes, npz_paths.items()):
|
||||
for i, (ax, (ds, npz)) in enumerate(zip(axes, npz_paths.items())):
|
||||
z = np.load(npz)
|
||||
vn_v = z["vnorm_v"] # [n, n_steps]
|
||||
vn_a = z["vnorm_a"]
|
||||
@@ -105,13 +106,15 @@ def plot_velocity_norm(npz_paths: dict[str, Path]) -> Path:
|
||||
ax.fill_between(t_steps, m_v - s_v, m_v + s_v, color="#2c7fb8", alpha=0.18)
|
||||
ax.plot(t_steps, m_a, color="#d7191c", lw=1.6, label="attack mean")
|
||||
ax.fill_between(t_steps, m_a - s_a, m_a + s_a, color="#d7191c", alpha=0.18)
|
||||
ax.set_xlabel("CFM time t (1 = data → 0 = source)")
|
||||
ax.set_ylabel("‖v(x_t, t)‖ per real token (mean over flow)")
|
||||
ax.set_xlabel("CFM time t")
|
||||
if i == 0:
|
||||
ax.set_ylabel(r"Per-token CFM velocity magnitude $\|v_\theta(x_t, t)\|_2$")
|
||||
ax.text(0.02, 1.02, PRETTY[ds], transform=ax.transAxes, fontsize=11)
|
||||
ax.invert_xaxis() # so left is t=1 (data), right is t=0 (source)
|
||||
ax.legend(fontsize=8, loc="upper left", framealpha=0.85)
|
||||
out = OUT / "velocity_norm_vs_t_benign_vs_attack.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".svg"), bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user