Initial commit: ER-TP-DGP research prototype
Event-Reified Temporal Provenance Dual-Granularity Prompting for LLM-based APT detection on DARPA provenance datasets. Includes phase 0-14 method spec, IR/graph/metapath/trimming/prompt modules, scripts for THEIA candidate universe, landmark CSG construction, hybrid prompting, and LLM inference. Excludes data/, reports/, and local LLM config from version control.
This commit is contained in:
226
scripts/run_evaluation.py
Normal file
226
scripts/run_evaluation.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
"""End-to-end evaluation: join LLM predictions with labels and aggregate metrics.
|
||||
|
||||
Inputs:
|
||||
--predictions-jsonl One file per method variant, produced by
|
||||
run_llm_inference.py. The file's basename is used as
|
||||
the method name in the metrics table.
|
||||
--labeled-targets evaluation_batch jsonl (target_id, label, ...)
|
||||
|
||||
Output:
|
||||
--output-dir/metrics.md Paper-Table-2-style table:
|
||||
method | AUPRC | AUROC | Macro-F1 |
|
||||
Recall@10 | FPR@0.9 | avg_tokens |
|
||||
avg_latency | evidence_path_hit_rate
|
||||
--output-dir/metrics.json Machine-readable equivalent.
|
||||
|
||||
Each row uses the calibrated first-token softmax score from
|
||||
``LLMInferenceResult.first_token_score`` (DGP paper formula 14). If a row's
|
||||
score is missing, it is excluded from the metrics with a warning.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.metrics import PredictionRecord, evaluate_classification
|
||||
|
||||
|
||||
_log = logging.getLogger("run_evaluation")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
||||
parser.add_argument(
|
||||
"--predictions-jsonl",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Repeat once per method variant. Filename stem is used as method name.",
|
||||
)
|
||||
parser.add_argument("--labeled-targets", required=True)
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 5, 10],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recall-levels",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[0.8, 0.9],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
labels = _index_labels(Path(args.labeled_targets))
|
||||
|
||||
method_metrics: dict[str, dict] = {}
|
||||
for path in args.predictions_jsonl:
|
||||
prediction_path = Path(path)
|
||||
method_name = prediction_path.stem
|
||||
records = _build_prediction_records(prediction_path, labels)
|
||||
if not records:
|
||||
_log.warning("No usable predictions in %s; skipping.", prediction_path)
|
||||
continue
|
||||
metrics = evaluate_classification(
|
||||
records, k_values=args.k_values, recall_levels=args.recall_levels
|
||||
)
|
||||
method_metrics[method_name] = {
|
||||
"metrics": metrics.to_dict(),
|
||||
"num_records_used": len(records),
|
||||
"predictions_path": str(prediction_path),
|
||||
}
|
||||
|
||||
(output_dir / "metrics.json").write_text(
|
||||
json.dumps(method_metrics, ensure_ascii=False, sort_keys=True, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(output_dir / "metrics.md").write_text(_render_markdown_table(method_metrics), encoding="utf-8")
|
||||
print(f"wrote {output_dir/'metrics.md'}")
|
||||
print(f"wrote {output_dir/'metrics.json'}")
|
||||
return 0
|
||||
|
||||
|
||||
def _index_labels(path: Path) -> dict[str, dict]:
|
||||
labels: dict[str, dict] = {}
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
target_id = row.get("target_id")
|
||||
if target_id:
|
||||
labels[target_id] = row
|
||||
return labels
|
||||
|
||||
|
||||
def _build_prediction_records(
|
||||
predictions_path: Path, labels: dict[str, dict]
|
||||
) -> list[PredictionRecord]:
|
||||
records: list[PredictionRecord] = []
|
||||
with predictions_path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
target_id = payload.get("target_id")
|
||||
output = payload.get("output") or {}
|
||||
score = (
|
||||
payload.get("first_token_score")
|
||||
if payload.get("first_token_score") is not None
|
||||
else output.get("score")
|
||||
)
|
||||
if score is None:
|
||||
# Fallback: many OpenAI-compatible endpoints don't honor logprobs.
|
||||
# Derive a degraded binary score from the first-token label so the
|
||||
# row is still usable (Macro-F1 / Precision@K stay valid; AUROC
|
||||
# collapses but AUPRC still works on rank order).
|
||||
first_label = (output.get("first_token_label") or "").upper()
|
||||
predicted_upper = str(output.get("predicted_label") or "").upper()
|
||||
if first_label == "MALICIOUS" or predicted_upper == "MALICIOUS":
|
||||
score = 1.0
|
||||
elif first_label == "BENIGN" or predicted_upper == "BENIGN":
|
||||
score = 0.0
|
||||
else:
|
||||
_log.warning(
|
||||
"missing first-token score AND no usable label for %s; skipping",
|
||||
target_id,
|
||||
)
|
||||
continue
|
||||
# Prompt-batch filenames carry an "NNNN_<uuid>" prefix (see
|
||||
# build_theia_prompt_batch.py:_safe_id). Recover the bare UUID
|
||||
# so that labeled_targets.jsonl lookups succeed.
|
||||
label_row = labels.get(target_id)
|
||||
if not label_row and isinstance(target_id, str) and "_" in target_id:
|
||||
bare = target_id.split("_", 1)[1]
|
||||
label_row = labels.get(bare)
|
||||
if label_row:
|
||||
target_id = bare
|
||||
if not label_row:
|
||||
_log.warning("no label for %s; skipping", target_id)
|
||||
continue
|
||||
true_label = "malicious" if label_row.get("label") == "malicious" else "benign"
|
||||
predicted = output.get("predicted_label", "BENIGN")
|
||||
predicted_label = "malicious" if str(predicted).upper() == "MALICIOUS" else "benign"
|
||||
records.append(
|
||||
PredictionRecord(
|
||||
target_id=target_id,
|
||||
target_type=label_row.get("target_type", "PROCESS"),
|
||||
score=float(max(0.0, min(1.0, score))),
|
||||
predicted_label=predicted_label,
|
||||
true_label=true_label,
|
||||
timestamp=label_row.get("anchor_timestamp"),
|
||||
evidence_path_ids=tuple(output.get("evidence_path_ids") or ()),
|
||||
prompt_tokens=payload.get("prompt_tokens"),
|
||||
inference_cost=None,
|
||||
prompt_construction_time=None,
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
def _render_markdown_table(method_metrics: dict[str, dict]) -> str:
|
||||
if not method_metrics:
|
||||
return "# ER-TP-DGP Evaluation\n\nNo method metrics produced.\n"
|
||||
headers = [
|
||||
"method",
|
||||
"n",
|
||||
"n+",
|
||||
"AUPRC",
|
||||
"AUROC",
|
||||
"Macro-F1",
|
||||
"Recall@10",
|
||||
"FPR@0.9",
|
||||
"avg_tokens",
|
||||
"evidence_hit",
|
||||
]
|
||||
rows: list[list[str]] = []
|
||||
for method_name, payload in sorted(method_metrics.items()):
|
||||
m = payload["metrics"]
|
||||
rows.append(
|
||||
[
|
||||
method_name,
|
||||
str(m["num_examples"]),
|
||||
str(m["num_positive"]),
|
||||
_fmt(m["auprc"]),
|
||||
_fmt(m["auroc"]),
|
||||
_fmt(m["macro_f1"]),
|
||||
_fmt(m["recall_at_k"].get(10)),
|
||||
_fmt(m["fpr_at_recall"].get(0.9)),
|
||||
_fmt(m["avg_prompt_tokens"]),
|
||||
_fmt(m["evidence_path_hit_rate"]),
|
||||
]
|
||||
)
|
||||
lines = [
|
||||
"# ER-TP-DGP Evaluation",
|
||||
"",
|
||||
"Per-method metrics. Score column is calibrated first-token softmax over (Yes, No)",
|
||||
"(DGP paper formula 14). Records missing logprobs are excluded with a warning.",
|
||||
"",
|
||||
"| " + " | ".join(headers) + " |",
|
||||
"|" + "|".join(["---"] * len(headers)) + "|",
|
||||
]
|
||||
for row in rows:
|
||||
lines.append("| " + " | ".join(row) + " |")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _fmt(value) -> str:
|
||||
if isinstance(value, float):
|
||||
return f"{value:.4f}"
|
||||
if value is None:
|
||||
return "n/a"
|
||||
return str(value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user