150 lines
6.0 KiB
Python
150 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import concurrent.futures
|
|
import json
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from src.evaluation.metrics import compute_error_taxonomy, format_results_table
|
|
from src.evaluation.scorer import score_batch
|
|
|
|
from EGPv3_2.api_client import OpenAICompatClient
|
|
from EGPv3_2.benchmark import collect_episode_paths, load_resume_state
|
|
from EGPv3_2.pipeline import evaluate_episode_with_egpv3
|
|
|
|
BENCHMARK_DIR = PROJECT_ROOT / "data" / "benchmark"
|
|
RESULTS_DIR = PROJECT_ROOT / "results"
|
|
|
|
|
|
def _load_episode(path: Path) -> dict:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="SafeHome EGPv3.2 runner")
|
|
parser.add_argument("--model", default="")
|
|
parser.add_argument("--api-base", default="")
|
|
parser.add_argument("--api-key", default="")
|
|
parser.add_argument("--sq", nargs="*", default=None)
|
|
parser.add_argument("--variant", nargs="*", default=None)
|
|
parser.add_argument("--episode-id", nargs="*", default=None)
|
|
parser.add_argument("--max-episodes", type=int, default=None)
|
|
parser.add_argument("--workers", type=int, default=1)
|
|
parser.add_argument("--temperature", type=float, default=0.0)
|
|
parser.add_argument("--max-tokens", type=int, default=2048)
|
|
parser.add_argument("--timeout", type=int, default=300)
|
|
parser.add_argument("--thinking", action="store_true")
|
|
parser.add_argument("--no-thinking", "--no_thinking", dest="no_thinking", action="store_true")
|
|
parser.add_argument("--extra-json", default=None)
|
|
parser.add_argument("--preview-only", action="store_true")
|
|
parser.add_argument("--output-dir", default=None)
|
|
parser.add_argument("--no-resume", action="store_true")
|
|
parser.add_argument("--chunk-size", type=int, default=80)
|
|
parser.add_argument("--max-focus-chunks", type=int, default=6)
|
|
args = parser.parse_args()
|
|
|
|
if not args.preview_only and (not args.model or not args.api_base):
|
|
raise SystemExit("--model and --api-base are required unless --preview-only is used")
|
|
|
|
episode_paths = collect_episode_paths(
|
|
BENCHMARK_DIR,
|
|
sq_types=args.sq,
|
|
variants=args.variant,
|
|
max_episodes=args.max_episodes,
|
|
episode_ids=args.episode_id,
|
|
)
|
|
model_safe = (args.model or "preview").replace("/", "_").replace(":", "_")
|
|
output_dir = Path(args.output_dir) if args.output_dir else RESULTS_DIR / f"{model_safe}_EGPv3_2"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
results_file = output_dir / "results.jsonl"
|
|
|
|
results = []
|
|
completed = set()
|
|
if not args.no_resume:
|
|
results, completed = load_resume_state(results_file)
|
|
remaining = [p for p in episode_paths if p.stem not in completed]
|
|
|
|
client = None
|
|
if not args.preview_only:
|
|
client = OpenAICompatClient(
|
|
model=args.model,
|
|
api_base=args.api_base,
|
|
api_key=args.api_key,
|
|
temperature=args.temperature,
|
|
max_tokens=args.max_tokens,
|
|
timeout=args.timeout,
|
|
thinking=args.thinking,
|
|
no_thinking=args.no_thinking,
|
|
extra_json=args.extra_json,
|
|
)
|
|
print("Testing API connectivity...")
|
|
print((client.test_connection() or "<empty-response>")[:80])
|
|
|
|
print(f"EGPv3.2 run: {len(remaining)} / {len(episode_paths)} episodes pending")
|
|
with open(results_file, "a", encoding="utf-8") as out_f:
|
|
if args.workers <= 1:
|
|
for idx, path in enumerate(remaining, 1):
|
|
result = evaluate_episode_with_egpv3(
|
|
_load_episode(path),
|
|
str(path),
|
|
client,
|
|
chunk_size=args.chunk_size,
|
|
max_focus_chunks=args.max_focus_chunks,
|
|
preview_only=args.preview_only,
|
|
)
|
|
results.append(result)
|
|
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
|
out_f.flush()
|
|
print(f"[{idx}/{len(remaining)}] {result['episode_id']}: {result['scores'].get('detection_label')}")
|
|
else:
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
|
|
futures = {
|
|
executor.submit(
|
|
evaluate_episode_with_egpv3,
|
|
_load_episode(path),
|
|
str(path),
|
|
client,
|
|
args.chunk_size,
|
|
args.max_focus_chunks,
|
|
args.preview_only,
|
|
): path
|
|
for path in remaining
|
|
}
|
|
done = 0
|
|
for future in concurrent.futures.as_completed(futures):
|
|
result = future.result()
|
|
results.append(result)
|
|
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
|
out_f.flush()
|
|
done += 1
|
|
print(f"[{done}/{len(remaining)}] {result['episode_id']}: {result['scores'].get('detection_label')}")
|
|
|
|
summary = score_batch(results)
|
|
errors = compute_error_taxonomy(results)
|
|
summary_data = {
|
|
"model": args.model or "preview-only",
|
|
"api_base": args.api_base,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"total_evaluated": len(results),
|
|
"summary": summary,
|
|
"errors": errors,
|
|
"pipeline": "EGPv3.2",
|
|
}
|
|
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary_data, f, ensure_ascii=False, indent=2)
|
|
|
|
if args.preview_only:
|
|
print("\n[preview-only] No model call was made. UNPARSEABLE here is expected and does not indicate a pipeline failure.")
|
|
print()
|
|
print(format_results_table(summary, f"{args.model or 'preview-only'} + EGPv3.2"))
|
|
print(f"\nResults saved to: {output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|