95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
import json
|
|
import re
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
from src.evaluation.scorer import parse_model_response, score_episode
|
|
|
|
from EGP.evidence import extract_evidence_packet
|
|
from EGP.prompts import build_stage1_prompt, build_stage2_prompt
|
|
|
|
|
|
def _extract_jsonish(text: str) -> Dict[str, Any]:
|
|
raw = text.strip()
|
|
block = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL)
|
|
if block:
|
|
raw = block.group(1).strip()
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
start = raw.find("{")
|
|
end = raw.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
try:
|
|
return json.loads(raw[start : end + 1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"_parse_failed": True, "raw_preview": text[:1000]}
|
|
|
|
|
|
def evaluate_episode_with_egp(
|
|
episode: Dict[str, Any],
|
|
episode_path: str,
|
|
client,
|
|
max_salient_events: int = 60,
|
|
max_focus_events: int = 120,
|
|
preview_only: bool = False,
|
|
) -> Dict[str, Any]:
|
|
meta = episode["metadata"]
|
|
gt = episode["ground_truth"]
|
|
packet = extract_evidence_packet(
|
|
episode,
|
|
max_salient_events=max_salient_events,
|
|
max_focus_events=max_focus_events,
|
|
)
|
|
|
|
stage1_prompt = build_stage1_prompt(packet)
|
|
stage2_prompt = None
|
|
stage1_raw = ""
|
|
stage2_raw = ""
|
|
stage1_parsed: Dict[str, Any] = {}
|
|
api_error = None
|
|
start = time.time()
|
|
|
|
if preview_only:
|
|
stage2_prompt = build_stage2_prompt(packet, "<preview-only>")
|
|
parsed = {"is_anomaly": None, "_parse_failed": True}
|
|
else:
|
|
try:
|
|
stage1_raw = client.chat(stage1_prompt["system"], stage1_prompt["user"])
|
|
stage1_parsed = _extract_jsonish(stage1_raw)
|
|
stage2_prompt = build_stage2_prompt(packet, stage1_raw)
|
|
stage2_raw = client.chat(stage2_prompt["system"], stage2_prompt["user"])
|
|
parsed = parse_model_response(stage2_raw, mode="baseline")
|
|
except Exception as exc:
|
|
api_error = str(exc)
|
|
parsed = {"is_anomaly": None, "_parse_failed": True}
|
|
|
|
latency = time.time() - start
|
|
scores = score_episode(gt, parsed, meta.get("variant", ""))
|
|
return {
|
|
"episode_id": episode.get("episode_id", ""),
|
|
"episode_path": episode_path,
|
|
"metadata": meta,
|
|
"ground_truth": gt,
|
|
"raw_response": stage2_raw,
|
|
"model_response": parsed,
|
|
"scores": scores,
|
|
"latency": latency,
|
|
"api_error": api_error,
|
|
"egp_trace": {
|
|
"packet_summary": {
|
|
"event_count": packet["event_count"],
|
|
"salient_event_count": packet["salient_event_count"],
|
|
"focus_event_count": packet["focus_event_count"],
|
|
"temperature_patterns": packet["temperature_patterns"],
|
|
"room_activity": packet["room_activity"],
|
|
"top_suspicious_signals": packet["suspicious_signals"][:8],
|
|
},
|
|
"stage1_raw": stage1_raw,
|
|
"stage1_parsed": stage1_parsed,
|
|
"preview_stage1_prompt": stage1_prompt["user"][:2000],
|
|
"preview_stage2_prompt": stage2_prompt["user"][:2000] if stage2_prompt else "",
|
|
},
|
|
}
|