diff --git a/scripts/figures/plot_field_view.py b/scripts/figures/plot_field_view.py new file mode 100644 index 0000000..44f92cc --- /dev/null +++ b/scripts/figures/plot_field_view.py @@ -0,0 +1,132 @@ +"""Render Unified-style 3-panel field view per dataset from run_field_view.py output. + +Panels (no titles; semantic info encoded in filename): + L: velocity field at t=0.5 (heatmap of log10‖v‖ + streamlines) + M: attack reverse trajectories t=1 → t=0 (lines + endpoints over benign t=1 cloud) + R: forward generation cloud comparison (benign t=1 / N(0,I) / generated overlays) +""" +from __future__ import annotations +import argparse +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl + +ROOT = Path(__file__).resolve().parents[2] +OUT = ROOT / "artifacts" / "janus_mechanism_figures_2026_05_08" + + +def _set_lim(ax, x, y, pad=0.08): + xlo, xhi = x.min(), x.max() + ylo, yhi = y.min(), y.max() + sx, sy = xhi - xlo, yhi - ylo + ax.set_xlim(xlo - pad * sx, xhi + pad * sx) + ax.set_ylim(ylo - pad * sy, yhi + pad * sy) + + +def plot_one(npz: Path, dataset: str) -> Path: + z = np.load(npz) + GX = z["grid_x"] + GY = z["grid_y"] + field_log = z["field_log_norm"] + field_v = z["field_v_2d"] + benign_t1 = z["benign_t1_2d"] + benign_t05 = z["benign_t05_2d"] + benign_t0 = z["benign_t0_2d"] + ra = z["reverse_a_2d"] + fw = z["forward_v_2d"] + ev = z["pca_explained_var"] + + fig = plt.figure(figsize=(15.5, 5.0), constrained_layout=True) + gs = fig.add_gridspec(1, 3, width_ratios=[1.05, 1, 1]) + + # ========== L: velocity field heatmap + streamplot ========== + axL = fig.add_subplot(gs[0, 0]) + vmin, vmax = np.percentile(field_log, [5, 95]) + pcm = axL.pcolormesh(GX, GY, field_log, cmap="viridis", shading="auto", + vmin=vmin, vmax=vmax, rasterized=True) + cbar = fig.colorbar(pcm, ax=axL, shrink=0.85, pad=0.02) + cbar.set_label(r"$\log_{10}\|v(x_t,t{=}0.5)\|$ (full token)", fontsize=8) + cbar.ax.tick_params(labelsize=7) + # streamlines: width varies with local speed + speed = np.linalg.norm(field_v, axis=-1) + lw = 0.35 + 1.6 * (speed / (speed.max() + 1e-9)) + axL.streamplot(GX, GY, field_v[..., 0], field_v[..., 1], + color="white", linewidth=lw, density=1.4, arrowsize=0.7) + # sparse benign t=0.5 cloud overlay (light, doesn't drown out heatmap) + n_overlay = min(300, benign_t05.shape[0]) + rng = np.random.default_rng(0) + idx_ov = rng.choice(benign_t05.shape[0], n_overlay, replace=False) + axL.scatter(benign_t05[idx_ov, 0], benign_t05[idx_ov, 1], + s=3, c="white", alpha=0.55, edgecolors="black", + linewidths=0.15, rasterized=True, zorder=4) + axL.set_xlabel(f"PC1 ({100*ev[0]:.1f}%)") + axL.set_ylabel(f"PC2 ({100*ev[1]:.1f}%)") + axL.text(0.02, 1.02, f"{dataset} · velocity field at t=0.5", + transform=axL.transAxes, fontsize=10) + + # ========== M: attack reverse trajectories over benign t=1 cloud ========== + axM = fig.add_subplot(gs[0, 1]) + axM.scatter(benign_t1[:, 0], benign_t1[:, 1], s=6, c="#a6cee3", alpha=0.55, + edgecolors="none", label="benign cloud (t=1)", rasterized=True) + for i in range(ra.shape[0]): + axM.plot(ra[i, :, 0], ra[i, :, 1], color="#d7191c", lw=0.55, alpha=0.55) + axM.scatter(ra[:, 0, 0], ra[:, 0, 1], s=14, c="#d7191c", marker="o", + edgecolors="white", linewidths=0.4, label="attack t=1 (start)", zorder=3) + axM.scatter(ra[:, -1, 0], ra[:, -1, 1], s=18, c="#d7191c", marker="x", + linewidths=1.0, label="attack t=0 (end)", zorder=3) + axM.legend(loc="upper left", bbox_to_anchor=(0.0, -0.12), ncol=3, + fontsize=7, framealpha=0.85, borderaxespad=0.0) + _set_lim(axM, + np.r_[benign_t1[:, 0], ra[..., 0].ravel()], + np.r_[benign_t1[:, 1], ra[..., 1].ravel()]) + axM.set_xlabel("PC1") + axM.text(0.02, 1.02, f"{dataset} · attack reverse trajectories t=1→0", + transform=axM.transAxes, fontsize=10) + + # ========== R: forward generation cloud comparison ========== + axR = fig.add_subplot(gs[0, 2]) + gen = fw[:, -1, :] # generated samples (t=1 endpoints) + axR.scatter(benign_t0[:, 0], benign_t0[:, 1], s=6, c="#888888", alpha=0.40, + edgecolors="none", label="N(0,I) at t=0", rasterized=True) + axR.scatter(benign_t1[:, 0], benign_t1[:, 1], s=8, c="#1f78b4", alpha=0.55, + edgecolors="none", label="benign cloud (t=1)", rasterized=True) + axR.scatter(gen[:, 0], gen[:, 1], s=12, c="#33a02c", alpha=0.75, + edgecolors="white", linewidths=0.3, + label="generated (forward t=0→1)", rasterized=True) + axR.legend(loc="upper left", bbox_to_anchor=(0.0, -0.12), ncol=3, + fontsize=7, framealpha=0.85, borderaxespad=0.0) + _set_lim(axR, + np.r_[benign_t1[:, 0], benign_t0[:, 0], gen[:, 0]], + np.r_[benign_t1[:, 1], benign_t0[:, 1], gen[:, 1]]) + axR.set_xlabel("PC1") + axR.text(0.02, 1.02, f"{dataset} · forward generation vs benign cloud", + transform=axR.transAxes, fontsize=10) + + out = OUT / f"velocity_field_view_{dataset.lower()}.pdf" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) + plt.close(fig) + return out + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--datasets", nargs="+", + default=["cicids2017", "cicddos2019", "iscxtor2016", "ciciot2023"]) + args = parser.parse_args() + OUT.mkdir(parents=True, exist_ok=True) + mpl.rcParams.update({"font.size": 9, "pdf.fonttype": 42, "ps.fonttype": 42}) + pretty = {"cicids2017": "CICIDS2017", "cicddos2019": "CICDDoS2019", + "iscxtor2016": "ISCXTor2016", "ciciot2023": "CICIoT2023"} + for ds in args.datasets: + npz = OUT / f"field_{ds}.npz" + if not npz.exists(): + print(f"[skip] missing {npz}") + continue + p = plot_one(npz, pretty.get(ds, ds)) + print(f"[wrote] {p}") + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/plot_mechanism.py b/scripts/figures/plot_mechanism.py new file mode 100644 index 0000000..8f17f6a --- /dev/null +++ b/scripts/figures/plot_mechanism.py @@ -0,0 +1,329 @@ +"""Mechanism-level figures for JANUS / Mixed_CFM. + +Generates: + fig6_score_corr.pdf — 10x10 sub-score correlation per dataset (benign val) + fig1_dual_head.pdf — (terminal_norm, disc_nll_total) + OAS ellipses + whitened PCA + fig3_score_hist.pdf — raw vs OAS-aggregated score distributions across datasets + +Inputs: artifacts/route_comparison/janus__seed42/phase1_scores.npz +""" +from __future__ import annotations +import argparse +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +from matplotlib.patches import Ellipse +from sklearn.covariance import OAS +from sklearn.metrics import roc_auc_score + +ROOT = Path(__file__).resolve().parents[2] +RUNS = ROOT / "artifacts" / "route_comparison" +OUT = ROOT / "artifacts" / "janus_mechanism_figures_2026_05_08" + +DATASETS = ["iscxtor2016", "cicids2017", "cicddos2019", "ciciot2023"] +PRETTY = { + "iscxtor2016": "ISCXTor2016", + "cicids2017": "CICIDS2017", + "cicddos2019": "CICDDoS2019", + "ciciot2023": "CICIoT2023", +} + +SCORE_KEYS = [ + "terminal_norm", "terminal_flow", "terminal_packet", + "disc_nll_total", + "disc_nll_ch2", "disc_nll_ch3", "disc_nll_ch4", + "disc_nll_ch5", "disc_nll_ch6", "disc_nll_ch7", +] +SCORE_LABELS = [ + r"$\|\!|\,t_{\mathrm{norm}}$", r"$t_{\mathrm{flow}}$", r"$t_{\mathrm{pkt}}$", + r"$\mathcal{L}_{\mathrm{disc}}$", + "ch2", "ch3", "ch4", "ch5", "ch6", "ch7", +] +SCORE_LABELS = [ + "term_norm", "term_flow", "term_pkt", + "disc_total", "disc_ch2", "disc_ch3", "disc_ch4", + "disc_ch5", "disc_ch6", "disc_ch7", +] + + +def load_scores(dataset: str, seed: int = 42) -> tuple[np.ndarray, np.ndarray]: + """Return (val_S, atk_S) of shape (n, 10).""" + npz = RUNS / f"janus_{dataset}_seed{seed}" / "phase1_scores.npz" + z = np.load(npz, allow_pickle=True) + val = np.stack([z[f"val_{k}"] for k in SCORE_KEYS], axis=1) + atk = np.stack([z[f"atk_{k}"] for k in SCORE_KEYS], axis=1) + val = np.nan_to_num(val, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64) + atk = np.nan_to_num(atk, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64) + return val, atk + + +def fit_oas(val_S: np.ndarray): + """Fit OAS on benign val. Return (mu, inv_cov, cov, transform) where transform whitens.""" + mu = val_S.mean(axis=0) + oas = OAS().fit(val_S) + cov = oas.covariance_ + inv_cov = np.linalg.inv(cov + 1e-9 * np.eye(cov.shape[0])) + # whitening: x_w = L^{-1} (x - mu) where cov = L L^T (Cholesky) + L = np.linalg.cholesky(cov + 1e-9 * np.eye(cov.shape[0])) + Linv = np.linalg.solve(L, np.eye(L.shape[0])) + return mu, inv_cov, cov, Linv + + +def mahal(S: np.ndarray, mu: np.ndarray, inv_cov: np.ndarray) -> np.ndarray: + d = S - mu + return np.einsum("ni,ij,nj->n", d, inv_cov, d) + + +def plot_corr_heatmap() -> Path: + fig, axes = plt.subplots(1, 4, figsize=(18, 4.6), constrained_layout=True) + for ax, ds in zip(axes, DATASETS): + val, _ = load_scores(ds) + # Pearson correlation on benign val; mask diagonals to free visual budget + C = np.corrcoef(val, rowvar=False) + np.fill_diagonal(C, np.nan) + im = ax.imshow(C, vmin=-1, vmax=1, cmap="RdBu_r") + ax.set_xticks(range(len(SCORE_LABELS))) + ax.set_yticks(range(len(SCORE_LABELS))) + ax.set_xticklabels(SCORE_LABELS, rotation=60, ha="right", fontsize=7) + ax.set_yticklabels(SCORE_LABELS, fontsize=7) + K = len(SCORE_LABELS) + off = C[~np.isnan(C)] + ax.text( + 0.02, 1.06, f"{PRETTY[ds]} ⟨|ρ|⟩={np.abs(off).mean():.2f}", + transform=ax.transAxes, fontsize=10, + ) + cbar = fig.colorbar(im, ax=axes, shrink=0.85, location="right", pad=0.01) + cbar.set_label("Pearson ρ on benign val", fontsize=10) + out = OUT / "subscore_correlation_benign_val.pdf" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) + plt.close(fig) + return out + + +def _ellipse_from_2x2(mu2, cov2, n_sigma, **kw): + vals, vecs = np.linalg.eigh(cov2) + order = vals.argsort()[::-1] + vals, vecs = vals[order], vecs[:, order] + angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0])) + w, h = 2 * n_sigma * np.sqrt(vals) + return Ellipse(xy=mu2, width=w, height=h, angle=angle, **kw) + + +def plot_dual_head() -> Path: + fig = plt.figure(figsize=(16, 8.5), constrained_layout=True) + gs = fig.add_gridspec(2, 4) + rng = np.random.default_rng(0) + + for col, ds in enumerate(DATASETS): + val, atk = load_scores(ds) + # raw two-axes scatter: terminal_norm (idx 0) vs disc_nll_total (idx 3) + x_v, y_v = val[:, 0], val[:, 3] + x_a, y_a = atk[:, 0], atk[:, 3] + # subsample for legibility + nv = min(3000, len(x_v)) + na = min(3000, len(x_a)) + iv = rng.choice(len(x_v), nv, replace=False) + ia = rng.choice(len(x_a), na, replace=False) + + ax = fig.add_subplot(gs[0, col]) + ax.scatter(x_v[iv], y_v[iv], s=3, alpha=0.25, c="#2c7fb8", label="benign", rasterized=True) + ax.scatter(x_a[ia], y_a[ia], s=3, alpha=0.18, c="#d7191c", label="attack", rasterized=True) + # 2D OAS on these two cols only + XY_v = val[:, [0, 3]] + oas2 = OAS().fit(XY_v) + mu2 = XY_v.mean(axis=0) + for ns, ls in [(1, "-"), (2, "--"), (3, ":")]: + e = _ellipse_from_2x2( + mu2, oas2.covariance_, ns, + edgecolor="black", facecolor="none", lw=1.1, ls=ls, alpha=0.85, + ) + ax.add_patch(e) + ax.set_xlabel(r"$t_{\mathrm{norm}}$ (continuous head)") + if col == 0: + ax.set_ylabel(r"$\mathcal{L}_{\mathrm{disc}}$ (discrete head)") + ax.text(0.02, 1.03, PRETTY[ds], transform=ax.transAxes, fontsize=11) + if col == 0: + ax.legend(loc="upper right", fontsize=8, framealpha=0.85) + # zoom to capture benign body + part of attack mass; use 99.5% of attack + x_lo = min(np.quantile(x_v, 0.005), np.quantile(x_a, 0.005)) + x_hi = max(np.quantile(x_v, 0.995), np.quantile(x_a, 0.995)) + y_lo = min(np.quantile(y_v, 0.005), np.quantile(y_a, 0.005)) + y_hi = max(np.quantile(y_v, 0.995), np.quantile(y_a, 0.995)) + ax.set_xlim(x_lo - 0.05 * (x_hi - x_lo), x_hi + 0.05 * (x_hi - x_lo)) + ax.set_ylim(y_lo - 0.05 * (y_hi - y_lo), y_hi + 0.05 * (y_hi - y_lo)) + + # whitened-PCA panel + ax2 = fig.add_subplot(gs[1, col]) + mu, inv_cov, cov, Linv = fit_oas(val) + Wv = (val - mu) @ Linv.T + Wa = (atk - mu) @ Linv.T + # PCA on benign whitened (which is ~ identity covariance, but we still pick top-2 PCs + # of the joint val+atk to maximize visual separation) + # Use SVD on val_w to get axes; benign should be ~isotropic, so PCA will essentially + # rotate; instead, use direction of maximum (atk - val) mean shift as PC1. + delta = Wa.mean(axis=0) - Wv.mean(axis=0) + delta_norm = np.linalg.norm(delta) + 1e-12 + u1 = delta / delta_norm + # u2: top PC of attack-whitened residual orthogonal to u1 + Wa_res = Wa - Wa @ u1[:, None] * u1[None, :] + _, _, Vt = np.linalg.svd(Wa_res - Wa_res.mean(axis=0), full_matrices=False) + u2 = Vt[0] + u2 = u2 - (u2 @ u1) * u1 + u2 /= np.linalg.norm(u2) + 1e-12 + Wv2 = np.c_[Wv @ u1, Wv @ u2] + Wa2 = np.c_[Wa @ u1, Wa @ u2] + nv2 = min(3000, len(Wv2)) + na2 = min(3000, len(Wa2)) + iv2 = rng.choice(len(Wv2), nv2, replace=False) + ia2 = rng.choice(len(Wa2), na2, replace=False) + ax2.scatter(Wv2[iv2, 0], Wv2[iv2, 1], s=3, alpha=0.25, c="#2c7fb8", rasterized=True) + ax2.scatter(Wa2[ia2, 0], Wa2[ia2, 1], s=3, alpha=0.18, c="#d7191c", rasterized=True) + # benign in whitened space ≈ N(0,I); draw unit-σ Mahalanobis circles + for ns, ls in [(1, "-"), (2, "--"), (3, ":")]: + ax2.add_patch(plt.Circle((0, 0), ns, fill=False, edgecolor="black", lw=1.1, ls=ls, alpha=0.85)) + ax2.set_xlabel("whitened PC1 (mean-shift dir)") + if col == 0: + ax2.set_ylabel("whitened PC2") + # symlog axes so benign unit-ball is visible alongside far-field attack + # linthresh = 3 covers the 3σ Mahalanobis circles linearly + ax2.set_xscale("symlog", linthresh=3) + ax2.set_yscale("symlog", linthresh=3) + # set a generous range that shows the unit circles AND the attack mass + x_max = max(np.quantile(np.abs(Wa2[:, 0]), 0.995), 5) + y_max = max(np.quantile(np.abs(Wa2[:, 1]), 0.995), 5) + ax2.set_xlim(-x_max * 1.1, x_max * 1.1) + ax2.set_ylim(-y_max * 1.1, y_max * 1.1) + ax2.axhline(0, color="0.7", lw=0.6, zorder=0) + ax2.axvline(0, color="0.7", lw=0.6, zorder=0) + # add Mahalanobis AUROC for reference + m_v = mahal(val, mu, inv_cov) + m_a = mahal(atk, mu, inv_cov) + y = np.r_[np.zeros(len(m_v)), np.ones(len(m_a))] + s = np.r_[m_v, m_a] + auc = roc_auc_score(y, s) + ax2.text( + 0.02, 0.97, f"AUROC(mahal-OAS)={auc:.4f}", + transform=ax2.transAxes, ha="left", va="top", + fontsize=9, bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="0.5", alpha=0.9), + ) + + out = OUT / "dual_head_oas_ellipses_top__whitened_pca_bottom.pdf" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) + plt.close(fig) + return out + + +def plot_score_hist() -> Path: + fig, axes = plt.subplots(4, 4, figsize=(16, 12), constrained_layout=True) + for col, ds in enumerate(DATASETS): + val, atk = load_scores(ds) + mu, inv_cov, _, _ = fit_oas(val) + + # Row 0: raw terminal_norm (linear) + sv, sa = val[:, 0], atk[:, 0] + _hist_panel(axes[0, col], sv, sa, log_x=False) + + # Row 1: OAS-Mahal terminal3 (log) + idx_t3 = [SCORE_KEYS.index(k) for k in ["terminal_norm", "terminal_flow", "terminal_packet"]] + mu_s = val[:, idx_t3].mean(axis=0) + oas_s = OAS().fit(val[:, idx_t3]) + iv_s = np.linalg.inv(oas_s.covariance_ + 1e-9 * np.eye(len(idx_t3))) + sv = mahal(val[:, idx_t3], mu_s, iv_s) + sa = mahal(atk[:, idx_t3], mu_s, iv_s) + _hist_panel(axes[1, col], sv, sa, log_x=True) + + # Row 2: OAS-Mahal disc7 (log) + idx_d7 = [SCORE_KEYS.index(k) for k in [ + "disc_nll_total", "disc_nll_ch2", "disc_nll_ch3", + "disc_nll_ch4", "disc_nll_ch5", "disc_nll_ch6", "disc_nll_ch7"]] + mu_s = val[:, idx_d7].mean(axis=0) + oas_s = OAS().fit(val[:, idx_d7]) + iv_s = np.linalg.inv(oas_s.covariance_ + 1e-9 * np.eye(len(idx_d7))) + sv = mahal(val[:, idx_d7], mu_s, iv_s) + sa = mahal(atk[:, idx_d7], mu_s, iv_s) + _hist_panel(axes[2, col], sv, sa, log_x=True) + + # Row 3: OAS-Mahal all 10 (log) + sv = mahal(val, mu, inv_cov) + sa = mahal(atk, mu, inv_cov) + _hist_panel(axes[3, col], sv, sa, log_x=True) + + axes[0, col].text(0.02, 1.04, PRETTY[ds], transform=axes[0, col].transAxes, fontsize=11) + + # row labels + row_labels = [ + "raw terminal_norm", + "OAS Mahal: term3 (CFM head)", + "OAS Mahal: disc7 (discrete head)", + "OAS Mahal: all 10 (deployed)", + ] + for r, lbl in enumerate(row_labels): + axes[r, 0].set_ylabel(lbl, fontsize=10) + axes[0, 3].legend(loc="upper right", fontsize=8, framealpha=0.85) + out = OUT / "score_distributions_raw__termOAS__discOAS__allOAS.pdf" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) + plt.close(fig) + return out + + +def _hist_panel(ax, sv, sa, log_x: bool = False): + y = np.r_[np.zeros(len(sv)), np.ones(len(sa))] + s = np.r_[sv, sa] + auc = roc_auc_score(y, s) + # Use within-class fraction weighting so heights stay comparable when bin + # widths are uneven (geomspace on log-x) — density=True compresses right-tail + # mass invisibly because density-per-linear-unit collapses at high x. + w_v = np.full_like(sv, 1.0 / len(sv)) + w_a = np.full_like(sa, 1.0 / len(sa)) + if log_x: + eps = max(1e-3, np.quantile(s[s > 0], 0.001) * 0.5) if (s > 0).any() else 1e-3 + sv_p = np.maximum(sv, eps) + sa_p = np.maximum(sa, eps) + lo = np.quantile(np.r_[sv_p, sa_p], 0.001) + hi = max(sv_p.max(), sa_p.max()) # show full right tail + bins = np.geomspace(max(lo, eps), hi, 80) + ax.hist(sv_p, bins=bins, color="#2c7fb8", alpha=0.55, label="benign", weights=w_v) + ax.hist(sa_p, bins=bins, color="#d7191c", alpha=0.55, label="attack", weights=w_a) + ax.set_xscale("log") + else: + lo, hi = np.quantile(s, [0.001, 0.999]) + bins = np.linspace(lo, hi, 80) + ax.hist(np.clip(sv, lo, hi), bins=bins, color="#2c7fb8", alpha=0.55, label="benign", weights=w_v) + ax.hist(np.clip(sa, lo, hi), bins=bins, color="#d7191c", alpha=0.55, label="attack", weights=w_a) + ax.text( + 0.97, 0.95, f"AUROC={auc:.4f}", + transform=ax.transAxes, ha="right", va="top", fontsize=8, + bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="0.5", alpha=0.9), + ) + ax.set_yticks([]) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--which", choices=["all", "corr", "dual", "hist"], default="all") + args = parser.parse_args() + OUT.mkdir(parents=True, exist_ok=True) + mpl.rcParams.update({ + "font.size": 10, + "axes.titlesize": 11, + "axes.labelsize": 10, + "pdf.fonttype": 42, + "ps.fonttype": 42, + }) + if args.which in ("all", "corr"): + p = plot_corr_heatmap() + print(f"[wrote] {p}") + if args.which in ("all", "dual"): + p = plot_dual_head() + print(f"[wrote] {p}") + if args.which in ("all", "hist"): + p = plot_score_hist() + print(f"[wrote] {p}") + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/plot_trajectory.py b/scripts/figures/plot_trajectory.py new file mode 100644 index 0000000..bd956b8 --- /dev/null +++ b/scripts/figures/plot_trajectory.py @@ -0,0 +1,135 @@ +"""Plots 4 (CFM trajectory in 2D PCA) and 5 (velocity-norm vs t). + +Reads npz produced by run_trajectory_inference.py. +""" +from __future__ import annotations +import argparse +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +from sklearn.decomposition import PCA + +ROOT = Path(__file__).resolve().parents[2] +OUT = ROOT / "artifacts" / "janus_mechanism_figures_2026_05_08" + +PRETTY = { + "cicids2017": "CICIDS2017", + "cicddos2019": "CICDDoS2019", + "iscxtor2016": "ISCXTor2016", + "ciciot2023": "CICIoT2023", +} + + +def _pca_fit_project(benign_t1: np.ndarray, benign_traj: np.ndarray, attack_traj: np.ndarray): + """benign_t1 [n, D] is data anchor; trajectories [n, S, D] each.""" + pca = PCA(n_components=2).fit(benign_t1) + bt = pca.transform(benign_traj.reshape(-1, benign_traj.shape[-1])).reshape(benign_traj.shape[0], benign_traj.shape[1], 2) + at = pca.transform(attack_traj.reshape(-1, attack_traj.shape[-1])).reshape(attack_traj.shape[0], attack_traj.shape[1], 2) + return pca, bt, at + + +def _draw_traj(ax, traj_v, traj_a, title, n_show=80): + rng = np.random.default_rng(0) + iv = rng.choice(traj_v.shape[0], min(n_show, traj_v.shape[0]), replace=False) + ia = rng.choice(traj_a.shape[0], min(n_show, traj_a.shape[0]), replace=False) + # trajectories: thin alpha lines + for i in iv: + ax.plot(traj_v[i, :, 0], traj_v[i, :, 1], color="#2c7fb8", alpha=0.18, lw=0.6, zorder=1) + for i in ia: + ax.plot(traj_a[i, :, 0], traj_a[i, :, 1], color="#d7191c", alpha=0.18, lw=0.6, zorder=1) + # endpoints + ax.scatter(traj_v[iv, 0, 0], traj_v[iv, 0, 1], s=14, c="#2c7fb8", marker="o", + edgecolors="white", linewidths=0.4, zorder=3, label="benign t=1 (data)") + ax.scatter(traj_v[iv, -1, 0], traj_v[iv, -1, 1], s=14, c="#2c7fb8", marker="x", + linewidths=0.9, alpha=0.85, zorder=3, label="benign t=0 (post-flow)") + ax.scatter(traj_a[ia, 0, 0], traj_a[ia, 0, 1], s=14, c="#d7191c", marker="o", + edgecolors="white", linewidths=0.4, zorder=3, label="attack t=1 (data)") + ax.scatter(traj_a[ia, -1, 0], traj_a[ia, -1, 1], s=14, c="#d7191c", marker="x", + linewidths=0.9, alpha=0.85, zorder=3, label="attack t=0 (post-flow)") + # unit circle (target N(0,I) for benign post-flow if flow learned correctly) + theta = np.linspace(0, 2 * np.pi, 120) + for ns, ls in [(1, "-"), (2, "--")]: + ax.plot(ns * np.cos(theta), ns * np.sin(theta), color="black", lw=0.8, ls=ls, alpha=0.5, zorder=2) + ax.axhline(0, color="0.85", lw=0.5, zorder=0) + ax.axvline(0, color="0.85", lw=0.5, zorder=0) + ax.set_title(title, fontsize=11) + ax.set_aspect("equal", adjustable="datalim") + + +def plot_trajectory(npz_paths: dict[str, Path]) -> Path: + fig, axes = plt.subplots(2, len(npz_paths), figsize=(7.5 * len(npz_paths), 12), constrained_layout=True) + if len(npz_paths) == 1: + axes = axes[:, None] + for col, (ds, npz) in enumerate(npz_paths.items()): + z = np.load(npz) + # FLOW-token trajectory + ftv = z["z_traj_flow_v"] # [n, S, D] + fta = z["z_traj_flow_a"] + _, bt, at = _pca_fit_project(ftv[:, 0], ftv, fta) + _draw_traj(axes[0, col], bt, at, f"{PRETTY[ds]} — FLOW token (PC1 vs PC2 of benign t=1)") + # mean-packet trajectory + ptv = z["z_traj_pkt_v"] + pta = z["z_traj_pkt_a"] + _, bt2, at2 = _pca_fit_project(ptv[:, 0], ptv, pta) + _draw_traj(axes[1, col], bt2, at2, f"{PRETTY[ds]} — mean packet token") + + axes[0, 0].legend(loc="upper right", fontsize=7, framealpha=0.85) + fig.suptitle( + "Reverse CFM flow (t=1 → t=0): benign collapses toward learned source (≈ N(0,I) inside dashed circles); " + "attack endpoints land off-distribution", + fontsize=12, + ) + out = OUT / "fig4_trajectory_pca.pdf" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) + plt.close(fig) + return out + + +def plot_velocity_norm(npz_paths: dict[str, Path]) -> Path: + fig, axes = plt.subplots(1, len(npz_paths), figsize=(6.5 * len(npz_paths), 5.6), constrained_layout=True) + if len(npz_paths) == 1: + axes = [axes] + for ax, (ds, npz) in zip(axes, npz_paths.items()): + z = np.load(npz) + vn_v = z["vnorm_v"] # [n, n_steps] + vn_a = z["vnorm_a"] + # t at integration step k corresponds to t = 1 - k*dt, k=0..n_steps-1 + n_steps = vn_v.shape[1] + t_steps = 1.0 - np.arange(n_steps) / n_steps + # mean ± std band + m_v, s_v = vn_v.mean(0), vn_v.std(0) + m_a, s_a = vn_a.mean(0), vn_a.std(0) + ax.plot(t_steps, m_v, color="#2c7fb8", lw=1.6, label="benign mean") + ax.fill_between(t_steps, m_v - s_v, m_v + s_v, color="#2c7fb8", alpha=0.18) + ax.plot(t_steps, m_a, color="#d7191c", lw=1.6, label="attack mean") + ax.fill_between(t_steps, m_a - s_a, m_a + s_a, color="#d7191c", alpha=0.18) + ax.set_xlabel("CFM time t (1 = data → 0 = source)") + ax.set_ylabel("‖v(x_t, t)‖ per real token (mean over flow)") + ax.text(0.02, 1.02, PRETTY[ds], transform=ax.transAxes, fontsize=11) + ax.invert_xaxis() # so left is t=1 (data), right is t=0 (source) + ax.legend(fontsize=8, loc="upper left", framealpha=0.85) + out = OUT / "velocity_norm_vs_t_benign_vs_attack.pdf" + fig.savefig(out, bbox_inches="tight") + fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160) + plt.close(fig) + return out + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--datasets", nargs="+", + default=["cicids2017", "cicddos2019", "iscxtor2016", "ciciot2023"]) + args = parser.parse_args() + OUT.mkdir(parents=True, exist_ok=True) + mpl.rcParams.update({"font.size": 10, "pdf.fonttype": 42, "ps.fonttype": 42}) + npz_paths = {ds: OUT / f"traj_{ds}.npz" for ds in args.datasets} + p4 = plot_trajectory(npz_paths) + print(f"[wrote] {p4}") + p5 = plot_velocity_norm(npz_paths) + print(f"[wrote] {p5}") + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/run_field_view.py b/scripts/figures/run_field_view.py new file mode 100644 index 0000000..547b7af --- /dev/null +++ b/scripts/figures/run_field_view.py @@ -0,0 +1,266 @@ +"""Generate Unified-style 3-panel field-view data for Mixed_CFM: + (1) velocity field at t=0.5 on benign-FLOW-token PCA grid (with streamlines) + (2) reverse-flow trajectories t=1 → t=0 + (3) forward-flow trajectories t=0 → t=1 (sample from N(0,I) noise back to data) + +All operations run in token space; the visual projection is 2D PCA fit on benign +FLOW token at t=0.5 (so interpolation states align with the source-side flow). + +Output npz keys: + pca_components [2, token_dim] PCA basis (benign FLOW @ t=0.5) + pca_mean [token_dim] + pca_explained_var [2] + benign_t1_2d [N, 2] benign FLOW token at t=1 in PCA coords + benign_t05_2d [N, 2] benign FLOW token at t=0.5 (the PCA fit basis) + benign_t0_2d [N, 2] benign FLOW token at t=0 (random N(0,I)) + reverse_v_2d [Nrv, S+1, 2] benign reverse trajectory in PCA coords + reverse_a_2d [Nra, S+1, 2] attack reverse trajectory + forward_v_2d [Nfv, S+1, 2] forward trajectory from noise → benign-cond-template + grid_x, grid_y [G, G] field grid coords in PCA space + field_v_2d [G, G, 2] velocity vectors in PCA coords at t=0.5 + field_log_norm [G, G] log10 ‖v(x_t,t)‖ at full token scale +""" +from __future__ import annotations +import argparse +import sys +from pathlib import Path +import numpy as np +import torch +import yaml +from sklearn.decomposition import PCA + +ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(ROOT / "Mixed_CFM")) +from data import load_mixed_data # noqa: E402 +from model import MixedCFMConfig, MixedTokenCFM # noqa: E402 + + +@torch.no_grad() +def integrate_reverse(model: MixedTokenCFM, z1: torch.Tensor, lens: torch.Tensor, *, n_steps: int) -> torch.Tensor: + """z1 already at t=1 (data). Returns snapshots [B, n_steps+1, L, D] from t=1 → t=0.""" + z = z1.clone() + mask = model._loss_mask(lens) + kpm = mask == 0 + B = z.shape[0] + dt = 1.0 / n_steps + cfg = model.cfg + disc_start = 1 + cfg.n_cont_pkt + disc_end = disc_start + cfg.n_disc_pkt + disc_embed = z[:, 1:, disc_start:disc_end].clone() + snaps = [z.clone()] + for k in range(n_steps): + t_val = 1.0 - k * dt + t = torch.full((B,), t_val, device=z.device) + v, _ = model.velocity(z, t, key_padding_mask=kpm) + v[:, :, disc_start:disc_end] = 0.0 + z = z - v * dt + z[:, 1:, disc_start:disc_end] = disc_embed + snaps.append(z.clone()) + return torch.stack(snaps, dim=1) + + +@torch.no_grad() +def integrate_forward(model: MixedTokenCFM, z0: torch.Tensor, lens: torch.Tensor, disc_embed: torch.Tensor, *, n_steps: int) -> torch.Tensor: + """Forward Euler from t=0 (Gaussian noise) → t=1 (generated). Holds disc embed fixed.""" + z = z0.clone() + mask = model._loss_mask(lens) + kpm = mask == 0 + B = z.shape[0] + dt = 1.0 / n_steps + cfg = model.cfg + disc_start = 1 + cfg.n_cont_pkt + disc_end = disc_start + cfg.n_disc_pkt + snaps = [z.clone()] + for k in range(n_steps): + t_val = k * dt + t = torch.full((B,), t_val, device=z.device) + z[:, 1:, disc_start:disc_end] = disc_embed + v, _ = model.velocity(z, t, key_padding_mask=kpm) + v[:, :, disc_start:disc_end] = 0.0 + z = z + v * dt + snaps.append(z.clone()) + return torch.stack(snaps, dim=1) + + +@torch.no_grad() +def velocity_at(model: MixedTokenCFM, z: torch.Tensor, lens: torch.Tensor, t_val: float) -> torch.Tensor: + """Return velocity v at given t. Mirrors signature of model.velocity but masks disc-channel rows in v.""" + mask = model._loss_mask(lens) + kpm = mask == 0 + B = z.shape[0] + t = torch.full((B,), float(t_val), device=z.device) + v, _ = model.velocity(z, t, key_padding_mask=kpm) + cfg = model.cfg + disc_start = 1 + cfg.n_cont_pkt + disc_end = disc_start + cfg.n_disc_pkt + v[:, :, disc_start:disc_end] = 0.0 + return v + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--model-dir", type=Path, required=True) + p.add_argument("--out", type=Path, required=True) + p.add_argument("--n-pca-benign", type=int, default=1500) + p.add_argument("--n-reverse-benign", type=int, default=30) + p.add_argument("--n-reverse-attack", type=int, default=30) + p.add_argument("--n-forward", type=int, default=200) + p.add_argument("--grid", type=int, default=40) + p.add_argument("--grid-templates", type=int, default=8, help="Number of benign templates to average velocity over per grid point") + p.add_argument("--n-steps", type=int, default=32) + p.add_argument("--device", type=str, default="auto") + p.add_argument("--batch-size", type=int, default=128) + args = p.parse_args() + + device = torch.device("cuda" if (args.device == "auto" and torch.cuda.is_available()) else (args.device if args.device != "auto" else "cpu")) + cfg = yaml.safe_load((args.model_dir / "config.yaml").read_text()) + ckpt = torch.load(args.model_dir / "model.pt", map_location="cpu", weights_only=False) + model_cfg = MixedCFMConfig(**ckpt["model_cfg"]) + model = MixedTokenCFM(model_cfg).to(device) + model.load_state_dict(ckpt["model_state_dict"]) + model.eval() + + data = load_mixed_data( + packets_npz=Path(cfg["packets_npz"]) if cfg.get("packets_npz") else None, + source_store=Path(cfg["source_store"]) if cfg.get("source_store") else None, + flows_parquet=Path(cfg["flows_parquet"]), + flow_features_path=Path(cfg["flow_features_path"]), + flow_features_align=str(cfg.get("flow_features_align", "auto")), + T=int(cfg["T"]), + split_seed=int(cfg.get("data_seed", cfg.get("seed", 42))), + train_ratio=float(cfg.get("train_ratio", 0.8)), + benign_label=str(cfg.get("benign_label", "normal")), + min_len=int(cfg.get("min_len", 2)), + attack_cap=int(cfg["attack_cap"]) if cfg.get("attack_cap") else None, + val_cap=int(cfg["val_cap"]) if cfg.get("val_cap") else None, + ) + print(f"[data] val={len(data.val_flow):,} attack={len(data.attack_flow):,}") + + rng = np.random.default_rng(0) + nv = min(args.n_pca_benign, len(data.val_flow)) + iv_pca = np.sort(rng.choice(len(data.val_flow), nv, replace=False)) + irv = np.sort(rng.choice(len(data.val_flow), args.n_reverse_benign, replace=False)) + ira = np.sort(rng.choice(len(data.attack_flow), args.n_reverse_attack, replace=False)) + ifw_template = np.sort(rng.choice(len(data.val_flow), args.n_forward, replace=False)) + + def build_z(flow, cont, disc): + flow_t = torch.from_numpy(flow).float().to(device) + cont_t = torch.from_numpy(cont).float().to(device) + disc_t = torch.from_numpy(disc).long().to(device) + return model.build_tokens(flow_t, cont_t, disc_t) + + # ==== Step 1: build z1 for benign PCA pool, get FLOW token at t=0.5 ==== + z1_pca = build_z(data.val_flow[iv_pca], data.val_cont[iv_pca], data.val_disc[iv_pca]) # [N, L, D] + lens_pca = torch.from_numpy(data.val_len[iv_pca]).long().to(device) + flow_t1 = z1_pca[:, 0, :].cpu().numpy() # [N, D] — FLOW token at t=1 + sigma = float(model_cfg.sigma) + z0_for_pca = torch.randn_like(z1_pca) + t_val = 0.5 + z_t05 = (1 - t_val) * z0_for_pca + t_val * z1_pca + if sigma > 0: + std = sigma * np.sqrt(t_val * (1 - t_val)) + z_t05 = z_t05 + std * torch.randn_like(z_t05) + flow_t05 = z_t05[:, 0, :].cpu().numpy() # [N, D] + flow_t0 = z0_for_pca[:, 0, :].cpu().numpy() + + pca = PCA(n_components=2).fit(flow_t05) + print(f"[pca] explained var on benign FLOW @ t=0.5: {pca.explained_variance_ratio_}") + + benign_t1_2d = pca.transform(flow_t1) + benign_t05_2d = pca.transform(flow_t05) + benign_t0_2d = pca.transform(flow_t0) + + # ==== Step 2: reverse trajectories ==== + print("[run] reverse benign") + z1_rv = build_z(data.val_flow[irv], data.val_cont[irv], data.val_disc[irv]) + lens_rv = torch.from_numpy(data.val_len[irv]).long().to(device) + snaps_rv = integrate_reverse(model, z1_rv, lens_rv, n_steps=args.n_steps) # [B, S+1, L, D] + rv_2d = pca.transform(snaps_rv[:, :, 0, :].reshape(-1, snaps_rv.shape[-1]).cpu().numpy()).reshape(snaps_rv.shape[0], snaps_rv.shape[1], 2) + + print("[run] reverse attack") + z1_ra = build_z(data.attack_flow[ira], data.attack_cont[ira], data.attack_disc[ira]) + lens_ra = torch.from_numpy(data.attack_len[ira]).long().to(device) + snaps_ra = integrate_reverse(model, z1_ra, lens_ra, n_steps=args.n_steps) + ra_2d = pca.transform(snaps_ra[:, :, 0, :].reshape(-1, snaps_ra.shape[-1]).cpu().numpy()).reshape(snaps_ra.shape[0], snaps_ra.shape[1], 2) + + # ==== Step 3: forward trajectories (sample noise, integrate to t=1, hold disc embed from a benign template) ==== + print("[run] forward from noise") + z1_fw = build_z(data.val_flow[ifw_template], data.val_cont[ifw_template], data.val_disc[ifw_template]) + lens_fw = torch.from_numpy(data.val_len[ifw_template]).long().to(device) + cfg_m = model.cfg + disc_start = 1 + cfg_m.n_cont_pkt + disc_end = disc_start + cfg_m.n_disc_pkt + disc_embed_fw = z1_fw[:, 1:, disc_start:disc_end].clone() + z0_fw = torch.randn_like(z1_fw) + snaps_fw = integrate_forward(model, z0_fw, lens_fw, disc_embed_fw, n_steps=args.n_steps) + fw_2d = pca.transform(snaps_fw[:, :, 0, :].reshape(-1, snaps_fw.shape[-1]).cpu().numpy()).reshape(snaps_fw.shape[0], snaps_fw.shape[1], 2) + + # ==== Step 4: velocity field on PCA grid at t=0.5 ==== + print("[run] velocity field grid") + pad = 0.3 + x_min, x_max = benign_t05_2d[:, 0].min(), benign_t05_2d[:, 0].max() + y_min, y_max = benign_t05_2d[:, 1].min(), benign_t05_2d[:, 1].max() + sx = x_max - x_min + sy = y_max - y_min + x_lo, x_hi = x_min - pad * sx, x_max + pad * sx + y_lo, y_hi = y_min - pad * sy, y_max + pad * sy + gx = np.linspace(x_lo, x_hi, args.grid) + gy = np.linspace(y_lo, y_hi, args.grid) + GX, GY = np.meshgrid(gx, gy) + grid_2d = np.stack([GX.ravel(), GY.ravel()], axis=1) # [G^2, 2] + grid_full = pca.inverse_transform(grid_2d) # [G^2, D] + grid_full_t = torch.from_numpy(grid_full).float().to(device) + + # For each grid point, replace FLOW token at t=0.5 of K random benign templates; + # average velocity at FLOW position over templates. + K = args.grid_templates + # Sample K template z_t05 from PCA pool + template_idx = rng.choice(len(z1_pca), K, replace=False) + z1_tpl = z1_pca[template_idx] # [K, L, D] + lens_tpl = lens_pca[template_idx] + z0_tpl = torch.randn_like(z1_tpl) + z_t05_tpl = (1 - t_val) * z0_tpl + t_val * z1_tpl + if sigma > 0: + z_t05_tpl = z_t05_tpl + (sigma * np.sqrt(t_val * (1 - t_val))) * torch.randn_like(z_t05_tpl) + + G2 = grid_full_t.shape[0] + v_grid_full = torch.zeros((G2, grid_full_t.shape[1]), device=device) + bs = args.batch_size + for k in range(K): + # build a [G^2, L, D] tensor where token 0 = grid point, tokens 1: = template's z_t05 packets + tpl_packets = z_t05_tpl[k:k + 1, 1:, :].expand(G2, -1, -1).contiguous() + z_grid = torch.cat([grid_full_t.unsqueeze(1), tpl_packets], dim=1) # [G^2, L, D] + lens_grid = lens_tpl[k:k + 1].expand(G2).contiguous() + # batched velocity eval + v_chunks = [] + for s in range(0, G2, bs): + v_chunks.append(velocity_at(model, z_grid[s:s + bs], lens_grid[s:s + bs], t_val=t_val)[:, 0, :]) + v_grid_full = v_grid_full + torch.cat(v_chunks, dim=0) + v_grid_full = v_grid_full / K + v_grid_np = v_grid_full.cpu().numpy() + v_grid_2d = (v_grid_np - 0) @ pca.components_.T # project [G^2, D] onto 2 PCA basis vectors → [G^2, 2] + v_grid_2d = v_grid_2d.reshape(args.grid, args.grid, 2) + log_norm_full = np.log10(np.linalg.norm(v_grid_np, axis=-1) + 1e-9).reshape(args.grid, args.grid) + + args.out.parent.mkdir(parents=True, exist_ok=True) + np.savez( + args.out, + pca_components=pca.components_.astype(np.float32), + pca_mean=pca.mean_.astype(np.float32), + pca_explained_var=pca.explained_variance_ratio_.astype(np.float32), + benign_t1_2d=benign_t1_2d.astype(np.float32), + benign_t05_2d=benign_t05_2d.astype(np.float32), + benign_t0_2d=benign_t0_2d.astype(np.float32), + reverse_v_2d=rv_2d.astype(np.float32), + reverse_a_2d=ra_2d.astype(np.float32), + forward_v_2d=fw_2d.astype(np.float32), + grid_x=GX.astype(np.float32), + grid_y=GY.astype(np.float32), + field_v_2d=v_grid_2d.astype(np.float32), + field_log_norm=log_norm_full.astype(np.float32), + ) + print(f"[wrote] {args.out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/run_trajectory_inference.py b/scripts/figures/run_trajectory_inference.py new file mode 100644 index 0000000..89efcef --- /dev/null +++ b/scripts/figures/run_trajectory_inference.py @@ -0,0 +1,157 @@ +"""Reverse-flow ODE integration with per-step snapshots, for trajectory + velocity-norm plots. + +For a chosen (dataset, seed), loads a small balanced sample (n_per_class benign+attack) +and integrates the velocity field from t=1 (data) to t=0 (Gaussian) using the same Euler +scheme as `MixedTokenCFM.trajectory_metrics`, but saves z snapshots and ‖v‖ at each step. + +Outputs npz with keys: + z_traj_flow_v / z_traj_flow_a [n, n_steps+1, token_dim] FLOW-token trajectories + z_traj_pkt_v / z_traj_pkt_a [n, n_steps+1, token_dim] masked-mean packet-token trajectories + vnorm_v / vnorm_a [n, n_steps] per-step velocity norm (real tokens only) + t_grid [n_steps+1] t values; t_grid[0] = 1.0, decreasing +""" +from __future__ import annotations +import argparse +import sys +import time +from pathlib import Path +import numpy as np +import torch +import yaml + +ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(ROOT / "Mixed_CFM")) +from data import load_mixed_data # noqa: E402 +from model import MixedCFMConfig, MixedTokenCFM # noqa: E402 + + +@torch.no_grad() +def run_reverse_flow(model: MixedTokenCFM, flow, cont, disc, lens, *, n_steps: int): + z = model.build_tokens(flow, cont, disc) # at t=1 + mask = model._loss_mask(lens) + kpm = mask == 0 + B = z.shape[0] + dt = 1.0 / n_steps + cfg = model.cfg + disc_start = 1 + cfg.n_cont_pkt + disc_end = disc_start + cfg.n_disc_pkt + disc_embed = z[:, 1:, disc_start:disc_end].clone() + + snaps = [z.clone()] + vnorms = [] + for k in range(n_steps): + t_val = 1.0 - k * dt + t = torch.full((B,), t_val, device=z.device) + v, _ = model.velocity(z, t, key_padding_mask=kpm) + v[:, :, disc_start:disc_end] = 0.0 + # per-sample velocity norm averaged over real tokens + v_norm_per_tok = v.norm(dim=-1) # [B, L] + per_sample = (v_norm_per_tok * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) + vnorms.append(per_sample.cpu().numpy()) + z = z - v * dt + z[:, 1:, disc_start:disc_end] = disc_embed + snaps.append(z.clone()) + return snaps, np.stack(vnorms, axis=1), mask.cpu().numpy() + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--model-dir", type=Path, required=True) + p.add_argument("--out", type=Path, required=True) + p.add_argument("--n-per-class", type=int, default=200) + p.add_argument("--n-steps", type=int, default=32) + p.add_argument("--device", type=str, default="auto") + p.add_argument("--batch-size", type=int, default=128) + args = p.parse_args() + + device = torch.device("cuda" if (args.device == "auto" and torch.cuda.is_available()) else (args.device if args.device != "auto" else "cpu")) + cfg = yaml.safe_load((args.model_dir / "config.yaml").read_text()) + ckpt = torch.load(args.model_dir / "model.pt", map_location="cpu", weights_only=False) + model_cfg = MixedCFMConfig(**ckpt["model_cfg"]) + model = MixedTokenCFM(model_cfg).to(device) + model.load_state_dict(ckpt["model_state_dict"]) + model.eval() + + data = load_mixed_data( + packets_npz=Path(cfg["packets_npz"]) if cfg.get("packets_npz") else None, + source_store=Path(cfg["source_store"]) if cfg.get("source_store") else None, + flows_parquet=Path(cfg["flows_parquet"]), + flow_features_path=Path(cfg["flow_features_path"]), + flow_features_align=str(cfg.get("flow_features_align", "auto")), + T=int(cfg["T"]), + split_seed=int(cfg.get("data_seed", cfg.get("seed", 42))), + train_ratio=float(cfg.get("train_ratio", 0.8)), + benign_label=str(cfg.get("benign_label", "normal")), + min_len=int(cfg.get("min_len", 2)), + attack_cap=int(cfg["attack_cap"]) if cfg.get("attack_cap") else None, + val_cap=int(cfg["val_cap"]) if cfg.get("val_cap") else None, + ) + print(f"[data] val={len(data.val_flow):,} attack={len(data.attack_flow):,}") + + rng = np.random.default_rng(0) + nv = min(args.n_per_class, len(data.val_flow)) + na = min(args.n_per_class, len(data.attack_flow)) + iv = np.sort(rng.choice(len(data.val_flow), nv, replace=False)) + ia = np.sort(rng.choice(len(data.attack_flow), na, replace=False)) + + def to_t(flow, cont, disc, lens): + return ( + torch.from_numpy(flow).float().to(device), + torch.from_numpy(cont).float().to(device), + torch.from_numpy(disc).long().to(device), + torch.from_numpy(lens).long().to(device), + ) + + def collect(flows_np, conts_np, discs_np, lens_np): + snaps_all, vnorms_all, mask_all = [], [], [] + for start in range(0, len(flows_np), args.batch_size): + sl = slice(start, start + args.batch_size) + flow, cont, disc, lens = to_t(flows_np[sl], conts_np[sl], discs_np[sl], lens_np[sl]) + snaps, vn, mask = run_reverse_flow(model, flow, cont, disc, lens, n_steps=args.n_steps) + snaps_all.append(torch.stack(snaps, dim=1).cpu().numpy()) # [b, n_steps+1, L, D] + vnorms_all.append(vn) + mask_all.append(mask) + print(f" [batch] {min(start + args.batch_size, len(flows_np))}/{len(flows_np)}", flush=True) + return (np.concatenate(snaps_all, axis=0), np.concatenate(vnorms_all, axis=0), np.concatenate(mask_all, axis=0)) + + print("[run] benign val") + t0 = time.time() + snaps_v, vn_v, mask_v = collect(data.val_flow[iv], data.val_cont[iv], data.val_disc[iv], data.val_len[iv]) + print(f" done {time.time() - t0:.1f}s snaps={snaps_v.shape}") + + print("[run] attack") + t0 = time.time() + snaps_a, vn_a, mask_a = collect(data.attack_flow[ia], data.attack_cont[ia], data.attack_disc[ia], data.attack_len[ia]) + print(f" done {time.time() - t0:.1f}s snaps={snaps_a.shape}") + + # extract FLOW token trajectory and packet-mean trajectory + def flow_and_pkt_traj(snaps, mask): + # snaps [n, S, L, D], mask [n, L] (L = T+1, includes flow token at idx 0) + flow_tok = snaps[:, :, 0, :] # [n, S, D] + pkt_mask = mask[:, 1:][:, None, :, None].astype(np.float32) + pkt_count = pkt_mask.sum(axis=2).clip(1.0) + pkt_mean = (snaps[:, :, 1:, :] * pkt_mask).sum(axis=2) / pkt_count # [n, S, D] + return flow_tok, pkt_mean + + flow_tok_v, pkt_mean_v = flow_and_pkt_traj(snaps_v, mask_v) + flow_tok_a, pkt_mean_a = flow_and_pkt_traj(snaps_a, mask_a) + + n_steps = args.n_steps + t_grid = np.array([1.0 - k * (1.0 / n_steps) for k in range(n_steps + 1)]) + + args.out.parent.mkdir(parents=True, exist_ok=True) + np.savez( + args.out, + z_traj_flow_v=flow_tok_v.astype(np.float32), + z_traj_flow_a=flow_tok_a.astype(np.float32), + z_traj_pkt_v=pkt_mean_v.astype(np.float32), + z_traj_pkt_a=pkt_mean_a.astype(np.float32), + vnorm_v=vn_v.astype(np.float32), + vnorm_a=vn_a.astype(np.float32), + t_grid=t_grid.astype(np.float32), + ) + print(f"[wrote] {args.out}") + + +if __name__ == "__main__": + main()