Initial commit: code, paper, small artifacts

This commit is contained in:
2026-05-07 20:47:30 +08:00
commit fae2db8cff
322 changed files with 33159 additions and 0 deletions

133
Unified_CFM/README.md Normal file
View File

@@ -0,0 +1,133 @@
# Unified_CFM
A single multi-scale OT-CFM over one token sequence per flow:
```text
[FLOW_TOKEN, PACKET_1, ..., PACKET_T]
```
This is **not** a Flow-CFM + Packet-CFM ensemble. Flow-level and packet-level
signals interact inside one Transformer velocity field, and a Phase 2
masked-prediction consistency loss explicitly trains the cross-modal
dependency.
This is the **current SOTA model** in the repo (within-dataset SOTA on
ISCXTor2016 / CICIDS2017 / CICDDoS2019; near-SOTA cross-dataset).
## Model
`UnifiedTokenCFM` uses fixed tokenization to avoid latent-collapse shortcuts:
```text
flow token: [type=-1, normalized 20-d canonical flow features, zero pad]
packet token: [type=+1, normalized 9-d packet features, zero pad]
```
Velocity field: 4-layer AdaLN-Zero Transformer (`d_model=128, n_heads=4`),
sinusoidal time embedding (`time_dim=64`). Total ≈ 1.23M parameters.
Loss with Phase 2 consistency:
```
L = L_main + λ_flow · L_mask_flow + λ_packet · L_mask_packet
L_main: standard OT-CFM velocity regression with σ-band noise +
Sinkhorn OT coupling.
L_mask_flow: zero out the flow token's input at x_t; predict v[flow]
from packet context only.
L_mask_packet: zero out a random 50% of real packet tokens at x_t;
predict their velocities from flow + remaining packets.
```
Best hyperparameters from the σ × λ sweeps:
```
lambda_flow = lambda_packet = 0.3
packet_mask_ratio = 0.5
sigma = 0.6 # cross-dataset best; σ=0.1 marginally better for some within
use_ot = True
```
## Scores
The model exposes three classes of scores at inference:
```text
# primary
terminal_norm
# decomposed (analysis only)
terminal_flow terminal_packet
arc_length kinetic_energy kinetic_flow kinetic_packet
velocity_total velocity_flow velocity_packet
# Phase 1 diagnostics
curvature_total curvature_flow curvature_packet # ∫ ||dv/dt||² dt
kappa2_speed2norm_packet_{mean,median,trimmed10_mean} # packet curvature / speed²
jacobian_total jacobian_flow jacobian_packet # Hutchinson VJP estimate of ||∂v/∂x||_F²
velocity_*_t{01..10} # 18 time-profile scores
# Phase 2 cross-modal consistency
flow_consistency packet_consistency consistency_total
```
`terminal_norm` is the paper's primary score. The decomposed and diagnostic
scores serve **per-attack-family analysis** — they are NOT competing
SOTA claims. Multi-seed std on `terminal_norm` is ≤ 0.005 across all our
runs.
The Phase 2 consistency scores have a notable property: they are
**discriminative only when the model is trained with the consistency loss**.
On a baseline model `flow_consistency` is roughly random (0.57 on
CICIDS2017); after Phase 2 training it lifts to 0.88. On SSH-Patator,
where standard density scores struggle (`terminal_norm` 0.64), Phase 2
`flow_consistency` reaches 0.94.
## Train
```bash
# baseline (no consistency loss)
uv run python Unified_CFM/train.py --config Unified_CFM/configs/cicids2017_baseline.yaml
# Phase 2 with consistency loss (λ=0.1, σ=0.1)
uv run python Unified_CFM/train.py --config Unified_CFM/configs/cicids2017_consistency.yaml
# σ × λ sweeps and multi-seed orchestrators live in
# artifacts/verify_2026_04_24/run_*.sh
```
The intended setup is to use the workspace-canonical 20-d packet-derived
flow feature file:
```yaml
flow_features_path: datasets/cicids2017/processed/flow_features.parquet
flow_features_align: auto
```
`flow_features.parquet` is row-aligned with the Packet_CFM artifacts via
`flow_id`. With `flow_features_align: auto`, the loader uses direct
row/`flow_id` alignment when possible; scan alignment remains only for
legacy full CSV-derived caches.
For large datasets where a monolithic `packets.npz` would exceed memory,
the loader supports the sharded backend:
```yaml
source_store: datasets/cicddos2019/processed/full_store
val_cap: 20000
attack_cap: 20000
```
If `flow_features_path` is empty, the loader derives compact 16-d flow-level
statistics from the packet sequence. That fallback is for debugging only;
new runs should use the canonical 20-d file generated by
`scripts/generate_flow_features.py`.
## Evaluation
`artifacts/verify_2026_04_24/eval_phase1_unified.py` runs Phase 1 + Phase 2
score battery on a trained checkpoint, with per-attack-class AUROC.
`artifacts/verify_2026_04_24/eval_phase2_cross_cicddos2019.py` runs
cross-dataset CICIDS2017→CICDDoS2019 evaluation under the standard
10k benign + 10k stratified attack protocol.

1
Unified_CFM/__init__.py Normal file
View File

@@ -0,0 +1 @@
pass

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/phaseC_reference_2026_04_25/cicddos2019_ref_blockdiag_seed42
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 20000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode: block_diagonal
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.0
lambda_packet: 0.0
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/phaseC_reference_2026_04_25/cicddos2019_ref_independent_seed42
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 20000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode: independent_token
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 10000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.0
lambda_packet: 0.0
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,41 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicddos2019_within_2026_04_25
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 20000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 10000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
device: auto

View File

@@ -0,0 +1,43 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicddos2019_within_consistency_2026_04_25
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 20000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 10000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.1
lambda_packet: 0.1
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,38 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicids2017_canonical_2026_04_24
packets_npz: /home/chy/JANUS/datasets/cicids2017/processed/packets.npz
flows_parquet: /home/chy/JANUS/datasets/cicids2017/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/cicids2017/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 2
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
device: auto

View File

@@ -0,0 +1,43 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicids2017_consistency_2026_04_25
packets_npz: /home/chy/JANUS/datasets/cicids2017/processed/packets.npz
flows_parquet: /home/chy/JANUS/datasets/cicids2017/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/cicids2017/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 2
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.1
lambda_packet: 0.1
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,43 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_ciciot2023_2026_04_29
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 10000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/baseline_ciciot2023_seed42
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/baseline_ciciot2023_seed43
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 43
data_seed: 43
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/baseline_ciciot2023_seed44
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 44
data_seed: 44
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_a_causal_ciciot2023_seed42
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode: causal_packets
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_a_causal_ciciot2023_seed43
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 43
data_seed: 43
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode: causal_packets
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_a_causal_ciciot2023_seed44
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 44
data_seed: 44
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
reference_mode: causal_packets
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,44 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_b_spectral_ciciot2023_seed42
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_spectral.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,44 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_b_spectral_ciciot2023_seed43
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_spectral.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 43
data_seed: 43
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,44 @@
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_b_spectral_ciciot2023_seed44
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_spectral.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 44
data_seed: 44
train_ratio: 0.8
benign_label: normal
val_cap: 10000
attack_cap: 20000
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,45 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_ciciot2023_shafir5_2026_04_29
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_shafir5.parquet
flow_feature_columns: ["HTTPS", "Protocol_Type", "Magnitude", "Variance", "fin_count"]
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: normal
val_cap: 10000
flow_dim: 5
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 0
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.3
lambda_packet: 0.3
packet_mask_ratio: 0.5
device: auto

View File

@@ -0,0 +1,39 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_iscxtor2016_2026_04_25
packets_npz: /home/chy/JANUS/datasets/iscxtor2016/processed/packets.npz
flows_parquet: /home/chy/JANUS/datasets/iscxtor2016/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/iscxtor2016/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: nontor
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 2
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
device: auto

View File

