import json import re import time from typing import Any, Dict, List from src.evaluation.scorer import parse_model_response, score_episode from EGPv3_2.prompts import ( build_defender_prompt, build_extractor_prompt, build_judge_prompt, build_prosecutor_prompt, ) from EGPv3_2.signals import build_case_material, materialize_chunks def _extract_jsonish(text: str) -> Dict[str, Any]: raw = text.strip() match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL) if match: raw = match.group(1).strip() try: return json.loads(raw) except json.JSONDecodeError: start = raw.find("{") end = raw.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(raw[start : end + 1]) except json.JSONDecodeError: pass return {"_parse_failed": True, "raw_preview": text[:1000]} def _dump_json(data: Dict[str, Any], fallback: str = "") -> str: if not data or data.get("_parse_failed"): return fallback return json.dumps(data, ensure_ascii=False, indent=2) def _normalize_extractor(parsed: Dict[str, Any]) -> Dict[str, Any]: primary = str(parsed.get("primary_task_profile", "")).strip() secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none" if primary and "latent_task_profile" not in parsed: parsed["latent_task_profile"] = primary if secondary == "none" else f"{primary} | {secondary}" return parsed def _infer_query_profile(query: str) -> str: query = str(query or "") if ("工作正常" in query) or ("故障类型" in query): return "device-health" if ("应急响应" in query) or ("严重程度" in query and "威胁类型" in query): return "emergency-response" if ("综合分析" in query and "安全状况" in query) or ("潜在风险" in query): return "composite-safety" if "异常行为模式" in query: return "behavior-sequence" if "安全威胁" in query: return "single-event-safety" return "" def _apply_query_intent_guard(parsed: Dict[str, Any], query: str) -> Dict[str, Any]: inferred = _infer_query_profile(query) if not inferred: return parsed parsed["query_intent_profile"] = inferred primary = str(parsed.get("primary_task_profile", "")).strip() secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none" if not primary: parsed["primary_task_profile"] = inferred return _normalize_extractor(parsed) if primary == "device-health" and inferred != "device-health": parsed["primary_task_profile"] = inferred parsed["secondary_task_profile"] = "device-health" if secondary in ("", "none") else secondary notes = parsed.get("guardrail_notes", []) if not isinstance(notes, list): notes = [str(notes)] notes.append("Primary task profile corrected from query intent to avoid device-health drift.") parsed["guardrail_notes"] = notes return _normalize_extractor(parsed) def _valid_chunk_ids(material: Dict) -> List[str]: return [chunk["chunk_id"] for chunk in material["chunk_index"]] def _dedupe_chunk_ids(chunk_ids: List[str], valid_ids: List[str]) -> List[str]: valid_set = set(valid_ids) out: List[str] = [] for chunk_id in chunk_ids: if chunk_id in valid_set and chunk_id not in out: out.append(chunk_id) return out def _neighbor_chunk_ids(chunk_ids: List[str], valid_ids: List[str], radius: int = 1) -> List[str]: order = {chunk_id: idx for idx, chunk_id in enumerate(valid_ids)} neighbors: List[str] = [] for chunk_id in chunk_ids: idx = order.get(chunk_id) if idx is None: continue for offset in range(1, radius + 1): for neighbor_idx in (idx - offset, idx + offset): if 0 <= neighbor_idx < len(valid_ids): neighbor_id = valid_ids[neighbor_idx] if neighbor_id not in neighbors: neighbors.append(neighbor_id) return neighbors def _select_focus_ids(material: Dict, extractor_parsed: Dict[str, Any], limit: int) -> List[str]: valid_ids = _valid_chunk_ids(material) raw_ids = extractor_parsed.get("focus_chunk_ids", []) if not isinstance(raw_ids, list): raw_ids = [] focus_ids = _dedupe_chunk_ids(raw_ids, valid_ids) if not focus_ids: focus_ids = valid_ids[: min(limit, len(valid_ids))] primary = str(extractor_parsed.get("primary_task_profile", "")).strip() if primary in {"single-event-safety", "behavior-sequence", "composite-safety", "emergency-response"}: focus_ids = _dedupe_chunk_ids(focus_ids + _neighbor_chunk_ids(focus_ids, valid_ids, radius=1), valid_ids) return focus_ids[:limit] def evaluate_episode_with_egpv3( episode: Dict[str, Any], episode_path: str, client, chunk_size: int = 80, max_focus_chunks: int = 6, preview_only: bool = False, ) -> Dict[str, Any]: meta = episode["metadata"] gt = episode["ground_truth"] material = build_case_material(episode, chunk_size=chunk_size) extractor_prompt = build_extractor_prompt(material) stage = { "extractor_raw": "", "extractor_parsed": {}, "prosecutor_raw": "", "prosecutor_parsed": {}, "defender_raw": "", "defender_parsed": {}, "judge_prompt_preview": "", } raw_response = "" api_error = None start = time.time() if preview_only: focus_ids = [c["chunk_id"] for c in material["chunk_index"][: min(4, len(material["chunk_index"]))]] focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=max_focus_chunks) judge_prompt = build_judge_prompt(material, "", "", "", focused_chunks) parsed = {"is_anomaly": None, "_parse_failed": True} else: try: stage["extractor_raw"] = client.chat(extractor_prompt["system"], extractor_prompt["user"]) stage["extractor_parsed"] = _apply_query_intent_guard( _normalize_extractor(_extract_jsonish(stage["extractor_raw"])), material.get("query", ""), ) extractor_context = _dump_json(stage["extractor_parsed"], fallback=stage["extractor_raw"]) focus_ids = _select_focus_ids(material, stage["extractor_parsed"], max_focus_chunks) focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=max_focus_chunks) prosecutor_prompt = build_prosecutor_prompt(material, extractor_context, focused_chunks) stage["prosecutor_raw"] = client.chat(prosecutor_prompt["system"], prosecutor_prompt["user"]) stage["prosecutor_parsed"] = _extract_jsonish(stage["prosecutor_raw"]) defender_prompt = build_defender_prompt( material, extractor_context, focused_chunks, stage["prosecutor_raw"], ) stage["defender_raw"] = client.chat(defender_prompt["system"], defender_prompt["user"]) stage["defender_parsed"] = _extract_jsonish(stage["defender_raw"]) judge_prompt = build_judge_prompt( material, extractor_context, stage["prosecutor_raw"], stage["defender_raw"], focused_chunks, ) stage["judge_prompt_preview"] = judge_prompt["user"][:2500] raw_response = client.chat(judge_prompt["system"], judge_prompt["user"]) parsed = parse_model_response(raw_response, mode="baseline") except Exception as exc: api_error = str(exc) focused_chunks = "" judge_prompt = build_judge_prompt(material, stage["extractor_raw"], stage["prosecutor_raw"], stage["defender_raw"], focused_chunks) stage["judge_prompt_preview"] = judge_prompt["user"][:2500] parsed = {"is_anomaly": None, "_parse_failed": True} latency = time.time() - start scores = score_episode(gt, parsed, meta.get("variant", "")) return { "episode_id": episode.get("episode_id", ""), "episode_path": episode_path, "metadata": meta, "ground_truth": gt, "raw_response": raw_response, "model_response": parsed, "scores": scores, "latency": latency, "api_error": api_error, "egpv3_trace": { "case_summary": { "event_count": material["event_count"], "chunk_count": len(material["chunk_index"]), "signals": material["signals"], "protocol_notes": material["protocol_notes"], }, "extractor_raw": stage["extractor_raw"], "extractor_parsed": stage["extractor_parsed"], "prosecutor_raw": stage["prosecutor_raw"], "prosecutor_parsed": stage["prosecutor_parsed"], "defender_raw": stage["defender_raw"], "defender_parsed": stage["defender_parsed"], "preview_extractor_prompt": extractor_prompt["user"][:2500], "preview_judge_prompt": stage["judge_prompt_preview"], }, }