#!/usr/bin/env python3 from __future__ import annotations import argparse import json from pathlib import Path METRICS = [ ("detection_accuracy", "Detection Accuracy"), ("f1_security", "F1-Security"), ("precision", "Precision"), ("recall", "Recall"), ("miss_rate", "Miss Rate"), ("false_alarm_rate", "False Alarm Rate"), ("threat_type_accuracy", "Threat Type Accuracy"), ("parse_failure_rate", "Parse Failure Rate"), ] def resolve_summary(path_str: str) -> Path: path = Path(path_str) if path.is_dir(): return path / "summary.json" return path def load_summary(path_str: str) -> dict: path = resolve_summary(path_str) with path.open("r", encoding="utf-8") as f: return json.load(f) def fmt_metric(name: str, value: float) -> str: if "rate" in name or "accuracy" in name: return f"{value:.1%}" return f"{value:.3f}" def fmt_delta(name: str, before: float, after: float) -> str: delta = after - before if "rate" in name or "accuracy" in name: return f"{delta:+.1%}" return f"{delta:+.3f}" def print_metric_table(before: dict, after: dict) -> None: print("Overall Metrics") print(f"{'Metric':<24} {'Before':>12} {'After':>12} {'Delta':>12}") print("-" * 64) for key, label in METRICS: b = before["summary"].get(key, 0.0) a = after["summary"].get(key, 0.0) print(f"{label:<24} {fmt_metric(key, b):>12} {fmt_metric(key, a):>12} {fmt_delta(key, b, a):>12}") print() def print_variant_table(before: dict, after: dict) -> None: print("By Variant Accuracy") print(f"{'Variant':<10} {'Before':>12} {'After':>12} {'Delta':>12}") print("-" * 48) for variant in ["TP", "FP", "TN"]: b = before["summary"].get("breakdown", {}).get(variant, {}).get("accuracy", 0.0) a = after["summary"].get("breakdown", {}).get(variant, {}).get("accuracy", 0.0) print(f"{variant:<10} {b:>11.1%} {a:>11.1%} {a-b:>+11.1%}") print() def print_sq_table(before: dict, after: dict) -> None: print("By Query Type") print(f"{'SQ':<6} {'Acc(B/A)':>18} {'Miss(B/A)':>16} {'FA(B/A)':>14}") print("-" * 60) for sq in ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]: b_sq = before["summary"].get("per_sq", {}).get(sq, {}) a_sq = after["summary"].get("per_sq", {}).get(sq, {}) b_acc = b_sq.get("accuracy", 0.0) a_acc = a_sq.get("accuracy", 0.0) b_miss = b_sq.get("miss_count", 0) a_miss = a_sq.get("miss_count", 0) b_fa = b_sq.get("false_alarm_count", 0) a_fa = a_sq.get("false_alarm_count", 0) print(f"{sq:<6} {b_acc:>7.1%} -> {a_acc:>7.1%} {b_miss:>6} -> {a_miss:<6} {b_fa:>5} -> {a_fa:<5}") print() def main() -> None: parser = argparse.ArgumentParser(description="Compare two SafeHome eval summary.json files or result directories.") parser.add_argument("--before", required=True, help="baseline result dir or summary.json") parser.add_argument("--after", required=True, help="finetuned result dir or summary.json") parser.add_argument("--output-json", default=None, help="optional path to save deltas as json") args = parser.parse_args() before = load_summary(args.before) after = load_summary(args.after) print(f"Before: {resolve_summary(args.before)}") print(f"After: {resolve_summary(args.after)}") print() print_metric_table(before, after) print_variant_table(before, after) print_sq_table(before, after) if args.output_json: payload = { "before": str(resolve_summary(args.before)), "after": str(resolve_summary(args.after)), "metric_deltas": { key: after["summary"].get(key, 0.0) - before["summary"].get(key, 0.0) for key, _ in METRICS }, } out_path = Path(args.output_json) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) print(f"saved json diff -> {out_path}") if __name__ == "__main__": main()