@@ -0,0 +1,41 @@
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_iscxtor2016_consistency_2026_04_25
packets_npz: /home/chy/JANUS/datasets/iscxtor2016/processed/packets.npz
flows_parquet: /home/chy/JANUS/datasets/iscxtor2016/processed/flows.parquet
flow_features_path: /home/chy/JANUS/datasets/iscxtor2016/processed/flow_features.parquet
flow_features_align: auto
T: 64
n_train: 10000
min_len: 2
packet_preprocess: mixed_dequant
seed: 42
data_seed: 42
train_ratio: 0.8
benign_label: nontor
d_model: 128
n_layers: 4
n_heads: 4
mlp_ratio: 4.0
time_dim: 64
token_dim:
batch_size: 256
num_workers: 2
epochs: 50
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
eval_every: 10
eval_n: 20000
eval_batch_size: 512
eval_n_steps: 8
sigma: 0.1
use_ot: true
lambda_flow: 0.1
lambda_packet: 0.1
packet_mask_ratio: 0.5
device: auto

275
Unified_CFM/data.py Normal file
View File

@@ -0,0 +1,275 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import sys as _sys
from pathlib import Path as _Path
_sys.path.insert(0, str(_Path(__file__).resolve().parents[1]))
from common.data_contract import PACKET_FEATURE_NAMES, PACKET_CONTINUOUS_CHANNEL_IDX as CONTINUOUS_CHANNEL_IDX, PACKET_BINARY_CHANNEL_IDX as BINARY_CHANNEL_IDX, canonical_5tuple as _canonical_key, fit_packet_stats as _fit_packet_stats, zscore as _zscore, apply_mixed_dequant as _apply_mixed_dequant
DEFAULT_FLOW_META_COLUMNS = {'flow_id', 'label', 'day', 'service', 'src_ip', 'dst_ip', 'src_port', 'dst_port', 'protocol', 'timestamp', 'start_ts', 'n_pkts'}
DERIVED_FLOW_FEATURE_NAMES = ('log_len', 'fwd_frac', 'bwd_frac', 'log_size_mean', 'log_size_std', 'log_size_min', 'log_size_max', 'log_dt_mean', 'log_dt_std', 'log_dt_max', 'syn_frac', 'fin_frac', 'rst_frac', 'psh_frac', 'ack_frac', 'log_win_mean')
@dataclass
class UnifiedData:
train_flow: np.ndarray
val_flow: np.ndarray
attack_flow: np.ndarray
train_packets: np.ndarray
val_packets: np.ndarray
attack_packets: np.ndarray
train_len: np.ndarray
val_len: np.ndarray
attack_len: np.ndarray
attack_labels: np.ndarray
packet_mean: np.ndarray
packet_std: np.ndarray
flow_mean: np.ndarray
flow_std: np.ndarray
packet_preprocess: str
flow_feature_names: tuple[str, ...]
packet_feature_names: tuple[str, ...] = PACKET_FEATURE_NAMES
@property
def T(self) -> int:
return int(self.train_packets.shape[1])
@property
def packet_dim(self) -> int:
return int(self.train_packets.shape[2])
@property
def flow_dim(self) -> int:
return int(self.train_flow.shape[1])
def _preprocess_packets(train_x: np.ndarray, val_x: np.ndarray, attack_x: np.ndarray, train_l: np.ndarray, val_l: np.ndarray, attack_l: np.ndarray, preprocess: str, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if preprocess not in ('zscore', 'mixed_dequant'):
raise ValueError("packet_preprocess must be 'zscore' or 'mixed_dequant'")
(mean, std) = _fit_packet_stats(train_x, train_l)
def prep(x: np.ndarray, l: np.ndarray, tag: str) -> np.ndarray:
if preprocess == 'zscore':
z = _zscore(x, mean, std)
mask = np.arange(x.shape[1])[None, :] < l[:, None]
return (z * mask[:, :, None]).astype(np.float32)
return _apply_mixed_dequant(x, l, mean, std, split_tag=tag, seed=seed)
return (prep(train_x, train_l, 'train'), prep(val_x, val_l, 'val'), prep(attack_x, attack_l, 'attack'), mean, std)
def _derive_flow_features(tokens: np.ndarray, lens: np.ndarray) -> np.ndarray:
(N, T, _) = tokens.shape
out = np.zeros((N, len(DERIVED_FLOW_FEATURE_NAMES)), dtype=np.float32)
for i in range(N):
n = int(max(lens[i], 1))
x = tokens[i, :n]
direction = x[:, 2]
size = x[:, 0]
dt = x[:, 1]
win = x[:, 8]
out[i, 0] = np.log1p(n)
out[i, 1] = np.mean(direction < 0.5)
out[i, 2] = np.mean(direction >= 0.5)
out[i, 3] = size.mean()
out[i, 4] = size.std()
out[i, 5] = size.min()
out[i, 6] = size.max()
out[i, 7] = dt.mean()
out[i, 8] = dt.std()
out[i, 9] = dt.max()
out[i, 10] = x[:, 3].mean()
out[i, 11] = x[:, 4].mean()
out[i, 12] = x[:, 5].mean()
out[i, 13] = x[:, 6].mean()
out[i, 14] = x[:, 7].mean()
out[i, 15] = win.mean()
return out
def _read_flow_features(path: Path, *, expected_rows: int, feature_columns: Optional[list[str]]=None) -> tuple[np.ndarray, tuple[str, ...], np.ndarray | None]:
path = Path(path)
if path.suffix == '.npz':
data = np.load(path, allow_pickle=True)
x = data['features'].astype(np.float32)
raw_names = data['feature_names'] if 'feature_names' in data.files else np.arange(x.shape[1])
names = tuple((str(v) for v in raw_names))
flow_id = data['flow_id'] if 'flow_id' in data.files else None
elif path.suffix in ('.parquet', '.pq'):
df = pd.read_parquet(path)
flow_id = df['flow_id'].to_numpy() if 'flow_id' in df.columns else None
if feature_columns:
cols = feature_columns
else:
cols = [c for c in df.columns if c not in DEFAULT_FLOW_META_COLUMNS and pd.api.types.is_numeric_dtype(df[c])]
if not cols:
raise ValueError(f'no numeric flow feature columns found in {path}')
x = df[cols].to_numpy(dtype=np.float32)
names = tuple(cols)
else:
raise ValueError(f'unsupported flow feature file: {path}')
if len(x) != expected_rows:
raise ValueError(f'flow feature row count {len(x):,} != packet row count {expected_rows:,}')
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
return (x, names, flow_id)
def _feature_columns_from_df(df: pd.DataFrame, requested: Optional[list[str]]) -> list[str]:
if requested:
return requested
return [c for c in df.columns if c not in DEFAULT_FLOW_META_COLUMNS and pd.api.types.is_numeric_dtype(df[c])]
def _align_flow_features_by_scan(feature_df: pd.DataFrame, packet_flows: pd.DataFrame, *, feature_columns: list[str]) -> tuple[np.ndarray, tuple[str, ...]]:
required = ['label', 'src_ip', 'src_port', 'dst_ip', 'dst_port', 'protocol']
missing_feature = [c for c in required if c not in feature_df.columns]
missing_packet = [c for c in required if c not in packet_flows.columns]
if missing_feature or missing_packet:
raise ValueError(f'scan alignment requires label + 5-tuple metadata. missing in feature_df={missing_feature}, packet_flows={missing_packet}')
packet_keys = [(str(lbl), _canonical_key(src, sp, dst, dp, proto)) for (lbl, src, sp, dst, dp, proto) in zip(packet_flows['label'].to_numpy(), packet_flows['src_ip'].to_numpy(), packet_flows['src_port'].to_numpy(), packet_flows['dst_ip'].to_numpy(), packet_flows['dst_port'].to_numpy(), packet_flows['protocol'].to_numpy())]
labels = feature_df['label'].to_numpy()
src_ip = feature_df['src_ip'].to_numpy()
src_port = feature_df['src_port'].to_numpy()
dst_ip = feature_df['dst_ip'].to_numpy()
dst_port = feature_df['dst_port'].to_numpy()
protocol = feature_df['protocol'].to_numpy()
matched: list[int] = []
j = 0
n_csv = len(feature_df)
for (i, target) in enumerate(packet_keys):
while j < n_csv:
cand = (str(labels[j]), _canonical_key(src_ip[j], src_port[j], dst_ip[j], dst_port[j], protocol[j]))
j += 1
if cand == target:
matched.append(j - 1)
break
else:
raise ValueError(f'failed to align packet flow row {i:,}/{len(packet_keys):,}; the CSV cache may not be the same one used for packet extraction')
print(f'[data] scan-aligned CSV flow features: matched={len(matched):,} from csv_rows={n_csv:,} skipped={matched[-1] + 1 - len(matched):,}')
x = feature_df.iloc[matched][feature_columns].to_numpy(dtype=np.float32)
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
return (x, tuple(feature_columns))
def _read_aligned_flow_features(path: Path, packet_flows: pd.DataFrame, *, feature_columns: Optional[list[str]]=None, align: str='auto') -> tuple[np.ndarray, tuple[str, ...]]:
path = Path(path)
if align not in ('auto', 'row', 'scan'):
raise ValueError("flow_features_align must be 'auto', 'row', or 'scan'")
if path.suffix == '.npz':
(x, names, flow_id) = _read_flow_features(path, expected_rows=len(packet_flows), feature_columns=feature_columns)
packet_id = packet_flows['flow_id'].to_numpy() if 'flow_id' in packet_flows else None
if flow_id is not None and packet_id is not None and (not np.array_equal(flow_id, packet_id)):
raise ValueError('NPZ flow_id does not align with Packet_CFM flows')
return (x, names)
if path.suffix not in ('.parquet', '.pq'):
raise ValueError(f'unsupported flow feature file: {path}')
feature_df = pd.read_parquet(path)
cols = _feature_columns_from_df(feature_df, feature_columns)
if not cols:
raise ValueError(f'no numeric flow feature columns found in {path}')
packet_id = packet_flows['flow_id'].to_numpy() if 'flow_id' in packet_flows else None
if len(feature_df) == len(packet_flows):
feature_id = feature_df['flow_id'].to_numpy() if 'flow_id' in feature_df.columns else None
if feature_id is None or packet_id is None or np.array_equal(feature_id, packet_id):
x = feature_df[cols].to_numpy(dtype=np.float32)
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
return (x, tuple(cols))
if align == 'row':
raise ValueError("flow_id mismatch with flow_features_align='row'")
if align == 'row':
raise ValueError(f'row alignment requested but feature rows={len(feature_df):,} packet rows={len(packet_flows):,}')
return _align_flow_features_by_scan(feature_df, packet_flows, feature_columns=cols)
def _preprocess_flow(train: np.ndarray, val: np.ndarray, attack: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
mean = train.mean(axis=0).astype(np.float32)
std = train.std(axis=0).astype(np.float32)
return (_zscore(train, mean, std), _zscore(val, mean, std), _zscore(attack, mean, std), mean, std)
def load_unified_data(*, packets_npz: Path | None=None, source_store: Path | None=None, flows_parquet: Path, flow_features_path: Path | None=None, flow_feature_columns: Optional[list[str]]=None, flow_features_align: str='auto', T: int=128, split_seed: int=42, train_ratio: float=0.8, benign_label: str='normal', min_len: int=2, packet_preprocess: str='mixed_dequant', attack_cap: int | None=None, val_cap: int | None=None) -> UnifiedData:
if (packets_npz is None) == (source_store is None):
raise ValueError('pass exactly one of packets_npz or source_store')
flows_parquet = Path(flows_parquet)
print(f'[data] flows={flows_parquet} packets_source={(packets_npz if packets_npz else source_store)}')
flow_cols = ['flow_id', 'label']
if flow_features_path is not None:
flow_cols += ['src_ip', 'src_port', 'dst_ip', 'dst_port', 'protocol']
flows = pd.read_parquet(flows_parquet, columns=flow_cols)
labels_full = flows['label'].to_numpy().astype(str)
flow_id = flows['flow_id'].to_numpy()
tokens_full: np.ndarray | None = None
store = None
if packets_npz is not None:
pz = np.load(Path(packets_npz))
tokens_full = pz['packet_tokens'].astype(np.float32)
lens_full = pz['packet_lengths'].astype(np.int32)
packet_flow_id = pz['flow_id'] if 'flow_id' in pz.files else None
if T > tokens_full.shape[1]:
raise ValueError(f'requested T={T} > stored T_full={tokens_full.shape[1]}')
tokens_full = tokens_full[:, :T].copy()
lens_full = np.minimum(lens_full, T).astype(np.int32)
if packet_flow_id is not None and (not np.array_equal(packet_flow_id, flow_id)):
raise ValueError('packets_npz and flows_parquet are not row-aligned by flow_id')
else:
if flow_features_path is None:
raise ValueError('source_store path requires flow_features_path (derived features need tokens in memory)')
from common.packet_store import PacketShardStore
store = PacketShardStore.open(Path(source_store))
store_flow_id = store.read_flows(columns=['flow_id'])['flow_id'].to_numpy()
if not np.array_equal(store_flow_id, flow_id):
raise ValueError('source_store and flows_parquet are not row-aligned by flow_id')
lens_full = np.minimum(store.manifest['packet_length'].to_numpy(dtype=np.int32), T)
if flow_features_path is None:
assert tokens_full is not None
flow_features = _derive_flow_features(tokens_full, lens_full)
flow_names = DERIVED_FLOW_FEATURE_NAMES
print(f'[data] using derived flow features D={flow_features.shape[1]}')
else:
(flow_features, flow_names) = _read_aligned_flow_features(Path(flow_features_path), flows, feature_columns=flow_feature_columns, align=flow_features_align)
print(f'[data] using external flow features D={flow_features.shape[1]}')
keep = lens_full >= min_len
labels = labels_full[keep]
flow_features = flow_features[keep]
lens = lens_full[keep]
global_idx = np.flatnonzero(keep).astype(np.int64)
if tokens_full is not None:
materialized_tokens = tokens_full[keep]
else:
materialized_tokens = None
print(f'[data] rows total={len(keep):,} keep len>={min_len}: {keep.sum():,}')
benign_local = np.where(labels == benign_label)[0]
attack_local = np.where(labels != benign_label)[0]
rng = np.random.default_rng(split_seed)
rng.shuffle(benign_local)
n_train = int(len(benign_local) * train_ratio)
train_local = benign_local[:n_train]
val_local = benign_local[n_train:]
if val_cap is not None and len(val_local) > val_cap:
val_local = np.sort(rng.choice(val_local, size=val_cap, replace=False))
if attack_cap is not None and len(attack_local) > attack_cap:
attack_local = np.sort(rng.choice(attack_local, size=attack_cap, replace=False))
print(f'[data] benign={len(benign_local):,} attack={len(attack_local):,} -> train={len(train_local):,} val={len(val_local):,}')
def _materialize(local_indices: np.ndarray) -> np.ndarray:
if materialized_tokens is not None:
return materialized_tokens[local_indices].astype(np.float32, copy=False)
assert store is not None
g = global_idx[local_indices]
(tok, _) = store.read_packets(g.astype(np.int64), T=T)
return tok.astype(np.float32, copy=False)
tr_p_raw = _materialize(train_local)
va_p_raw = _materialize(val_local)
at_p_raw = _materialize(attack_local)
tr_l = lens[train_local]
va_l = lens[val_local]
at_l = lens[attack_local]
tr_f_raw = flow_features[train_local]
va_f_raw = flow_features[val_local]
at_f_raw = flow_features[attack_local]
train_idx = train_local
val_idx = val_local
attack_idx = attack_local
(tr_p, va_p, at_p, p_mean, p_std) = _preprocess_packets(tr_p_raw, va_p_raw, at_p_raw, tr_l, va_l, at_l, preprocess=packet_preprocess, seed=split_seed)
(tr_f, va_f, at_f, f_mean, f_std) = _preprocess_flow(tr_f_raw, va_f_raw, at_f_raw)
return UnifiedData(train_flow=tr_f, val_flow=va_f, attack_flow=at_f, train_packets=tr_p, val_packets=va_p, attack_packets=at_p, train_len=tr_l, val_len=va_l, attack_len=at_l, attack_labels=labels[attack_idx], packet_mean=p_mean, packet_std=p_std, flow_mean=f_mean, flow_std=f_std, packet_preprocess=packet_preprocess, flow_feature_names=tuple(flow_names))
def subsample_train(data: UnifiedData, n_train: int, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if n_train <= 0 or n_train >= len(data.train_flow):
return (data.train_flow, data.train_packets, data.train_len)
rng = np.random.default_rng(seed)
idx = rng.choice(len(data.train_flow), n_train, replace=False)
idx.sort()
return (data.train_flow[idx], data.train_packets[idx], data.train_len[idx])

588
Unified_CFM/model.py Normal file
View File

@@ -0,0 +1,588 @@
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torchdiffeq import odeint
@torch.no_grad()
def _sinkhorn_coupling(C: torch.Tensor, reg: float=0.05, n_iter: int=20) -> torch.Tensor:
C = C.float()
log_k = -C / reg
B = C.shape[0]
log_u = torch.zeros(B, device=C.device)
log_v = torch.zeros(B, device=C.device)
for _ in range(n_iter):
log_v = -torch.logsumexp(log_k + log_u.unsqueeze(1), dim=0)
log_u = -torch.logsumexp(log_k + log_v.unsqueeze(0), dim=1)
log_p = log_u.unsqueeze(1) + log_k + log_v.unsqueeze(0)
return log_p.argmax(dim=1)
class SinusoidalTimeEmb(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
if dim % 2 != 0:
raise ValueError('time embedding dimension must be even')
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half - 1, 1))
args = t[:, None] * freqs[None, :]
return torch.cat([args.sin(), args.cos()], dim=-1)
class AdaLNBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float, cond_dim: int) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
hidden = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, d_model))
self.cond_proj = nn.Linear(cond_dim, 6 * d_model)
nn.init.zeros_(self.cond_proj.weight)
nn.init.zeros_(self.cond_proj.bias)
@staticmethod
def _modulate(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
return x * (1.0 + gamma[:, None, :]) + beta[:, None, :]
def forward(self, x: torch.Tensor, cond: torch.Tensor, key_padding_mask: torch.Tensor | None, attn_mask: torch.Tensor | None=None) -> torch.Tensor:
(g1, b1, a1, g2, b2, a2) = self.cond_proj(cond).chunk(6, dim=-1)
h = self._modulate(self.norm1(x), g1, b1)
(attn_out, _) = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
x = x + a1[:, None, :] * attn_out
h = self._modulate(self.norm2(x), g2, b2)
return x + a2[:, None, :] * self.mlp(h)
class UnifiedVelocity(nn.Module):
def __init__(self, token_dim: int, seq_len: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None:
super().__init__()
if reference_mode not in (None, 'independent_token', 'block_diagonal', 'causal_packets', 'causal_all'):
raise ValueError(f'unknown reference_mode={reference_mode!r}')
self.token_dim = token_dim
self.seq_len = seq_len
self.reference_mode = reference_mode
self.input_proj = nn.Linear(token_dim, d_model)
self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
self.type_emb = nn.Embedding(2, d_model)
nn.init.trunc_normal_(self.pos_emb, std=0.02)
nn.init.normal_(self.type_emb.weight, std=0.02)
self.time_emb = SinusoidalTimeEmb(time_dim)
self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model))
self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)])
self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False)
self.out = nn.Linear(d_model, token_dim)
nn.init.zeros_(self.out.weight)
nn.init.zeros_(self.out.bias)
type_ids = torch.ones(seq_len, dtype=torch.long)
type_ids[0] = 0
self.register_buffer('type_ids', type_ids, persistent=False)
def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None, attn_mask_override: torch.Tensor | None=None) -> torch.Tensor:
(B, L, _) = x.shape
if L > self.seq_len:
raise ValueError(f'sequence length {L} exceeds configured {self.seq_len}')
if t.dim() == 0:
t = t.expand(B)
h = self.input_proj(x)
h = h + self.pos_emb[:, :L, :]
h = h + self.type_emb(self.type_ids[:L])[None, :, :]
cond = self.cond_mlp(self.time_emb(t))
if attn_mask_override is not None:
attn_mask = attn_mask_override
else:
attn_mask = self._reference_attn_mask(L, x.device)
for block in self.blocks:
h = block(h, cond, key_padding_mask, attn_mask=attn_mask)
return self.out(self.out_norm(h))
def _reference_attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None:
if self.reference_mode is None:
return None
if self.reference_mode == 'independent_token':
return ~torch.eye(L, dtype=torch.bool, device=device)
if self.reference_mode == 'block_diagonal':
mask = torch.ones((L, L), dtype=torch.bool, device=device)
mask[0, 0] = False
if L > 1:
mask[1:, 1:] = False
return mask
if self.reference_mode == 'causal_packets':
mask = torch.zeros((L, L), dtype=torch.bool, device=device)
if L > 1:
packet_causal = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1)
mask[1:, 1:] = packet_causal
return mask
if self.reference_mode == 'causal_all':
return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
raise AssertionError(self.reference_mode)
@dataclass
class UnifiedCFMConfig:
T: int = 128
packet_dim: int = 9
flow_dim: int = 16
token_dim: int | None = None
d_model: int = 128
n_layers: int = 4
n_heads: int = 4
mlp_ratio: float = 4.0
time_dim: int = 64
sigma: float = 0.1
use_ot: bool = False
reference_mode: str | None = None
class UnifiedTokenCFM(nn.Module):
def __init__(self, cfg: UnifiedCFMConfig) -> None:
super().__init__()
self.cfg = cfg
self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cfg.packet_dim)
if self.token_dim < 1 + max(cfg.flow_dim, cfg.packet_dim):
raise ValueError('token_dim is too small for flow_dim/packet_dim')
self.seq_len = cfg.T + 1
self.velocity = UnifiedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode)
def build_tokens(self, flow: torch.Tensor, packets: torch.Tensor) -> torch.Tensor:
(B, T, Dp) = packets.shape
if T != self.cfg.T:
raise ValueError(f'packet T={T} but config T={self.cfg.T}')
if Dp != self.cfg.packet_dim:
raise ValueError(f'packet_dim={Dp} but config packet_dim={self.cfg.packet_dim}')
if flow.shape[-1] != self.cfg.flow_dim:
raise ValueError(f'flow_dim={flow.shape[-1]} but config flow_dim={self.cfg.flow_dim}')
z = packets.new_zeros((B, T + 1, self.token_dim))
z[:, 0, 0] = -1.0
z[:, 0, 1:1 + self.cfg.flow_dim] = flow
z[:, 1:, 0] = 1.0
z[:, 1:, 1:1 + self.cfg.packet_dim] = packets
return z
def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor:
B = lens.shape[0]
idx = torch.arange(self.cfg.T, device=lens.device)[None, :]
packet_real = idx < lens[:, None]
real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1)
return ~real
def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor:
return (~self.key_padding_mask(lens)).float()
@staticmethod
def _masked_trimmed_mean(values: torch.Tensor, mask: torch.Tensor, trim_frac: float=0.1) -> torch.Tensor:
out = values.new_zeros(values.shape[0])
for i in range(values.shape[0]):
v = values[i][mask[i] > 0]
if v.numel() == 0:
continue
if v.numel() < 5:
out[i] = v.mean()
continue
v_sorted = torch.sort(v).values
lo = int(trim_frac * v_sorted.numel())
hi = int((1.0 - trim_frac) * v_sorted.numel())
if hi <= lo:
out[i] = v_sorted.mean()
else:
out[i] = v_sorted[lo:hi].mean()
return out
@staticmethod
def _masked_median(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = values.new_zeros(values.shape[0])
for i in range(values.shape[0]):
v = values[i][mask[i] > 0]
if v.numel() == 0:
continue
v_sorted = torch.sort(v).values
mid = v_sorted.numel() // 2
if v_sorted.numel() % 2:
out[i] = v_sorted[mid]
else:
out[i] = 0.5 * (v_sorted[mid - 1] + v_sorted[mid])
return out
def compute_loss(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, *, lambda_flow: float=0.0, lambda_packet: float=0.0, packet_mask_ratio: float=0.5, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]:
x1 = self.build_tokens(flow, packets)
B = x1.shape[0]
x0 = torch.randn_like(x1)
mask = self._loss_mask(lens)
kpm = mask == 0
if self.cfg.use_ot:
flat0 = (x0 * mask[:, :, None]).reshape(B, -1)
flat1 = (x1 * mask[:, :, None]).reshape(B, -1)
col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float()))
x1 = x1[col]
flow = flow[col]
packets = packets[col]
lens = lens[col]
mask = self._loss_mask(lens)
kpm = mask == 0
t = torch.rand(B, device=x1.device)
x_t = (1.0 - t[:, None, None]) * x0 + t[:, None, None] * x1
if self.cfg.sigma > 0:
std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
x_t = x_t + std * torch.randn_like(x_t)
target = x1 - x0
pred = self.velocity(x_t, t, key_padding_mask=kpm)
sq = (pred - target).square().mean(dim=-1)
per_sample = (sq * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
main_loss = per_sample.mean()
aux_flow_loss = x1.new_zeros(())
aux_packet_loss = x1.new_zeros(())
if lambda_flow > 0.0:
x_t_mf = x_t.clone()
x_t_mf[:, 0, :] = 0.0
pred_mf = self.velocity(x_t_mf, t, key_padding_mask=kpm)
err = (pred_mf[:, 0] - target[:, 0]).square().mean(dim=-1)
aux_flow_loss = err.mean()
if lambda_packet > 0.0:
packet_real = mask[:, 1:] > 0
rand_draw = torch.rand(packet_real.shape, device=x1.device)
mask_pkt = (rand_draw < packet_mask_ratio) & packet_real
pkt_mask_full = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x1.device), mask_pkt], dim=1)
x_t_mp = x_t.clone()
x_t_mp[pkt_mask_full] = 0.0
pred_mp = self.velocity(x_t_mp, t, key_padding_mask=kpm)
sq_mp = (pred_mp - target).square().mean(dim=-1)
mask_f = pkt_mask_full.float()
denom = mask_f.sum(dim=-1).clamp_min(1.0)
aux_packet_loss = ((sq_mp * mask_f).sum(dim=-1) / denom).mean()
total = main_loss + lambda_flow * aux_flow_loss + lambda_packet * aux_packet_loss
if return_components:
return {'total': total, 'main': main_loss.detach(), 'aux_flow': aux_flow_loss.detach(), 'aux_packet': aux_packet_loss.detach()}
return total
@torch.no_grad()
def velocity_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5, 0.75, 1.0)) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
total = torch.zeros(x.shape[0], device=x.device)
flow_s = torch.zeros_like(total)
packet_s = torch.zeros_like(total)
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
for t_val in t_eval:
t = torch.full((x.shape[0],), float(t_val), device=x.device)
v = self.velocity(x, t, key_padding_mask=kpm)
e = v.square().mean(dim=-1)
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
flow_s = flow_s + e[:, 0]
packet_s = packet_s + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
denom = float(len(t_eval))
return {'velocity_total': total / denom, 'velocity_flow': flow_s / denom, 'velocity_packet': packet_s / denom}
@torch.no_grad()
def trajectory_metrics(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]:
z = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
B = z.shape[0]
dt = 1.0 / n_steps
total_arc = torch.zeros(B, device=z.device)
total_ke = torch.zeros(B, device=z.device)
flow_ke = torch.zeros(B, device=z.device)
packet_ke = torch.zeros(B, device=z.device)
total_curv = torch.zeros(B, device=z.device)
flow_curv = torch.zeros(B, device=z.device)
packet_curv = torch.zeros(B, device=z.device)
packet_kappa2_speed2 = torch.zeros(B, max(z.shape[1] - 1, 0), device=z.device)
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
v_prev = None
v_prev_norm = None
for k in range(n_steps):
t_val = 1.0 - k * dt
t = torch.full((B,), t_val, device=z.device)
v = self.velocity(z, t, key_padding_mask=kpm)
e = v.square().mean(dim=-1)
v_norm = v.square().sum(dim=-1).clamp_min(1e-12).sqrt()
total_ke = total_ke + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) * dt
flow_ke = flow_ke + e[:, 0] * dt
packet_ke = packet_ke + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count * dt
if v_prev is not None:
dv = v - v_prev
dve = dv.square().mean(dim=-1)
total_curv = total_curv + (dve * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
flow_curv = flow_curv + dve[:, 0]
packet_curv = packet_curv + (dve[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
dv2_sum = dv[:, 1:].square().sum(dim=-1)
assert v_prev_norm is not None
v_avg = 0.5 * (v_norm[:, 1:] + v_prev_norm[:, 1:])
packet_kappa2_speed2 = packet_kappa2_speed2 + dv2_sum / v_avg.square().clamp_min(1e-06)
v_prev = v
v_prev_norm = v_norm
z_new = z - v * dt
dz = (z_new - z) * mask[:, :, None]
total_arc = total_arc + dz.reshape(B, -1).norm(dim=-1) / mask.sum(dim=-1).sqrt()
z = z_new
z_masked = z * mask[:, :, None]
terminal = z_masked.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
terminal_flow = z[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
terminal_packet = (z[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
packet_mask = mask[:, 1:]
kappa2_speed2_mean = (packet_kappa2_speed2 * packet_mask).sum(dim=-1) / packet_count
kappa2_speed2_median = self._masked_median(packet_kappa2_speed2, packet_mask)
kappa2_speed2_trimmed = self._masked_trimmed_mean(packet_kappa2_speed2, packet_mask)
return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet, 'arc_length': total_arc, 'kinetic_energy': total_ke, 'kinetic_flow': flow_ke, 'kinetic_packet': packet_ke, 'curvature_total': total_curv, 'curvature_flow': flow_curv, 'curvature_packet': packet_curv, 'kappa2_speed2norm_packet_mean': kappa2_speed2_mean, 'kappa2_speed2norm_packet_median': kappa2_speed2_median, 'kappa2_speed2norm_packet_trimmed10_mean': kappa2_speed2_trimmed}
@torch.no_grad()
def score_profile_vt(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0)) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
out: dict[str, torch.Tensor] = {}
for t_val in t_eval:
t = torch.full((x.shape[0],), float(t_val), device=x.device)
v = self.velocity(x, t, key_padding_mask=kpm)
e = v.square().mean(dim=-1)
tag = f't{int(round(t_val * 10)):02d}'
out[f'velocity_total_{tag}'] = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
out[f'velocity_flow_{tag}'] = e[:, 0]
out[f'velocity_packet_{tag}'] = (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
return out
@torch.no_grad()
def consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
B = x.shape[0]
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
t = torch.full((B,), float(t_eval), device=x.device)
v_full = self.velocity(x, t, key_padding_mask=kpm)
x_mf = x.clone()
x_mf[:, 0, :] = 0.0
v_mf = self.velocity(x_mf, t, key_padding_mask=kpm)
flow_cons = (v_full[:, 0] - v_mf[:, 0]).square().mean(dim=-1)
x_mp = x.clone()
pkt_mask_full = mask[:, 1:] > 0
idx_pkt_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x.device), pkt_mask_full], dim=1)
x_mp[idx_pkt_mask] = 0.0
v_mp = self.velocity(x_mp, t, key_padding_mask=kpm)
diff = (v_full - v_mp).square().mean(dim=-1)
packet_cons = (diff[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
return {'flow_consistency': flow_cons, 'packet_consistency': packet_cons, 'consistency_total': flow_cons + packet_cons}
def jacobian_hutchinson(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5,), n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
B = x.shape[0]
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
total = torch.zeros(B, device=x.device)
flow_j = torch.zeros(B, device=x.device)
packet_j = torch.zeros(B, device=x.device)
n_draws = n_eps * len(t_eval)
for t_val in t_eval:
t_current = torch.full((B,), float(t_val), device=x.device)
for _ in range(n_eps):
x_req = x.detach().clone().requires_grad_(True)
v = self.velocity(x_req, t_current, key_padding_mask=kpm)
eps = torch.randn(v.shape, device=v.device, generator=generator)
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
e = g.square().mean(dim=-1)
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
flow_j = flow_j + e[:, 0]
packet_j = packet_j + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
return {'jacobian_total': (total / n_draws).detach(), 'jacobian_flow': (flow_j / n_draws).detach(), 'jacobian_packet': (packet_j / n_draws).detach()}
@torch.no_grad()
def pna_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, flow_masked: bool=False) -> dict[str, torch.Tensor]:
eps_v2 = 1e-06
dt = 1.0 / n_steps
z = self.build_tokens(flow, packets)
if flow_masked:
z = z.clone()
z[:, 0, :] = 0.0
mask = self._loss_mask(lens)
kpm = mask == 0
(B, L, _) = z.shape
pna = torch.zeros(B, L, device=z.device)
v_prev: torch.Tensor | None = None
v_norm_prev: torch.Tensor | None = None
for k in range(n_steps):
t_val = 1.0 - k * dt
t = torch.full((B,), t_val, device=z.device)
v = self.velocity(z, t, key_padding_mask=kpm)
v_norm = (v.square().sum(dim=-1) + 1e-12).sqrt()
if v_prev is not None:
dv2 = (v - v_prev).square().sum(dim=-1)
v_avg2 = (0.5 * (v_norm + v_norm_prev)).square().clamp_min(eps_v2)
pna = pna + dv2 / v_avg2
v_prev = v
v_norm_prev = v_norm
z = z - v * dt
if flow_masked:
z[:, 0, :] = 0.0
flow_pna = pna[:, 0]
packet_pna = pna[:, 1:]
packet_mask = mask[:, 1:]
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
pna_median = self._masked_median(packet_pna, packet_mask)
pna_mean = (packet_pna * packet_mask).sum(dim=-1) / packet_count
masked_for_max = packet_pna.masked_fill(packet_mask == 0, float('-inf'))
pna_max = masked_for_max.max(dim=-1).values
pna_trimmed = self._masked_trimmed_mean(packet_pna, packet_mask)
return {'pna_packet_median': pna_median, 'pna_packet_mean': pna_mean, 'pna_packet_max': pna_max, 'pna_packet_trimmed10_mean': pna_trimmed, 'pna_flow': flow_pna}
@torch.no_grad()
def causal_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
(B, L, _) = x.shape
t = torch.full((B,), float(t_eval), device=x.device)
v_full = self.velocity(x, t, key_padding_mask=kpm)
causal = torch.triu(torch.ones(L, L, dtype=torch.bool, device=x.device), diagonal=1)
v_causal = self.velocity(x, t, key_padding_mask=kpm, attn_mask_override=causal)
diff = (v_full - v_causal).square().mean(dim=-1)
flow_surprisal = diff[:, 0]
packet_diff = diff[:, 1:]
packet_mask = mask[:, 1:]
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
packet_mean = (packet_diff * packet_mask).sum(dim=-1) / packet_count
packet_median = self._masked_median(packet_diff, packet_mask)
masked_for_max = packet_diff.masked_fill(packet_mask == 0, float('-inf'))
packet_max = masked_for_max.max(dim=-1).values
packet_trimmed = self._masked_trimmed_mean(packet_diff, packet_mask)
total = (diff * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
return {'causal_surprisal_total': total, 'causal_surprisal_flow': flow_surprisal, 'causal_surprisal_packet_mean': packet_mean, 'causal_surprisal_packet_median': packet_median, 'causal_surprisal_packet_max': packet_max, 'causal_surprisal_packet_trimmed10_mean': packet_trimmed}
@torch.no_grad()
def direction_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.2, 0.4, 0.6, 0.8, 1.0)) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
(B, L, _) = x.shape
t_eval = tuple(t_eval)
if len(t_eval) < 2:
raise ValueError('direction_consistency_score needs >=2 t values')
prev_v: torch.Tensor | None = None
drift = x.new_zeros(B, L)
n_pairs = len(t_eval) - 1
for t_val in t_eval:
t = torch.full((B,), float(t_val), device=x.device)
v = self.velocity(x, t, key_padding_mask=kpm)
if prev_v is not None:
num = (prev_v * v).sum(dim=-1)
denom = prev_v.norm(dim=-1).clamp_min(1e-08) * v.norm(dim=-1).clamp_min(1e-08)
cos = num / denom
drift = drift + (1.0 - cos)
prev_v = v
drift = drift / max(n_pairs, 1)
flow_drift = drift[:, 0]
packet_drift = drift[:, 1:]
packet_mask = mask[:, 1:]
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
packet_mean = (packet_drift * packet_mask).sum(dim=-1) / packet_count
packet_median = self._masked_median(packet_drift, packet_mask)
masked_for_max = packet_drift.masked_fill(packet_mask == 0, float('-inf'))
packet_max = masked_for_max.max(dim=-1).values
packet_trimmed = self._masked_trimmed_mean(packet_drift, packet_mask)
total = (drift * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
return {'direction_drift_total': total, 'direction_drift_flow': flow_drift, 'direction_drift_packet_mean': packet_mean, 'direction_drift_packet_median': packet_median, 'direction_drift_packet_max': packet_max, 'direction_drift_packet_trimmed10_mean': packet_trimmed}
def inverse_flow_nll_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, n_eps: int=4, compute_divergence: bool=True, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
z = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
(B, L, D) = z.shape
dt = 1.0 / n_steps
accum_div = torch.zeros(B, device=z.device)
if compute_divergence:
for k in range(n_steps):
t_val = 1.0 - k * dt
t = torch.full((B,), t_val, device=z.device)
z_req = z.detach().clone().requires_grad_(True)
v = self.velocity(z_req, t, key_padding_mask=kpm)
div_step = torch.zeros(B, device=z.device)
for j in range(n_eps):
eps = torch.randn_like(v)
eps_masked = eps * mask[:, :, None]
retain = j < n_eps - 1
(g,) = torch.autograd.grad(outputs=v, inputs=z_req, grad_outputs=eps_masked, retain_graph=retain, create_graph=False)
div_step = div_step + (eps_masked * g).sum(dim=(1, 2))
div_step = div_step / float(n_eps)
accum_div = accum_div + div_step * dt
with torch.no_grad():
z = (z_req - v * dt).detach()
else:
with torch.no_grad():
for k in range(n_steps):
t_val = 1.0 - k * dt
t = torch.full((B,), t_val, device=z.device)
v = self.velocity(z, t, key_padding_mask=kpm)
z = z - v * dt
with torch.no_grad():
z_masked = z * mask[:, :, None]
n_real = mask.sum(dim=-1).clamp_min(1.0)
x0_quadratic = z_masked.reshape(B, -1).square().sum(dim=-1) / (n_real * float(D))
nll_x0_only = x0_quadratic
nll_div_only = accum_div / (n_real * float(D))
nll_full = nll_x0_only + nll_div_only
return {'nll_x0_only': nll_x0_only.detach(), 'nll_div_only': nll_div_only.detach(), 'nll_full': nll_full.detach()}
def jacobian_spectral_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
x = self.build_tokens(flow, packets)
mask = self._loss_mask(lens)
kpm = mask == 0
(B, L, D) = x.shape
t = torch.full((B,), float(t_eval), device=x.device)
packet_mask = mask[:, 1:]
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
norms_total: list[torch.Tensor] = []
norms_flow: list[torch.Tensor] = []
norms_packet: list[torch.Tensor] = []
for _ in range(n_eps):
x_req = x.detach().clone().requires_grad_(True)
v = self.velocity(x_req, t, key_padding_mask=kpm)
eps = torch.randn(v.shape, device=v.device, generator=generator)
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
e = g.square().mean(dim=-1)
n_total = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
n_flow = e[:, 0]
n_packet = (e[:, 1:] * packet_mask).sum(dim=-1) / packet_count
norms_total.append(n_total.detach())
norms_flow.append(n_flow.detach())
norms_packet.append(n_packet.detach())
def _spectral_summary(samples: list[torch.Tensor]) -> dict[str, torch.Tensor]:
stack = torch.stack(samples, dim=1)
mean = stack.mean(dim=1).clamp_min(1e-12)
mx = stack.max(dim=1).values
mn = stack.min(dim=1).values
logfro = torch.log(mean)
aniso = mx / mean
min_over_max = mn / mx.clamp_min(1e-12)
p = stack / stack.sum(dim=1, keepdim=True).clamp_min(1e-12)
entropy = -(p * p.clamp_min(1e-12).log()).sum(dim=1)
eff_rank = torch.exp(entropy)
return {'logfro': logfro, 'anisotropy': aniso, 'min_over_max': min_over_max, 'eff_rank': eff_rank}
out: dict[str, torch.Tensor] = {}
for (tag, samples) in (('total', norms_total), ('flow', norms_flow), ('packet', norms_packet)):
summ = _spectral_summary(samples)
for (stat_name, val) in summ.items():
out[f'jac_{stat_name}_{tag}'] = val
return out
@torch.no_grad()
def sample(self, n: int, lens: torch.Tensor, device: torch.device, n_steps: int=50, method: str='euler') -> torch.Tensor:
z = torch.randn(n, self.seq_len, self.token_dim, device=device)
ts = torch.linspace(0.0, 1.0, n_steps + 1, device=device)
kpm = self.key_padding_mask(lens.to(device))
def f(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return self.velocity(x, t.expand(x.shape[0]), key_padding_mask=kpm)
if method == 'euler':
for i in range(n_steps):
z = z + f(ts[i], z) * (ts[i + 1] - ts[i])
return z
return odeint(f, z, ts, method=method)[-1]
def param_count(self) -> int:
return sum((p.numel() for p in self.parameters()))

View File

@@ -0,0 +1,157 @@
import sys
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from model import UnifiedCFMConfig, UnifiedTokenCFM
def _build_model():
return UnifiedTokenCFM(UnifiedCFMConfig(T=4, packet_dim=3, flow_dim=5, d_model=16, n_layers=1, n_heads=4, time_dim=8))
def _build_reference_model(reference_mode: str):
return UnifiedTokenCFM(UnifiedCFMConfig(T=4, packet_dim=3, flow_dim=5, d_model=16, n_layers=1, n_heads=4, time_dim=8, reference_mode=reference_mode))
def _sample_batch(seed: int=0):
torch.manual_seed(seed)
flow = torch.randn(2, 5)
packets = torch.randn(2, 4, 3)
lens = torch.tensor([4, 2])
return (flow, packets, lens)
def test_unified_cfm_shapes_and_scores():
model = _build_model()
(flow, packets, lens) = _sample_batch()
tokens = model.build_tokens(flow, packets)
assert tokens.shape == (2, 5, 6)
loss = model.compute_loss(flow, packets, lens)
assert loss.ndim == 0
assert torch.isfinite(loss)
traj = model.trajectory_metrics(flow, packets, lens, n_steps=2)
assert 'terminal_norm' in traj
assert traj['terminal_norm'].shape == (2,)
vel = model.velocity_score(flow, packets, lens)
assert set(vel) == {'velocity_total', 'velocity_flow', 'velocity_packet'}
def test_reference_mode_independent_token_shapes_and_scores():
model = _build_reference_model('independent_token')
(flow, packets, lens) = _sample_batch(seed=9)
loss = model.compute_loss(flow, packets, lens)
assert loss.ndim == 0
assert torch.isfinite(loss)
traj = model.trajectory_metrics(flow, packets, lens, n_steps=2)
assert traj['terminal_norm'].shape == (2,)
assert torch.all(torch.isfinite(traj['curvature_packet']))
def test_reference_mode_block_diagonal_shapes_and_scores():
model = _build_reference_model('block_diagonal')
(flow, packets, lens) = _sample_batch(seed=10)
loss = model.compute_loss(flow, packets, lens)
assert loss.ndim == 0
assert torch.isfinite(loss)
vel = model.velocity_score(flow, packets, lens)
assert set(vel) == {'velocity_total', 'velocity_flow', 'velocity_packet'}
def test_trajectory_curvature_keys_and_shapes():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=1)
traj = model.trajectory_metrics(flow, packets, lens, n_steps=4)
for key in ('curvature_total', 'curvature_flow', 'curvature_packet'):
assert key in traj, f'missing {key}'
assert traj[key].shape == (2,)
assert torch.all(torch.isfinite(traj[key]))
assert torch.all(traj[key] >= 0)
def test_trajectory_curvature_zero_with_one_step():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=2)
traj = model.trajectory_metrics(flow, packets, lens, n_steps=1)
for key in ('curvature_total', 'curvature_flow', 'curvature_packet'):
assert traj[key].abs().sum().item() == 0.0
def test_speed_normalized_packet_curvature_scores():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=11)
traj = model.trajectory_metrics(flow, packets, lens, n_steps=4)
keys = ('kappa2_speed2norm_packet_mean', 'kappa2_speed2norm_packet_median', 'kappa2_speed2norm_packet_trimmed10_mean')
for key in keys:
assert key in traj, f'missing {key}'
assert traj[key].shape == (2,)
assert torch.all(torch.isfinite(traj[key]))
assert torch.all(traj[key] >= 0)
one_step = model.trajectory_metrics(flow, packets, lens, n_steps=1)
for key in keys:
assert one_step[key].abs().sum().item() == 0.0
def test_score_profile_vt_shapes():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=3)
t_eval = (0.1, 0.3, 0.5, 0.7, 0.9, 1.0)
prof = model.score_profile_vt(flow, packets, lens, t_eval=t_eval)
assert len(prof) == 3 * len(t_eval)
for (k, v) in prof.items():
assert v.shape == (2,), k
assert torch.all(torch.isfinite(v))
assert torch.all(v >= 0)
assert 'velocity_total_t05' in prof
assert 'velocity_flow_t10' in prof
assert 'velocity_packet_t01' in prof
def test_compute_loss_backward_compat():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=5)
torch.manual_seed(0)
a = model.compute_loss(flow, packets, lens)
torch.manual_seed(0)
b = model.compute_loss(flow, packets, lens, lambda_flow=0.0, lambda_packet=0.0)
assert torch.allclose(a, b), f'λ=0 must match old loss; got {a.item()} vs {b.item()}'
def test_compute_loss_aux_components_finite():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=6)
torch.manual_seed(7)
comp = model.compute_loss(flow, packets, lens, lambda_flow=0.1, lambda_packet=0.1, return_components=True)
assert set(comp) == {'total', 'main', 'aux_flow', 'aux_packet'}
for (k, v) in comp.items():
assert torch.isfinite(v), k
assert v >= 0, f'{k} negative: {v.item()}'
def test_compute_loss_aux_affects_gradient():
model = _build_model()
with torch.no_grad():
model.velocity.out.weight.normal_(std=0.01)
for block in model.velocity.blocks:
block.cond_proj.weight.normal_(std=0.01)
(flow, packets, lens) = _sample_batch(seed=8)
torch.manual_seed(10)
total = model.compute_loss(flow, packets, lens, lambda_flow=1.0, lambda_packet=1.0)
total.backward()
some_grad = False
for p in model.parameters():
if p.grad is not None and p.grad.abs().sum().item() > 0:
some_grad = True
break
assert some_grad, 'no gradient flowed through aux losses'
def test_consistency_score_shapes():
model = _build_model()
(flow, packets, lens) = _sample_batch(seed=9)
cs = model.consistency_score(flow, packets, lens)
assert set(cs) == {'flow_consistency', 'packet_consistency', 'consistency_total'}
for (k, v) in cs.items():
assert v.shape == (2,), k
assert torch.all(torch.isfinite(v))
assert torch.all(v >= 0), k
def test_jacobian_hutchinson_shapes_and_nonneg():
model = _build_model()
with torch.no_grad():
model.velocity.out.weight.normal_(std=0.01)
for block in model.velocity.blocks:
block.cond_proj.weight.normal_(std=0.01)
(flow, packets, lens) = _sample_batch(seed=4)
gen = torch.Generator().manual_seed(42)
jac = model.jacobian_hutchinson(flow, packets, lens, t_eval=(0.5,), n_eps=2, generator=gen)
assert set(jac) == {'jacobian_total', 'jacobian_flow', 'jacobian_packet'}
for (k, v) in jac.items():
assert v.shape == (2,), k
assert torch.all(torch.isfinite(v))
assert torch.all(v >= 0), f'{k} has negative value'

