Event-Reified Temporal Provenance Dual-Granularity Prompting for LLM-based APT detection on DARPA provenance datasets. Includes phase 0-14 method spec, IR/graph/metapath/trimming/prompt modules, scripts for THEIA candidate universe, landmark CSG construction, hybrid prompting, and LLM inference. Excludes data/, reports/, and local LLM config from version control.
227 lines
8.1 KiB
Python
227 lines
8.1 KiB
Python
#!/usr/bin/env python3
|
|
"""End-to-end evaluation: join LLM predictions with labels and aggregate metrics.
|
|
|
|
Inputs:
|
|
--predictions-jsonl One file per method variant, produced by
|
|
run_llm_inference.py. The file's basename is used as
|
|
the method name in the metrics table.
|
|
--labeled-targets evaluation_batch jsonl (target_id, label, ...)
|
|
|
|
Output:
|
|
--output-dir/metrics.md Paper-Table-2-style table:
|
|
method | AUPRC | AUROC | Macro-F1 |
|
|
Recall@10 | FPR@0.9 | avg_tokens |
|
|
avg_latency | evidence_path_hit_rate
|
|
--output-dir/metrics.json Machine-readable equivalent.
|
|
|
|
Each row uses the calibrated first-token softmax score from
|
|
``LLMInferenceResult.first_token_score`` (DGP paper formula 14). If a row's
|
|
score is missing, it is excluded from the metrics with a warning.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from er_tp_dgp.metrics import PredictionRecord, evaluate_classification
|
|
|
|
|
|
_log = logging.getLogger("run_evaluation")
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
|
parser.add_argument(
|
|
"--predictions-jsonl",
|
|
action="append",
|
|
required=True,
|
|
help="Repeat once per method variant. Filename stem is used as method name.",
|
|
)
|
|
parser.add_argument("--labeled-targets", required=True)
|
|
parser.add_argument("--output-dir", required=True)
|
|
parser.add_argument(
|
|
"--k-values",
|
|
type=int,
|
|
nargs="+",
|
|
default=[1, 5, 10],
|
|
)
|
|
parser.add_argument(
|
|
"--recall-levels",
|
|
type=float,
|
|
nargs="+",
|
|
default=[0.8, 0.9],
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
labels = _index_labels(Path(args.labeled_targets))
|
|
|
|
method_metrics: dict[str, dict] = {}
|
|
for path in args.predictions_jsonl:
|
|
prediction_path = Path(path)
|
|
method_name = prediction_path.stem
|
|
records = _build_prediction_records(prediction_path, labels)
|
|
if not records:
|
|
_log.warning("No usable predictions in %s; skipping.", prediction_path)
|
|
continue
|
|
metrics = evaluate_classification(
|
|
records, k_values=args.k_values, recall_levels=args.recall_levels
|
|
)
|
|
method_metrics[method_name] = {
|
|
"metrics": metrics.to_dict(),
|
|
"num_records_used": len(records),
|
|
"predictions_path": str(prediction_path),
|
|
}
|
|
|
|
(output_dir / "metrics.json").write_text(
|
|
json.dumps(method_metrics, ensure_ascii=False, sort_keys=True, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
(output_dir / "metrics.md").write_text(_render_markdown_table(method_metrics), encoding="utf-8")
|
|
print(f"wrote {output_dir/'metrics.md'}")
|
|
print(f"wrote {output_dir/'metrics.json'}")
|
|
return 0
|
|
|
|
|
|
def _index_labels(path: Path) -> dict[str, dict]:
|
|
labels: dict[str, dict] = {}
|
|
with path.open("r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
row = json.loads(line)
|
|
target_id = row.get("target_id")
|
|
if target_id:
|
|
labels[target_id] = row
|
|
return labels
|
|
|
|
|
|
def _build_prediction_records(
|
|
predictions_path: Path, labels: dict[str, dict]
|
|
) -> list[PredictionRecord]:
|
|
records: list[PredictionRecord] = []
|
|
with predictions_path.open("r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
payload = json.loads(line)
|
|
target_id = payload.get("target_id")
|
|
output = payload.get("output") or {}
|
|
score = (
|
|
payload.get("first_token_score")
|
|
if payload.get("first_token_score") is not None
|
|
else output.get("score")
|
|
)
|
|
if score is None:
|
|
# Fallback: many OpenAI-compatible endpoints don't honor logprobs.
|
|
# Derive a degraded binary score from the first-token label so the
|
|
# row is still usable (Macro-F1 / Precision@K stay valid; AUROC
|
|
# collapses but AUPRC still works on rank order).
|
|
first_label = (output.get("first_token_label") or "").upper()
|
|
predicted_upper = str(output.get("predicted_label") or "").upper()
|
|
if first_label == "MALICIOUS" or predicted_upper == "MALICIOUS":
|
|
score = 1.0
|
|
elif first_label == "BENIGN" or predicted_upper == "BENIGN":
|
|
score = 0.0
|
|
else:
|
|
_log.warning(
|
|
"missing first-token score AND no usable label for %s; skipping",
|
|
target_id,
|
|
)
|
|
continue
|
|
# Prompt-batch filenames carry an "NNNN_<uuid>" prefix (see
|
|
# build_theia_prompt_batch.py:_safe_id). Recover the bare UUID
|
|
# so that labeled_targets.jsonl lookups succeed.
|
|
label_row = labels.get(target_id)
|
|
if not label_row and isinstance(target_id, str) and "_" in target_id:
|
|
bare = target_id.split("_", 1)[1]
|
|
label_row = labels.get(bare)
|
|
if label_row:
|
|
target_id = bare
|
|
if not label_row:
|
|
_log.warning("no label for %s; skipping", target_id)
|
|
continue
|
|
true_label = "malicious" if label_row.get("label") == "malicious" else "benign"
|
|
predicted = output.get("predicted_label", "BENIGN")
|
|
predicted_label = "malicious" if str(predicted).upper() == "MALICIOUS" else "benign"
|
|
records.append(
|
|
PredictionRecord(
|
|
target_id=target_id,
|
|
target_type=label_row.get("target_type", "PROCESS"),
|
|
score=float(max(0.0, min(1.0, score))),
|
|
predicted_label=predicted_label,
|
|
true_label=true_label,
|
|
timestamp=label_row.get("anchor_timestamp"),
|
|
evidence_path_ids=tuple(output.get("evidence_path_ids") or ()),
|
|
prompt_tokens=payload.get("prompt_tokens"),
|
|
inference_cost=None,
|
|
prompt_construction_time=None,
|
|
)
|
|
)
|
|
return records
|
|
|
|
|
|
def _render_markdown_table(method_metrics: dict[str, dict]) -> str:
|
|
if not method_metrics:
|
|
return "# ER-TP-DGP Evaluation\n\nNo method metrics produced.\n"
|
|
headers = [
|
|
"method",
|
|
"n",
|
|
"n+",
|
|
"AUPRC",
|
|
"AUROC",
|
|
"Macro-F1",
|
|
"Recall@10",
|
|
"FPR@0.9",
|
|
"avg_tokens",
|
|
"evidence_hit",
|
|
]
|
|
rows: list[list[str]] = []
|
|
for method_name, payload in sorted(method_metrics.items()):
|
|
m = payload["metrics"]
|
|
rows.append(
|
|
[
|
|
method_name,
|
|
str(m["num_examples"]),
|
|
str(m["num_positive"]),
|
|
_fmt(m["auprc"]),
|
|
_fmt(m["auroc"]),
|
|
_fmt(m["macro_f1"]),
|
|
_fmt(m["recall_at_k"].get(10)),
|
|
_fmt(m["fpr_at_recall"].get(0.9)),
|
|
_fmt(m["avg_prompt_tokens"]),
|
|
_fmt(m["evidence_path_hit_rate"]),
|
|
]
|
|
)
|
|
lines = [
|
|
"# ER-TP-DGP Evaluation",
|
|
"",
|
|
"Per-method metrics. Score column is calibrated first-token softmax over (Yes, No)",
|
|
"(DGP paper formula 14). Records missing logprobs are excluded with a warning.",
|
|
"",
|
|
"| " + " | ".join(headers) + " |",
|
|
"|" + "|".join(["---"] * len(headers)) + "|",
|
|
]
|
|
for row in rows:
|
|
lines.append("| " + " | ".join(row) + " |")
|
|
return "\n".join(lines) + "\n"
|
|
|
|
|
|
def _fmt(value) -> str:
|
|
if isinstance(value, float):
|
|
return f"{value:.4f}"
|
|
if value is None:
|
|
return "n/a"
|
|
return str(value)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|