Files
llmiotsafe/EGPv3/run_egpv3.py
2026-05-12 17:01:39 +08:00

150 lines
5.9 KiB
Python

#!/usr/bin/env python3
import argparse
import concurrent.futures
import json
import sys
from datetime import datetime
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
from src.evaluation.metrics import compute_error_taxonomy, format_results_table
from src.evaluation.scorer import score_batch
from EGPv3.api_client import OpenAICompatClient
from EGPv3.benchmark import collect_episode_paths, load_resume_state
from EGPv3.pipeline import evaluate_episode_with_egpv3
BENCHMARK_DIR = PROJECT_ROOT / "data" / "benchmark"
RESULTS_DIR = PROJECT_ROOT / "results"
def _load_episode(path: Path) -> dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def main() -> None:
parser = argparse.ArgumentParser(description="SafeHome EGPv3 runner")
parser.add_argument("--model", default="")
parser.add_argument("--api-base", default="")
parser.add_argument("--api-key", default="")
parser.add_argument("--sq", nargs="*", default=None)
parser.add_argument("--variant", nargs="*", default=None)
parser.add_argument("--episode-id", nargs="*", default=None)
parser.add_argument("--max-episodes", type=int, default=None)
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--max-tokens", type=int, default=2048)
parser.add_argument("--timeout", type=int, default=300)
parser.add_argument("--thinking", action="store_true")
parser.add_argument("--no-thinking", "--no_thinking", dest="no_thinking", action="store_true")
parser.add_argument("--extra-json", default=None)
parser.add_argument("--preview-only", action="store_true")
parser.add_argument("--output-dir", default=None)
parser.add_argument("--no-resume", action="store_true")
parser.add_argument("--chunk-size", type=int, default=80)
parser.add_argument("--max-focus-chunks", type=int, default=6)
args = parser.parse_args()
if not args.preview_only and (not args.model or not args.api_base):
raise SystemExit("--model and --api-base are required unless --preview-only is used")
episode_paths = collect_episode_paths(
BENCHMARK_DIR,
sq_types=args.sq,
variants=args.variant,
max_episodes=args.max_episodes,
episode_ids=args.episode_id,
)
model_safe = (args.model or "preview").replace("/", "_").replace(":", "_")
output_dir = Path(args.output_dir) if args.output_dir else RESULTS_DIR / f"{model_safe}_EGPv3"
output_dir.mkdir(parents=True, exist_ok=True)
results_file = output_dir / "results.jsonl"
results = []
completed = set()
if not args.no_resume:
results, completed = load_resume_state(results_file)
remaining = [p for p in episode_paths if p.stem not in completed]
client = None
if not args.preview_only:
client = OpenAICompatClient(
model=args.model,
api_base=args.api_base,
api_key=args.api_key,
temperature=args.temperature,
max_tokens=args.max_tokens,
timeout=args.timeout,
thinking=args.thinking,
no_thinking=args.no_thinking,
extra_json=args.extra_json,
)
print("Testing API connectivity...")
print((client.test_connection() or "<empty-response>")[:80])
print(f"EGPv3 run: {len(remaining)} / {len(episode_paths)} episodes pending")
with open(results_file, "a", encoding="utf-8") as out_f:
if args.workers <= 1:
for idx, path in enumerate(remaining, 1):
result = evaluate_episode_with_egpv3(
_load_episode(path),
str(path),
client,
chunk_size=args.chunk_size,
max_focus_chunks=args.max_focus_chunks,
preview_only=args.preview_only,
)
results.append(result)
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
out_f.flush()
print(f"[{idx}/{len(remaining)}] {result['episode_id']}: {result['scores'].get('detection_label')}")
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {
executor.submit(
evaluate_episode_with_egpv3,
_load_episode(path),
str(path),
client,
args.chunk_size,
args.max_focus_chunks,
args.preview_only,
): path
for path in remaining
}
done = 0
for future in concurrent.futures.as_completed(futures):
result = future.result()
results.append(result)
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
out_f.flush()
done += 1
print(f"[{done}/{len(remaining)}] {result['episode_id']}: {result['scores'].get('detection_label')}")
summary = score_batch(results)
errors = compute_error_taxonomy(results)
summary_data = {
"model": args.model or "preview-only",
"api_base": args.api_base,
"timestamp": datetime.now().isoformat(),
"total_evaluated": len(results),
"summary": summary,
"errors": errors,
"pipeline": "EGPv3",
}
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
json.dump(summary_data, f, ensure_ascii=False, indent=2)
if args.preview_only:
print("\n[preview-only] No model call was made. UNPARSEABLE here is expected and does not indicate a pipeline failure.")
print()
print(format_results_table(summary, f"{args.model or 'preview-only'} + EGPv3"))
print(f"\nResults saved to: {output_dir}")
if __name__ == "__main__":
main()