147
Unified_CFM/train.py Normal file
View File

@@ -0,0 +1,147 @@
from __future__ import annotations
import argparse
import json
import time
from dataclasses import asdict
from pathlib import Path
from typing import Any
import numpy as np
import torch
import yaml
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, TensorDataset
from data import UnifiedData, load_unified_data, subsample_train
from model import UnifiedCFMConfig, UnifiedTokenCFM
def _device(dev_arg: str) -> torch.device:
if dev_arg == 'auto':
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
return torch.device(dev_arg)
def _batch_score(model: UnifiedTokenCFM, flow_np: np.ndarray, packet_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int) -> dict[str, np.ndarray]:
out: dict[str, list[np.ndarray]] = {}
model.eval()
for start in range(0, len(flow_np), batch_size):
sl = slice(start, start + batch_size)
flow = torch.from_numpy(flow_np[sl]).float().to(device)
packets = torch.from_numpy(packet_np[sl]).float().to(device)
lens = torch.from_numpy(len_np[sl]).long().to(device)
metrics = model.trajectory_metrics(flow, packets, lens, n_steps=n_steps)
vel = model.velocity_score(flow, packets, lens)
metrics.update(vel)
for (k, v) in metrics.items():
out.setdefault(k, []).append(v.detach().cpu().numpy())
return {k: np.concatenate(v, axis=0) for (k, v) in out.items()}
def _quick_eval(model: UnifiedTokenCFM, data: UnifiedData, device: torch.device, cfg: dict[str, Any]) -> dict[str, float]:
n_eval = int(cfg.get('eval_n', 2000))
rng = np.random.default_rng(0)
def pick(n: int) -> np.ndarray:
m = min(n_eval, n)
return rng.choice(n, m, replace=False)
vi = pick(len(data.val_flow))
ai = pick(len(data.attack_flow))
v = _batch_score(model, data.val_flow[vi], data.val_packets[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
a = _batch_score(model, data.attack_flow[ai], data.attack_packets[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
y = np.concatenate([np.zeros(len(vi)), np.ones(len(ai))])
result: dict[str, float] = {}
for key in sorted(v.keys()):
s = np.concatenate([v[key], a[key]])
s = np.nan_to_num(s, nan=0.0, posinf=1000000000000.0, neginf=-1000000000000.0)
result[f'auroc_{key}'] = float(roc_auc_score(y, s))
return result
def train(cfg: dict[str, Any]) -> Path:
device = _device(str(cfg.get('device', 'auto')))
save_dir = Path(cfg['save_dir'])
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / 'config.yaml', 'w') as f:
yaml.safe_dump(cfg, f)
seed = int(cfg.get('seed', 42))
data_seed = int(cfg.get('data_seed', seed))
torch.manual_seed(seed)
np.random.seed(seed)
print(f'Device: {device}')
print(f'[seed] model={seed} data={data_seed}')
feature_columns = cfg.get('flow_feature_columns')
data = load_unified_data(packets_npz=Path(cfg['packets_npz']) if cfg.get('packets_npz') else None, source_store=Path(cfg['source_store']) if cfg.get('source_store') else None, flows_parquet=Path(cfg['flows_parquet']), flow_features_path=Path(cfg['flow_features_path']) if cfg.get('flow_features_path') else None, flow_feature_columns=feature_columns, flow_features_align=str(cfg.get('flow_features_align', 'auto')), T=int(cfg['T']), split_seed=data_seed, train_ratio=float(cfg.get('train_ratio', 0.8)), benign_label=str(cfg.get('benign_label', 'normal')), min_len=int(cfg.get('min_len', 2)), packet_preprocess=str(cfg.get('packet_preprocess', 'mixed_dequant')), attack_cap=int(cfg['attack_cap']) if cfg.get('attack_cap') else None, val_cap=int(cfg['val_cap']) if cfg.get('val_cap') else None)
print(f'[data] T={data.T} packet_D={data.packet_dim} flow_D={data.flow_dim} train={len(data.train_flow):,} val={len(data.val_flow):,} attack={len(data.attack_flow):,}')
(tr_f, tr_p, tr_l) = subsample_train(data, int(cfg.get('n_train', 0)), data_seed)
ds = TensorDataset(torch.from_numpy(tr_f).float(), torch.from_numpy(tr_p).float(), torch.from_numpy(tr_l).long())
loader = DataLoader(ds, batch_size=int(cfg['batch_size']), shuffle=True, drop_last=True, num_workers=int(cfg.get('num_workers', 0)), pin_memory=device.type == 'cuda')
print(f'[data] using {len(ds):,} benign training flows')
model_cfg = UnifiedCFMConfig(T=data.T, packet_dim=data.packet_dim, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'), d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']), mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)), sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)), reference_mode=cfg.get('reference_mode'))
model = UnifiedTokenCFM(model_cfg).to(device)
print(f'[model] params={model.param_count():,} token_dim={model.token_dim} seq_len={model.seq_len} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} reference_mode={model_cfg.reference_mode}')
opt = torch.optim.AdamW(model.parameters(), lr=float(cfg['lr']), weight_decay=float(cfg.get('weight_decay', 0.01)))
total_steps = max(1, int(cfg['epochs']) * len(loader))
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
history: dict[str, list[Any]] = {'epoch': [], 'loss': [], 'eval': []}
lambda_flow = float(cfg.get('lambda_flow', 0.0))
lambda_packet = float(cfg.get('lambda_packet', 0.0))
packet_mask_ratio = float(cfg.get('packet_mask_ratio', 0.5))
aux_enabled = lambda_flow > 0.0 or lambda_packet > 0.0
if aux_enabled:
print(f'[loss] λ_flow={lambda_flow} λ_packet={lambda_packet} packet_mask_ratio={packet_mask_ratio}')
for epoch in range(1, int(cfg['epochs']) + 1):
model.train()
losses: list[float] = []
aux_flow_sum = 0.0
aux_packet_sum = 0.0
n_steps_this_epoch = 0
t0 = time.time()
for (flow, packets, lens) in loader:
flow = flow.to(device, non_blocking=True)
packets = packets.to(device, non_blocking=True)
lens = lens.to(device, non_blocking=True)
if aux_enabled:
comp = model.compute_loss(flow, packets, lens, lambda_flow=lambda_flow, lambda_packet=lambda_packet, packet_mask_ratio=packet_mask_ratio, return_components=True)
loss = comp['total']
aux_flow_sum += float(comp['aux_flow'].item())
aux_packet_sum += float(comp['aux_packet'].item())
else:
loss = model.compute_loss(flow, packets, lens)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get('grad_clip', 1.0)))
opt.step()
sched.step()
losses.append(float(loss.item()))
n_steps_this_epoch += 1
mean_loss = float(np.mean(losses)) if losses else float('nan')
eval_metrics: dict[str, float] | None = None
if epoch % int(cfg.get('eval_every', 5)) == 0 or epoch == int(cfg['epochs']):
eval_metrics = _quick_eval(model, data, device, cfg)
history['epoch'].append(epoch)
history['loss'].append(mean_loss)
history['eval'].append(eval_metrics)
elapsed = time.time() - t0
terminal = ''
if eval_metrics:
terminal = f" auroc_terminal={eval_metrics['auroc_terminal_norm']:.3f}"
if aux_enabled and n_steps_this_epoch:
terminal += f' aux_flow={aux_flow_sum / n_steps_this_epoch:.4f} aux_pkt={aux_packet_sum / n_steps_this_epoch:.4f}'
print(f"[epoch {epoch:>3d}/{cfg['epochs']:<3d}] ({elapsed:.1f}s) loss={mean_loss:.4f}{terminal}")
if not np.isfinite(mean_loss):
raise RuntimeError(f'non-finite loss at epoch {epoch}')
payload = {'model_state_dict': model.state_dict(), 'model_cfg': asdict(model_cfg), 'packet_mean': data.packet_mean, 'packet_std': data.packet_std, 'flow_mean': data.flow_mean, 'flow_std': data.flow_std, 'packet_preprocess': data.packet_preprocess, 'flow_feature_names': np.asarray(data.flow_feature_names), 'packet_feature_names': np.asarray(data.packet_feature_names)}
torch.save(payload, save_dir / 'model.pt')
with open(save_dir / 'history.json', 'w') as f:
json.dump(history, f, indent=2, default=str)
print(f"[saved] {save_dir / 'model.pt'}")
return save_dir
def main() -> None:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument('--config', type=Path, required=True)
p.add_argument('--override', type=str, nargs='*', default=[])
args = p.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
for override in args.override:
(key, value) = override.split('=', 1)
cfg[key] = yaml.safe_load(value)
train(cfg)
if __name__ == '__main__':
main()