Initial commit: code, paper, small artifacts
This commit is contained in:
157
Unified_CFM/tests/test_model_shapes.py
Normal file
157
Unified_CFM/tests/test_model_shapes.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from model import UnifiedCFMConfig, UnifiedTokenCFM
|
||||
|
||||
def _build_model():
|
||||
return UnifiedTokenCFM(UnifiedCFMConfig(T=4, packet_dim=3, flow_dim=5, d_model=16, n_layers=1, n_heads=4, time_dim=8))
|
||||
|
||||
def _build_reference_model(reference_mode: str):
|
||||
return UnifiedTokenCFM(UnifiedCFMConfig(T=4, packet_dim=3, flow_dim=5, d_model=16, n_layers=1, n_heads=4, time_dim=8, reference_mode=reference_mode))
|
||||
|
||||
def _sample_batch(seed: int=0):
|
||||
torch.manual_seed(seed)
|
||||
flow = torch.randn(2, 5)
|
||||
packets = torch.randn(2, 4, 3)
|
||||
lens = torch.tensor([4, 2])
|
||||
return (flow, packets, lens)
|
||||
|
||||
def test_unified_cfm_shapes_and_scores():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch()
|
||||
tokens = model.build_tokens(flow, packets)
|
||||
assert tokens.shape == (2, 5, 6)
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=2)
|
||||
assert 'terminal_norm' in traj
|
||||
assert traj['terminal_norm'].shape == (2,)
|
||||
vel = model.velocity_score(flow, packets, lens)
|
||||
assert set(vel) == {'velocity_total', 'velocity_flow', 'velocity_packet'}
|
||||
|
||||
def test_reference_mode_independent_token_shapes_and_scores():
|
||||
model = _build_reference_model('independent_token')
|
||||
(flow, packets, lens) = _sample_batch(seed=9)
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=2)
|
||||
assert traj['terminal_norm'].shape == (2,)
|
||||
assert torch.all(torch.isfinite(traj['curvature_packet']))
|
||||
|
||||
def test_reference_mode_block_diagonal_shapes_and_scores():
|
||||
model = _build_reference_model('block_diagonal')
|
||||
(flow, packets, lens) = _sample_batch(seed=10)
|
||||
loss = model.compute_loss(flow, packets, lens)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
vel = model.velocity_score(flow, packets, lens)
|
||||
assert set(vel) == {'velocity_total', 'velocity_flow', 'velocity_packet'}
|
||||
|
||||
def test_trajectory_curvature_keys_and_shapes():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=1)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=4)
|
||||
for key in ('curvature_total', 'curvature_flow', 'curvature_packet'):
|
||||
assert key in traj, f'missing {key}'
|
||||
assert traj[key].shape == (2,)
|
||||
assert torch.all(torch.isfinite(traj[key]))
|
||||
assert torch.all(traj[key] >= 0)
|
||||
|
||||
def test_trajectory_curvature_zero_with_one_step():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=2)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=1)
|
||||
for key in ('curvature_total', 'curvature_flow', 'curvature_packet'):
|
||||
assert traj[key].abs().sum().item() == 0.0
|
||||
|
||||
def test_speed_normalized_packet_curvature_scores():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=11)
|
||||
traj = model.trajectory_metrics(flow, packets, lens, n_steps=4)
|
||||
keys = ('kappa2_speed2norm_packet_mean', 'kappa2_speed2norm_packet_median', 'kappa2_speed2norm_packet_trimmed10_mean')
|
||||
for key in keys:
|
||||
assert key in traj, f'missing {key}'
|
||||
assert traj[key].shape == (2,)
|
||||
assert torch.all(torch.isfinite(traj[key]))
|
||||
assert torch.all(traj[key] >= 0)
|
||||
one_step = model.trajectory_metrics(flow, packets, lens, n_steps=1)
|
||||
for key in keys:
|
||||
assert one_step[key].abs().sum().item() == 0.0
|
||||
|
||||
def test_score_profile_vt_shapes():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=3)
|
||||
t_eval = (0.1, 0.3, 0.5, 0.7, 0.9, 1.0)
|
||||
prof = model.score_profile_vt(flow, packets, lens, t_eval=t_eval)
|
||||
assert len(prof) == 3 * len(t_eval)
|
||||
for (k, v) in prof.items():
|
||||
assert v.shape == (2,), k
|
||||
assert torch.all(torch.isfinite(v))
|
||||
assert torch.all(v >= 0)
|
||||
assert 'velocity_total_t05' in prof
|
||||
assert 'velocity_flow_t10' in prof
|
||||
assert 'velocity_packet_t01' in prof
|
||||
|
||||
def test_compute_loss_backward_compat():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=5)
|
||||
torch.manual_seed(0)
|
||||
a = model.compute_loss(flow, packets, lens)
|
||||
torch.manual_seed(0)
|
||||
b = model.compute_loss(flow, packets, lens, lambda_flow=0.0, lambda_packet=0.0)
|
||||
assert torch.allclose(a, b), f'λ=0 must match old loss; got {a.item()} vs {b.item()}'
|
||||
|
||||
def test_compute_loss_aux_components_finite():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=6)
|
||||
torch.manual_seed(7)
|
||||
comp = model.compute_loss(flow, packets, lens, lambda_flow=0.1, lambda_packet=0.1, return_components=True)
|
||||
assert set(comp) == {'total', 'main', 'aux_flow', 'aux_packet'}
|
||||
for (k, v) in comp.items():
|
||||
assert torch.isfinite(v), k
|
||||
assert v >= 0, f'{k} negative: {v.item()}'
|
||||
|
||||
def test_compute_loss_aux_affects_gradient():
|
||||
model = _build_model()
|
||||
with torch.no_grad():
|
||||
model.velocity.out.weight.normal_(std=0.01)
|
||||
for block in model.velocity.blocks:
|
||||
block.cond_proj.weight.normal_(std=0.01)
|
||||
(flow, packets, lens) = _sample_batch(seed=8)
|
||||
torch.manual_seed(10)
|
||||
total = model.compute_loss(flow, packets, lens, lambda_flow=1.0, lambda_packet=1.0)
|
||||
total.backward()
|
||||
some_grad = False
|
||||
for p in model.parameters():
|
||||
if p.grad is not None and p.grad.abs().sum().item() > 0:
|
||||
some_grad = True
|
||||
break
|
||||
assert some_grad, 'no gradient flowed through aux losses'
|
||||
|
||||
def test_consistency_score_shapes():
|
||||
model = _build_model()
|
||||
(flow, packets, lens) = _sample_batch(seed=9)
|
||||
cs = model.consistency_score(flow, packets, lens)
|
||||
assert set(cs) == {'flow_consistency', 'packet_consistency', 'consistency_total'}
|
||||
for (k, v) in cs.items():
|
||||
assert v.shape == (2,), k
|
||||
assert torch.all(torch.isfinite(v))
|
||||
assert torch.all(v >= 0), k
|
||||
|
||||
def test_jacobian_hutchinson_shapes_and_nonneg():
|
||||
model = _build_model()
|
||||
with torch.no_grad():
|
||||
model.velocity.out.weight.normal_(std=0.01)
|
||||
for block in model.velocity.blocks:
|
||||
block.cond_proj.weight.normal_(std=0.01)
|
||||
(flow, packets, lens) = _sample_batch(seed=4)
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
jac = model.jacobian_hutchinson(flow, packets, lens, t_eval=(0.5,), n_eps=2, generator=gen)
|
||||
assert set(jac) == {'jacobian_total', 'jacobian_flow', 'jacobian_packet'}
|
||||
for (k, v) in jac.items():
|
||||
assert v.shape == (2,), k
|
||||
assert torch.all(torch.isfinite(v))
|
||||
assert torch.all(v >= 0), f'{k} has negative value'
|
||||
Reference in New Issue
Block a user