Files
llmiotsafe/EGPv3/benchmark.py
2026-05-12 17:01:39 +08:00

45 lines
1.5 KiB
Python

import json
from pathlib import Path
from typing import Iterable, List, Optional, Set, Tuple
def collect_episode_paths(
benchmark_dir: Path,
sq_types: Optional[Iterable[str]] = None,
variants: Optional[Iterable[str]] = None,
max_episodes: Optional[int] = None,
episode_ids: Optional[Iterable[str]] = None,
) -> List[Path]:
sq_list = [s.upper() for s in sq_types] if sq_types else ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]
variant_list = [v.upper() for v in variants] if variants else ["TP", "FP", "TN"]
wanted_ids = set(episode_ids) if episode_ids else None
paths: List[Path] = []
for sq in sq_list:
sq_dir = benchmark_dir / sq.lower()
if not sq_dir.exists():
continue
for path in sorted(sq_dir.glob("*.json")):
if wanted_ids and path.stem not in wanted_ids:
continue
if any(f"_{v}_" in path.stem for v in variant_list):
paths.append(path)
if max_episodes is not None:
paths = paths[:max_episodes]
return paths
def load_resume_state(results_file: Path) -> Tuple[list, Set[str]]:
if not results_file.exists():
return [], set()
results = []
completed = set()
with open(results_file, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
row = json.loads(line)
results.append(row)
completed.add(row["episode_id"])
return results, completed