#!/usr/bin/env python3 """End-to-end evaluation: join LLM predictions with labels and aggregate metrics. Inputs: --predictions-jsonl One file per method variant, produced by run_llm_inference.py. The file's basename is used as the method name in the metrics table. --labeled-targets evaluation_batch jsonl (target_id, label, ...) Output: --output-dir/metrics.md Paper-Table-2-style table: method | AUPRC | AUROC | Macro-F1 | Recall@10 | FPR@0.9 | avg_tokens | avg_latency | evidence_path_hit_rate --output-dir/metrics.json Machine-readable equivalent. Each row uses the calibrated first-token softmax score from ``LLMInferenceResult.first_token_score`` (DGP paper formula 14). If a row's score is missing, it is excluded from the metrics with a warning. """ from __future__ import annotations import argparse import json import logging from pathlib import Path from er_tp_dgp.metrics import PredictionRecord, evaluate_classification _log = logging.getLogger("run_evaluation") def main() -> int: parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0]) parser.add_argument( "--predictions-jsonl", action="append", required=True, help="Repeat once per method variant. Filename stem is used as method name.", ) parser.add_argument("--labeled-targets", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument( "--k-values", type=int, nargs="+", default=[1, 5, 10], ) parser.add_argument( "--recall-levels", type=float, nargs="+", default=[0.8, 0.9], ) args = parser.parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) labels = _index_labels(Path(args.labeled_targets)) method_metrics: dict[str, dict] = {} for path in args.predictions_jsonl: prediction_path = Path(path) method_name = prediction_path.stem records = _build_prediction_records(prediction_path, labels) if not records: _log.warning("No usable predictions in %s; skipping.", prediction_path) continue metrics = evaluate_classification( records, k_values=args.k_values, recall_levels=args.recall_levels ) method_metrics[method_name] = { "metrics": metrics.to_dict(), "num_records_used": len(records), "predictions_path": str(prediction_path), } (output_dir / "metrics.json").write_text( json.dumps(method_metrics, ensure_ascii=False, sort_keys=True, indent=2), encoding="utf-8", ) (output_dir / "metrics.md").write_text(_render_markdown_table(method_metrics), encoding="utf-8") print(f"wrote {output_dir/'metrics.md'}") print(f"wrote {output_dir/'metrics.json'}") return 0 def _index_labels(path: Path) -> dict[str, dict]: labels: dict[str, dict] = {} with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue row = json.loads(line) target_id = row.get("target_id") if target_id: labels[target_id] = row return labels def _build_prediction_records( predictions_path: Path, labels: dict[str, dict] ) -> list[PredictionRecord]: records: list[PredictionRecord] = [] with predictions_path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue payload = json.loads(line) target_id = payload.get("target_id") output = payload.get("output") or {} score = ( payload.get("first_token_score") if payload.get("first_token_score") is not None else output.get("score") ) if score is None: # Fallback: many OpenAI-compatible endpoints don't honor logprobs. # Derive a degraded binary score from the first-token label so the # row is still usable (Macro-F1 / Precision@K stay valid; AUROC # collapses but AUPRC still works on rank order). first_label = (output.get("first_token_label") or "").upper() predicted_upper = str(output.get("predicted_label") or "").upper() if first_label == "MALICIOUS" or predicted_upper == "MALICIOUS": score = 1.0 elif first_label == "BENIGN" or predicted_upper == "BENIGN": score = 0.0 else: _log.warning( "missing first-token score AND no usable label for %s; skipping", target_id, ) continue # Prompt-batch filenames carry an "NNNN_" prefix (see # build_theia_prompt_batch.py:_safe_id). Recover the bare UUID # so that labeled_targets.jsonl lookups succeed. label_row = labels.get(target_id) if not label_row and isinstance(target_id, str) and "_" in target_id: bare = target_id.split("_", 1)[1] label_row = labels.get(bare) if label_row: target_id = bare if not label_row: _log.warning("no label for %s; skipping", target_id) continue true_label = "malicious" if label_row.get("label") == "malicious" else "benign" predicted = output.get("predicted_label", "BENIGN") predicted_label = "malicious" if str(predicted).upper() == "MALICIOUS" else "benign" records.append( PredictionRecord( target_id=target_id, target_type=label_row.get("target_type", "PROCESS"), score=float(max(0.0, min(1.0, score))), predicted_label=predicted_label, true_label=true_label, timestamp=label_row.get("anchor_timestamp"), evidence_path_ids=tuple(output.get("evidence_path_ids") or ()), prompt_tokens=payload.get("prompt_tokens"), inference_cost=None, prompt_construction_time=None, ) ) return records def _render_markdown_table(method_metrics: dict[str, dict]) -> str: if not method_metrics: return "# ER-TP-DGP Evaluation\n\nNo method metrics produced.\n" headers = [ "method", "n", "n+", "AUPRC", "AUROC", "Macro-F1", "Recall@10", "FPR@0.9", "avg_tokens", "evidence_hit", ] rows: list[list[str]] = [] for method_name, payload in sorted(method_metrics.items()): m = payload["metrics"] rows.append( [ method_name, str(m["num_examples"]), str(m["num_positive"]), _fmt(m["auprc"]), _fmt(m["auroc"]), _fmt(m["macro_f1"]), _fmt(m["recall_at_k"].get(10)), _fmt(m["fpr_at_recall"].get(0.9)), _fmt(m["avg_prompt_tokens"]), _fmt(m["evidence_path_hit_rate"]), ] ) lines = [ "# ER-TP-DGP Evaluation", "", "Per-method metrics. Score column is calibrated first-token softmax over (Yes, No)", "(DGP paper formula 14). Records missing logprobs are excluded with a warning.", "", "| " + " | ".join(headers) + " |", "|" + "|".join(["---"] * len(headers)) + "|", ] for row in rows: lines.append("| " + " | ".join(row) + " |") return "\n".join(lines) + "\n" def _fmt(value) -> str: if isinstance(value, float): return f"{value:.4f}" if value is None: return "n/a" return str(value) if __name__ == "__main__": raise SystemExit(main())