Files
llmiotsafe/SFT_FULLLOG_QWEN2B_SERVER_BUNDLE/scripts/compare_eval_results.py
2026-05-12 17:01:39 +08:00

121 lines
4.1 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from pathlib import Path
METRICS = [
("detection_accuracy", "Detection Accuracy"),
("f1_security", "F1-Security"),
("precision", "Precision"),
("recall", "Recall"),
("miss_rate", "Miss Rate"),
("false_alarm_rate", "False Alarm Rate"),
("threat_type_accuracy", "Threat Type Accuracy"),
("parse_failure_rate", "Parse Failure Rate"),
]
def resolve_summary(path_str: str) -> Path:
path = Path(path_str)
if path.is_dir():
return path / "summary.json"
return path
def load_summary(path_str: str) -> dict:
path = resolve_summary(path_str)
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def fmt_metric(name: str, value: float) -> str:
if "rate" in name or "accuracy" in name:
return f"{value:.1%}"
return f"{value:.3f}"
def fmt_delta(name: str, before: float, after: float) -> str:
delta = after - before
if "rate" in name or "accuracy" in name:
return f"{delta:+.1%}"
return f"{delta:+.3f}"
def print_metric_table(before: dict, after: dict) -> None:
print("Overall Metrics")
print(f"{'Metric':<24} {'Before':>12} {'After':>12} {'Delta':>12}")
print("-" * 64)
for key, label in METRICS:
b = before["summary"].get(key, 0.0)
a = after["summary"].get(key, 0.0)
print(f"{label:<24} {fmt_metric(key, b):>12} {fmt_metric(key, a):>12} {fmt_delta(key, b, a):>12}")
print()
def print_variant_table(before: dict, after: dict) -> None:
print("By Variant Accuracy")
print(f"{'Variant':<10} {'Before':>12} {'After':>12} {'Delta':>12}")
print("-" * 48)
for variant in ["TP", "FP", "TN"]:
b = before["summary"].get("breakdown", {}).get(variant, {}).get("accuracy", 0.0)
a = after["summary"].get("breakdown", {}).get(variant, {}).get("accuracy", 0.0)
print(f"{variant:<10} {b:>11.1%} {a:>11.1%} {a-b:>+11.1%}")
print()
def print_sq_table(before: dict, after: dict) -> None:
print("By Query Type")
print(f"{'SQ':<6} {'Acc(B/A)':>18} {'Miss(B/A)':>16} {'FA(B/A)':>14}")
print("-" * 60)
for sq in ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]:
b_sq = before["summary"].get("per_sq", {}).get(sq, {})
a_sq = after["summary"].get("per_sq", {}).get(sq, {})
b_acc = b_sq.get("accuracy", 0.0)
a_acc = a_sq.get("accuracy", 0.0)
b_miss = b_sq.get("miss_count", 0)
a_miss = a_sq.get("miss_count", 0)
b_fa = b_sq.get("false_alarm_count", 0)
a_fa = a_sq.get("false_alarm_count", 0)
print(f"{sq:<6} {b_acc:>7.1%} -> {a_acc:>7.1%} {b_miss:>6} -> {a_miss:<6} {b_fa:>5} -> {a_fa:<5}")
print()
def main() -> None:
parser = argparse.ArgumentParser(description="Compare two SafeHome eval summary.json files or result directories.")
parser.add_argument("--before", required=True, help="baseline result dir or summary.json")
parser.add_argument("--after", required=True, help="finetuned result dir or summary.json")
parser.add_argument("--output-json", default=None, help="optional path to save deltas as json")
args = parser.parse_args()
before = load_summary(args.before)
after = load_summary(args.after)
print(f"Before: {resolve_summary(args.before)}")
print(f"After: {resolve_summary(args.after)}")
print()
print_metric_table(before, after)
print_variant_table(before, after)
print_sq_table(before, after)
if args.output_json:
payload = {
"before": str(resolve_summary(args.before)),
"after": str(resolve_summary(args.after)),
"metric_deltas": {
key: after["summary"].get(key, 0.0) - before["summary"].get(key, 0.0)
for key, _ in METRICS
},
}
out_path = Path(args.output_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
print(f"saved json diff -> {out_path}")
if __name__ == "__main__":
main()