45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
import json
|
|
from pathlib import Path
|
|
from typing import Iterable, List, Optional, Set, Tuple
|
|
|
|
|
|
def collect_episode_paths(
|
|
benchmark_dir: Path,
|
|
sq_types: Optional[Iterable[str]] = None,
|
|
variants: Optional[Iterable[str]] = None,
|
|
max_episodes: Optional[int] = None,
|
|
episode_ids: Optional[Iterable[str]] = None,
|
|
) -> List[Path]:
|
|
sq_list = [s.upper() for s in sq_types] if sq_types else ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]
|
|
variant_list = [v.upper() for v in variants] if variants else ["TP", "FP", "TN"]
|
|
wanted_ids = set(episode_ids) if episode_ids else None
|
|
paths: List[Path] = []
|
|
for sq in sq_list:
|
|
sq_dir = benchmark_dir / sq.lower()
|
|
if not sq_dir.exists():
|
|
continue
|
|
for path in sorted(sq_dir.glob("*.json")):
|
|
if wanted_ids and path.stem not in wanted_ids:
|
|
continue
|
|
if any(f"_{v}_" in path.stem for v in variant_list):
|
|
paths.append(path)
|
|
if max_episodes is not None:
|
|
paths = paths[:max_episodes]
|
|
return paths
|
|
|
|
|
|
def load_resume_state(results_file: Path) -> Tuple[list, Set[str]]:
|
|
if not results_file.exists():
|
|
return [], set()
|
|
results = []
|
|
completed = set()
|
|
with open(results_file, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
row = json.loads(line)
|
|
results.append(row)
|
|
completed.add(row["episode_id"])
|
|
return results, completed
|
|
|