Initial commit: code, paper, small artifacts
This commit is contained in:
147
Unified_CFM/train.py
Normal file
147
Unified_CFM/train.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from data import UnifiedData, load_unified_data, subsample_train
|
||||
from model import UnifiedCFMConfig, UnifiedTokenCFM
|
||||
|
||||
def _device(dev_arg: str) -> torch.device:
|
||||
if dev_arg == 'auto':
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return torch.device(dev_arg)
|
||||
|
||||
def _batch_score(model: UnifiedTokenCFM, flow_np: np.ndarray, packet_np: np.ndarray, len_np: np.ndarray, device: torch.device, *, batch_size: int, n_steps: int) -> dict[str, np.ndarray]:
|
||||
out: dict[str, list[np.ndarray]] = {}
|
||||
model.eval()
|
||||
for start in range(0, len(flow_np), batch_size):
|
||||
sl = slice(start, start + batch_size)
|
||||
flow = torch.from_numpy(flow_np[sl]).float().to(device)
|
||||
packets = torch.from_numpy(packet_np[sl]).float().to(device)
|
||||
lens = torch.from_numpy(len_np[sl]).long().to(device)
|
||||
metrics = model.trajectory_metrics(flow, packets, lens, n_steps=n_steps)
|
||||
vel = model.velocity_score(flow, packets, lens)
|
||||
metrics.update(vel)
|
||||
for (k, v) in metrics.items():
|
||||
out.setdefault(k, []).append(v.detach().cpu().numpy())
|
||||
return {k: np.concatenate(v, axis=0) for (k, v) in out.items()}
|
||||
|
||||
def _quick_eval(model: UnifiedTokenCFM, data: UnifiedData, device: torch.device, cfg: dict[str, Any]) -> dict[str, float]:
|
||||
n_eval = int(cfg.get('eval_n', 2000))
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
def pick(n: int) -> np.ndarray:
|
||||
m = min(n_eval, n)
|
||||
return rng.choice(n, m, replace=False)
|
||||
vi = pick(len(data.val_flow))
|
||||
ai = pick(len(data.attack_flow))
|
||||
v = _batch_score(model, data.val_flow[vi], data.val_packets[vi], data.val_len[vi], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
|
||||
a = _batch_score(model, data.attack_flow[ai], data.attack_packets[ai], data.attack_len[ai], device, batch_size=int(cfg.get('eval_batch_size', 512)), n_steps=int(cfg.get('eval_n_steps', 8)))
|
||||
y = np.concatenate([np.zeros(len(vi)), np.ones(len(ai))])
|
||||
result: dict[str, float] = {}
|
||||
for key in sorted(v.keys()):
|
||||
s = np.concatenate([v[key], a[key]])
|
||||
s = np.nan_to_num(s, nan=0.0, posinf=1000000000000.0, neginf=-1000000000000.0)
|
||||
result[f'auroc_{key}'] = float(roc_auc_score(y, s))
|
||||
return result
|
||||
|
||||
def train(cfg: dict[str, Any]) -> Path:
|
||||
device = _device(str(cfg.get('device', 'auto')))
|
||||
save_dir = Path(cfg['save_dir'])
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(save_dir / 'config.yaml', 'w') as f:
|
||||
yaml.safe_dump(cfg, f)
|
||||
seed = int(cfg.get('seed', 42))
|
||||
data_seed = int(cfg.get('data_seed', seed))
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
print(f'Device: {device}')
|
||||
print(f'[seed] model={seed} data={data_seed}')
|
||||
feature_columns = cfg.get('flow_feature_columns')
|
||||
data = load_unified_data(packets_npz=Path(cfg['packets_npz']) if cfg.get('packets_npz') else None, source_store=Path(cfg['source_store']) if cfg.get('source_store') else None, flows_parquet=Path(cfg['flows_parquet']), flow_features_path=Path(cfg['flow_features_path']) if cfg.get('flow_features_path') else None, flow_feature_columns=feature_columns, flow_features_align=str(cfg.get('flow_features_align', 'auto')), T=int(cfg['T']), split_seed=data_seed, train_ratio=float(cfg.get('train_ratio', 0.8)), benign_label=str(cfg.get('benign_label', 'normal')), min_len=int(cfg.get('min_len', 2)), packet_preprocess=str(cfg.get('packet_preprocess', 'mixed_dequant')), attack_cap=int(cfg['attack_cap']) if cfg.get('attack_cap') else None, val_cap=int(cfg['val_cap']) if cfg.get('val_cap') else None)
|
||||
print(f'[data] T={data.T} packet_D={data.packet_dim} flow_D={data.flow_dim} train={len(data.train_flow):,} val={len(data.val_flow):,} attack={len(data.attack_flow):,}')
|
||||
(tr_f, tr_p, tr_l) = subsample_train(data, int(cfg.get('n_train', 0)), data_seed)
|
||||
ds = TensorDataset(torch.from_numpy(tr_f).float(), torch.from_numpy(tr_p).float(), torch.from_numpy(tr_l).long())
|
||||
loader = DataLoader(ds, batch_size=int(cfg['batch_size']), shuffle=True, drop_last=True, num_workers=int(cfg.get('num_workers', 0)), pin_memory=device.type == 'cuda')
|
||||
print(f'[data] using {len(ds):,} benign training flows')
|
||||
model_cfg = UnifiedCFMConfig(T=data.T, packet_dim=data.packet_dim, flow_dim=data.flow_dim, token_dim=cfg.get('token_dim'), d_model=int(cfg['d_model']), n_layers=int(cfg['n_layers']), n_heads=int(cfg['n_heads']), mlp_ratio=float(cfg.get('mlp_ratio', 4.0)), time_dim=int(cfg.get('time_dim', 64)), sigma=float(cfg.get('sigma', 0.1)), use_ot=bool(cfg.get('use_ot', False)), reference_mode=cfg.get('reference_mode'))
|
||||
model = UnifiedTokenCFM(model_cfg).to(device)
|
||||
print(f'[model] params={model.param_count():,} token_dim={model.token_dim} seq_len={model.seq_len} sigma={model_cfg.sigma} use_ot={model_cfg.use_ot} reference_mode={model_cfg.reference_mode}')
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=float(cfg['lr']), weight_decay=float(cfg.get('weight_decay', 0.01)))
|
||||
total_steps = max(1, int(cfg['epochs']) * len(loader))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
|
||||
history: dict[str, list[Any]] = {'epoch': [], 'loss': [], 'eval': []}
|
||||
lambda_flow = float(cfg.get('lambda_flow', 0.0))
|
||||
lambda_packet = float(cfg.get('lambda_packet', 0.0))
|
||||
packet_mask_ratio = float(cfg.get('packet_mask_ratio', 0.5))
|
||||
aux_enabled = lambda_flow > 0.0 or lambda_packet > 0.0
|
||||
if aux_enabled:
|
||||
print(f'[loss] λ_flow={lambda_flow} λ_packet={lambda_packet} packet_mask_ratio={packet_mask_ratio}')
|
||||
for epoch in range(1, int(cfg['epochs']) + 1):
|
||||
model.train()
|
||||
losses: list[float] = []
|
||||
aux_flow_sum = 0.0
|
||||
aux_packet_sum = 0.0
|
||||
n_steps_this_epoch = 0
|
||||
t0 = time.time()
|
||||
for (flow, packets, lens) in loader:
|
||||
flow = flow.to(device, non_blocking=True)
|
||||
packets = packets.to(device, non_blocking=True)
|
||||
lens = lens.to(device, non_blocking=True)
|
||||
if aux_enabled:
|
||||
comp = model.compute_loss(flow, packets, lens, lambda_flow=lambda_flow, lambda_packet=lambda_packet, packet_mask_ratio=packet_mask_ratio, return_components=True)
|
||||
loss = comp['total']
|
||||
aux_flow_sum += float(comp['aux_flow'].item())
|
||||
aux_packet_sum += float(comp['aux_packet'].item())
|
||||
else:
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
opt.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get('grad_clip', 1.0)))
|
||||
opt.step()
|
||||
sched.step()
|
||||
losses.append(float(loss.item()))
|
||||
n_steps_this_epoch += 1
|
||||
mean_loss = float(np.mean(losses)) if losses else float('nan')
|
||||
eval_metrics: dict[str, float] | None = None
|
||||
if epoch % int(cfg.get('eval_every', 5)) == 0 or epoch == int(cfg['epochs']):
|
||||
eval_metrics = _quick_eval(model, data, device, cfg)
|
||||
history['epoch'].append(epoch)
|
||||
history['loss'].append(mean_loss)
|
||||
history['eval'].append(eval_metrics)
|
||||
elapsed = time.time() - t0
|
||||
terminal = ''
|
||||
if eval_metrics:
|
||||
terminal = f" auroc_terminal={eval_metrics['auroc_terminal_norm']:.3f}"
|
||||
if aux_enabled and n_steps_this_epoch:
|
||||
terminal += f' aux_flow={aux_flow_sum / n_steps_this_epoch:.4f} aux_pkt={aux_packet_sum / n_steps_this_epoch:.4f}'
|
||||
print(f"[epoch {epoch:>3d}/{cfg['epochs']:<3d}] ({elapsed:.1f}s) loss={mean_loss:.4f}{terminal}")
|
||||
if not np.isfinite(mean_loss):
|
||||
raise RuntimeError(f'non-finite loss at epoch {epoch}')
|
||||
payload = {'model_state_dict': model.state_dict(), 'model_cfg': asdict(model_cfg), 'packet_mean': data.packet_mean, 'packet_std': data.packet_std, 'flow_mean': data.flow_mean, 'flow_std': data.flow_std, 'packet_preprocess': data.packet_preprocess, 'flow_feature_names': np.asarray(data.flow_feature_names), 'packet_feature_names': np.asarray(data.packet_feature_names)}
|
||||
torch.save(payload, save_dir / 'model.pt')
|
||||
with open(save_dir / 'history.json', 'w') as f:
|
||||
json.dump(history, f, indent=2, default=str)
|
||||
print(f"[saved] {save_dir / 'model.pt'}")
|
||||
return save_dir
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument('--config', type=Path, required=True)
|
||||
p.add_argument('--override', type=str, nargs='*', default=[])
|
||||
args = p.parse_args()
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
for override in args.override:
|
||||
(key, value) = override.split('=', 1)
|
||||
cfg[key] = yaml.safe_load(value)
|
||||
train(cfg)
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user