Initial commit: code, paper, small artifacts
This commit is contained in:
133
Unified_CFM/README.md
Normal file
133
Unified_CFM/README.md
Normal file
@@ -0,0 +1,133 @@
|
||||
# Unified_CFM
|
||||
|
||||
A single multi-scale OT-CFM over one token sequence per flow:
|
||||
|
||||
```text
|
||||
[FLOW_TOKEN, PACKET_1, ..., PACKET_T]
|
||||
```
|
||||
|
||||
This is **not** a Flow-CFM + Packet-CFM ensemble. Flow-level and packet-level
|
||||
signals interact inside one Transformer velocity field, and a Phase 2
|
||||
masked-prediction consistency loss explicitly trains the cross-modal
|
||||
dependency.
|
||||
|
||||
This is the **current SOTA model** in the repo (within-dataset SOTA on
|
||||
ISCXTor2016 / CICIDS2017 / CICDDoS2019; near-SOTA cross-dataset).
|
||||
|
||||
## Model
|
||||
|
||||
`UnifiedTokenCFM` uses fixed tokenization to avoid latent-collapse shortcuts:
|
||||
|
||||
```text
|
||||
flow token: [type=-1, normalized 20-d canonical flow features, zero pad]
|
||||
packet token: [type=+1, normalized 9-d packet features, zero pad]
|
||||
```
|
||||
|
||||
Velocity field: 4-layer AdaLN-Zero Transformer (`d_model=128, n_heads=4`),
|
||||
sinusoidal time embedding (`time_dim=64`). Total ≈ 1.23M parameters.
|
||||
|
||||
Loss with Phase 2 consistency:
|
||||
|
||||
```
|
||||
L = L_main + λ_flow · L_mask_flow + λ_packet · L_mask_packet
|
||||
|
||||
L_main: standard OT-CFM velocity regression with σ-band noise +
|
||||
Sinkhorn OT coupling.
|
||||
L_mask_flow: zero out the flow token's input at x_t; predict v[flow]
|
||||
from packet context only.
|
||||
L_mask_packet: zero out a random 50% of real packet tokens at x_t;
|
||||
predict their velocities from flow + remaining packets.
|
||||
```
|
||||
|
||||
Best hyperparameters from the σ × λ sweeps:
|
||||
|
||||
```
|
||||
lambda_flow = lambda_packet = 0.3
|
||||
packet_mask_ratio = 0.5
|
||||
sigma = 0.6 # cross-dataset best; σ=0.1 marginally better for some within
|
||||
use_ot = True
|
||||
```
|
||||
|
||||
## Scores
|
||||
|
||||
The model exposes three classes of scores at inference:
|
||||
|
||||
```text
|
||||
# primary
|
||||
terminal_norm
|
||||
|
||||
# decomposed (analysis only)
|
||||
terminal_flow terminal_packet
|
||||
arc_length kinetic_energy kinetic_flow kinetic_packet
|
||||
velocity_total velocity_flow velocity_packet
|
||||
|
||||
# Phase 1 diagnostics
|
||||
curvature_total curvature_flow curvature_packet # ∫ ||dv/dt||² dt
|
||||
kappa2_speed2norm_packet_{mean,median,trimmed10_mean} # packet curvature / speed²
|
||||
jacobian_total jacobian_flow jacobian_packet # Hutchinson VJP estimate of ||∂v/∂x||_F²
|
||||
velocity_*_t{01..10} # 18 time-profile scores
|
||||
|
||||
# Phase 2 cross-modal consistency
|
||||
flow_consistency packet_consistency consistency_total
|
||||
```
|
||||
|
||||
`terminal_norm` is the paper's primary score. The decomposed and diagnostic
|
||||
scores serve **per-attack-family analysis** — they are NOT competing
|
||||
SOTA claims. Multi-seed std on `terminal_norm` is ≤ 0.005 across all our
|
||||
runs.
|
||||
|
||||
The Phase 2 consistency scores have a notable property: they are
|
||||
**discriminative only when the model is trained with the consistency loss**.
|
||||
On a baseline model `flow_consistency` is roughly random (0.57 on
|
||||
CICIDS2017); after Phase 2 training it lifts to 0.88. On SSH-Patator,
|
||||
where standard density scores struggle (`terminal_norm` 0.64), Phase 2
|
||||
`flow_consistency` reaches 0.94.
|
||||
|
||||
## Train
|
||||
|
||||
```bash
|
||||
# baseline (no consistency loss)
|
||||
uv run python Unified_CFM/train.py --config Unified_CFM/configs/cicids2017_baseline.yaml
|
||||
|
||||
# Phase 2 with consistency loss (λ=0.1, σ=0.1)
|
||||
uv run python Unified_CFM/train.py --config Unified_CFM/configs/cicids2017_consistency.yaml
|
||||
|
||||
# σ × λ sweeps and multi-seed orchestrators live in
|
||||
# artifacts/verify_2026_04_24/run_*.sh
|
||||
```
|
||||
|
||||
The intended setup is to use the workspace-canonical 20-d packet-derived
|
||||
flow feature file:
|
||||
|
||||
```yaml
|
||||
flow_features_path: datasets/cicids2017/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
```
|
||||
|
||||
`flow_features.parquet` is row-aligned with the Packet_CFM artifacts via
|
||||
`flow_id`. With `flow_features_align: auto`, the loader uses direct
|
||||
row/`flow_id` alignment when possible; scan alignment remains only for
|
||||
legacy full CSV-derived caches.
|
||||
|
||||
For large datasets where a monolithic `packets.npz` would exceed memory,
|
||||
the loader supports the sharded backend:
|
||||
|
||||
```yaml
|
||||
source_store: datasets/cicddos2019/processed/full_store
|
||||
val_cap: 20000
|
||||
attack_cap: 20000
|
||||
```
|
||||
|
||||
If `flow_features_path` is empty, the loader derives compact 16-d flow-level
|
||||
statistics from the packet sequence. That fallback is for debugging only;
|
||||
new runs should use the canonical 20-d file generated by
|
||||
`scripts/generate_flow_features.py`.
|
||||
|
||||
## Evaluation
|
||||
|
||||
`artifacts/verify_2026_04_24/eval_phase1_unified.py` runs Phase 1 + Phase 2
|
||||
score battery on a trained checkpoint, with per-attack-class AUROC.
|
||||
|
||||
`artifacts/verify_2026_04_24/eval_phase2_cross_cicddos2019.py` runs
|
||||
cross-dataset CICIDS2017→CICDDoS2019 evaluation under the standard
|
||||
10k benign + 10k stratified attack protocol.
|
||||
1
Unified_CFM/__init__.py
Normal file
1
Unified_CFM/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
pass
|
||||
45
Unified_CFM/configs/cicddos2019_reference_blockdiag.yaml
Normal file
45
Unified_CFM/configs/cicddos2019_reference_blockdiag.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
save_dir: /home/chy/JANUS/artifacts/phaseC_reference_2026_04_25/cicddos2019_ref_blockdiag_seed42
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 20000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode: block_diagonal
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
lambda_flow: 0.0
|
||||
lambda_packet: 0.0
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/cicddos2019_reference_independent.yaml
Normal file
45
Unified_CFM/configs/cicddos2019_reference_independent.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
save_dir: /home/chy/JANUS/artifacts/phaseC_reference_2026_04_25/cicddos2019_ref_independent_seed42
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 20000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode: independent_token
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 10000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
lambda_flow: 0.0
|
||||
lambda_packet: 0.0
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
41
Unified_CFM/configs/cicddos2019_within.yaml
Normal file
41
Unified_CFM/configs/cicddos2019_within.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicddos2019_within_2026_04_25
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
|
||||
val_cap: 20000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 10000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
device: auto
|
||||
43
Unified_CFM/configs/cicddos2019_within_consistency.yaml
Normal file
43
Unified_CFM/configs/cicddos2019_within_consistency.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicddos2019_within_consistency_2026_04_25
|
||||
source_store: /home/chy/JANUS/datasets/cicddos2019/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/cicddos2019/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/cicddos2019/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 20000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 10000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
lambda_flow: 0.1
|
||||
lambda_packet: 0.1
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
38
Unified_CFM/configs/cicids2017_baseline.yaml
Normal file
38
Unified_CFM/configs/cicids2017_baseline.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicids2017_canonical_2026_04_24
|
||||
|
||||
packets_npz: /home/chy/JANUS/datasets/cicids2017/processed/packets.npz
|
||||
flows_parquet: /home/chy/JANUS/datasets/cicids2017/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/cicids2017/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 2
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
device: auto
|
||||
43
Unified_CFM/configs/cicids2017_consistency.yaml
Normal file
43
Unified_CFM/configs/cicids2017_consistency.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_cicids2017_consistency_2026_04_25
|
||||
|
||||
packets_npz: /home/chy/JANUS/datasets/cicids2017/processed/packets.npz
|
||||
flows_parquet: /home/chy/JANUS/datasets/cicids2017/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/cicids2017/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 2
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
lambda_flow: 0.1
|
||||
lambda_packet: 0.1
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
43
Unified_CFM/configs/ciciot2023.yaml
Normal file
43
Unified_CFM/configs/ciciot2023.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_ciciot2023_2026_04_29
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_baseline_seed42.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_baseline_seed42.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/baseline_ciciot2023_seed42
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_baseline_seed43.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_baseline_seed43.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/baseline_ciciot2023_seed43
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 43
|
||||
data_seed: 43
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_baseline_seed44.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_baseline_seed44.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/baseline_ciciot2023_seed44
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 44
|
||||
data_seed: 44
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_route_a_causal.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_route_a_causal.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_a_causal_ciciot2023_seed42
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode: causal_packets
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_route_a_causal_seed43.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_route_a_causal_seed43.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_a_causal_ciciot2023_seed43
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 43
|
||||
data_seed: 43
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode: causal_packets
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_route_a_causal_seed44.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_route_a_causal_seed44.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_a_causal_ciciot2023_seed44
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 44
|
||||
data_seed: 44
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
reference_mode: causal_packets
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
44
Unified_CFM/configs/ciciot2023_route_b_spectral_seed42.yaml
Normal file
44
Unified_CFM/configs/ciciot2023_route_b_spectral_seed42.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_b_spectral_ciciot2023_seed42
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_spectral.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
44
Unified_CFM/configs/ciciot2023_route_b_spectral_seed43.yaml
Normal file
44
Unified_CFM/configs/ciciot2023_route_b_spectral_seed43.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_b_spectral_ciciot2023_seed43
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_spectral.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 43
|
||||
data_seed: 43
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
44
Unified_CFM/configs/ciciot2023_route_b_spectral_seed44.yaml
Normal file
44
Unified_CFM/configs/ciciot2023_route_b_spectral_seed44.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/route_comparison/route_b_spectral_ciciot2023_seed44
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_spectral.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 44
|
||||
data_seed: 44
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
attack_cap: 20000
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
45
Unified_CFM/configs/ciciot2023_shafir5.yaml
Normal file
45
Unified_CFM/configs/ciciot2023_shafir5.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_ciciot2023_shafir5_2026_04_29
|
||||
|
||||
source_store: /home/chy/JANUS/datasets/ciciot2023/processed/full_store
|
||||
flows_parquet: /home/chy/JANUS/datasets/ciciot2023/processed/full_store/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/ciciot2023/processed/flow_features_shafir5.parquet
|
||||
flow_feature_columns: ["HTTPS", "Protocol_Type", "Magnitude", "Variance", "fin_count"]
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: normal
|
||||
val_cap: 10000
|
||||
|
||||
flow_dim: 5
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 0
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
lambda_flow: 0.3
|
||||
lambda_packet: 0.3
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
39
Unified_CFM/configs/iscxtor2016.yaml
Normal file
39
Unified_CFM/configs/iscxtor2016.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_iscxtor2016_2026_04_25
|
||||
|
||||
packets_npz: /home/chy/JANUS/datasets/iscxtor2016/processed/packets.npz
|
||||
flows_parquet: /home/chy/JANUS/datasets/iscxtor2016/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/iscxtor2016/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: nontor
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 2
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
device: auto
|
||||
41
Unified_CFM/configs/iscxtor2016_consistency.yaml
Normal file
41
Unified_CFM/configs/iscxtor2016_consistency.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
save_dir: /home/chy/JANUS/artifacts/runs/unified_cfm_iscxtor2016_consistency_2026_04_25
|
||||
packets_npz: /home/chy/JANUS/datasets/iscxtor2016/processed/packets.npz
|
||||
flows_parquet: /home/chy/JANUS/datasets/iscxtor2016/processed/flows.parquet
|
||||
flow_features_path: /home/chy/JANUS/datasets/iscxtor2016/processed/flow_features.parquet
|
||||
flow_features_align: auto
|
||||
|
||||
T: 64
|
||||
n_train: 10000
|
||||
min_len: 2
|
||||
packet_preprocess: mixed_dequant
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
train_ratio: 0.8
|
||||
benign_label: nontor
|
||||
|
||||
d_model: 128
|
||||
n_layers: 4
|
||||
n_heads: 4
|
||||
mlp_ratio: 4.0
|
||||
time_dim: 64
|
||||
token_dim:
|
||||
|
||||
batch_size: 256
|
||||
num_workers: 2
|
||||
epochs: 50
|
||||
lr: 3.0e-4
|
||||
weight_decay: 0.01
|
||||
grad_clip: 1.0
|
||||
eval_every: 10
|
||||
eval_n: 20000
|
||||
eval_batch_size: 512
|
||||
eval_n_steps: 8
|
||||
|
||||
sigma: 0.1
|
||||
use_ot: true
|
||||
|
||||
lambda_flow: 0.1
|
||||
lambda_packet: 0.1
|
||||
packet_mask_ratio: 0.5
|
||||
|
||||
device: auto
|
||||
275
Unified_CFM/data.py
Normal file
275
Unified_CFM/data.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sys as _sys
|
||||
from pathlib import Path as _Path
|
||||
_sys.path.insert(0, str(_Path(__file__).resolve().parents[1]))
|
||||
from common.data_contract import PACKET_FEATURE_NAMES, PACKET_CONTINUOUS_CHANNEL_IDX as CONTINUOUS_CHANNEL_IDX, PACKET_BINARY_CHANNEL_IDX as BINARY_CHANNEL_IDX, canonical_5tuple as _canonical_key, fit_packet_stats as _fit_packet_stats, zscore as _zscore, apply_mixed_dequant as _apply_mixed_dequant
|
||||
DEFAULT_FLOW_META_COLUMNS = {'flow_id', 'label', 'day', 'service', 'src_ip', 'dst_ip', 'src_port', 'dst_port', 'protocol', 'timestamp', 'start_ts', 'n_pkts'}
|
||||
DERIVED_FLOW_FEATURE_NAMES = ('log_len', 'fwd_frac', 'bwd_frac', 'log_size_mean', 'log_size_std', 'log_size_min', 'log_size_max', 'log_dt_mean', 'log_dt_std', 'log_dt_max', 'syn_frac', 'fin_frac', 'rst_frac', 'psh_frac', 'ack_frac', 'log_win_mean')
|
||||
|
||||
@dataclass
|
||||
class UnifiedData:
|
||||
train_flow: np.ndarray
|
||||
val_flow: np.ndarray
|
||||
attack_flow: np.ndarray
|
||||
train_packets: np.ndarray
|
||||
val_packets: np.ndarray
|
||||
attack_packets: np.ndarray
|
||||
train_len: np.ndarray
|
||||
val_len: np.ndarray
|
||||
attack_len: np.ndarray
|
||||
attack_labels: np.ndarray
|
||||
packet_mean: np.ndarray
|
||||
packet_std: np.ndarray
|
||||
flow_mean: np.ndarray
|
||||
flow_std: np.ndarray
|
||||
packet_preprocess: str
|
||||
flow_feature_names: tuple[str, ...]
|
||||
packet_feature_names: tuple[str, ...] = PACKET_FEATURE_NAMES
|
||||
|
||||
@property
|
||||
def T(self) -> int:
|
||||
return int(self.train_packets.shape[1])
|
||||
|
||||
@property
|
||||
def packet_dim(self) -> int:
|
||||
return int(self.train_packets.shape[2])
|
||||
|
||||
@property
|
||||
def flow_dim(self) -> int:
|
||||
return int(self.train_flow.shape[1])
|
||||
|
||||
def _preprocess_packets(train_x: np.ndarray, val_x: np.ndarray, attack_x: np.ndarray, train_l: np.ndarray, val_l: np.ndarray, attack_l: np.ndarray, preprocess: str, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if preprocess not in ('zscore', 'mixed_dequant'):
|
||||
raise ValueError("packet_preprocess must be 'zscore' or 'mixed_dequant'")
|
||||
(mean, std) = _fit_packet_stats(train_x, train_l)
|
||||
|
||||
def prep(x: np.ndarray, l: np.ndarray, tag: str) -> np.ndarray:
|
||||
if preprocess == 'zscore':
|
||||
z = _zscore(x, mean, std)
|
||||
mask = np.arange(x.shape[1])[None, :] < l[:, None]
|
||||
return (z * mask[:, :, None]).astype(np.float32)
|
||||
return _apply_mixed_dequant(x, l, mean, std, split_tag=tag, seed=seed)
|
||||
return (prep(train_x, train_l, 'train'), prep(val_x, val_l, 'val'), prep(attack_x, attack_l, 'attack'), mean, std)
|
||||
|
||||
def _derive_flow_features(tokens: np.ndarray, lens: np.ndarray) -> np.ndarray:
|
||||
(N, T, _) = tokens.shape
|
||||
out = np.zeros((N, len(DERIVED_FLOW_FEATURE_NAMES)), dtype=np.float32)
|
||||
for i in range(N):
|
||||
n = int(max(lens[i], 1))
|
||||
x = tokens[i, :n]
|
||||
direction = x[:, 2]
|
||||
size = x[:, 0]
|
||||
dt = x[:, 1]
|
||||
win = x[:, 8]
|
||||
out[i, 0] = np.log1p(n)
|
||||
out[i, 1] = np.mean(direction < 0.5)
|
||||
out[i, 2] = np.mean(direction >= 0.5)
|
||||
out[i, 3] = size.mean()
|
||||
out[i, 4] = size.std()
|
||||
out[i, 5] = size.min()
|
||||
out[i, 6] = size.max()
|
||||
out[i, 7] = dt.mean()
|
||||
out[i, 8] = dt.std()
|
||||
out[i, 9] = dt.max()
|
||||
out[i, 10] = x[:, 3].mean()
|
||||
out[i, 11] = x[:, 4].mean()
|
||||
out[i, 12] = x[:, 5].mean()
|
||||
out[i, 13] = x[:, 6].mean()
|
||||
out[i, 14] = x[:, 7].mean()
|
||||
out[i, 15] = win.mean()
|
||||
return out
|
||||
|
||||
def _read_flow_features(path: Path, *, expected_rows: int, feature_columns: Optional[list[str]]=None) -> tuple[np.ndarray, tuple[str, ...], np.ndarray | None]:
|
||||
path = Path(path)
|
||||
if path.suffix == '.npz':
|
||||
data = np.load(path, allow_pickle=True)
|
||||
x = data['features'].astype(np.float32)
|
||||
raw_names = data['feature_names'] if 'feature_names' in data.files else np.arange(x.shape[1])
|
||||
names = tuple((str(v) for v in raw_names))
|
||||
flow_id = data['flow_id'] if 'flow_id' in data.files else None
|
||||
elif path.suffix in ('.parquet', '.pq'):
|
||||
df = pd.read_parquet(path)
|
||||
flow_id = df['flow_id'].to_numpy() if 'flow_id' in df.columns else None
|
||||
if feature_columns:
|
||||
cols = feature_columns
|
||||
else:
|
||||
cols = [c for c in df.columns if c not in DEFAULT_FLOW_META_COLUMNS and pd.api.types.is_numeric_dtype(df[c])]
|
||||
if not cols:
|
||||
raise ValueError(f'no numeric flow feature columns found in {path}')
|
||||
x = df[cols].to_numpy(dtype=np.float32)
|
||||
names = tuple(cols)
|
||||
else:
|
||||
raise ValueError(f'unsupported flow feature file: {path}')
|
||||
if len(x) != expected_rows:
|
||||
raise ValueError(f'flow feature row count {len(x):,} != packet row count {expected_rows:,}')
|
||||
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
||||
return (x, names, flow_id)
|
||||
|
||||
def _feature_columns_from_df(df: pd.DataFrame, requested: Optional[list[str]]) -> list[str]:
|
||||
if requested:
|
||||
return requested
|
||||
return [c for c in df.columns if c not in DEFAULT_FLOW_META_COLUMNS and pd.api.types.is_numeric_dtype(df[c])]
|
||||
|
||||
def _align_flow_features_by_scan(feature_df: pd.DataFrame, packet_flows: pd.DataFrame, *, feature_columns: list[str]) -> tuple[np.ndarray, tuple[str, ...]]:
|
||||
required = ['label', 'src_ip', 'src_port', 'dst_ip', 'dst_port', 'protocol']
|
||||
missing_feature = [c for c in required if c not in feature_df.columns]
|
||||
missing_packet = [c for c in required if c not in packet_flows.columns]
|
||||
if missing_feature or missing_packet:
|
||||
raise ValueError(f'scan alignment requires label + 5-tuple metadata. missing in feature_df={missing_feature}, packet_flows={missing_packet}')
|
||||
packet_keys = [(str(lbl), _canonical_key(src, sp, dst, dp, proto)) for (lbl, src, sp, dst, dp, proto) in zip(packet_flows['label'].to_numpy(), packet_flows['src_ip'].to_numpy(), packet_flows['src_port'].to_numpy(), packet_flows['dst_ip'].to_numpy(), packet_flows['dst_port'].to_numpy(), packet_flows['protocol'].to_numpy())]
|
||||
labels = feature_df['label'].to_numpy()
|
||||
src_ip = feature_df['src_ip'].to_numpy()
|
||||
src_port = feature_df['src_port'].to_numpy()
|
||||
dst_ip = feature_df['dst_ip'].to_numpy()
|
||||
dst_port = feature_df['dst_port'].to_numpy()
|
||||
protocol = feature_df['protocol'].to_numpy()
|
||||
matched: list[int] = []
|
||||
j = 0
|
||||
n_csv = len(feature_df)
|
||||
for (i, target) in enumerate(packet_keys):
|
||||
while j < n_csv:
|
||||
cand = (str(labels[j]), _canonical_key(src_ip[j], src_port[j], dst_ip[j], dst_port[j], protocol[j]))
|
||||
j += 1
|
||||
if cand == target:
|
||||
matched.append(j - 1)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'failed to align packet flow row {i:,}/{len(packet_keys):,}; the CSV cache may not be the same one used for packet extraction')
|
||||
print(f'[data] scan-aligned CSV flow features: matched={len(matched):,} from csv_rows={n_csv:,} skipped={matched[-1] + 1 - len(matched):,}')
|
||||
x = feature_df.iloc[matched][feature_columns].to_numpy(dtype=np.float32)
|
||||
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
||||
return (x, tuple(feature_columns))
|
||||
|
||||
def _read_aligned_flow_features(path: Path, packet_flows: pd.DataFrame, *, feature_columns: Optional[list[str]]=None, align: str='auto') -> tuple[np.ndarray, tuple[str, ...]]:
|
||||
path = Path(path)
|
||||
if align not in ('auto', 'row', 'scan'):
|
||||
raise ValueError("flow_features_align must be 'auto', 'row', or 'scan'")
|
||||
if path.suffix == '.npz':
|
||||
(x, names, flow_id) = _read_flow_features(path, expected_rows=len(packet_flows), feature_columns=feature_columns)
|
||||
packet_id = packet_flows['flow_id'].to_numpy() if 'flow_id' in packet_flows else None
|
||||
if flow_id is not None and packet_id is not None and (not np.array_equal(flow_id, packet_id)):
|
||||
raise ValueError('NPZ flow_id does not align with Packet_CFM flows')
|
||||
return (x, names)
|
||||
if path.suffix not in ('.parquet', '.pq'):
|
||||
raise ValueError(f'unsupported flow feature file: {path}')
|
||||
feature_df = pd.read_parquet(path)
|
||||
cols = _feature_columns_from_df(feature_df, feature_columns)
|
||||
if not cols:
|
||||
raise ValueError(f'no numeric flow feature columns found in {path}')
|
||||
packet_id = packet_flows['flow_id'].to_numpy() if 'flow_id' in packet_flows else None
|
||||
if len(feature_df) == len(packet_flows):
|
||||
feature_id = feature_df['flow_id'].to_numpy() if 'flow_id' in feature_df.columns else None
|
||||
if feature_id is None or packet_id is None or np.array_equal(feature_id, packet_id):
|
||||
x = feature_df[cols].to_numpy(dtype=np.float32)
|
||||
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
||||
return (x, tuple(cols))
|
||||
if align == 'row':
|
||||
raise ValueError("flow_id mismatch with flow_features_align='row'")
|
||||
if align == 'row':
|
||||
raise ValueError(f'row alignment requested but feature rows={len(feature_df):,} packet rows={len(packet_flows):,}')
|
||||
return _align_flow_features_by_scan(feature_df, packet_flows, feature_columns=cols)
|
||||
|
||||
def _preprocess_flow(train: np.ndarray, val: np.ndarray, attack: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
mean = train.mean(axis=0).astype(np.float32)
|
||||
std = train.std(axis=0).astype(np.float32)
|
||||
return (_zscore(train, mean, std), _zscore(val, mean, std), _zscore(attack, mean, std), mean, std)
|
||||
|
||||
def load_unified_data(*, packets_npz: Path | None=None, source_store: Path | None=None, flows_parquet: Path, flow_features_path: Path | None=None, flow_feature_columns: Optional[list[str]]=None, flow_features_align: str='auto', T: int=128, split_seed: int=42, train_ratio: float=0.8, benign_label: str='normal', min_len: int=2, packet_preprocess: str='mixed_dequant', attack_cap: int | None=None, val_cap: int | None=None) -> UnifiedData:
|
||||
if (packets_npz is None) == (source_store is None):
|
||||
raise ValueError('pass exactly one of packets_npz or source_store')
|
||||
flows_parquet = Path(flows_parquet)
|
||||
print(f'[data] flows={flows_parquet} packets_source={(packets_npz if packets_npz else source_store)}')
|
||||
flow_cols = ['flow_id', 'label']
|
||||
if flow_features_path is not None:
|
||||
flow_cols += ['src_ip', 'src_port', 'dst_ip', 'dst_port', 'protocol']
|
||||
flows = pd.read_parquet(flows_parquet, columns=flow_cols)
|
||||
labels_full = flows['label'].to_numpy().astype(str)
|
||||
flow_id = flows['flow_id'].to_numpy()
|
||||
tokens_full: np.ndarray | None = None
|
||||
store = None
|
||||
if packets_npz is not None:
|
||||
pz = np.load(Path(packets_npz))
|
||||
tokens_full = pz['packet_tokens'].astype(np.float32)
|
||||
lens_full = pz['packet_lengths'].astype(np.int32)
|
||||
packet_flow_id = pz['flow_id'] if 'flow_id' in pz.files else None
|
||||
if T > tokens_full.shape[1]:
|
||||
raise ValueError(f'requested T={T} > stored T_full={tokens_full.shape[1]}')
|
||||
tokens_full = tokens_full[:, :T].copy()
|
||||
lens_full = np.minimum(lens_full, T).astype(np.int32)
|
||||
if packet_flow_id is not None and (not np.array_equal(packet_flow_id, flow_id)):
|
||||
raise ValueError('packets_npz and flows_parquet are not row-aligned by flow_id')
|
||||
else:
|
||||
if flow_features_path is None:
|
||||
raise ValueError('source_store path requires flow_features_path (derived features need tokens in memory)')
|
||||
from common.packet_store import PacketShardStore
|
||||
store = PacketShardStore.open(Path(source_store))
|
||||
store_flow_id = store.read_flows(columns=['flow_id'])['flow_id'].to_numpy()
|
||||
if not np.array_equal(store_flow_id, flow_id):
|
||||
raise ValueError('source_store and flows_parquet are not row-aligned by flow_id')
|
||||
lens_full = np.minimum(store.manifest['packet_length'].to_numpy(dtype=np.int32), T)
|
||||
if flow_features_path is None:
|
||||
assert tokens_full is not None
|
||||
flow_features = _derive_flow_features(tokens_full, lens_full)
|
||||
flow_names = DERIVED_FLOW_FEATURE_NAMES
|
||||
print(f'[data] using derived flow features D={flow_features.shape[1]}')
|
||||
else:
|
||||
(flow_features, flow_names) = _read_aligned_flow_features(Path(flow_features_path), flows, feature_columns=flow_feature_columns, align=flow_features_align)
|
||||
print(f'[data] using external flow features D={flow_features.shape[1]}')
|
||||
keep = lens_full >= min_len
|
||||
labels = labels_full[keep]
|
||||
flow_features = flow_features[keep]
|
||||
lens = lens_full[keep]
|
||||
global_idx = np.flatnonzero(keep).astype(np.int64)
|
||||
if tokens_full is not None:
|
||||
materialized_tokens = tokens_full[keep]
|
||||
else:
|
||||
materialized_tokens = None
|
||||
print(f'[data] rows total={len(keep):,} keep len>={min_len}: {keep.sum():,}')
|
||||
benign_local = np.where(labels == benign_label)[0]
|
||||
attack_local = np.where(labels != benign_label)[0]
|
||||
rng = np.random.default_rng(split_seed)
|
||||
rng.shuffle(benign_local)
|
||||
n_train = int(len(benign_local) * train_ratio)
|
||||
train_local = benign_local[:n_train]
|
||||
val_local = benign_local[n_train:]
|
||||
if val_cap is not None and len(val_local) > val_cap:
|
||||
val_local = np.sort(rng.choice(val_local, size=val_cap, replace=False))
|
||||
if attack_cap is not None and len(attack_local) > attack_cap:
|
||||
attack_local = np.sort(rng.choice(attack_local, size=attack_cap, replace=False))
|
||||
print(f'[data] benign={len(benign_local):,} attack={len(attack_local):,} -> train={len(train_local):,} val={len(val_local):,}')
|
||||
|
||||
def _materialize(local_indices: np.ndarray) -> np.ndarray:
|
||||
if materialized_tokens is not None:
|
||||
return materialized_tokens[local_indices].astype(np.float32, copy=False)
|
||||
assert store is not None
|
||||
g = global_idx[local_indices]
|
||||
(tok, _) = store.read_packets(g.astype(np.int64), T=T)
|
||||
return tok.astype(np.float32, copy=False)
|
||||
tr_p_raw = _materialize(train_local)
|
||||
va_p_raw = _materialize(val_local)
|
||||
at_p_raw = _materialize(attack_local)
|
||||
tr_l = lens[train_local]
|
||||
va_l = lens[val_local]
|
||||
at_l = lens[attack_local]
|
||||
tr_f_raw = flow_features[train_local]
|
||||
va_f_raw = flow_features[val_local]
|
||||
at_f_raw = flow_features[attack_local]
|
||||
train_idx = train_local
|
||||
val_idx = val_local
|
||||
attack_idx = attack_local
|
||||
(tr_p, va_p, at_p, p_mean, p_std) = _preprocess_packets(tr_p_raw, va_p_raw, at_p_raw, tr_l, va_l, at_l, preprocess=packet_preprocess, seed=split_seed)
|
||||
(tr_f, va_f, at_f, f_mean, f_std) = _preprocess_flow(tr_f_raw, va_f_raw, at_f_raw)
|
||||
return UnifiedData(train_flow=tr_f, val_flow=va_f, attack_flow=at_f, train_packets=tr_p, val_packets=va_p, attack_packets=at_p, train_len=tr_l, val_len=va_l, attack_len=at_l, attack_labels=labels[attack_idx], packet_mean=p_mean, packet_std=p_std, flow_mean=f_mean, flow_std=f_std, packet_preprocess=packet_preprocess, flow_feature_names=tuple(flow_names))
|
||||
|
||||
def subsample_train(data: UnifiedData, n_train: int, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
if n_train <= 0 or n_train >= len(data.train_flow):
|
||||
return (data.train_flow, data.train_packets, data.train_len)
|
||||
rng = np.random.default_rng(seed)
|
||||
idx = rng.choice(len(data.train_flow), n_train, replace=False)
|
||||
idx.sort()
|
||||
return (data.train_flow[idx], data.train_packets[idx], data.train_len[idx])
|
||||
588
Unified_CFM/model.py
Normal file
588
Unified_CFM/model.py
Normal file
@@ -0,0 +1,588 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchdiffeq import odeint
|
||||
|
||||
@torch.no_grad()
|
||||
def _sinkhorn_coupling(C: torch.Tensor, reg: float=0.05, n_iter: int=20) -> torch.Tensor:
|
||||
C = C.float()
|
||||
log_k = -C / reg
|
||||
B = C.shape[0]
|
||||
log_u = torch.zeros(B, device=C.device)
|
||||
log_v = torch.zeros(B, device=C.device)
|
||||
for _ in range(n_iter):
|
||||
log_v = -torch.logsumexp(log_k + log_u.unsqueeze(1), dim=0)
|
||||
log_u = -torch.logsumexp(log_k + log_v.unsqueeze(0), dim=1)
|
||||
log_p = log_u.unsqueeze(1) + log_k + log_v.unsqueeze(0)
|
||||
return log_p.argmax(dim=1)
|
||||
|
||||
class SinusoidalTimeEmb(nn.Module):
|
||||
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError('time embedding dimension must be even')
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
half = self.dim // 2
|
||||
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device, dtype=t.dtype) / max(half - 1, 1))
|
||||
args = t[:, None] * freqs[None, :]
|
||||
return torch.cat([args.sin(), args.cos()], dim=-1)
|
||||
|
||||
class AdaLNBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float, cond_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
|
||||
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
hidden = int(d_model * mlp_ratio)
|
||||
self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, d_model))
|
||||
self.cond_proj = nn.Linear(cond_dim, 6 * d_model)
|
||||
nn.init.zeros_(self.cond_proj.weight)
|
||||
nn.init.zeros_(self.cond_proj.bias)
|
||||
|
||||
@staticmethod
|
||||
def _modulate(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
|
||||
return x * (1.0 + gamma[:, None, :]) + beta[:, None, :]
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor, key_padding_mask: torch.Tensor | None, attn_mask: torch.Tensor | None=None) -> torch.Tensor:
|
||||
(g1, b1, a1, g2, b2, a2) = self.cond_proj(cond).chunk(6, dim=-1)
|
||||
h = self._modulate(self.norm1(x), g1, b1)
|
||||
(attn_out, _) = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
|
||||
x = x + a1[:, None, :] * attn_out
|
||||
h = self._modulate(self.norm2(x), g2, b2)
|
||||
return x + a2[:, None, :] * self.mlp(h)
|
||||
|
||||
class UnifiedVelocity(nn.Module):
|
||||
|
||||
def __init__(self, token_dim: int, seq_len: int, d_model: int=128, n_layers: int=4, n_heads: int=4, mlp_ratio: float=4.0, time_dim: int=64, reference_mode: str | None=None) -> None:
|
||||
super().__init__()
|
||||
if reference_mode not in (None, 'independent_token', 'block_diagonal', 'causal_packets', 'causal_all'):
|
||||
raise ValueError(f'unknown reference_mode={reference_mode!r}')
|
||||
self.token_dim = token_dim
|
||||
self.seq_len = seq_len
|
||||
self.reference_mode = reference_mode
|
||||
self.input_proj = nn.Linear(token_dim, d_model)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
|
||||
self.type_emb = nn.Embedding(2, d_model)
|
||||
nn.init.trunc_normal_(self.pos_emb, std=0.02)
|
||||
nn.init.normal_(self.type_emb.weight, std=0.02)
|
||||
self.time_emb = SinusoidalTimeEmb(time_dim)
|
||||
self.cond_mlp = nn.Sequential(nn.Linear(time_dim, d_model), nn.SiLU(), nn.Linear(d_model, d_model))
|
||||
self.blocks = nn.ModuleList([AdaLNBlock(d_model, n_heads, mlp_ratio, cond_dim=d_model) for _ in range(n_layers)])
|
||||
self.out_norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.out = nn.Linear(d_model, token_dim)
|
||||
nn.init.zeros_(self.out.weight)
|
||||
nn.init.zeros_(self.out.bias)
|
||||
type_ids = torch.ones(seq_len, dtype=torch.long)
|
||||
type_ids[0] = 0
|
||||
self.register_buffer('type_ids', type_ids, persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, key_padding_mask: torch.Tensor | None=None, attn_mask_override: torch.Tensor | None=None) -> torch.Tensor:
|
||||
(B, L, _) = x.shape
|
||||
if L > self.seq_len:
|
||||
raise ValueError(f'sequence length {L} exceeds configured {self.seq_len}')
|
||||
if t.dim() == 0:
|
||||
t = t.expand(B)
|
||||
h = self.input_proj(x)
|
||||
h = h + self.pos_emb[:, :L, :]
|
||||
h = h + self.type_emb(self.type_ids[:L])[None, :, :]
|
||||
cond = self.cond_mlp(self.time_emb(t))
|
||||
if attn_mask_override is not None:
|
||||
attn_mask = attn_mask_override
|
||||
else:
|
||||
attn_mask = self._reference_attn_mask(L, x.device)
|
||||
for block in self.blocks:
|
||||
h = block(h, cond, key_padding_mask, attn_mask=attn_mask)
|
||||
return self.out(self.out_norm(h))
|
||||
|
||||
def _reference_attn_mask(self, L: int, device: torch.device) -> torch.Tensor | None:
|
||||
if self.reference_mode is None:
|
||||
return None
|
||||
if self.reference_mode == 'independent_token':
|
||||
return ~torch.eye(L, dtype=torch.bool, device=device)
|
||||
if self.reference_mode == 'block_diagonal':
|
||||
mask = torch.ones((L, L), dtype=torch.bool, device=device)
|
||||
mask[0, 0] = False
|
||||
if L > 1:
|
||||
mask[1:, 1:] = False
|
||||
return mask
|
||||
if self.reference_mode == 'causal_packets':
|
||||
mask = torch.zeros((L, L), dtype=torch.bool, device=device)
|
||||
if L > 1:
|
||||
packet_causal = torch.triu(torch.ones(L - 1, L - 1, dtype=torch.bool, device=device), diagonal=1)
|
||||
mask[1:, 1:] = packet_causal
|
||||
return mask
|
||||
if self.reference_mode == 'causal_all':
|
||||
return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
|
||||
raise AssertionError(self.reference_mode)
|
||||
|
||||
@dataclass
|
||||
class UnifiedCFMConfig:
|
||||
T: int = 128
|
||||
packet_dim: int = 9
|
||||
flow_dim: int = 16
|
||||
token_dim: int | None = None
|
||||
d_model: int = 128
|
||||
n_layers: int = 4
|
||||
n_heads: int = 4
|
||||
mlp_ratio: float = 4.0
|
||||
time_dim: int = 64
|
||||
sigma: float = 0.1
|
||||
use_ot: bool = False
|
||||
reference_mode: str | None = None
|
||||
|
||||
class UnifiedTokenCFM(nn.Module):
|
||||
|
||||
def __init__(self, cfg: UnifiedCFMConfig) -> None:
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.token_dim = cfg.token_dim or 1 + max(cfg.flow_dim, cfg.packet_dim)
|
||||
if self.token_dim < 1 + max(cfg.flow_dim, cfg.packet_dim):
|
||||
raise ValueError('token_dim is too small for flow_dim/packet_dim')
|
||||
self.seq_len = cfg.T + 1
|
||||
self.velocity = UnifiedVelocity(token_dim=self.token_dim, seq_len=self.seq_len, d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads, mlp_ratio=cfg.mlp_ratio, time_dim=cfg.time_dim, reference_mode=cfg.reference_mode)
|
||||
|
||||
def build_tokens(self, flow: torch.Tensor, packets: torch.Tensor) -> torch.Tensor:
|
||||
(B, T, Dp) = packets.shape
|
||||
if T != self.cfg.T:
|
||||
raise ValueError(f'packet T={T} but config T={self.cfg.T}')
|
||||
if Dp != self.cfg.packet_dim:
|
||||
raise ValueError(f'packet_dim={Dp} but config packet_dim={self.cfg.packet_dim}')
|
||||
if flow.shape[-1] != self.cfg.flow_dim:
|
||||
raise ValueError(f'flow_dim={flow.shape[-1]} but config flow_dim={self.cfg.flow_dim}')
|
||||
z = packets.new_zeros((B, T + 1, self.token_dim))
|
||||
z[:, 0, 0] = -1.0
|
||||
z[:, 0, 1:1 + self.cfg.flow_dim] = flow
|
||||
z[:, 1:, 0] = 1.0
|
||||
z[:, 1:, 1:1 + self.cfg.packet_dim] = packets
|
||||
return z
|
||||
|
||||
def key_padding_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
||||
B = lens.shape[0]
|
||||
idx = torch.arange(self.cfg.T, device=lens.device)[None, :]
|
||||
packet_real = idx < lens[:, None]
|
||||
real = torch.cat([torch.ones(B, 1, dtype=torch.bool, device=lens.device), packet_real], dim=1)
|
||||
return ~real
|
||||
|
||||
def _loss_mask(self, lens: torch.Tensor) -> torch.Tensor:
|
||||
return (~self.key_padding_mask(lens)).float()
|
||||
|
||||
@staticmethod
|
||||
def _masked_trimmed_mean(values: torch.Tensor, mask: torch.Tensor, trim_frac: float=0.1) -> torch.Tensor:
|
||||
out = values.new_zeros(values.shape[0])
|
||||
for i in range(values.shape[0]):
|
||||
v = values[i][mask[i] > 0]
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
if v.numel() < 5:
|
||||
out[i] = v.mean()
|
||||
continue
|
||||
v_sorted = torch.sort(v).values
|
||||
lo = int(trim_frac * v_sorted.numel())
|
||||
hi = int((1.0 - trim_frac) * v_sorted.numel())
|
||||
if hi <= lo:
|
||||
out[i] = v_sorted.mean()
|
||||
else:
|
||||
out[i] = v_sorted[lo:hi].mean()
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _masked_median(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = values.new_zeros(values.shape[0])
|
||||
for i in range(values.shape[0]):
|
||||
v = values[i][mask[i] > 0]
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
v_sorted = torch.sort(v).values
|
||||
mid = v_sorted.numel() // 2
|
||||
if v_sorted.numel() % 2:
|
||||
out[i] = v_sorted[mid]
|
||||
else:
|
||||
out[i] = 0.5 * (v_sorted[mid - 1] + v_sorted[mid])
|
||||
return out
|
||||
|
||||
def compute_loss(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, *, lambda_flow: float=0.0, lambda_packet: float=0.0, packet_mask_ratio: float=0.5, return_components: bool=False) -> torch.Tensor | dict[str, torch.Tensor]:
|
||||
x1 = self.build_tokens(flow, packets)
|
||||
B = x1.shape[0]
|
||||
x0 = torch.randn_like(x1)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
if self.cfg.use_ot:
|
||||
flat0 = (x0 * mask[:, :, None]).reshape(B, -1)
|
||||
flat1 = (x1 * mask[:, :, None]).reshape(B, -1)
|
||||
col = _sinkhorn_coupling(torch.cdist(flat0.float(), flat1.float()))
|
||||
x1 = x1[col]
|
||||
flow = flow[col]
|
||||
packets = packets[col]
|
||||
lens = lens[col]
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
t = torch.rand(B, device=x1.device)
|
||||
x_t = (1.0 - t[:, None, None]) * x0 + t[:, None, None] * x1
|
||||
if self.cfg.sigma > 0:
|
||||
std = self.cfg.sigma * torch.sqrt(t * (1.0 - t))[:, None, None]
|
||||
x_t = x_t + std * torch.randn_like(x_t)
|
||||
target = x1 - x0
|
||||
pred = self.velocity(x_t, t, key_padding_mask=kpm)
|
||||
sq = (pred - target).square().mean(dim=-1)
|
||||
per_sample = (sq * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
main_loss = per_sample.mean()
|
||||
aux_flow_loss = x1.new_zeros(())
|
||||
aux_packet_loss = x1.new_zeros(())
|
||||
if lambda_flow > 0.0:
|
||||
x_t_mf = x_t.clone()
|
||||
x_t_mf[:, 0, :] = 0.0
|
||||
pred_mf = self.velocity(x_t_mf, t, key_padding_mask=kpm)
|
||||
err = (pred_mf[:, 0] - target[:, 0]).square().mean(dim=-1)
|
||||
aux_flow_loss = err.mean()
|
||||
if lambda_packet > 0.0:
|
||||
packet_real = mask[:, 1:] > 0
|
||||
rand_draw = torch.rand(packet_real.shape, device=x1.device)
|
||||
mask_pkt = (rand_draw < packet_mask_ratio) & packet_real
|
||||
pkt_mask_full = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x1.device), mask_pkt], dim=1)
|
||||
x_t_mp = x_t.clone()
|
||||
x_t_mp[pkt_mask_full] = 0.0
|
||||
pred_mp = self.velocity(x_t_mp, t, key_padding_mask=kpm)
|
||||
sq_mp = (pred_mp - target).square().mean(dim=-1)
|
||||
mask_f = pkt_mask_full.float()
|
||||
denom = mask_f.sum(dim=-1).clamp_min(1.0)
|
||||
aux_packet_loss = ((sq_mp * mask_f).sum(dim=-1) / denom).mean()
|
||||
total = main_loss + lambda_flow * aux_flow_loss + lambda_packet * aux_packet_loss
|
||||
if return_components:
|
||||
return {'total': total, 'main': main_loss.detach(), 'aux_flow': aux_flow_loss.detach(), 'aux_packet': aux_packet_loss.detach()}
|
||||
return total
|
||||
|
||||
@torch.no_grad()
|
||||
def velocity_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5, 0.75, 1.0)) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
total = torch.zeros(x.shape[0], device=x.device)
|
||||
flow_s = torch.zeros_like(total)
|
||||
packet_s = torch.zeros_like(total)
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
for t_val in t_eval:
|
||||
t = torch.full((x.shape[0],), float(t_val), device=x.device)
|
||||
v = self.velocity(x, t, key_padding_mask=kpm)
|
||||
e = v.square().mean(dim=-1)
|
||||
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
flow_s = flow_s + e[:, 0]
|
||||
packet_s = packet_s + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
denom = float(len(t_eval))
|
||||
return {'velocity_total': total / denom, 'velocity_flow': flow_s / denom, 'velocity_packet': packet_s / denom}
|
||||
|
||||
@torch.no_grad()
|
||||
def trajectory_metrics(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16) -> dict[str, torch.Tensor]:
|
||||
z = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = z.shape[0]
|
||||
dt = 1.0 / n_steps
|
||||
total_arc = torch.zeros(B, device=z.device)
|
||||
total_ke = torch.zeros(B, device=z.device)
|
||||
flow_ke = torch.zeros(B, device=z.device)
|
||||
packet_ke = torch.zeros(B, device=z.device)
|
||||
total_curv = torch.zeros(B, device=z.device)
|
||||
flow_curv = torch.zeros(B, device=z.device)
|
||||
packet_curv = torch.zeros(B, device=z.device)
|
||||
packet_kappa2_speed2 = torch.zeros(B, max(z.shape[1] - 1, 0), device=z.device)
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
v_prev = None
|
||||
v_prev_norm = None
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v = self.velocity(z, t, key_padding_mask=kpm)
|
||||
e = v.square().mean(dim=-1)
|
||||
v_norm = v.square().sum(dim=-1).clamp_min(1e-12).sqrt()
|
||||
total_ke = total_ke + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0) * dt
|
||||
flow_ke = flow_ke + e[:, 0] * dt
|
||||
packet_ke = packet_ke + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count * dt
|
||||
if v_prev is not None:
|
||||
dv = v - v_prev
|
||||
dve = dv.square().mean(dim=-1)
|
||||
total_curv = total_curv + (dve * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
flow_curv = flow_curv + dve[:, 0]
|
||||
packet_curv = packet_curv + (dve[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
dv2_sum = dv[:, 1:].square().sum(dim=-1)
|
||||
assert v_prev_norm is not None
|
||||
v_avg = 0.5 * (v_norm[:, 1:] + v_prev_norm[:, 1:])
|
||||
packet_kappa2_speed2 = packet_kappa2_speed2 + dv2_sum / v_avg.square().clamp_min(1e-06)
|
||||
v_prev = v
|
||||
v_prev_norm = v_norm
|
||||
z_new = z - v * dt
|
||||
dz = (z_new - z) * mask[:, :, None]
|
||||
total_arc = total_arc + dz.reshape(B, -1).norm(dim=-1) / mask.sum(dim=-1).sqrt()
|
||||
z = z_new
|
||||
z_masked = z * mask[:, :, None]
|
||||
terminal = z_masked.reshape(B, -1).norm(dim=-1) / (mask.sum(dim=-1) * self.token_dim).clamp_min(1.0).sqrt()
|
||||
terminal_flow = z[:, 0].norm(dim=-1) / math.sqrt(self.token_dim)
|
||||
terminal_packet = (z[:, 1:] * mask[:, 1:, None]).reshape(B, -1).norm(dim=-1) / (packet_count * self.token_dim).sqrt()
|
||||
packet_mask = mask[:, 1:]
|
||||
kappa2_speed2_mean = (packet_kappa2_speed2 * packet_mask).sum(dim=-1) / packet_count
|
||||
kappa2_speed2_median = self._masked_median(packet_kappa2_speed2, packet_mask)
|
||||
kappa2_speed2_trimmed = self._masked_trimmed_mean(packet_kappa2_speed2, packet_mask)
|
||||
return {'terminal_norm': terminal, 'terminal_flow': terminal_flow, 'terminal_packet': terminal_packet, 'arc_length': total_arc, 'kinetic_energy': total_ke, 'kinetic_flow': flow_ke, 'kinetic_packet': packet_ke, 'curvature_total': total_curv, 'curvature_flow': flow_curv, 'curvature_packet': packet_curv, 'kappa2_speed2norm_packet_mean': kappa2_speed2_mean, 'kappa2_speed2norm_packet_median': kappa2_speed2_median, 'kappa2_speed2norm_packet_trimmed10_mean': kappa2_speed2_trimmed}
|
||||
|
||||
@torch.no_grad()
|
||||
def score_profile_vt(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0)) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
out: dict[str, torch.Tensor] = {}
|
||||
for t_val in t_eval:
|
||||
t = torch.full((x.shape[0],), float(t_val), device=x.device)
|
||||
v = self.velocity(x, t, key_padding_mask=kpm)
|
||||
e = v.square().mean(dim=-1)
|
||||
tag = f't{int(round(t_val * 10)):02d}'
|
||||
out[f'velocity_total_{tag}'] = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
out[f'velocity_flow_{tag}'] = e[:, 0]
|
||||
out[f'velocity_packet_{tag}'] = (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = x.shape[0]
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
t = torch.full((B,), float(t_eval), device=x.device)
|
||||
v_full = self.velocity(x, t, key_padding_mask=kpm)
|
||||
x_mf = x.clone()
|
||||
x_mf[:, 0, :] = 0.0
|
||||
v_mf = self.velocity(x_mf, t, key_padding_mask=kpm)
|
||||
flow_cons = (v_full[:, 0] - v_mf[:, 0]).square().mean(dim=-1)
|
||||
x_mp = x.clone()
|
||||
pkt_mask_full = mask[:, 1:] > 0
|
||||
idx_pkt_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x.device), pkt_mask_full], dim=1)
|
||||
x_mp[idx_pkt_mask] = 0.0
|
||||
v_mp = self.velocity(x_mp, t, key_padding_mask=kpm)
|
||||
diff = (v_full - v_mp).square().mean(dim=-1)
|
||||
packet_cons = (diff[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
return {'flow_consistency': flow_cons, 'packet_consistency': packet_cons, 'consistency_total': flow_cons + packet_cons}
|
||||
|
||||
def jacobian_hutchinson(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.5,), n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
B = x.shape[0]
|
||||
packet_count = mask[:, 1:].sum(dim=-1).clamp_min(1.0)
|
||||
total = torch.zeros(B, device=x.device)
|
||||
flow_j = torch.zeros(B, device=x.device)
|
||||
packet_j = torch.zeros(B, device=x.device)
|
||||
n_draws = n_eps * len(t_eval)
|
||||
for t_val in t_eval:
|
||||
t_current = torch.full((B,), float(t_val), device=x.device)
|
||||
for _ in range(n_eps):
|
||||
x_req = x.detach().clone().requires_grad_(True)
|
||||
v = self.velocity(x_req, t_current, key_padding_mask=kpm)
|
||||
eps = torch.randn(v.shape, device=v.device, generator=generator)
|
||||
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
|
||||
e = g.square().mean(dim=-1)
|
||||
total = total + (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
flow_j = flow_j + e[:, 0]
|
||||
packet_j = packet_j + (e[:, 1:] * mask[:, 1:]).sum(dim=-1) / packet_count
|
||||
return {'jacobian_total': (total / n_draws).detach(), 'jacobian_flow': (flow_j / n_draws).detach(), 'jacobian_packet': (packet_j / n_draws).detach()}
|
||||
|
||||
@torch.no_grad()
|
||||
def pna_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, flow_masked: bool=False) -> dict[str, torch.Tensor]:
|
||||
eps_v2 = 1e-06
|
||||
dt = 1.0 / n_steps
|
||||
z = self.build_tokens(flow, packets)
|
||||
if flow_masked:
|
||||
z = z.clone()
|
||||
z[:, 0, :] = 0.0
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, _) = z.shape
|
||||
pna = torch.zeros(B, L, device=z.device)
|
||||
v_prev: torch.Tensor | None = None
|
||||
v_norm_prev: torch.Tensor | None = None
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v = self.velocity(z, t, key_padding_mask=kpm)
|
||||
v_norm = (v.square().sum(dim=-1) + 1e-12).sqrt()
|
||||
if v_prev is not None:
|
||||
dv2 = (v - v_prev).square().sum(dim=-1)
|
||||
v_avg2 = (0.5 * (v_norm + v_norm_prev)).square().clamp_min(eps_v2)
|
||||
pna = pna + dv2 / v_avg2
|
||||
v_prev = v
|
||||
v_norm_prev = v_norm
|
||||
z = z - v * dt
|
||||
if flow_masked:
|
||||
z[:, 0, :] = 0.0
|
||||
flow_pna = pna[:, 0]
|
||||
packet_pna = pna[:, 1:]
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
pna_median = self._masked_median(packet_pna, packet_mask)
|
||||
pna_mean = (packet_pna * packet_mask).sum(dim=-1) / packet_count
|
||||
masked_for_max = packet_pna.masked_fill(packet_mask == 0, float('-inf'))
|
||||
pna_max = masked_for_max.max(dim=-1).values
|
||||
pna_trimmed = self._masked_trimmed_mean(packet_pna, packet_mask)
|
||||
return {'pna_packet_median': pna_median, 'pna_packet_mean': pna_mean, 'pna_packet_max': pna_max, 'pna_packet_trimmed10_mean': pna_trimmed, 'pna_flow': flow_pna}
|
||||
|
||||
@torch.no_grad()
|
||||
def causal_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, _) = x.shape
|
||||
t = torch.full((B,), float(t_eval), device=x.device)
|
||||
v_full = self.velocity(x, t, key_padding_mask=kpm)
|
||||
causal = torch.triu(torch.ones(L, L, dtype=torch.bool, device=x.device), diagonal=1)
|
||||
v_causal = self.velocity(x, t, key_padding_mask=kpm, attn_mask_override=causal)
|
||||
diff = (v_full - v_causal).square().mean(dim=-1)
|
||||
flow_surprisal = diff[:, 0]
|
||||
packet_diff = diff[:, 1:]
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
packet_mean = (packet_diff * packet_mask).sum(dim=-1) / packet_count
|
||||
packet_median = self._masked_median(packet_diff, packet_mask)
|
||||
masked_for_max = packet_diff.masked_fill(packet_mask == 0, float('-inf'))
|
||||
packet_max = masked_for_max.max(dim=-1).values
|
||||
packet_trimmed = self._masked_trimmed_mean(packet_diff, packet_mask)
|
||||
total = (diff * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
return {'causal_surprisal_total': total, 'causal_surprisal_flow': flow_surprisal, 'causal_surprisal_packet_mean': packet_mean, 'causal_surprisal_packet_median': packet_median, 'causal_surprisal_packet_max': packet_max, 'causal_surprisal_packet_trimmed10_mean': packet_trimmed}
|
||||
|
||||
@torch.no_grad()
|
||||
def direction_consistency_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: tuple[float, ...]=(0.2, 0.4, 0.6, 0.8, 1.0)) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, _) = x.shape
|
||||
t_eval = tuple(t_eval)
|
||||
if len(t_eval) < 2:
|
||||
raise ValueError('direction_consistency_score needs >=2 t values')
|
||||
prev_v: torch.Tensor | None = None
|
||||
drift = x.new_zeros(B, L)
|
||||
n_pairs = len(t_eval) - 1
|
||||
for t_val in t_eval:
|
||||
t = torch.full((B,), float(t_val), device=x.device)
|
||||
v = self.velocity(x, t, key_padding_mask=kpm)
|
||||
if prev_v is not None:
|
||||
num = (prev_v * v).sum(dim=-1)
|
||||
denom = prev_v.norm(dim=-1).clamp_min(1e-08) * v.norm(dim=-1).clamp_min(1e-08)
|
||||
cos = num / denom
|
||||
drift = drift + (1.0 - cos)
|
||||
prev_v = v
|
||||
drift = drift / max(n_pairs, 1)
|
||||
flow_drift = drift[:, 0]
|
||||
packet_drift = drift[:, 1:]
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
packet_mean = (packet_drift * packet_mask).sum(dim=-1) / packet_count
|
||||
packet_median = self._masked_median(packet_drift, packet_mask)
|
||||
masked_for_max = packet_drift.masked_fill(packet_mask == 0, float('-inf'))
|
||||
packet_max = masked_for_max.max(dim=-1).values
|
||||
packet_trimmed = self._masked_trimmed_mean(packet_drift, packet_mask)
|
||||
total = (drift * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
return {'direction_drift_total': total, 'direction_drift_flow': flow_drift, 'direction_drift_packet_mean': packet_mean, 'direction_drift_packet_median': packet_median, 'direction_drift_packet_max': packet_max, 'direction_drift_packet_trimmed10_mean': packet_trimmed}
|
||||
|
||||
def inverse_flow_nll_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, n_steps: int=16, n_eps: int=4, compute_divergence: bool=True, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
||||
z = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, D) = z.shape
|
||||
dt = 1.0 / n_steps
|
||||
accum_div = torch.zeros(B, device=z.device)
|
||||
if compute_divergence:
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
z_req = z.detach().clone().requires_grad_(True)
|
||||
v = self.velocity(z_req, t, key_padding_mask=kpm)
|
||||
div_step = torch.zeros(B, device=z.device)
|
||||
for j in range(n_eps):
|
||||
eps = torch.randn_like(v)
|
||||
eps_masked = eps * mask[:, :, None]
|
||||
retain = j < n_eps - 1
|
||||
(g,) = torch.autograd.grad(outputs=v, inputs=z_req, grad_outputs=eps_masked, retain_graph=retain, create_graph=False)
|
||||
div_step = div_step + (eps_masked * g).sum(dim=(1, 2))
|
||||
div_step = div_step / float(n_eps)
|
||||
accum_div = accum_div + div_step * dt
|
||||
with torch.no_grad():
|
||||
z = (z_req - v * dt).detach()
|
||||
else:
|
||||
with torch.no_grad():
|
||||
for k in range(n_steps):
|
||||
t_val = 1.0 - k * dt
|
||||
t = torch.full((B,), t_val, device=z.device)
|
||||
v = self.velocity(z, t, key_padding_mask=kpm)
|
||||
z = z - v * dt
|
||||
with torch.no_grad():
|
||||
z_masked = z * mask[:, :, None]
|
||||
n_real = mask.sum(dim=-1).clamp_min(1.0)
|
||||
x0_quadratic = z_masked.reshape(B, -1).square().sum(dim=-1) / (n_real * float(D))
|
||||
nll_x0_only = x0_quadratic
|
||||
nll_div_only = accum_div / (n_real * float(D))
|
||||
nll_full = nll_x0_only + nll_div_only
|
||||
return {'nll_x0_only': nll_x0_only.detach(), 'nll_div_only': nll_div_only.detach(), 'nll_full': nll_full.detach()}
|
||||
|
||||
def jacobian_spectral_score(self, flow: torch.Tensor, packets: torch.Tensor, lens: torch.Tensor, t_eval: float=0.5, n_eps: int=4, generator: torch.Generator | None=None) -> dict[str, torch.Tensor]:
|
||||
x = self.build_tokens(flow, packets)
|
||||
mask = self._loss_mask(lens)
|
||||
kpm = mask == 0
|
||||
(B, L, D) = x.shape
|
||||
t = torch.full((B,), float(t_eval), device=x.device)
|
||||
packet_mask = mask[:, 1:]
|
||||
packet_count = packet_mask.sum(dim=-1).clamp_min(1.0)
|
||||
norms_total: list[torch.Tensor] = []
|
||||
norms_flow: list[torch.Tensor] = []
|
||||
norms_packet: list[torch.Tensor] = []
|
||||
for _ in range(n_eps):
|
||||
x_req = x.detach().clone().requires_grad_(True)
|
||||
v = self.velocity(x_req, t, key_padding_mask=kpm)
|
||||
eps = torch.randn(v.shape, device=v.device, generator=generator)
|
||||
(g,) = torch.autograd.grad(outputs=v, inputs=x_req, grad_outputs=eps, retain_graph=False, create_graph=False)
|
||||
e = g.square().mean(dim=-1)
|
||||
n_total = (e * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1.0)
|
||||
n_flow = e[:, 0]
|
||||
n_packet = (e[:, 1:] * packet_mask).sum(dim=-1) / packet_count
|
||||
norms_total.append(n_total.detach())
|
||||
norms_flow.append(n_flow.detach())
|
||||
norms_packet.append(n_packet.detach())
|
||||
|
||||
def _spectral_summary(samples: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
stack = torch.stack(samples, dim=1)
|
||||
mean = stack.mean(dim=1).clamp_min(1e-12)
|
||||
mx = stack.max(dim=1).values
|
||||
mn = stack.min(dim=1).values
|
||||
logfro = torch.log(mean)
|
||||
aniso = mx / mean
|
||||
min_over_max = mn / mx.clamp_min(1e-12)
|
||||
p = stack / stack.sum(dim=1, keepdim=True).clamp_min(1e-12)
|
||||
entropy = -(p * p.clamp_min(1e-12).log()).sum(dim=1)
|
||||
eff_rank = torch.exp(entropy)
|
||||
return {'logfro': logfro, 'anisotropy': aniso, 'min_over_max': min_over_max, 'eff_rank': eff_rank}
|
||||
out: dict[str, torch.Tensor] = {}
|
||||
for (tag, samples) in (('total', norms_total), ('flow', norms_flow), ('packet', norms_packet)):
|
||||
summ = _spectral_summary(samples)
|
||||
for (stat_name, val) in summ.items():
|
||||
out[f'jac_{stat_name}_{tag}'] = val
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, n: int, lens: torch.Tensor, device: torch.device, n_steps: int=50, method: str='euler') -> torch.Tensor:
|
||||
z = torch.randn(n, self.seq_len, self.token_dim, device=device)
|
||||
ts = torch.linspace(0.0, 1.0, n_steps + 1, device=device)
|
||||
kpm = self.key_padding_mask(lens.to(device))
|
||||
|
||||
def f(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.velocity(x, t.expand(x.shape[0]), key_padding_mask=kpm)
|
||||
if method == 'euler':
|
||||
for i in range(n_steps):
|
||||
z = z + f(ts[i], z) * (ts[i + 1] - ts[i])
|
||||
return z
|
||||
return odeint(f, z, ts, method=method)[-1]
|
||||
|
||||
def param_count(self) -> int:
|
||||
return sum((p.numel() for p in self.parameters()))
|
||||
157
Unified_CFM/tests/test_model_shapes.py
Normal file
157
Unified_CFM/tests/test_model_shapes.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from model import UnifiedCFMConfig, UnifiedTokenCFM
|
||||
|
||||
def _build_model():
|
||||
return UnifiedTokenCFM(UnifiedCFMConfig(T=4, packet_dim=3, flow_dim=5, d_model=16, n_layers=1, n_heads=4, time_dim=8))
|
||||
|
||||
def _build_reference_model(reference_mode: str):
|
||||
return UnifiedTokenCFM(UnifiedCFMConfig(T=4, packet_dim=3, flow_dim=5, d_model=16, n_layers=1, n_heads=4, time_dim=8, reference_mode=reference_mode))
|
||||
|
||||
def _sample_batch(seed: int=0):
|
||||
torch.manual_seed(seed)
|
||||
flow = torch.randn(2, 5)
|
||||
packets = torch.randn(2, 4, 3)
|
||||
lens = torch.tensor([4, 2])
|
||||
return (flow, packets, lens)
|
||||
|
||||
def test_unified_cfm_shapes_and_scores():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch()
|
||||
tokens = model.build_tokens(flow, packets)
|
||||
assert tokens.shape == (2, 5, 6)
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=2)
|
||||
assert 'terminal_norm' in traj
|
||||
assert traj['terminal_norm'].shape == (2,)
|
||||
vel = model.velocity_score(flow, packets, lens)
|
||||
assert set(vel) == {'velocity_total', 'velocity_flow', 'velocity_packet'}
|
||||
|
||||
def test_reference_mode_independent_token_shapes_and_scores():
|
||||
model = _build_reference_model('independent_token')
|
||||
(flow, packets, lens) = _sample_batch(seed=9)
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=2)
|
||||
assert traj['terminal_norm'].shape == (2,)
|
||||
assert torch.all(torch.isfinite(traj['curvature_packet']))
|
||||
|
||||
def test_reference_mode_block_diagonal_shapes_and_scores():
|
||||
model = _build_reference_model('block_diagonal')
|
||||
(flow, packets, lens) = _sample_batch(seed=10)
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
vel = model.velocity_score(flow, packets, lens)
|
||||
assert set(vel) == {'velocity_total', 'velocity_flow', 'velocity_packet'}
|
||||
|
||||
def test_trajectory_curvature_keys_and_shapes():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=1)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=4)
|
||||
for key in ('curvature_total', 'curvature_flow', 'curvature_packet'):
|
||||
assert key in traj, f'missing {key}'
|
||||
assert traj[key].shape == (2,)
|
||||
assert torch.all(torch.isfinite(traj[key]))
|
||||
assert torch.all(traj[key] >= 0)
|
||||
|
||||
def test_trajectory_curvature_zero_with_one_step():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=2)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=1)
|
||||
for key in ('curvature_total', 'curvature_flow', 'curvature_packet'):
|
||||
assert traj[key].abs().sum().item() == 0.0
|
||||
|
||||
def test_speed_normalized_packet_curvature_scores():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=11)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=4)
|
||||
keys = ('kappa2_speed2norm_packet_mean', 'kappa2_speed2norm_packet_median', 'kappa2_speed2norm_packet_trimmed10_mean')
|
||||
for key in keys:
|
||||
assert key in traj, f'missing {key}'
|
||||
assert traj[key].shape == (2,)
|
||||
assert torch.all(torch.isfinite(traj[key]))
|
||||
assert torch.all(traj[key] >= 0)
|
||||
one_step = model.trajectory_metrics(flow, packets, lens, n_steps=1)
|
||||
for key in keys:
|
||||
assert one_step[key].abs().sum().item() == 0.0
|
||||
|
||||
def test_score_profile_vt_shapes():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=3)
|
||||
t_eval = (0.1, 0.3, 0.5, 0.7, 0.9, 1.0)
|
||||
prof = model.score_profile_vt(flow, packets, lens, t_eval=t_eval)
|
||||
assert len(prof) == 3 * len(t_eval)
|
||||
for (k, v) in prof.items():
|
||||
assert v.shape == (2,), k
|
||||
assert torch.all(torch.isfinite(v))
|
||||
assert torch.all(v >= 0)
|
||||
assert 'velocity_total_t05' in prof
|
||||
assert 'velocity_flow_t10' in prof
|
||||
assert 'velocity_packet_t01' in prof
|
||||
|
||||
def test_compute_loss_backward_compat():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=5)
|
||||
torch.manual_seed(0)
|
||||
a = model.compute_loss(flow, packets, lens)
|
||||
torch.manual_seed(0)
|
||||
b = model.compute_loss(flow, packets, lens, lambda_flow=0.0, lambda_packet=0.0)
|
||||
assert torch.allclose(a, b), f'λ=0 must match old loss; got {a.item()} vs {b.item()}'
|
||||
|
||||
def test_compute_loss_aux_components_finite():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=6)
|
||||
torch.manual_seed(7)
|
||||
comp = model.compute_loss(flow, packets, lens, lambda_flow=0.1, lambda_packet=0.1, return_components=True)
|
||||
assert set(comp) == {'total', 'main', 'aux_flow', 'aux_packet'}
|
||||
for (k, v) in comp.items():
|
||||
assert torch.isfinite(v), k
|
||||
assert v >= 0, f'{k} negative: {v.item()}'
|
||||
|
||||
def test_compute_loss_aux_affects_gradient():
|
||||
model = _build_model()
|
||||
with torch.no_grad():
|
||||
model.velocity.out.weight.normal_(std=0.01)
|
||||
for block in model.velocity.blocks:
|
||||
block.cond_proj.weight.normal_(std=0.01)
|
||||
(flow, packets, lens) = _sample_batch(seed=8)
|
||||
torch.manual_seed(10)
|
||||
total = model.compute_loss(flow, packets, lens, lambda_flow=1.0, lambda_packet=1.0)
|
||||
total.backward()
|
||||
some_grad = False
|
||||
for p in model.parameters():
|
||||
if p.grad is not None and p.grad.abs().sum().item() > 0:
|
||||
some_grad = True
|
||||
break
|
||||
assert some_grad, 'no gradient flowed through aux losses'
|
||||
|
||||
def test_consistency_score_shapes():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=9)
|
||||
cs = model.consistency_score(flow, packets, lens)
|
||||
assert set(cs) == {'flow_consistency', 'packet_consistency', 'consistency_total'}
|
||||
for (k, v) in cs.items():
|
||||
assert v.shape == (2,), k
|
||||
assert torch.all(torch.isfinite(v))
|
||||
assert torch.all(v >= 0), k
|
||||
|
||||
def test_jacobian_hutchinson_shapes_and_nonneg():
|
||||
model = _build_model()
|
||||
with torch.no_grad():
|
||||
model.velocity.out.weight.normal_(std=0.01)
|
||||
for block in model.velocity.blocks:
|
||||
block.cond_proj.weight.normal_(std=0.01)
|
||||
(flow, packets, lens) = _sample_batch(seed=4)
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
jac = model.jacobian_hutchinson(flow, packets, lens, t_eval=(0.5,), n_eps=2, generator=gen)
|
||||
assert set(jac) == {'jacobian_total', 'jacobian_flow', 'jacobian_packet'}
|
||||
for (k, v) in jac.items():
|
||||
assert v.shape == (2,), k
|
||||
assert torch.all(torch.isfinite(v))
|
||||
assert torch.all(v >= 0), f'{k} has negative value'
|
||||
147
Unified_CFM/train.py
Normal file
147
Unified_CFM/train.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from data import UnifiedData, load_unified_data, subsample_train
|
||||
from model import UnifiedCFMConfig, UnifiedTokenCFM
|
||||
|
||||
def _device(dev_arg: str) -> torch.device:
|
||||
if dev_arg == 'auto':
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return torch.device(dev_arg)
|
||||
|
||||
def _batch_score(model: UnifiedTokenCFM, flow_np: np.ndarray, packet_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int) -> dict[str, np.ndarray]:
|
||||
out: dict[str, list[np.ndarray]] = {}
|
||||
model.eval()
|
||||
for start in range(0, len(flow_np), batch_size):
|
||||
sl = slice(start, start + batch_size)
|
||||
flow = torch.from_numpy(flow_np[sl]).float().to(device)
|
||||
packets = torch.from_numpy(packet_np[sl]).float().to(device)
|
||||
lens = torch.from_numpy(len_np[sl]).long().to(device)
|
||||
metrics = model.trajectory_metrics(flow, packets, lens, n_steps=n_steps)
|
||||
vel = model.velocity_score(flow, packets, lens)
|
||||
metrics.update(vel)
|
||||
for (k, v) in metrics.items():
|
||||
out.setdefault(k, []).append(v.detach().cpu().numpy())
|
||||
return {k: np.concatenate(v, axis=0) for (k, v) in out.items()}
|
||||
|
||||
def _quick_eval(model: UnifiedTokenCFM, data: UnifiedData, device: torch.device, cfg: dict[str, Any]) -> dict[str, float]:
|
||||
n_eval = int(cfg.get('eval_n', 2000))
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
def pick(n: int) -> np.ndarray:
|
||||
m = min(n_eval, n)
|
||||
return rng.choice(n, m, replace=False)
|
||||
vi = pick(len(data.val_flow))
|
||||
ai = pick(len(data.attack_flow))
|
||||
v = _batch_score(model, data.val_flow[vi], data.val_packets[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
|
||||
a = _batch_score(model, data.attack_flow[ai], data.attack_packets[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
|
||||
y = np.concatenate([np.zeros(len(vi)), np.ones(len(ai))])
|
||||
result: dict[str, float] = {}
|
||||
for key in sorted(v.keys()):
|
||||
s = np.concatenate([v[key], a[key]])
|
||||
s = np.nan_to_num(s, nan=0.0, posinf=1000000000000.0, neginf=-1000000000000.0)
|
||||
result[f'auroc_{key}'] = float(roc_auc_score(y, s))
|
||||
return result
|
||||
|
||||
def train(cfg: dict[str, Any]) -> Path:
|
||||
device = _device(str(cfg.get('device', 'auto')))
|
||||
save_dir = Path(cfg['save_dir'])
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(save_dir / 'config.yaml', 'w') as f:
|
||||
yaml.safe_dump(cfg, f)
|
||||
seed = int(cfg.get('seed', 42))
|
||||
data_seed = int(cfg.get('data_seed', seed))
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
print(f'Device: {device}')
|
||||
print(f'[seed] model={seed} data={data_seed}')
|
||||
feature_columns = cfg.get('flow_feature_columns')
|
||||
data = load_unified_data(packets_npz=Path(cfg['packets_npz']) if cfg.get('packets_npz') else None, source_store=Path(cfg['source_store']) if cfg.get('source_store') else None, flows_parquet=Path(cfg['flows_parquet']), flow_features_path=Path(cfg['flow_features_path']) if cfg.get('flow_features_path') else None, flow_feature_columns=feature_columns, flow_features_align=str(cfg.get('flow_features_align', 'auto')), T=int(cfg['T']), split_seed=data_seed, train_ratio=float(cfg.get('train_ratio', 0.8)), benign_label=str(cfg.get('benign_label', 'normal')), min_len=int(cfg.get('min_len', 2)), packet_preprocess=str(cfg.get('packet_preprocess', 'mixed_dequant')), attack_cap=int(cfg['attack_cap']) if cfg.get('attack_cap') else None, val_cap=int(cfg['val_cap']) if cfg.get('val_cap') else None)
|
||||
print(f'[data] T={data.T} packet_D={data.packet_dim} flow_D={data.flow_dim} train={len(data.train_flow):,} val={len(data.val_flow):,} attack={len(data.attack_flow):,}')
|
||||
(tr_f, tr_p, tr_l) = subsample_train(data, int(cfg.get('n_train', 0)), data_seed)
|
||||
ds = TensorDataset(torch.from_numpy(tr_f).float(), torch.from_numpy(tr_p).float(), torch.from_numpy(tr_l).long())
|
||||
loader = DataLoader(ds, batch_size=int(cfg['batch_size']), shuffle=True, drop_last=True, num_workers=int(cfg.get('num_workers', 0)), pin_memory=device.type == 'cuda')
|
||||
print(f'[data] using {len(ds):,} benign training flows')
|
||||
model_cfg = UnifiedCFMConfig(T=data.T, packet_dim=data.packet_dim, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'), d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']), mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)), sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)), reference_mode=cfg.get('reference_mode'))
|
||||
model = UnifiedTokenCFM(model_cfg).to(device)
|
||||
print(f'[model] params={model.param_count():,} token_dim={model.token_dim} seq_len={model.seq_len} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} reference_mode={model_cfg.reference_mode}')
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=float(cfg['lr']), weight_decay=float(cfg.get('weight_decay', 0.01)))
|
||||
total_steps = max(1, int(cfg['epochs']) * len(loader))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
|
||||
history: dict[str, list[Any]] = {'epoch': [], 'loss': [], 'eval': []}
|
||||
lambda_flow = float(cfg.get('lambda_flow', 0.0))
|
||||
lambda_packet = float(cfg.get('lambda_packet', 0.0))
|
||||
packet_mask_ratio = float(cfg.get('packet_mask_ratio', 0.5))
|
||||
aux_enabled = lambda_flow > 0.0 or lambda_packet > 0.0
|
||||
if aux_enabled:
|
||||
print(f'[loss] λ_flow={lambda_flow} λ_packet={lambda_packet} packet_mask_ratio={packet_mask_ratio}')
|
||||
for epoch in range(1, int(cfg['epochs']) + 1):
|
||||
model.train()
|
||||
losses: list[float] = []
|
||||
aux_flow_sum = 0.0
|
||||
aux_packet_sum = 0.0
|
||||
n_steps_this_epoch = 0
|
||||
t0 = time.time()
|
||||
for (flow, packets, lens) in loader:
|
||||
flow = flow.to(device, non_blocking=True)
|
||||
packets = packets.to(device, non_blocking=True)
|
||||
lens = lens.to(device, non_blocking=True)
|
||||
if aux_enabled:
|
||||
comp = model.compute_loss(flow, packets, lens, lambda_flow=lambda_flow, lambda_packet=lambda_packet, packet_mask_ratio=packet_mask_ratio, return_components=True)
|
||||
loss = comp['total']
|
||||
aux_flow_sum += float(comp['aux_flow'].item())
|
||||
aux_packet_sum += float(comp['aux_packet'].item())
|
||||
else:
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
opt.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get('grad_clip', 1.0)))
|
||||
opt.step()
|
||||
sched.step()
|
||||
losses.append(float(loss.item()))
|
||||
n_steps_this_epoch += 1
|
||||
mean_loss = float(np.mean(losses)) if losses else float('nan')
|
||||
eval_metrics: dict[str, float] | None = None
|
||||
if epoch % int(cfg.get('eval_every', 5)) == 0 or epoch == int(cfg['epochs']):
|
||||
eval_metrics = _quick_eval(model, data, device, cfg)
|
||||
history['epoch'].append(epoch)
|
||||
history['loss'].append(mean_loss)
|
||||
history['eval'].append(eval_metrics)
|
||||
elapsed = time.time() - t0
|
||||
terminal = ''
|
||||
if eval_metrics:
|
||||
terminal = f" auroc_terminal={eval_metrics['auroc_terminal_norm']:.3f}"
|
||||
if aux_enabled and n_steps_this_epoch:
|
||||
terminal += f' aux_flow={aux_flow_sum / n_steps_this_epoch:.4f} aux_pkt={aux_packet_sum / n_steps_this_epoch:.4f}'
|
||||
print(f"[epoch {epoch:>3d}/{cfg['epochs']:<3d}] ({elapsed:.1f}s) loss={mean_loss:.4f}{terminal}")
|
||||
if not np.isfinite(mean_loss):
|
||||
raise RuntimeError(f'non-finite loss at epoch {epoch}')
|
||||
payload = {'model_state_dict': model.state_dict(), 'model_cfg': asdict(model_cfg), 'packet_mean': data.packet_mean, 'packet_std': data.packet_std, 'flow_mean': data.flow_mean, 'flow_std': data.flow_std, 'packet_preprocess': data.packet_preprocess, 'flow_feature_names': np.asarray(data.flow_feature_names), 'packet_feature_names': np.asarray(data.packet_feature_names)}
|
||||
torch.save(payload, save_dir / 'model.pt')
|
||||
with open(save_dir / 'history.json', 'w') as f:
|
||||
json.dump(history, f, indent=2, default=str)
|
||||
print(f"[saved] {save_dir / 'model.pt'}")
|
||||
return save_dir
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument('--config', type=Path, required=True)
|
||||
p.add_argument('--override', type=str, nargs='*', default=[])
|
||||
args = p.parse_args()
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
for override in args.override:
|
||||
(key, value) = override.split('=', 1)
|
||||
cfg[key] = yaml.safe_load(value)
|
||||
train(cfg)
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user