figures: add JANUS mechanism figure scripts (trajectory + field view + score hist)
scripts/figures/ contains the per-dataset figure generators used to render the JANUS mechanism figures (reverse-flow trajectory PCA, t=0.5 velocity field view with sparse benign overlay, score-distribution histograms with within-class fraction weighting). Outputs go to artifacts/janus_mechanism_figures_<date>/ (gitignored under artifacts/). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
132
scripts/figures/plot_field_view.py
Normal file
132
scripts/figures/plot_field_view.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Render Unified-style 3-panel field view per dataset from run_field_view.py output.
|
||||
|
||||
Panels (no titles; semantic info encoded in filename):
|
||||
L: velocity field at t=0.5 (heatmap of log10‖v‖ + streamlines)
|
||||
M: attack reverse trajectories t=1 → t=0 (lines + endpoints over benign t=1 cloud)
|
||||
R: forward generation cloud comparison (benign t=1 / N(0,I) / generated overlays)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
OUT = ROOT / "artifacts" / "janus_mechanism_figures_2026_05_08"
|
||||
|
||||
|
||||
def _set_lim(ax, x, y, pad=0.08):
|
||||
xlo, xhi = x.min(), x.max()
|
||||
ylo, yhi = y.min(), y.max()
|
||||
sx, sy = xhi - xlo, yhi - ylo
|
||||
ax.set_xlim(xlo - pad * sx, xhi + pad * sx)
|
||||
ax.set_ylim(ylo - pad * sy, yhi + pad * sy)
|
||||
|
||||
|
||||
def plot_one(npz: Path, dataset: str) -> Path:
|
||||
z = np.load(npz)
|
||||
GX = z["grid_x"]
|
||||
GY = z["grid_y"]
|
||||
field_log = z["field_log_norm"]
|
||||
field_v = z["field_v_2d"]
|
||||
benign_t1 = z["benign_t1_2d"]
|
||||
benign_t05 = z["benign_t05_2d"]
|
||||
benign_t0 = z["benign_t0_2d"]
|
||||
ra = z["reverse_a_2d"]
|
||||
fw = z["forward_v_2d"]
|
||||
ev = z["pca_explained_var"]
|
||||
|
||||
fig = plt.figure(figsize=(15.5, 5.0), constrained_layout=True)
|
||||
gs = fig.add_gridspec(1, 3, width_ratios=[1.05, 1, 1])
|
||||
|
||||
# ========== L: velocity field heatmap + streamplot ==========
|
||||
axL = fig.add_subplot(gs[0, 0])
|
||||
vmin, vmax = np.percentile(field_log, [5, 95])
|
||||
pcm = axL.pcolormesh(GX, GY, field_log, cmap="viridis", shading="auto",
|
||||
vmin=vmin, vmax=vmax, rasterized=True)
|
||||
cbar = fig.colorbar(pcm, ax=axL, shrink=0.85, pad=0.02)
|
||||
cbar.set_label(r"$\log_{10}\|v(x_t,t{=}0.5)\|$ (full token)", fontsize=8)
|
||||
cbar.ax.tick_params(labelsize=7)
|
||||
# streamlines: width varies with local speed
|
||||
speed = np.linalg.norm(field_v, axis=-1)
|
||||
lw = 0.35 + 1.6 * (speed / (speed.max() + 1e-9))
|
||||
axL.streamplot(GX, GY, field_v[..., 0], field_v[..., 1],
|
||||
color="white", linewidth=lw, density=1.4, arrowsize=0.7)
|
||||
# sparse benign t=0.5 cloud overlay (light, doesn't drown out heatmap)
|
||||
n_overlay = min(300, benign_t05.shape[0])
|
||||
rng = np.random.default_rng(0)
|
||||
idx_ov = rng.choice(benign_t05.shape[0], n_overlay, replace=False)
|
||||
axL.scatter(benign_t05[idx_ov, 0], benign_t05[idx_ov, 1],
|
||||
s=3, c="white", alpha=0.55, edgecolors="black",
|
||||
linewidths=0.15, rasterized=True, zorder=4)
|
||||
axL.set_xlabel(f"PC1 ({100*ev[0]:.1f}%)")
|
||||
axL.set_ylabel(f"PC2 ({100*ev[1]:.1f}%)")
|
||||
axL.text(0.02, 1.02, f"{dataset} · velocity field at t=0.5",
|
||||
transform=axL.transAxes, fontsize=10)
|
||||
|
||||
# ========== M: attack reverse trajectories over benign t=1 cloud ==========
|
||||
axM = fig.add_subplot(gs[0, 1])
|
||||
axM.scatter(benign_t1[:, 0], benign_t1[:, 1], s=6, c="#a6cee3", alpha=0.55,
|
||||
edgecolors="none", label="benign cloud (t=1)", rasterized=True)
|
||||
for i in range(ra.shape[0]):
|
||||
axM.plot(ra[i, :, 0], ra[i, :, 1], color="#d7191c", lw=0.55, alpha=0.55)
|
||||
axM.scatter(ra[:, 0, 0], ra[:, 0, 1], s=14, c="#d7191c", marker="o",
|
||||
edgecolors="white", linewidths=0.4, label="attack t=1 (start)", zorder=3)
|
||||
axM.scatter(ra[:, -1, 0], ra[:, -1, 1], s=18, c="#d7191c", marker="x",
|
||||
linewidths=1.0, label="attack t=0 (end)", zorder=3)
|
||||
axM.legend(loc="upper left", bbox_to_anchor=(0.0, -0.12), ncol=3,
|
||||
fontsize=7, framealpha=0.85, borderaxespad=0.0)
|
||||
_set_lim(axM,
|
||||
np.r_[benign_t1[:, 0], ra[..., 0].ravel()],
|
||||
np.r_[benign_t1[:, 1], ra[..., 1].ravel()])
|
||||
axM.set_xlabel("PC1")
|
||||
axM.text(0.02, 1.02, f"{dataset} · attack reverse trajectories t=1→0",
|
||||
transform=axM.transAxes, fontsize=10)
|
||||
|
||||
# ========== R: forward generation cloud comparison ==========
|
||||
axR = fig.add_subplot(gs[0, 2])
|
||||
gen = fw[:, -1, :] # generated samples (t=1 endpoints)
|
||||
axR.scatter(benign_t0[:, 0], benign_t0[:, 1], s=6, c="#888888", alpha=0.40,
|
||||
edgecolors="none", label="N(0,I) at t=0", rasterized=True)
|
||||
axR.scatter(benign_t1[:, 0], benign_t1[:, 1], s=8, c="#1f78b4", alpha=0.55,
|
||||
edgecolors="none", label="benign cloud (t=1)", rasterized=True)
|
||||
axR.scatter(gen[:, 0], gen[:, 1], s=12, c="#33a02c", alpha=0.75,
|
||||
edgecolors="white", linewidths=0.3,
|
||||
label="generated (forward t=0→1)", rasterized=True)
|
||||
axR.legend(loc="upper left", bbox_to_anchor=(0.0, -0.12), ncol=3,
|
||||
fontsize=7, framealpha=0.85, borderaxespad=0.0)
|
||||
_set_lim(axR,
|
||||
np.r_[benign_t1[:, 0], benign_t0[:, 0], gen[:, 0]],
|
||||
np.r_[benign_t1[:, 1], benign_t0[:, 1], gen[:, 1]])
|
||||
axR.set_xlabel("PC1")
|
||||
axR.text(0.02, 1.02, f"{dataset} · forward generation vs benign cloud",
|
||||
transform=axR.transAxes, fontsize=10)
|
||||
|
||||
out = OUT / f"velocity_field_view_{dataset.lower()}.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--datasets", nargs="+",
|
||||
default=["cicids2017", "cicddos2019", "iscxtor2016", "ciciot2023"])
|
||||
args = parser.parse_args()
|
||||
OUT.mkdir(parents=True, exist_ok=True)
|
||||
mpl.rcParams.update({"font.size": 9, "pdf.fonttype": 42, "ps.fonttype": 42})
|
||||
pretty = {"cicids2017": "CICIDS2017", "cicddos2019": "CICDDoS2019",
|
||||
"iscxtor2016": "ISCXTor2016", "ciciot2023": "CICIoT2023"}
|
||||
for ds in args.datasets:
|
||||
npz = OUT / f"field_{ds}.npz"
|
||||
if not npz.exists():
|
||||
print(f"[skip] missing {npz}")
|
||||
continue
|
||||
p = plot_one(npz, pretty.get(ds, ds))
|
||||
print(f"[wrote] {p}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
329
scripts/figures/plot_mechanism.py
Normal file
329
scripts/figures/plot_mechanism.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""Mechanism-level figures for JANUS / Mixed_CFM.
|
||||
|
||||
Generates:
|
||||
fig6_score_corr.pdf — 10x10 sub-score correlation per dataset (benign val)
|
||||
fig1_dual_head.pdf — (terminal_norm, disc_nll_total) + OAS ellipses + whitened PCA
|
||||
fig3_score_hist.pdf — raw vs OAS-aggregated score distributions across datasets
|
||||
|
||||
Inputs: artifacts/route_comparison/janus_<dataset>_seed42/phase1_scores.npz
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
from matplotlib.patches import Ellipse
|
||||
from sklearn.covariance import OAS
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
RUNS = ROOT / "artifacts" / "route_comparison"
|
||||
OUT = ROOT / "artifacts" / "janus_mechanism_figures_2026_05_08"
|
||||
|
||||
DATASETS = ["iscxtor2016", "cicids2017", "cicddos2019", "ciciot2023"]
|
||||
PRETTY = {
|
||||
"iscxtor2016": "ISCXTor2016",
|
||||
"cicids2017": "CICIDS2017",
|
||||
"cicddos2019": "CICDDoS2019",
|
||||
"ciciot2023": "CICIoT2023",
|
||||
}
|
||||
|
||||
SCORE_KEYS = [
|
||||
"terminal_norm", "terminal_flow", "terminal_packet",
|
||||
"disc_nll_total",
|
||||
"disc_nll_ch2", "disc_nll_ch3", "disc_nll_ch4",
|
||||
"disc_nll_ch5", "disc_nll_ch6", "disc_nll_ch7",
|
||||
]
|
||||
SCORE_LABELS = [
|
||||
r"$\|\!|\,t_{\mathrm{norm}}$", r"$t_{\mathrm{flow}}$", r"$t_{\mathrm{pkt}}$",
|
||||
r"$\mathcal{L}_{\mathrm{disc}}$",
|
||||
"ch2", "ch3", "ch4", "ch5", "ch6", "ch7",
|
||||
]
|
||||
SCORE_LABELS = [
|
||||
"term_norm", "term_flow", "term_pkt",
|
||||
"disc_total", "disc_ch2", "disc_ch3", "disc_ch4",
|
||||
"disc_ch5", "disc_ch6", "disc_ch7",
|
||||
]
|
||||
|
||||
|
||||
def load_scores(dataset: str, seed: int = 42) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return (val_S, atk_S) of shape (n, 10)."""
|
||||
npz = RUNS / f"janus_{dataset}_seed{seed}" / "phase1_scores.npz"
|
||||
z = np.load(npz, allow_pickle=True)
|
||||
val = np.stack([z[f"val_{k}"] for k in SCORE_KEYS], axis=1)
|
||||
atk = np.stack([z[f"atk_{k}"] for k in SCORE_KEYS], axis=1)
|
||||
val = np.nan_to_num(val, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64)
|
||||
atk = np.nan_to_num(atk, nan=0.0, posinf=1e6, neginf=-1e6).astype(np.float64)
|
||||
return val, atk
|
||||
|
||||
|
||||
def fit_oas(val_S: np.ndarray):
|
||||
"""Fit OAS on benign val. Return (mu, inv_cov, cov, transform) where transform whitens."""
|
||||
mu = val_S.mean(axis=0)
|
||||
oas = OAS().fit(val_S)
|
||||
cov = oas.covariance_
|
||||
inv_cov = np.linalg.inv(cov + 1e-9 * np.eye(cov.shape[0]))
|
||||
# whitening: x_w = L^{-1} (x - mu) where cov = L L^T (Cholesky)
|
||||
L = np.linalg.cholesky(cov + 1e-9 * np.eye(cov.shape[0]))
|
||||
Linv = np.linalg.solve(L, np.eye(L.shape[0]))
|
||||
return mu, inv_cov, cov, Linv
|
||||
|
||||
|
||||
def mahal(S: np.ndarray, mu: np.ndarray, inv_cov: np.ndarray) -> np.ndarray:
|
||||
d = S - mu
|
||||
return np.einsum("ni,ij,nj->n", d, inv_cov, d)
|
||||
|
||||
|
||||
def plot_corr_heatmap() -> Path:
|
||||
fig, axes = plt.subplots(1, 4, figsize=(18, 4.6), constrained_layout=True)
|
||||
for ax, ds in zip(axes, DATASETS):
|
||||
val, _ = load_scores(ds)
|
||||
# Pearson correlation on benign val; mask diagonals to free visual budget
|
||||
C = np.corrcoef(val, rowvar=False)
|
||||
np.fill_diagonal(C, np.nan)
|
||||
im = ax.imshow(C, vmin=-1, vmax=1, cmap="RdBu_r")
|
||||
ax.set_xticks(range(len(SCORE_LABELS)))
|
||||
ax.set_yticks(range(len(SCORE_LABELS)))
|
||||
ax.set_xticklabels(SCORE_LABELS, rotation=60, ha="right", fontsize=7)
|
||||
ax.set_yticklabels(SCORE_LABELS, fontsize=7)
|
||||
K = len(SCORE_LABELS)
|
||||
off = C[~np.isnan(C)]
|
||||
ax.text(
|
||||
0.02, 1.06, f"{PRETTY[ds]} ⟨|ρ|⟩={np.abs(off).mean():.2f}",
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
)
|
||||
cbar = fig.colorbar(im, ax=axes, shrink=0.85, location="right", pad=0.01)
|
||||
cbar.set_label("Pearson ρ on benign val", fontsize=10)
|
||||
out = OUT / "subscore_correlation_benign_val.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def _ellipse_from_2x2(mu2, cov2, n_sigma, **kw):
|
||||
vals, vecs = np.linalg.eigh(cov2)
|
||||
order = vals.argsort()[::-1]
|
||||
vals, vecs = vals[order], vecs[:, order]
|
||||
angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0]))
|
||||
w, h = 2 * n_sigma * np.sqrt(vals)
|
||||
return Ellipse(xy=mu2, width=w, height=h, angle=angle, **kw)
|
||||
|
||||
|
||||
def plot_dual_head() -> Path:
|
||||
fig = plt.figure(figsize=(16, 8.5), constrained_layout=True)
|
||||
gs = fig.add_gridspec(2, 4)
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
for col, ds in enumerate(DATASETS):
|
||||
val, atk = load_scores(ds)
|
||||
# raw two-axes scatter: terminal_norm (idx 0) vs disc_nll_total (idx 3)
|
||||
x_v, y_v = val[:, 0], val[:, 3]
|
||||
x_a, y_a = atk[:, 0], atk[:, 3]
|
||||
# subsample for legibility
|
||||
nv = min(3000, len(x_v))
|
||||
na = min(3000, len(x_a))
|
||||
iv = rng.choice(len(x_v), nv, replace=False)
|
||||
ia = rng.choice(len(x_a), na, replace=False)
|
||||
|
||||
ax = fig.add_subplot(gs[0, col])
|
||||
ax.scatter(x_v[iv], y_v[iv], s=3, alpha=0.25, c="#2c7fb8", label="benign", rasterized=True)
|
||||
ax.scatter(x_a[ia], y_a[ia], s=3, alpha=0.18, c="#d7191c", label="attack", rasterized=True)
|
||||
# 2D OAS on these two cols only
|
||||
XY_v = val[:, [0, 3]]
|
||||
oas2 = OAS().fit(XY_v)
|
||||
mu2 = XY_v.mean(axis=0)
|
||||
for ns, ls in [(1, "-"), (2, "--"), (3, ":")]:
|
||||
e = _ellipse_from_2x2(
|
||||
mu2, oas2.covariance_, ns,
|
||||
edgecolor="black", facecolor="none", lw=1.1, ls=ls, alpha=0.85,
|
||||
)
|
||||
ax.add_patch(e)
|
||||
ax.set_xlabel(r"$t_{\mathrm{norm}}$ (continuous head)")
|
||||
if col == 0:
|
||||
ax.set_ylabel(r"$\mathcal{L}_{\mathrm{disc}}$ (discrete head)")
|
||||
ax.text(0.02, 1.03, PRETTY[ds], transform=ax.transAxes, fontsize=11)
|
||||
if col == 0:
|
||||
ax.legend(loc="upper right", fontsize=8, framealpha=0.85)
|
||||
# zoom to capture benign body + part of attack mass; use 99.5% of attack
|
||||
x_lo = min(np.quantile(x_v, 0.005), np.quantile(x_a, 0.005))
|
||||
x_hi = max(np.quantile(x_v, 0.995), np.quantile(x_a, 0.995))
|
||||
y_lo = min(np.quantile(y_v, 0.005), np.quantile(y_a, 0.005))
|
||||
y_hi = max(np.quantile(y_v, 0.995), np.quantile(y_a, 0.995))
|
||||
ax.set_xlim(x_lo - 0.05 * (x_hi - x_lo), x_hi + 0.05 * (x_hi - x_lo))
|
||||
ax.set_ylim(y_lo - 0.05 * (y_hi - y_lo), y_hi + 0.05 * (y_hi - y_lo))
|
||||
|
||||
# whitened-PCA panel
|
||||
ax2 = fig.add_subplot(gs[1, col])
|
||||
mu, inv_cov, cov, Linv = fit_oas(val)
|
||||
Wv = (val - mu) @ Linv.T
|
||||
Wa = (atk - mu) @ Linv.T
|
||||
# PCA on benign whitened (which is ~ identity covariance, but we still pick top-2 PCs
|
||||
# of the joint val+atk to maximize visual separation)
|
||||
# Use SVD on val_w to get axes; benign should be ~isotropic, so PCA will essentially
|
||||
# rotate; instead, use direction of maximum (atk - val) mean shift as PC1.
|
||||
delta = Wa.mean(axis=0) - Wv.mean(axis=0)
|
||||
delta_norm = np.linalg.norm(delta) + 1e-12
|
||||
u1 = delta / delta_norm
|
||||
# u2: top PC of attack-whitened residual orthogonal to u1
|
||||
Wa_res = Wa - Wa @ u1[:, None] * u1[None, :]
|
||||
_, _, Vt = np.linalg.svd(Wa_res - Wa_res.mean(axis=0), full_matrices=False)
|
||||
u2 = Vt[0]
|
||||
u2 = u2 - (u2 @ u1) * u1
|
||||
u2 /= np.linalg.norm(u2) + 1e-12
|
||||
Wv2 = np.c_[Wv @ u1, Wv @ u2]
|
||||
Wa2 = np.c_[Wa @ u1, Wa @ u2]
|
||||
nv2 = min(3000, len(Wv2))
|
||||
na2 = min(3000, len(Wa2))
|
||||
iv2 = rng.choice(len(Wv2), nv2, replace=False)
|
||||
ia2 = rng.choice(len(Wa2), na2, replace=False)
|
||||
ax2.scatter(Wv2[iv2, 0], Wv2[iv2, 1], s=3, alpha=0.25, c="#2c7fb8", rasterized=True)
|
||||
ax2.scatter(Wa2[ia2, 0], Wa2[ia2, 1], s=3, alpha=0.18, c="#d7191c", rasterized=True)
|
||||
# benign in whitened space ≈ N(0,I); draw unit-σ Mahalanobis circles
|
||||
for ns, ls in [(1, "-"), (2, "--"), (3, ":")]:
|
||||
ax2.add_patch(plt.Circle((0, 0), ns, fill=False, edgecolor="black", lw=1.1, ls=ls, alpha=0.85))
|
||||
ax2.set_xlabel("whitened PC1 (mean-shift dir)")
|
||||
if col == 0:
|
||||
ax2.set_ylabel("whitened PC2")
|
||||
# symlog axes so benign unit-ball is visible alongside far-field attack
|
||||
# linthresh = 3 covers the 3σ Mahalanobis circles linearly
|
||||
ax2.set_xscale("symlog", linthresh=3)
|
||||
ax2.set_yscale("symlog", linthresh=3)
|
||||
# set a generous range that shows the unit circles AND the attack mass
|
||||
x_max = max(np.quantile(np.abs(Wa2[:, 0]), 0.995), 5)
|
||||
y_max = max(np.quantile(np.abs(Wa2[:, 1]), 0.995), 5)
|
||||
ax2.set_xlim(-x_max * 1.1, x_max * 1.1)
|
||||
ax2.set_ylim(-y_max * 1.1, y_max * 1.1)
|
||||
ax2.axhline(0, color="0.7", lw=0.6, zorder=0)
|
||||
ax2.axvline(0, color="0.7", lw=0.6, zorder=0)
|
||||
# add Mahalanobis AUROC for reference
|
||||
m_v = mahal(val, mu, inv_cov)
|
||||
m_a = mahal(atk, mu, inv_cov)
|
||||
y = np.r_[np.zeros(len(m_v)), np.ones(len(m_a))]
|
||||
s = np.r_[m_v, m_a]
|
||||
auc = roc_auc_score(y, s)
|
||||
ax2.text(
|
||||
0.02, 0.97, f"AUROC(mahal-OAS)={auc:.4f}",
|
||||
transform=ax2.transAxes, ha="left", va="top",
|
||||
fontsize=9, bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="0.5", alpha=0.9),
|
||||
)
|
||||
|
||||
out = OUT / "dual_head_oas_ellipses_top__whitened_pca_bottom.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_score_hist() -> Path:
|
||||
fig, axes = plt.subplots(4, 4, figsize=(16, 12), constrained_layout=True)
|
||||
for col, ds in enumerate(DATASETS):
|
||||
val, atk = load_scores(ds)
|
||||
mu, inv_cov, _, _ = fit_oas(val)
|
||||
|
||||
# Row 0: raw terminal_norm (linear)
|
||||
sv, sa = val[:, 0], atk[:, 0]
|
||||
_hist_panel(axes[0, col], sv, sa, log_x=False)
|
||||
|
||||
# Row 1: OAS-Mahal terminal3 (log)
|
||||
idx_t3 = [SCORE_KEYS.index(k) for k in ["terminal_norm", "terminal_flow", "terminal_packet"]]
|
||||
mu_s = val[:, idx_t3].mean(axis=0)
|
||||
oas_s = OAS().fit(val[:, idx_t3])
|
||||
iv_s = np.linalg.inv(oas_s.covariance_ + 1e-9 * np.eye(len(idx_t3)))
|
||||
sv = mahal(val[:, idx_t3], mu_s, iv_s)
|
||||
sa = mahal(atk[:, idx_t3], mu_s, iv_s)
|
||||
_hist_panel(axes[1, col], sv, sa, log_x=True)
|
||||
|
||||
# Row 2: OAS-Mahal disc7 (log)
|
||||
idx_d7 = [SCORE_KEYS.index(k) for k in [
|
||||
"disc_nll_total", "disc_nll_ch2", "disc_nll_ch3",
|
||||
"disc_nll_ch4", "disc_nll_ch5", "disc_nll_ch6", "disc_nll_ch7"]]
|
||||
mu_s = val[:, idx_d7].mean(axis=0)
|
||||
oas_s = OAS().fit(val[:, idx_d7])
|
||||
iv_s = np.linalg.inv(oas_s.covariance_ + 1e-9 * np.eye(len(idx_d7)))
|
||||
sv = mahal(val[:, idx_d7], mu_s, iv_s)
|
||||
sa = mahal(atk[:, idx_d7], mu_s, iv_s)
|
||||
_hist_panel(axes[2, col], sv, sa, log_x=True)
|
||||
|
||||
# Row 3: OAS-Mahal all 10 (log)
|
||||
sv = mahal(val, mu, inv_cov)
|
||||
sa = mahal(atk, mu, inv_cov)
|
||||
_hist_panel(axes[3, col], sv, sa, log_x=True)
|
||||
|
||||
axes[0, col].text(0.02, 1.04, PRETTY[ds], transform=axes[0, col].transAxes, fontsize=11)
|
||||
|
||||
# row labels
|
||||
row_labels = [
|
||||
"raw terminal_norm",
|
||||
"OAS Mahal: term3 (CFM head)",
|
||||
"OAS Mahal: disc7 (discrete head)",
|
||||
"OAS Mahal: all 10 (deployed)",
|
||||
]
|
||||
for r, lbl in enumerate(row_labels):
|
||||
axes[r, 0].set_ylabel(lbl, fontsize=10)
|
||||
axes[0, 3].legend(loc="upper right", fontsize=8, framealpha=0.85)
|
||||
out = OUT / "score_distributions_raw__termOAS__discOAS__allOAS.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def _hist_panel(ax, sv, sa, log_x: bool = False):
|
||||
y = np.r_[np.zeros(len(sv)), np.ones(len(sa))]
|
||||
s = np.r_[sv, sa]
|
||||
auc = roc_auc_score(y, s)
|
||||
# Use within-class fraction weighting so heights stay comparable when bin
|
||||
# widths are uneven (geomspace on log-x) — density=True compresses right-tail
|
||||
# mass invisibly because density-per-linear-unit collapses at high x.
|
||||
w_v = np.full_like(sv, 1.0 / len(sv))
|
||||
w_a = np.full_like(sa, 1.0 / len(sa))
|
||||
if log_x:
|
||||
eps = max(1e-3, np.quantile(s[s > 0], 0.001) * 0.5) if (s > 0).any() else 1e-3
|
||||
sv_p = np.maximum(sv, eps)
|
||||
sa_p = np.maximum(sa, eps)
|
||||
lo = np.quantile(np.r_[sv_p, sa_p], 0.001)
|
||||
hi = max(sv_p.max(), sa_p.max()) # show full right tail
|
||||
bins = np.geomspace(max(lo, eps), hi, 80)
|
||||
ax.hist(sv_p, bins=bins, color="#2c7fb8", alpha=0.55, label="benign", weights=w_v)
|
||||
ax.hist(sa_p, bins=bins, color="#d7191c", alpha=0.55, label="attack", weights=w_a)
|
||||
ax.set_xscale("log")
|
||||
else:
|
||||
lo, hi = np.quantile(s, [0.001, 0.999])
|
||||
bins = np.linspace(lo, hi, 80)
|
||||
ax.hist(np.clip(sv, lo, hi), bins=bins, color="#2c7fb8", alpha=0.55, label="benign", weights=w_v)
|
||||
ax.hist(np.clip(sa, lo, hi), bins=bins, color="#d7191c", alpha=0.55, label="attack", weights=w_a)
|
||||
ax.text(
|
||||
0.97, 0.95, f"AUROC={auc:.4f}",
|
||||
transform=ax.transAxes, ha="right", va="top", fontsize=8,
|
||||
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="0.5", alpha=0.9),
|
||||
)
|
||||
ax.set_yticks([])
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--which", choices=["all", "corr", "dual", "hist"], default="all")
|
||||
args = parser.parse_args()
|
||||
OUT.mkdir(parents=True, exist_ok=True)
|
||||
mpl.rcParams.update({
|
||||
"font.size": 10,
|
||||
"axes.titlesize": 11,
|
||||
"axes.labelsize": 10,
|
||||
"pdf.fonttype": 42,
|
||||
"ps.fonttype": 42,
|
||||
})
|
||||
if args.which in ("all", "corr"):
|
||||
p = plot_corr_heatmap()
|
||||
print(f"[wrote] {p}")
|
||||
if args.which in ("all", "dual"):
|
||||
p = plot_dual_head()
|
||||
print(f"[wrote] {p}")
|
||||
if args.which in ("all", "hist"):
|
||||
p = plot_score_hist()
|
||||
print(f"[wrote] {p}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
135
scripts/figures/plot_trajectory.py
Normal file
135
scripts/figures/plot_trajectory.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Plots 4 (CFM trajectory in 2D PCA) and 5 (velocity-norm vs t).
|
||||
|
||||
Reads npz produced by run_trajectory_inference.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
OUT = ROOT / "artifacts" / "janus_mechanism_figures_2026_05_08"
|
||||
|
||||
PRETTY = {
|
||||
"cicids2017": "CICIDS2017",
|
||||
"cicddos2019": "CICDDoS2019",
|
||||
"iscxtor2016": "ISCXTor2016",
|
||||
"ciciot2023": "CICIoT2023",
|
||||
}
|
||||
|
||||
|
||||
def _pca_fit_project(benign_t1: np.ndarray, benign_traj: np.ndarray, attack_traj: np.ndarray):
|
||||
"""benign_t1 [n, D] is data anchor; trajectories [n, S, D] each."""
|
||||
pca = PCA(n_components=2).fit(benign_t1)
|
||||
bt = pca.transform(benign_traj.reshape(-1, benign_traj.shape[-1])).reshape(benign_traj.shape[0], benign_traj.shape[1], 2)
|
||||
at = pca.transform(attack_traj.reshape(-1, attack_traj.shape[-1])).reshape(attack_traj.shape[0], attack_traj.shape[1], 2)
|
||||
return pca, bt, at
|
||||
|
||||
|
||||
def _draw_traj(ax, traj_v, traj_a, title, n_show=80):
|
||||
rng = np.random.default_rng(0)
|
||||
iv = rng.choice(traj_v.shape[0], min(n_show, traj_v.shape[0]), replace=False)
|
||||
ia = rng.choice(traj_a.shape[0], min(n_show, traj_a.shape[0]), replace=False)
|
||||
# trajectories: thin alpha lines
|
||||
for i in iv:
|
||||
ax.plot(traj_v[i, :, 0], traj_v[i, :, 1], color="#2c7fb8", alpha=0.18, lw=0.6, zorder=1)
|
||||
for i in ia:
|
||||
ax.plot(traj_a[i, :, 0], traj_a[i, :, 1], color="#d7191c", alpha=0.18, lw=0.6, zorder=1)
|
||||
# endpoints
|
||||
ax.scatter(traj_v[iv, 0, 0], traj_v[iv, 0, 1], s=14, c="#2c7fb8", marker="o",
|
||||
edgecolors="white", linewidths=0.4, zorder=3, label="benign t=1 (data)")
|
||||
ax.scatter(traj_v[iv, -1, 0], traj_v[iv, -1, 1], s=14, c="#2c7fb8", marker="x",
|
||||
linewidths=0.9, alpha=0.85, zorder=3, label="benign t=0 (post-flow)")
|
||||
ax.scatter(traj_a[ia, 0, 0], traj_a[ia, 0, 1], s=14, c="#d7191c", marker="o",
|
||||
edgecolors="white", linewidths=0.4, zorder=3, label="attack t=1 (data)")
|
||||
ax.scatter(traj_a[ia, -1, 0], traj_a[ia, -1, 1], s=14, c="#d7191c", marker="x",
|
||||
linewidths=0.9, alpha=0.85, zorder=3, label="attack t=0 (post-flow)")
|
||||
# unit circle (target N(0,I) for benign post-flow if flow learned correctly)
|
||||
theta = np.linspace(0, 2 * np.pi, 120)
|
||||
for ns, ls in [(1, "-"), (2, "--")]:
|
||||
ax.plot(ns * np.cos(theta), ns * np.sin(theta), color="black", lw=0.8, ls=ls, alpha=0.5, zorder=2)
|
||||
ax.axhline(0, color="0.85", lw=0.5, zorder=0)
|
||||
ax.axvline(0, color="0.85", lw=0.5, zorder=0)
|
||||
ax.set_title(title, fontsize=11)
|
||||
ax.set_aspect("equal", adjustable="datalim")
|
||||
|
||||
|
||||
def plot_trajectory(npz_paths: dict[str, Path]) -> Path:
|
||||
fig, axes = plt.subplots(2, len(npz_paths), figsize=(7.5 * len(npz_paths), 12), constrained_layout=True)
|
||||
if len(npz_paths) == 1:
|
||||
axes = axes[:, None]
|
||||
for col, (ds, npz) in enumerate(npz_paths.items()):
|
||||
z = np.load(npz)
|
||||
# FLOW-token trajectory
|
||||
ftv = z["z_traj_flow_v"] # [n, S, D]
|
||||
fta = z["z_traj_flow_a"]
|
||||
_, bt, at = _pca_fit_project(ftv[:, 0], ftv, fta)
|
||||
_draw_traj(axes[0, col], bt, at, f"{PRETTY[ds]} — FLOW token (PC1 vs PC2 of benign t=1)")
|
||||
# mean-packet trajectory
|
||||
ptv = z["z_traj_pkt_v"]
|
||||
pta = z["z_traj_pkt_a"]
|
||||
_, bt2, at2 = _pca_fit_project(ptv[:, 0], ptv, pta)
|
||||
_draw_traj(axes[1, col], bt2, at2, f"{PRETTY[ds]} — mean packet token")
|
||||
|
||||
axes[0, 0].legend(loc="upper right", fontsize=7, framealpha=0.85)
|
||||
fig.suptitle(
|
||||
"Reverse CFM flow (t=1 → t=0): benign collapses toward learned source (≈ N(0,I) inside dashed circles); "
|
||||
"attack endpoints land off-distribution",
|
||||
fontsize=12,
|
||||
)
|
||||
out = OUT / "fig4_trajectory_pca.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def plot_velocity_norm(npz_paths: dict[str, Path]) -> Path:
|
||||
fig, axes = plt.subplots(1, len(npz_paths), figsize=(6.5 * len(npz_paths), 5.6), constrained_layout=True)
|
||||
if len(npz_paths) == 1:
|
||||
axes = [axes]
|
||||
for ax, (ds, npz) in zip(axes, npz_paths.items()):
|
||||
z = np.load(npz)
|
||||
vn_v = z["vnorm_v"] # [n, n_steps]
|
||||
vn_a = z["vnorm_a"]
|
||||
# t at integration step k corresponds to t = 1 - k*dt, k=0..n_steps-1
|
||||
n_steps = vn_v.shape[1]
|
||||
t_steps = 1.0 - np.arange(n_steps) / n_steps
|
||||
# mean ± std band
|
||||
m_v, s_v = vn_v.mean(0), vn_v.std(0)
|
||||
m_a, s_a = vn_a.mean(0), vn_a.std(0)
|
||||
ax.plot(t_steps, m_v, color="#2c7fb8", lw=1.6, label="benign mean")
|
||||
ax.fill_between(t_steps, m_v - s_v, m_v + s_v, color="#2c7fb8", alpha=0.18)
|
||||
ax.plot(t_steps, m_a, color="#d7191c", lw=1.6, label="attack mean")
|
||||
ax.fill_between(t_steps, m_a - s_a, m_a + s_a, color="#d7191c", alpha=0.18)
|
||||
ax.set_xlabel("CFM time t (1 = data → 0 = source)")
|
||||
ax.set_ylabel("‖v(x_t, t)‖ per real token (mean over flow)")
|
||||
ax.text(0.02, 1.02, PRETTY[ds], transform=ax.transAxes, fontsize=11)
|
||||
ax.invert_xaxis() # so left is t=1 (data), right is t=0 (source)
|
||||
ax.legend(fontsize=8, loc="upper left", framealpha=0.85)
|
||||
out = OUT / "velocity_norm_vs_t_benign_vs_attack.pdf"
|
||||
fig.savefig(out, bbox_inches="tight")
|
||||
fig.savefig(out.with_suffix(".png"), bbox_inches="tight", dpi=160)
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--datasets", nargs="+",
|
||||
default=["cicids2017", "cicddos2019", "iscxtor2016", "ciciot2023"])
|
||||
args = parser.parse_args()
|
||||
OUT.mkdir(parents=True, exist_ok=True)
|
||||
mpl.rcParams.update({"font.size": 10, "pdf.fonttype": 42, "ps.fonttype": 42})
|
||||
npz_paths = {ds: OUT / f"traj_{ds}.npz" for ds in args.datasets}
|
||||
p4 = plot_trajectory(npz_paths)
|
||||
print(f"[wrote] {p4}")
|
||||
p5 = plot_velocity_norm(npz_paths)
|
||||
print(f"[wrote] {p5}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
266
scripts/figures/run_field_view.py
Normal file
266
scripts/figures/run_field_view.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Generate Unified-style 3-panel field-view data for Mixed_CFM:
|
||||
(1) velocity field at t=0.5 on benign-FLOW-token PCA grid (with streamlines)
|
||||
(2) reverse-flow trajectories t=1 → t=0
|
||||
(3) forward-flow trajectories t=0 → t=1 (sample from N(0,I) noise back to data)
|
||||
|
||||
All operations run in token space; the visual projection is 2D PCA fit on benign
|
||||
FLOW token at t=0.5 (so interpolation states align with the source-side flow).
|
||||
|
||||
Output npz keys:
|
||||
pca_components [2, token_dim] PCA basis (benign FLOW @ t=0.5)
|
||||
pca_mean [token_dim]
|
||||
pca_explained_var [2]
|
||||
benign_t1_2d [N, 2] benign FLOW token at t=1 in PCA coords
|
||||
benign_t05_2d [N, 2] benign FLOW token at t=0.5 (the PCA fit basis)
|
||||
benign_t0_2d [N, 2] benign FLOW token at t=0 (random N(0,I))
|
||||
reverse_v_2d [Nrv, S+1, 2] benign reverse trajectory in PCA coords
|
||||
reverse_a_2d [Nra, S+1, 2] attack reverse trajectory
|
||||
forward_v_2d [Nfv, S+1, 2] forward trajectory from noise → benign-cond-template
|
||||
grid_x, grid_y [G, G] field grid coords in PCA space
|
||||
field_v_2d [G, G, 2] velocity vectors in PCA coords at t=0.5
|
||||
field_log_norm [G, G] log10 ‖v(x_t,t)‖ at full token scale
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.insert(0, str(ROOT / "Mixed_CFM"))
|
||||
from data import load_mixed_data # noqa: E402
|
||||
from model import MixedCFMConfig, MixedTokenCFM # noqa: E402
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def integrate_reverse(model: MixedTokenCFM, z1: torch.Tensor, lens: torch.Tensor, *, n_steps: int) -> torch.Tensor:
|
||||
"""z1 already at t=1 (data). Returns snapshots [B, n_steps+1, L, D] from t=1 → t=0."""
|
||||
z = z1.clone()
|
||||
mask = model._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
dt = 1.0 / n_steps
|
||||
cfg = model.cfg
|
||||
disc_start = 1 + cfg.n_cont_pkt
|
||||
disc_end = disc_start + cfg.n_disc_pkt
|
||||
disc_embed = z[:, 1:, disc_start:disc_end].clone()
|
||||
snaps = [z.clone()]
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v, _ = model.velocity(z, t, key_padding_mask=kpm)
|
||||
v[:, :, disc_start:disc_end] = 0.0
|
||||
z = z - v * dt
|
||||
z[:, 1:, disc_start:disc_end] = disc_embed
|
||||
snaps.append(z.clone())
|
||||
return torch.stack(snaps, dim=1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def integrate_forward(model: MixedTokenCFM, z0: torch.Tensor, lens: torch.Tensor, disc_embed: torch.Tensor, *, n_steps: int) -> torch.Tensor:
|
||||
"""Forward Euler from t=0 (Gaussian noise) → t=1 (generated). Holds disc embed fixed."""
|
||||
z = z0.clone()
|
||||
mask = model._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
dt = 1.0 / n_steps
|
||||
cfg = model.cfg
|
||||
disc_start = 1 + cfg.n_cont_pkt
|
||||
disc_end = disc_start + cfg.n_disc_pkt
|
||||
snaps = [z.clone()]
|
||||
for k in range(n_steps):
|
||||
t_val = k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
z[:, 1:, disc_start:disc_end] = disc_embed
|
||||
v, _ = model.velocity(z, t, key_padding_mask=kpm)
|
||||
v[:, :, disc_start:disc_end] = 0.0
|
||||
z = z + v * dt
|
||||
snaps.append(z.clone())
|
||||
return torch.stack(snaps, dim=1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def velocity_at(model: MixedTokenCFM, z: torch.Tensor, lens: torch.Tensor, t_val: float) -> torch.Tensor:
|
||||
"""Return velocity v at given t. Mirrors signature of model.velocity but masks disc-channel rows in v."""
|
||||
mask = model._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
t = torch.full((B,), float(t_val), device=z.device)
|
||||
v, _ = model.velocity(z, t, key_padding_mask=kpm)
|
||||
cfg = model.cfg
|
||||
disc_start = 1 + cfg.n_cont_pkt
|
||||
disc_end = disc_start + cfg.n_disc_pkt
|
||||
v[:, :, disc_start:disc_end] = 0.0
|
||||
return v
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model-dir", type=Path, required=True)
|
||||
p.add_argument("--out", type=Path, required=True)
|
||||
p.add_argument("--n-pca-benign", type=int, default=1500)
|
||||
p.add_argument("--n-reverse-benign", type=int, default=30)
|
||||
p.add_argument("--n-reverse-attack", type=int, default=30)
|
||||
p.add_argument("--n-forward", type=int, default=200)
|
||||
p.add_argument("--grid", type=int, default=40)
|
||||
p.add_argument("--grid-templates", type=int, default=8, help="Number of benign templates to average velocity over per grid point")
|
||||
p.add_argument("--n-steps", type=int, default=32)
|
||||
p.add_argument("--device", type=str, default="auto")
|
||||
p.add_argument("--batch-size", type=int, default=128)
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device("cuda" if (args.device == "auto" and torch.cuda.is_available()) else (args.device if args.device != "auto" else "cpu"))
|
||||
cfg = yaml.safe_load((args.model_dir / "config.yaml").read_text())
|
||||
ckpt = torch.load(args.model_dir / "model.pt", map_location="cpu", weights_only=False)
|
||||
model_cfg = MixedCFMConfig(**ckpt["model_cfg"])
|
||||
model = MixedTokenCFM(model_cfg).to(device)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
data = load_mixed_data(
|
||||
packets_npz=Path(cfg["packets_npz"]) if cfg.get("packets_npz") else None,
|
||||
source_store=Path(cfg["source_store"]) if cfg.get("source_store") else None,
|
||||
flows_parquet=Path(cfg["flows_parquet"]),
|
||||
flow_features_path=Path(cfg["flow_features_path"]),
|
||||
flow_features_align=str(cfg.get("flow_features_align", "auto")),
|
||||
T=int(cfg["T"]),
|
||||
split_seed=int(cfg.get("data_seed", cfg.get("seed", 42))),
|
||||
train_ratio=float(cfg.get("train_ratio", 0.8)),
|
||||
benign_label=str(cfg.get("benign_label", "normal")),
|
||||
min_len=int(cfg.get("min_len", 2)),
|
||||
attack_cap=int(cfg["attack_cap"]) if cfg.get("attack_cap") else None,
|
||||
val_cap=int(cfg["val_cap"]) if cfg.get("val_cap") else None,
|
||||
)
|
||||
print(f"[data] val={len(data.val_flow):,} attack={len(data.attack_flow):,}")
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
nv = min(args.n_pca_benign, len(data.val_flow))
|
||||
iv_pca = np.sort(rng.choice(len(data.val_flow), nv, replace=False))
|
||||
irv = np.sort(rng.choice(len(data.val_flow), args.n_reverse_benign, replace=False))
|
||||
ira = np.sort(rng.choice(len(data.attack_flow), args.n_reverse_attack, replace=False))
|
||||
ifw_template = np.sort(rng.choice(len(data.val_flow), args.n_forward, replace=False))
|
||||
|
||||
def build_z(flow, cont, disc):
|
||||
flow_t = torch.from_numpy(flow).float().to(device)
|
||||
cont_t = torch.from_numpy(cont).float().to(device)
|
||||
disc_t = torch.from_numpy(disc).long().to(device)
|
||||
return model.build_tokens(flow_t, cont_t, disc_t)
|
||||
|
||||
# ==== Step 1: build z1 for benign PCA pool, get FLOW token at t=0.5 ====
|
||||
z1_pca = build_z(data.val_flow[iv_pca], data.val_cont[iv_pca], data.val_disc[iv_pca]) # [N, L, D]
|
||||
lens_pca = torch.from_numpy(data.val_len[iv_pca]).long().to(device)
|
||||
flow_t1 = z1_pca[:, 0, :].cpu().numpy() # [N, D] — FLOW token at t=1
|
||||
sigma = float(model_cfg.sigma)
|
||||
z0_for_pca = torch.randn_like(z1_pca)
|
||||
t_val = 0.5
|
||||
z_t05 = (1 - t_val) * z0_for_pca + t_val * z1_pca
|
||||
if sigma > 0:
|
||||
std = sigma * np.sqrt(t_val * (1 - t_val))
|
||||
z_t05 = z_t05 + std * torch.randn_like(z_t05)
|
||||
flow_t05 = z_t05[:, 0, :].cpu().numpy() # [N, D]
|
||||
flow_t0 = z0_for_pca[:, 0, :].cpu().numpy()
|
||||
|
||||
pca = PCA(n_components=2).fit(flow_t05)
|
||||
print(f"[pca] explained var on benign FLOW @ t=0.5: {pca.explained_variance_ratio_}")
|
||||
|
||||
benign_t1_2d = pca.transform(flow_t1)
|
||||
benign_t05_2d = pca.transform(flow_t05)
|
||||
benign_t0_2d = pca.transform(flow_t0)
|
||||
|
||||
# ==== Step 2: reverse trajectories ====
|
||||
print("[run] reverse benign")
|
||||
z1_rv = build_z(data.val_flow[irv], data.val_cont[irv], data.val_disc[irv])
|
||||
lens_rv = torch.from_numpy(data.val_len[irv]).long().to(device)
|
||||
snaps_rv = integrate_reverse(model, z1_rv, lens_rv, n_steps=args.n_steps) # [B, S+1, L, D]
|
||||
rv_2d = pca.transform(snaps_rv[:, :, 0, :].reshape(-1, snaps_rv.shape[-1]).cpu().numpy()).reshape(snaps_rv.shape[0], snaps_rv.shape[1], 2)
|
||||
|
||||
print("[run] reverse attack")
|
||||
z1_ra = build_z(data.attack_flow[ira], data.attack_cont[ira], data.attack_disc[ira])
|
||||
lens_ra = torch.from_numpy(data.attack_len[ira]).long().to(device)
|
||||
snaps_ra = integrate_reverse(model, z1_ra, lens_ra, n_steps=args.n_steps)
|
||||
ra_2d = pca.transform(snaps_ra[:, :, 0, :].reshape(-1, snaps_ra.shape[-1]).cpu().numpy()).reshape(snaps_ra.shape[0], snaps_ra.shape[1], 2)
|
||||
|
||||
# ==== Step 3: forward trajectories (sample noise, integrate to t=1, hold disc embed from a benign template) ====
|
||||
print("[run] forward from noise")
|
||||
z1_fw = build_z(data.val_flow[ifw_template], data.val_cont[ifw_template], data.val_disc[ifw_template])
|
||||
lens_fw = torch.from_numpy(data.val_len[ifw_template]).long().to(device)
|
||||
cfg_m = model.cfg
|
||||
disc_start = 1 + cfg_m.n_cont_pkt
|
||||
disc_end = disc_start + cfg_m.n_disc_pkt
|
||||
disc_embed_fw = z1_fw[:, 1:, disc_start:disc_end].clone()
|
||||
z0_fw = torch.randn_like(z1_fw)
|
||||
snaps_fw = integrate_forward(model, z0_fw, lens_fw, disc_embed_fw, n_steps=args.n_steps)
|
||||
fw_2d = pca.transform(snaps_fw[:, :, 0, :].reshape(-1, snaps_fw.shape[-1]).cpu().numpy()).reshape(snaps_fw.shape[0], snaps_fw.shape[1], 2)
|
||||
|
||||
# ==== Step 4: velocity field on PCA grid at t=0.5 ====
|
||||
print("[run] velocity field grid")
|
||||
pad = 0.3
|
||||
x_min, x_max = benign_t05_2d[:, 0].min(), benign_t05_2d[:, 0].max()
|
||||
y_min, y_max = benign_t05_2d[:, 1].min(), benign_t05_2d[:, 1].max()
|
||||
sx = x_max - x_min
|
||||
sy = y_max - y_min
|
||||
x_lo, x_hi = x_min - pad * sx, x_max + pad * sx
|
||||
y_lo, y_hi = y_min - pad * sy, y_max + pad * sy
|
||||
gx = np.linspace(x_lo, x_hi, args.grid)
|
||||
gy = np.linspace(y_lo, y_hi, args.grid)
|
||||
GX, GY = np.meshgrid(gx, gy)
|
||||
grid_2d = np.stack([GX.ravel(), GY.ravel()], axis=1) # [G^2, 2]
|
||||
grid_full = pca.inverse_transform(grid_2d) # [G^2, D]
|
||||
grid_full_t = torch.from_numpy(grid_full).float().to(device)
|
||||
|
||||
# For each grid point, replace FLOW token at t=0.5 of K random benign templates;
|
||||
# average velocity at FLOW position over templates.
|
||||
K = args.grid_templates
|
||||
# Sample K template z_t05 from PCA pool
|
||||
template_idx = rng.choice(len(z1_pca), K, replace=False)
|
||||
z1_tpl = z1_pca[template_idx] # [K, L, D]
|
||||
lens_tpl = lens_pca[template_idx]
|
||||
z0_tpl = torch.randn_like(z1_tpl)
|
||||
z_t05_tpl = (1 - t_val) * z0_tpl + t_val * z1_tpl
|
||||
if sigma > 0:
|
||||
z_t05_tpl = z_t05_tpl + (sigma * np.sqrt(t_val * (1 - t_val))) * torch.randn_like(z_t05_tpl)
|
||||
|
||||
G2 = grid_full_t.shape[0]
|
||||
v_grid_full = torch.zeros((G2, grid_full_t.shape[1]), device=device)
|
||||
bs = args.batch_size
|
||||
for k in range(K):
|
||||
# build a [G^2, L, D] tensor where token 0 = grid point, tokens 1: = template's z_t05 packets
|
||||
tpl_packets = z_t05_tpl[k:k + 1, 1:, :].expand(G2, -1, -1).contiguous()
|
||||
z_grid = torch.cat([grid_full_t.unsqueeze(1), tpl_packets], dim=1) # [G^2, L, D]
|
||||
lens_grid = lens_tpl[k:k + 1].expand(G2).contiguous()
|
||||
# batched velocity eval
|
||||
v_chunks = []
|
||||
for s in range(0, G2, bs):
|
||||
v_chunks.append(velocity_at(model, z_grid[s:s + bs], lens_grid[s:s + bs], t_val=t_val)[:, 0, :])
|
||||
v_grid_full = v_grid_full + torch.cat(v_chunks, dim=0)
|
||||
v_grid_full = v_grid_full / K
|
||||
v_grid_np = v_grid_full.cpu().numpy()
|
||||
v_grid_2d = (v_grid_np - 0) @ pca.components_.T # project [G^2, D] onto 2 PCA basis vectors → [G^2, 2]
|
||||
v_grid_2d = v_grid_2d.reshape(args.grid, args.grid, 2)
|
||||
log_norm_full = np.log10(np.linalg.norm(v_grid_np, axis=-1) + 1e-9).reshape(args.grid, args.grid)
|
||||
|
||||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(
|
||||
args.out,
|
||||
pca_components=pca.components_.astype(np.float32),
|
||||
pca_mean=pca.mean_.astype(np.float32),
|
||||
pca_explained_var=pca.explained_variance_ratio_.astype(np.float32),
|
||||
benign_t1_2d=benign_t1_2d.astype(np.float32),
|
||||
benign_t05_2d=benign_t05_2d.astype(np.float32),
|
||||
benign_t0_2d=benign_t0_2d.astype(np.float32),
|
||||
reverse_v_2d=rv_2d.astype(np.float32),
|
||||
reverse_a_2d=ra_2d.astype(np.float32),
|
||||
forward_v_2d=fw_2d.astype(np.float32),
|
||||
grid_x=GX.astype(np.float32),
|
||||
grid_y=GY.astype(np.float32),
|
||||
field_v_2d=v_grid_2d.astype(np.float32),
|
||||
field_log_norm=log_norm_full.astype(np.float32),
|
||||
)
|
||||
print(f"[wrote] {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
157
scripts/figures/run_trajectory_inference.py
Normal file
157
scripts/figures/run_trajectory_inference.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Reverse-flow ODE integration with per-step snapshots, for trajectory + velocity-norm plots.
|
||||
|
||||
For a chosen (dataset, seed), loads a small balanced sample (n_per_class benign+attack)
|
||||
and integrates the velocity field from t=1 (data) to t=0 (Gaussian) using the same Euler
|
||||
scheme as `MixedTokenCFM.trajectory_metrics`, but saves z snapshots and ‖v‖ at each step.
|
||||
|
||||
Outputs npz with keys:
|
||||
z_traj_flow_v / z_traj_flow_a [n, n_steps+1, token_dim] FLOW-token trajectories
|
||||
z_traj_pkt_v / z_traj_pkt_a [n, n_steps+1, token_dim] masked-mean packet-token trajectories
|
||||
vnorm_v / vnorm_a [n, n_steps] per-step velocity norm (real tokens only)
|
||||
t_grid [n_steps+1] t values; t_grid[0] = 1.0, decreasing
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.insert(0, str(ROOT / "Mixed_CFM"))
|
||||
from data import load_mixed_data # noqa: E402
|
||||
from model import MixedCFMConfig, MixedTokenCFM # noqa: E402
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_reverse_flow(model: MixedTokenCFM, flow, cont, disc, lens, *, n_steps: int):
|
||||
z = model.build_tokens(flow, cont, disc) # at t=1
|
||||
mask = model._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
dt = 1.0 / n_steps
|
||||
cfg = model.cfg
|
||||
disc_start = 1 + cfg.n_cont_pkt
|
||||
disc_end = disc_start + cfg.n_disc_pkt
|
||||
disc_embed = z[:, 1:, disc_start:disc_end].clone()
|
||||
|
||||
snaps = [z.clone()]
|
||||
vnorms = []
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v, _ = model.velocity(z, t, key_padding_mask=kpm)
|
||||
v[:, :, disc_start:disc_end] = 0.0
|
||||
# per-sample velocity norm averaged over real tokens
|
||||
v_norm_per_tok = v.norm(dim=-1) # [B, L]
|
||||
per_sample = (v_norm_per_tok * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
vnorms.append(per_sample.cpu().numpy())
|
||||
z = z - v * dt
|
||||
z[:, 1:, disc_start:disc_end] = disc_embed
|
||||
snaps.append(z.clone())
|
||||
return snaps, np.stack(vnorms, axis=1), mask.cpu().numpy()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model-dir", type=Path, required=True)
|
||||
p.add_argument("--out", type=Path, required=True)
|
||||
p.add_argument("--n-per-class", type=int, default=200)
|
||||
p.add_argument("--n-steps", type=int, default=32)
|
||||
p.add_argument("--device", type=str, default="auto")
|
||||
p.add_argument("--batch-size", type=int, default=128)
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device("cuda" if (args.device == "auto" and torch.cuda.is_available()) else (args.device if args.device != "auto" else "cpu"))
|
||||
cfg = yaml.safe_load((args.model_dir / "config.yaml").read_text())
|
||||
ckpt = torch.load(args.model_dir / "model.pt", map_location="cpu", weights_only=False)
|
||||
model_cfg = MixedCFMConfig(**ckpt["model_cfg"])
|
||||
model = MixedTokenCFM(model_cfg).to(device)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
data = load_mixed_data(
|
||||
packets_npz=Path(cfg["packets_npz"]) if cfg.get("packets_npz") else None,
|
||||
source_store=Path(cfg["source_store"]) if cfg.get("source_store") else None,
|
||||
flows_parquet=Path(cfg["flows_parquet"]),
|
||||
flow_features_path=Path(cfg["flow_features_path"]),
|
||||
flow_features_align=str(cfg.get("flow_features_align", "auto")),
|
||||
T=int(cfg["T"]),
|
||||
split_seed=int(cfg.get("data_seed", cfg.get("seed", 42))),
|
||||
train_ratio=float(cfg.get("train_ratio", 0.8)),
|
||||
benign_label=str(cfg.get("benign_label", "normal")),
|
||||
min_len=int(cfg.get("min_len", 2)),
|
||||
attack_cap=int(cfg["attack_cap"]) if cfg.get("attack_cap") else None,
|
||||
val_cap=int(cfg["val_cap"]) if cfg.get("val_cap") else None,
|
||||
)
|
||||
print(f"[data] val={len(data.val_flow):,} attack={len(data.attack_flow):,}")
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
nv = min(args.n_per_class, len(data.val_flow))
|
||||
na = min(args.n_per_class, len(data.attack_flow))
|
||||
iv = np.sort(rng.choice(len(data.val_flow), nv, replace=False))
|
||||
ia = np.sort(rng.choice(len(data.attack_flow), na, replace=False))
|
||||
|
||||
def to_t(flow, cont, disc, lens):
|
||||
return (
|
||||
torch.from_numpy(flow).float().to(device),
|
||||
torch.from_numpy(cont).float().to(device),
|
||||
torch.from_numpy(disc).long().to(device),
|
||||
torch.from_numpy(lens).long().to(device),
|
||||
)
|
||||
|
||||
def collect(flows_np, conts_np, discs_np, lens_np):
|
||||
snaps_all, vnorms_all, mask_all = [], [], []
|
||||
for start in range(0, len(flows_np), args.batch_size):
|
||||
sl = slice(start, start + args.batch_size)
|
||||
flow, cont, disc, lens = to_t(flows_np[sl], conts_np[sl], discs_np[sl], lens_np[sl])
|
||||
snaps, vn, mask = run_reverse_flow(model, flow, cont, disc, lens, n_steps=args.n_steps)
|
||||
snaps_all.append(torch.stack(snaps, dim=1).cpu().numpy()) # [b, n_steps+1, L, D]
|
||||
vnorms_all.append(vn)
|
||||
mask_all.append(mask)
|
||||
print(f" [batch] {min(start + args.batch_size, len(flows_np))}/{len(flows_np)}", flush=True)
|
||||
return (np.concatenate(snaps_all, axis=0), np.concatenate(vnorms_all, axis=0), np.concatenate(mask_all, axis=0))
|
||||
|
||||
print("[run] benign val")
|
||||
t0 = time.time()
|
||||
snaps_v, vn_v, mask_v = collect(data.val_flow[iv], data.val_cont[iv], data.val_disc[iv], data.val_len[iv])
|
||||
print(f" done {time.time() - t0:.1f}s snaps={snaps_v.shape}")
|
||||
|
||||
print("[run] attack")
|
||||
t0 = time.time()
|
||||
snaps_a, vn_a, mask_a = collect(data.attack_flow[ia], data.attack_cont[ia], data.attack_disc[ia], data.attack_len[ia])
|
||||
print(f" done {time.time() - t0:.1f}s snaps={snaps_a.shape}")
|
||||
|
||||
# extract FLOW token trajectory and packet-mean trajectory
|
||||
def flow_and_pkt_traj(snaps, mask):
|
||||
# snaps [n, S, L, D], mask [n, L] (L = T+1, includes flow token at idx 0)
|
||||
flow_tok = snaps[:, :, 0, :] # [n, S, D]
|
||||
pkt_mask = mask[:, 1:][:, None, :, None].astype(np.float32)
|
||||
pkt_count = pkt_mask.sum(axis=2).clip(1.0)
|
||||
pkt_mean = (snaps[:, :, 1:, :] * pkt_mask).sum(axis=2) / pkt_count # [n, S, D]
|
||||
return flow_tok, pkt_mean
|
||||
|
||||
flow_tok_v, pkt_mean_v = flow_and_pkt_traj(snaps_v, mask_v)
|
||||
flow_tok_a, pkt_mean_a = flow_and_pkt_traj(snaps_a, mask_a)
|
||||
|
||||
n_steps = args.n_steps
|
||||
t_grid = np.array([1.0 - k * (1.0 / n_steps) for k in range(n_steps + 1)])
|
||||
|
||||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(
|
||||
args.out,
|
||||
z_traj_flow_v=flow_tok_v.astype(np.float32),
|
||||
z_traj_flow_a=flow_tok_a.astype(np.float32),
|
||||
z_traj_pkt_v=pkt_mean_v.astype(np.float32),
|
||||
z_traj_pkt_a=pkt_mean_a.astype(np.float32),
|
||||
vnorm_v=vn_v.astype(np.float32),
|
||||
vnorm_a=vn_a.astype(np.float32),
|
||||
t_grid=t_grid.astype(np.float32),
|
||||
)
|
||||
print(f"[wrote] {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user