Add baseline reproduction: Shafir NF 2-NF ensemble (17/18 cells), ConMD Table I citation, JANUS thresholded F1 across 4 datasets

This commit is contained in:
2026-05-08 11:47:27 +08:00
parent dc22e20616
commit c33efc290a
4 changed files with 498 additions and 27 deletions

View File

@@ -12,7 +12,7 @@ import pandas as pd
os.environ.setdefault('JAX_PLATFORMS', 'cpu')
warnings.filterwarnings('ignore')
import optax
from pzflow import Flow
from pzflow import Flow, FlowEnsemble
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.preprocessing import StandardScaler
REPO = Path(__file__).resolve().parents[2]
@@ -175,7 +175,7 @@ def _safe_metric(fn, y, s) -> float:
except ValueError:
return float('nan')
def _train_and_score(train, val, atk, feat_cols, *, epochs, lr, optimizer):
def _train_and_score(train, val, atk, feat_cols, *, epochs, lr, optimizer, n_flows=1, seed=0):
raw_train = train[feat_cols].astype(np.float64).values
keep = raw_train.std(axis=0) > 0
if not keep.all():
@@ -200,9 +200,16 @@ def _train_and_score(train, val, atk, feat_cols, *, epochs, lr, optimizer):
opt = optax.sgd(learning_rate=lr)
else:
opt = optax.adam(learning_rate=lr)
flow = Flow(df_train.columns.tolist())
if n_flows > 1:
flow = FlowEnsemble(df_train.columns.tolist(), N=n_flows)
else:
flow = Flow(df_train.columns.tolist())
t0 = time.time()
losses = flow.train(df_train, optimizer=opt, epochs=epochs, verbose=False)
if n_flows > 1:
losses_dict = flow.train(df_train, optimizer=opt, epochs=epochs, verbose=False, seed=seed)
losses = list(losses_dict.values())[0]
else:
losses = flow.train(df_train, optimizer=opt, epochs=epochs, verbose=False)
t_train = time.time() - t0
t0 = time.time()
lp_val = np.asarray(flow.log_prob(df_val))
@@ -229,6 +236,7 @@ def main():
p.add_argument('--epochs', type=int, default=100)
p.add_argument('--lr', type=float, default=0.001)
p.add_argument('--optimizer', choices=['sgd', 'adam'], default='sgd')
p.add_argument('--n-flows', type=int, default=1, help='1 = single NF (Shafir paper baseline mode); 2 = paper headline ensemble')
args = p.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
(src_name, tgt_name, caps) = PROTOCOL_CONFIG[args.protocol]
@@ -247,15 +255,16 @@ def main():
print(f' [features] within: {len(feat_cols)} cols')
(train, val, atk) = _sample_within(src_df, caps, args.seed)
print(f' [data] train={len(train):,} val={len(val):,} attack={len(atk):,} D={len(feat_cols)}')
res = _train_and_score(train, val, atk, feat_cols, epochs=args.epochs, lr=args.lr, optimizer=args.optimizer)
res = _train_and_score(train, val, atk, feat_cols, epochs=args.epochs, lr=args.lr, optimizer=args.optimizer, n_flows=args.n_flows, seed=args.seed)
(val_score, atk_score) = (res['score_val'], res['score_atk'])
y = np.r_[np.zeros(len(val_score)), np.ones(len(atk_score))]
s = np.r_[val_score, atk_score]
overall = {'neg_log_prob': {'auroc': _safe_metric(roc_auc_score, y, s), 'auprc': _safe_metric(average_precision_score, y, s)}}
a_labels = atk['cls_label'].astype(str).to_numpy()
per_cls = _per_class(val_score, atk_score, a_labels)
out = {'method': 'shafir_nf_csv', 'protocol': args.protocol, 'seed': args.seed, 'src_dataset': src_name, 'tgt_dataset': tgt_name, 'feature_set': feat_cols, 'n_features': len(feat_cols), 'n_train': len(train), 'n_val': len(val), 'n_atk': len(atk), 'epochs': args.epochs, 'lr': args.lr, 'optimizer': args.optimizer, 't_train_sec': round(res['t_train'], 2), 't_score_sec': round(res['t_score'], 2), 'loss_first_last': [float(res['losses'][0]), float(res['losses'][-1])], 'overall': overall, 'per_class': per_cls}
out_json = args.out_dir / f'{args.protocol}_seed{args.seed}.json'
out = {'method': 'shafir_nf_csv', 'protocol': args.protocol, 'seed': args.seed, 'n_flows': args.n_flows, 'src_dataset': src_name, 'tgt_dataset': tgt_name, 'feature_set': feat_cols, 'n_features': len(feat_cols), 'n_train': len(train), 'n_val': len(val), 'n_atk': len(atk), 'epochs': args.epochs, 'lr': args.lr, 'optimizer': args.optimizer, 't_train_sec': round(res['t_train'], 2), 't_score_sec': round(res['t_score'], 2), 'loss_first_last': [float(res['losses'][0]), float(res['losses'][-1])], 'overall': overall, 'per_class': per_cls}
suffix = f"_n{args.n_flows}" if args.n_flows > 1 else ""
out_json = args.out_dir / f'{args.protocol}_seed{args.seed}{suffix}.json'
out_json.write_text(json.dumps(out, indent=2))
npz_path = out_json.with_suffix('.npz')
np.savez_compressed(npz_path, b_neg_log_prob=val_score, a_neg_log_prob=atk_score, a_labels=a_labels.astype(str), losses=res['losses'])