diff --git a/.gitignore b/.gitignore index f1e1e67..aa69094 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,7 @@ Thumbs.db *.tmp -CLAUDE.md \ No newline at end of file +CLAUDE.md +.gitignore + +drafts/ diff --git a/README.md b/README.md index db8cec7..f50eda6 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,21 @@ JANUS is fully unsupervised (benign-only training, no attack labels at any stage Thresholded F1 metrics for JANUS across all four datasets are in `RESULTS.md` Section D. --> +### Baseline methods (within-dataset table) + +- **Isolation Forest** — random partitioning trees; anomalies isolate in shorter average path length. +- **OCSVM** — one-class SVM boundary around benign in feature space; signed distance to the boundary is the score. +- **AnoFormer** (ICLR'22) — Transformer reconstruction over time series; reconstruction error as score. +- **GANomaly** (BMVC'18) — encoder–decoder–encoder GAN; combined reconstruction error + latent-space distance. +- **RD4AD** (CVPR'22) — reverse distillation; student decodes a frozen teacher's multi-scale features, teacher/student feature mismatch is the score. +- **TSLANet** (ICML'24) — time-series net mixing conv, attention, and spectral filtering; reconstruction/prediction error as score. +- **ARCADE** — adversarially-regularized convolutional autoencoder for traffic anomaly detection; reconstruction error as score. +- **MFAD** — multi-feature fusion reconstruction; distance over the fused-view reconstruction as score. +- **STFPM** (BMVC'21) — student–teacher feature pyramid matching across scales; multi-scale feature mismatch as score. +- **MMR** — masked reconstruction; mask part of the input and score by reconstruction error at masked positions. +- **Shafir NF + Shapley** (arXiv'26) — Normalizing Flow on CICFlowMeter flow statistics with SHAP-selected top-5 features; negative log-likelihood as score. +- **ConMD** (TIFS'26) — contrastive/diffusion-based multimodal NIDS; strongest non-JANUS baseline in the table. + ### 3×3 cross-dataset transfer matrix Source (rows) trained on 10K benign of source dataset; target (columns) tested on full target benign + **all** target attacks. Aggregator fit on target benign val only — no attack labels at any stage. Diagonal italic = within-dataset. diff --git a/scripts/figures/plot_field_view.py b/scripts/figures/plot_field_view.py index 44f92cc..13533b1 100644 --- a/scripts/figures/plot_field_view.py +++ b/scripts/figures/plot_field_view.py @@ -105,11 +105,52 @@ def plot_one(npz: Path, dataset: str) -> Path: out = OUT / f"velocity_field_view_{dataset.lower()}.pdf" fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".svg"), bbox_inches="tight") fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) plt.close(fig) return out +def plot_one_overview(npz: Path, dataset: str) -> Path: + """Render a clean single-panel velocity-field SVG for use as the overview- + figure component 03 (CFM head). Training-phase visualization only: + log-norm heatmap + white streamlines + benign t=0.5 cloud. No attacks, + no axes / colorbar / title (the surrounding overview wrapper supplies + those). Outputs both SVG and PDF for LaTeX flexibility. + """ + z = np.load(npz) + GX = z["grid_x"] + GY = z["grid_y"] + field_log = z["field_log_norm"] + field_v = z["field_v_2d"] + benign_t05 = z["benign_t05_2d"] + + fig, ax = plt.subplots(figsize=(3.0, 2.6), constrained_layout=True) + vmin, vmax = np.percentile(field_log, [5, 95]) + ax.pcolormesh(GX, GY, field_log, cmap="viridis", shading="auto", + vmin=vmin, vmax=vmax, rasterized=True) + speed = np.linalg.norm(field_v, axis=-1) + lw = 0.35 + 1.5 * (speed / (speed.max() + 1e-9)) + ax.streamplot(GX, GY, field_v[..., 0], field_v[..., 1], + color="white", linewidth=lw, density=0.85, arrowsize=0.7) + n_overlay = min(200, benign_t05.shape[0]) + rng = np.random.default_rng(0) + idx_ov = rng.choice(benign_t05.shape[0], n_overlay, replace=False) + ax.scatter(benign_t05[idx_ov, 0], benign_t05[idx_ov, 1], + s=2.5, c="white", alpha=0.55, edgecolors="black", + linewidths=0.12, rasterized=True, zorder=4) + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_visible(False) + + out = OUT / f"velocity_field_overview_{dataset.lower()}.svg" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") + plt.close(fig) + return out + + def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--datasets", nargs="+", @@ -126,6 +167,8 @@ def main() -> None: continue p = plot_one(npz, pretty.get(ds, ds)) print(f"[wrote] {p}") + p_ov = plot_one_overview(npz, pretty.get(ds, ds)) + print(f"[wrote] {p_ov}") if __name__ == "__main__": diff --git a/scripts/figures/plot_mechanism.py b/scripts/figures/plot_mechanism.py index 8f17f6a..f14fe87 100644 --- a/scripts/figures/plot_mechanism.py +++ b/scripts/figures/plot_mechanism.py @@ -97,6 +97,7 @@ def plot_corr_heatmap() -> Path: cbar.set_label("Pearson ρ on benign val", fontsize=10) out = OUT / "subscore_correlation_benign_val.pdf" fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".svg"), bbox_inches="tight") fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) plt.close(fig) return out @@ -211,11 +212,227 @@ def plot_dual_head() -> Path: out = OUT / "dual_head_oas_ellipses_top__whitened_pca_bottom.pdf" fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".svg"), bbox_inches="tight") fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) plt.close(fig) return out +def plot_dual_head_overview(dataset: str = "cicddos2019") -> Path: + """Render a clean single-panel OAS-ellipse SVG for use as overview-figure + component 06 (Mahalanobis-OAS aggregator). + + Visual: benign as a smooth 2D KDE density blob (blue, 6 filled + contour levels), attacks as sparse bright red dots with white halos + clearly outside the dense benign region, and 1/2/3-sigma OAS-Mahalanobis + ellipses overlaid on top in bold black. The visual story: the + aggregator (ellipses) is fit on the dense benign cloud; attacks at + inference fall outside the ellipses, which is what makes $d^2$ a + useful anomaly score. + """ + from scipy.stats import gaussian_kde + + val, atk = load_scores(dataset) + rng = np.random.default_rng(0) + fig, ax = plt.subplots(figsize=(3.0, 2.6), constrained_layout=True) + + x_v, y_v = val[:, 0], val[:, 3] + x_a, y_a = atk[:, 0], atk[:, 3] + + nv = min(5000, len(x_v)) + iv = rng.choice(len(x_v), nv, replace=False) + na = min(120, len(x_a)) # fewer, brighter attack dots for visibility + ia = rng.choice(len(x_a), na, replace=False) + + # View window: capture 99% of benign + 95% of attack + x_lo = min(np.quantile(x_v, 0.005), np.quantile(x_a, 0.05)) + x_hi = max(np.quantile(x_v, 0.995), np.quantile(x_a, 0.95)) + y_lo = min(np.quantile(y_v, 0.005), np.quantile(y_a, 0.05)) + y_hi = max(np.quantile(y_v, 0.995), np.quantile(y_a, 0.95)) + pad_x = 0.05 * (x_hi - x_lo) + pad_y = 0.05 * (y_hi - y_lo) + xlim = (x_lo - pad_x, x_hi + pad_x) + ylim = (y_lo - pad_y, y_hi + pad_y) + + # Benign 2D KDE density blob + kde = gaussian_kde(np.vstack([x_v[iv], y_v[iv]])) + xx, yy = np.meshgrid(np.linspace(*xlim, 90), np.linspace(*ylim, 90)) + grid = np.vstack([xx.ravel(), yy.ravel()]) + density = kde(grid).reshape(xx.shape) + # Drop the lowest-density floor (clip near-zero edge artefacts) + floor = np.quantile(density, 0.55) + levels = np.linspace(floor, density.max() * 0.97, 6) + ax.contourf(xx, yy, density, levels=levels, cmap="Blues", alpha=0.92, zorder=1) + + # Attack scatter (sparse, bright, white halo for crispness) + ax.scatter(x_a[ia], y_a[ia], s=11, c="#d7191c", + edgecolors="white", linewidth=0.5, alpha=0.95, zorder=3) + + # OAS Mahalanobis ellipses on top: bold black + XY_v = val[:, [0, 3]] + oas2 = OAS().fit(XY_v) + mu2 = XY_v.mean(axis=0) + for ns, ls in [(1, "-"), (2, "--"), (3, ":")]: + e = _ellipse_from_2x2( + mu2, oas2.covariance_, ns, + edgecolor="black", facecolor="none", lw=1.3, ls=ls, alpha=0.92, + zorder=5, + ) + ax.add_patch(e) + + ax.set_xlim(*xlim) + ax.set_ylim(*ylim) + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_visible(False) + + out = OUT / f"oas_ellipse_overview_{dataset.lower()}.svg" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") + plt.close(fig) + return out + + +def _benign_disc_p1(dataset: str) -> np.ndarray: + """Empirical per-channel P(c_i = 1) over all benign packets in the + processed dataset. Returns shape [6] for (dir, SYN, FIN, RST, PSH, ACK). + """ + import pandas as pd + data_root = ROOT / "datasets" / dataset / "processed" + pkts = np.load(data_root / "packets.npz") + packet_tokens = pkts["packet_tokens"] # [N, T_full, 9] + packet_lengths = pkts["packet_lengths"] # [N] + + flows_df = pd.read_parquet(data_root / "flows.parquet") + label_norm = flows_df["label"].astype(str).str.strip().str.lower() + benign_aliases = {"benign", "normal"} + benign_mask = label_norm.isin(benign_aliases).values + benign_idx = np.where(benign_mask)[0] + + if benign_idx.size == 0: + return np.zeros(6, dtype=np.float64) + + T_full = packet_tokens.shape[1] + valid = np.arange(T_full)[None, :] < packet_lengths[benign_idx, None] # [Nb, T] + # Disc channels live at indices 2..7 of the canonical 9-d packet schema. + disc = packet_tokens[benign_idx, :, 2:8].astype(np.float64) # [Nb, T, 6] + masked_sum = (disc * valid[..., None]).sum(axis=(0, 1)) # [6] + total = float(valid.sum()) + return masked_sum / max(total, 1.0) + + +def plot_dfm_head_overview(dataset: str = "cicids2017") -> Path: + """Render a clean single-panel DFM head SVG for use as overview-figure + component 04. Six paired bars (P(c=0) light / P(c=1) dark) show the + empirical categorical distribution of the six binary packet channels + (direction + five TCP flags) on benign packets — the distribution the + DFM head is trained to model. Training-phase visualization: benign-only. + + Default uses CICIDS2017 because (a) it ships a flat `packets.npz` that + this helper reads directly, and (b) its 51/49 TCP/UDP split exercises + the full range of flag distributions in a way UDP-heavy CICDDoS2019 + or TCP-heavy CICIoT2023 do not. + """ + p_c1 = _benign_disc_p1(dataset) + p_c0 = 1.0 - p_c1 + channel_labels = ["dir", "SYN", "FIN", "RST", "PSH", "ACK"] + + fig, ax = plt.subplots(figsize=(3.2, 2.4), constrained_layout=True) + x = np.arange(6, dtype=float) + bar_w = 0.36 + ax.bar(x - bar_w / 2 - 0.02, p_c0, bar_w, + color="#F4C58A", edgecolor="#a85518", linewidth=0.5, + label=r"$P(c_i{=}0)$") + ax.bar(x + bar_w / 2 + 0.02, p_c1, bar_w, + color="#A85518", edgecolor="#5a2f0e", linewidth=0.5, + label=r"$P(c_i{=}1)$") + + # Reference line at y=0.5 + ax.axhline(0.5, color="#888", lw=0.4, ls="--", alpha=0.55, zorder=0) + + ax.set_xticks(x) + ax.set_xticklabels(channel_labels, fontsize=7.5) + ax.set_ylim(0, 1.08) + ax.set_yticks([]) + ax.tick_params(axis="x", length=0, pad=2) + for side in ("top", "right", "left"): + ax.spines[side].set_visible(False) + ax.spines["bottom"].set_linewidth(0.45) + + ax.legend( + loc="upper right", fontsize=6.5, frameon=False, + ncol=2, bbox_to_anchor=(1.00, 1.13), + handletextpad=0.35, columnspacing=0.8, + ) + + out = OUT / f"dfm_head_overview_{dataset.lower()}.svg" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") + plt.close(fig) + return out + + +def plot_score_family_overview(dataset: str = "cicids2017") -> Path: + """Render a clean single-panel score-family SVG for use as overview-figure + component 05 (the 10-d score vector $s(x)$ between the heads and the + aggregator). Replaces the schematic 10-cell row with real data: each of + the 10 sub-scores becomes one vertical bar whose height is the *attack + median z-score* relative to benign val (i.e., how many benign-std units + the typical attack is shifted on this score). + + Layout: + - 3 cBlue bars on the left (term3, CFM-head scores) + - 7 cOrange bars on the right (disc7, DFM-head scores) + - Group brackets and small labels above each group. + - Score-name x-tick labels below each bar (rotated 30°). + - Faint benign reference line at z=0. + """ + val, atk = load_scores(dataset) + mu = val.mean(axis=0) + sd = val.std(axis=0) + 1e-9 + # z-normalised attacks: median over the attack class per score. + z_atk = np.median((atk - mu) / sd, axis=0) + # Same for benign val (sanity: should be ~0). + z_val = np.median((val - mu) / sd, axis=0) + + # CFM head fill = #FFF2CC (drawio yellow), DFM head fill = #D5E8D4 (drawio green). + # Use the matching darker shade for the edge so bars are still visible. + cfm_fill, cfm_edge = "#FFF2CC", "#D6B656" + dfm_fill, dfm_edge = "#D5E8D4", "#82B366" + fills = [cfm_fill] * 3 + [dfm_fill] * 7 + edges = [cfm_edge] * 3 + [dfm_edge] * 7 + + fig, ax = plt.subplots(figsize=(3.6, 2.0), constrained_layout=True) + x = np.arange(10, dtype=float) + bar_w = 0.72 + ax.bar(x, z_atk, bar_w, color=fills, + edgecolor=edges, linewidth=0.9, zorder=3) + + # Faint benign reference line at z=0. + ax.axhline(0.0, color="#888", lw=0.5, ls="--", alpha=0.7, zorder=1) + + # No x-tick labels, no top bracket annotations: clean bars only. + ax.set_xticks([]) + ax.set_yticks([]) + ax.tick_params(axis="x", length=0, pad=2) + ax.set_xlim(-0.6, 9.6) + for side in ("top", "right", "left"): + ax.spines[side].set_visible(False) + ax.spines["bottom"].set_linewidth(0.45) + + # Y-limits: keep small headroom but no need for bracket clearance now. + hi = float(max(z_atk.max() * 1.10, 1.0)) + lo = float(min(z_atk.min() * 1.10, -0.05)) + ax.set_ylim(lo, hi) + + out = OUT / f"score_family_overview_{dataset.lower()}.svg" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") + plt.close(fig) + return out + + def plot_score_hist() -> Path: fig, axes = plt.subplots(4, 4, figsize=(16, 12), constrained_layout=True) for col, ds in enumerate(DATASETS): @@ -265,6 +482,7 @@ def plot_score_hist() -> Path: axes[0, 3].legend(loc="upper right", fontsize=8, framealpha=0.85) out = OUT / "score_distributions_raw__termOAS__discOAS__allOAS.pdf" fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".svg"), bbox_inches="tight") fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) plt.close(fig) return out @@ -302,9 +520,131 @@ def _hist_panel(ax, sv, sa, log_x: bool = False): ax.set_yticks([]) +def _load_cross(src: str, tgt: str, seeds=(42, 43, 44)) -> tuple[np.ndarray, np.ndarray]: + """Load 10-d score vectors for the (src→tgt) cross-domain pair, pooled + across seeds. b_* are benign-val from the source training domain; + a_* are attacks from the target test domain. + """ + val_pool, atk_pool = [], [] + for s in seeds: + npz = ROOT / "artifacts" / "route_comparison" / "cross" / f"janus_seed{s}_{src}_to_{tgt}.npz" + z = np.load(npz, allow_pickle=True) + bv = np.stack([z[f"b_{k}"] for k in SCORE_KEYS], axis=1) + av = np.stack([z[f"a_{k}"] for k in SCORE_KEYS], axis=1) + val_pool.append(bv) + atk_pool.append(av) + val = np.concatenate(val_pool, axis=0) + atk = np.concatenate(atk_pool, axis=0) + val = np.nan_to_num(val, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64) + atk = np.nan_to_num(atk, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64) + return val, atk + + +def plot_collapse_diagnosis_overview(src: str = "cicddos2019", + tgt: str = "cicids2017") -> Path: + """Render a compact two-panel SVG that visualises the source-likeness + collapse (left: raw 1D terminal_norm under cross-dataset shift) and the + Mahalanobis-OAS cure (right: aggregated d^2). Real cross-domain scores + pooled across seeds 42-44. + + Used as the overview figure component that bridges Stage 4 (score family) + and Stage 5 (Mahal-OAS aggregator), supplying the diagnostic backbone of + contribution C2 directly inside the architecture sketch. + """ + val, atk = _load_cross(src, tgt) + y = np.r_[np.zeros(len(val)), np.ones(len(atk))] + + # --- left panel: raw terminal_norm (1D NLL-style score) --- + sv_raw = val[:, SCORE_KEYS.index("terminal_norm")] + sa_raw = atk[:, SCORE_KEYS.index("terminal_norm")] + auc_raw = roc_auc_score(y, np.r_[sv_raw, sa_raw]) + + # --- right panel: Mahal-OAS d^2 over the full 10-d score family --- + mu, inv_cov, *_ = fit_oas(val) + sv_mah = mahal(val, mu, inv_cov) + sa_mah = mahal(atk, mu, inv_cov) + auc_mah = roc_auc_score(y, np.r_[sv_mah, sa_mah]) + + fig, axes = plt.subplots(1, 2, figsize=(5.8, 1.95), constrained_layout=False, + gridspec_kw=dict(wspace=0.22)) + # No bottom-legend reservation; legend moves into the upper-left panel. + fig.subplots_adjust(left=0.05, right=0.99, top=0.80, bottom=0.14) + + def _kde_panel(ax, sv, sa, auc, log_x: bool, label_top: str): + s = np.r_[sv, sa] + if log_x: + eps = max(1e-3, np.quantile(s[s > 0], 0.005) * 0.5) if (s > 0).any() else 1e-3 + sv_p = np.maximum(sv, eps) + sa_p = np.maximum(sa, eps) + lo = np.quantile(np.r_[sv_p, sa_p], 0.005) + hi = np.quantile(np.r_[sv_p, sa_p], 0.999) + bins = np.geomspace(max(lo, eps), hi, 60) + mask_v = (sv_p >= lo) & (sv_p <= hi) + mask_a = (sa_p >= lo) & (sa_p <= hi) + sv_p, sa_p = sv_p[mask_v], sa_p[mask_a] + ax.set_xscale("log") + else: + lo, hi = np.quantile(s, [0.005, 0.995]) + bins = np.linspace(lo, hi, 60) + mask_v = (sv >= lo) & (sv <= hi) + mask_a = (sa >= lo) & (sa <= hi) + sv_p, sa_p = sv[mask_v], sa[mask_a] + # Use raw count weighting so each class integrates to 1 + # (avoids leftover-mass spikes at clip edges). + w_v = np.full_like(sv_p, 1.0 / max(len(sv_p), 1)) + w_a = np.full_like(sa_p, 1.0 / max(len(sa_p), 1)) + ax.hist(sv_p, bins=bins, color="#2c7fb8", alpha=0.65, + weights=w_v, edgecolor="none") + ax.hist(sa_p, bins=bins, color="#d7191c", alpha=0.65, + weights=w_a, edgecolor="none") + ax.set_yticks([]) + ax.tick_params(axis="x", labelsize=6.5, length=2, pad=1.5) + for side in ("top", "right", "left"): + ax.spines[side].set_visible(False) + ax.spines["bottom"].set_linewidth(0.45) + # State word top-center (one or two words describing the panel state). + ax.text( + 0.5, 1.08, label_top, + transform=ax.transAxes, ha="center", va="bottom", fontsize=9.0, + color=("#a02a2a" if auc < 0.75 else "#1f6f3a"), + fontweight="bold", + ) + + _kde_panel(axes[0], sv_raw, sa_raw, auc_raw, log_x=False, + label_top="collapse") + _kde_panel(axes[1], sv_mah, sa_mah, auc_mah, log_x=True, + label_top="separated") + + # Centred arrow between panels — fig-coordinates so it sits between axes. + fig.text(0.5125, 0.48, r"$\Rightarrow$", ha="center", va="center", + fontsize=18, color="#444") + # Compact vertical legend in the upper-right of the LEFT panel. + from matplotlib.patches import Patch + legend_handles = [ + Patch(facecolor="#2c7fb8", alpha=0.65, label="benign-val"), + Patch(facecolor="#d7191c", alpha=0.65, label="attack"), + ] + axes[0].legend( + handles=legend_handles, + loc="upper right", ncol=1, + fontsize=6.5, frameon=False, + handlelength=0.9, handleheight=0.7, + handletextpad=0.35, labelspacing=0.30, + borderaxespad=0.4, + ) + + out = OUT / f"collapse_diagnosis_overview_{src}_to_{tgt}.svg" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") + plt.close(fig) + return out + + def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--which", choices=["all", "corr", "dual", "hist"], default="all") + parser.add_argument("--which", + choices=["all", "corr", "dual", "hist", "diag", "score"], + default="all") args = parser.parse_args() OUT.mkdir(parents=True, exist_ok=True) mpl.rcParams.update({ @@ -320,9 +660,19 @@ def main() -> None: if args.which in ("all", "dual"): p = plot_dual_head() print(f"[wrote] {p}") + p_ov = plot_dual_head_overview() + print(f"[wrote] {p_ov}") + p_dfm = plot_dfm_head_overview() + print(f"[wrote] {p_dfm}") if args.which in ("all", "hist"): p = plot_score_hist() print(f"[wrote] {p}") + if args.which in ("all", "diag"): + p = plot_collapse_diagnosis_overview() + print(f"[wrote] {p}") + if args.which in ("all", "score"): + p = plot_score_family_overview() + print(f"[wrote] {p}") if __name__ == "__main__": diff --git a/scripts/figures/plot_trajectory.py b/scripts/figures/plot_trajectory.py index bd956b8..da5ba94 100644 --- a/scripts/figures/plot_trajectory.py +++ b/scripts/figures/plot_trajectory.py @@ -82,16 +82,17 @@ def plot_trajectory(npz_paths: dict[str, Path]) -> Path: ) out = OUT / "fig4_trajectory_pca.pdf" fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".svg"), bbox_inches="tight") fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) plt.close(fig) return out def plot_velocity_norm(npz_paths: dict[str, Path]) -> Path: - fig, axes = plt.subplots(1, len(npz_paths), figsize=(6.5 * len(npz_paths), 5.6), constrained_layout=True) + fig, axes = plt.subplots(1, len(npz_paths), figsize=(6.5 * len(npz_paths), 5.6 * 2 / 3), constrained_layout=True) if len(npz_paths) == 1: axes = [axes] - for ax, (ds, npz) in zip(axes, npz_paths.items()): + for i, (ax, (ds, npz)) in enumerate(zip(axes, npz_paths.items())): z = np.load(npz) vn_v = z["vnorm_v"] # [n, n_steps] vn_a = z["vnorm_a"] @@ -105,13 +106,15 @@ def plot_velocity_norm(npz_paths: dict[str, Path]) -> Path: ax.fill_between(t_steps, m_v - s_v, m_v + s_v, color="#2c7fb8", alpha=0.18) ax.plot(t_steps, m_a, color="#d7191c", lw=1.6, label="attack mean") ax.fill_between(t_steps, m_a - s_a, m_a + s_a, color="#d7191c", alpha=0.18) - ax.set_xlabel("CFM time t (1 = data → 0 = source)") - ax.set_ylabel("‖v(x_t, t)‖ per real token (mean over flow)") + ax.set_xlabel("CFM time t") + if i == 0: + ax.set_ylabel(r"Per-token CFM velocity magnitude $\|v_\theta(x_t, t)\|_2$") ax.text(0.02, 1.02, PRETTY[ds], transform=ax.transAxes, fontsize=11) ax.invert_xaxis() # so left is t=1 (data), right is t=0 (source) ax.legend(fontsize=8, loc="upper left", framealpha=0.85) out = OUT / "velocity_norm_vs_t_benign_vs_attack.pdf" fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".svg"), bbox_inches="tight") fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) plt.close(fig) return out