121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
|
|
METRICS = [
|
|
("detection_accuracy", "Detection Accuracy"),
|
|
("f1_security", "F1-Security"),
|
|
("precision", "Precision"),
|
|
("recall", "Recall"),
|
|
("miss_rate", "Miss Rate"),
|
|
("false_alarm_rate", "False Alarm Rate"),
|
|
("threat_type_accuracy", "Threat Type Accuracy"),
|
|
("parse_failure_rate", "Parse Failure Rate"),
|
|
]
|
|
|
|
|
|
def resolve_summary(path_str: str) -> Path:
|
|
path = Path(path_str)
|
|
if path.is_dir():
|
|
return path / "summary.json"
|
|
return path
|
|
|
|
|
|
def load_summary(path_str: str) -> dict:
|
|
path = resolve_summary(path_str)
|
|
with path.open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def fmt_metric(name: str, value: float) -> str:
|
|
if "rate" in name or "accuracy" in name:
|
|
return f"{value:.1%}"
|
|
return f"{value:.3f}"
|
|
|
|
|
|
def fmt_delta(name: str, before: float, after: float) -> str:
|
|
delta = after - before
|
|
if "rate" in name or "accuracy" in name:
|
|
return f"{delta:+.1%}"
|
|
return f"{delta:+.3f}"
|
|
|
|
|
|
def print_metric_table(before: dict, after: dict) -> None:
|
|
print("Overall Metrics")
|
|
print(f"{'Metric':<24} {'Before':>12} {'After':>12} {'Delta':>12}")
|
|
print("-" * 64)
|
|
for key, label in METRICS:
|
|
b = before["summary"].get(key, 0.0)
|
|
a = after["summary"].get(key, 0.0)
|
|
print(f"{label:<24} {fmt_metric(key, b):>12} {fmt_metric(key, a):>12} {fmt_delta(key, b, a):>12}")
|
|
print()
|
|
|
|
|
|
def print_variant_table(before: dict, after: dict) -> None:
|
|
print("By Variant Accuracy")
|
|
print(f"{'Variant':<10} {'Before':>12} {'After':>12} {'Delta':>12}")
|
|
print("-" * 48)
|
|
for variant in ["TP", "FP", "TN"]:
|
|
b = before["summary"].get("breakdown", {}).get(variant, {}).get("accuracy", 0.0)
|
|
a = after["summary"].get("breakdown", {}).get(variant, {}).get("accuracy", 0.0)
|
|
print(f"{variant:<10} {b:>11.1%} {a:>11.1%} {a-b:>+11.1%}")
|
|
print()
|
|
|
|
|
|
def print_sq_table(before: dict, after: dict) -> None:
|
|
print("By Query Type")
|
|
print(f"{'SQ':<6} {'Acc(B/A)':>18} {'Miss(B/A)':>16} {'FA(B/A)':>14}")
|
|
print("-" * 60)
|
|
for sq in ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]:
|
|
b_sq = before["summary"].get("per_sq", {}).get(sq, {})
|
|
a_sq = after["summary"].get("per_sq", {}).get(sq, {})
|
|
b_acc = b_sq.get("accuracy", 0.0)
|
|
a_acc = a_sq.get("accuracy", 0.0)
|
|
b_miss = b_sq.get("miss_count", 0)
|
|
a_miss = a_sq.get("miss_count", 0)
|
|
b_fa = b_sq.get("false_alarm_count", 0)
|
|
a_fa = a_sq.get("false_alarm_count", 0)
|
|
print(f"{sq:<6} {b_acc:>7.1%} -> {a_acc:>7.1%} {b_miss:>6} -> {a_miss:<6} {b_fa:>5} -> {a_fa:<5}")
|
|
print()
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Compare two SafeHome eval summary.json files or result directories.")
|
|
parser.add_argument("--before", required=True, help="baseline result dir or summary.json")
|
|
parser.add_argument("--after", required=True, help="finetuned result dir or summary.json")
|
|
parser.add_argument("--output-json", default=None, help="optional path to save deltas as json")
|
|
args = parser.parse_args()
|
|
|
|
before = load_summary(args.before)
|
|
after = load_summary(args.after)
|
|
|
|
print(f"Before: {resolve_summary(args.before)}")
|
|
print(f"After: {resolve_summary(args.after)}")
|
|
print()
|
|
print_metric_table(before, after)
|
|
print_variant_table(before, after)
|
|
print_sq_table(before, after)
|
|
|
|
if args.output_json:
|
|
payload = {
|
|
"before": str(resolve_summary(args.before)),
|
|
"after": str(resolve_summary(args.after)),
|
|
"metric_deltas": {
|
|
key: after["summary"].get(key, 0.0) - before["summary"].get(key, 0.0)
|
|
for key, _ in METRICS
|
|
},
|
|
}
|
|
out_path = Path(args.output_json)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with out_path.open("w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
print(f"saved json diff -> {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|