import json import re import time from typing import Any, Dict, List from src.evaluation.scorer import parse_model_response, score_episode from EGPv2_1.prompts import ( build_investigator_prompt, build_supervisor_prompt, build_triage_prompt, build_verifier_prompt, ) from EGPv2_1.signals import build_case_material, materialize_chunks def _extract_jsonish(text: str) -> Dict[str, Any]: raw = text.strip() match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL) if match: raw = match.group(1).strip() try: return json.loads(raw) except json.JSONDecodeError: start = raw.find("{") end = raw.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(raw[start : end + 1]) except json.JSONDecodeError: pass return {"_parse_failed": True, "raw_preview": text[:1000]} def _normalize_triage(parsed: Dict[str, Any]) -> Dict[str, Any]: primary = str(parsed.get("primary_task_profile", "")).strip() secondary = str(parsed.get("secondary_task_profile", "none")).strip() if primary and "latent_task_profile" not in parsed: parsed["latent_task_profile"] = primary if secondary in ("", "none") else f"{primary} | {secondary}" return parsed def _dump_json(data: Dict[str, Any], fallback: str = "") -> str: if not data or data.get("_parse_failed"): return fallback return json.dumps(data, ensure_ascii=False, indent=2) def _infer_query_profile(query: str) -> str: query = str(query or "") if ("工作正常" in query) or ("故障类型" in query): return "device-health" if ("应急响应" in query) or ("严重程度" in query and "威胁类型" in query): return "emergency-response" if ("综合分析" in query and "安全状况" in query) or ("潜在风险" in query): return "composite-safety" if "异常行为模式" in query: return "behavior-sequence" if "安全威胁" in query: return "single-event-safety" return "" def _apply_query_intent_guard(parsed: Dict[str, Any], query: str) -> Dict[str, Any]: inferred = _infer_query_profile(query) if not inferred: return parsed parsed["query_intent_profile"] = inferred primary = str(parsed.get("primary_task_profile", "")).strip() secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none" if not primary: parsed["primary_task_profile"] = inferred return _normalize_triage(parsed) if primary == "device-health" and inferred != "device-health": parsed["primary_task_profile"] = inferred parsed["secondary_task_profile"] = "device-health" if secondary in ("", "none") else secondary notes = parsed.get("guardrail_notes", []) if not isinstance(notes, list): notes = [str(notes)] notes.append("Primary task profile corrected from query intent to avoid telemetry-only drift.") parsed["guardrail_notes"] = notes return _normalize_triage(parsed) def _valid_chunk_ids(material: Dict) -> List[str]: return [chunk["chunk_id"] for chunk in material["chunk_index"]] def _dedupe_chunk_ids(chunk_ids: List[str], valid_ids: List[str]) -> List[str]: valid_set = set(valid_ids) deduped = [] for chunk_id in chunk_ids: if chunk_id in valid_set and chunk_id not in deduped: deduped.append(chunk_id) return deduped def _select_focus_ids(material: Dict, triage_parsed: Dict[str, Any], limit: int) -> List[str]: valid_ids = _valid_chunk_ids(material) raw_ids = triage_parsed.get("focus_chunk_ids", []) if not isinstance(raw_ids, list): raw_ids = [] focus_ids = _dedupe_chunk_ids(raw_ids, valid_ids) if not focus_ids: focus_ids = valid_ids[: min(2, len(valid_ids))] return focus_ids[:limit] def _merge_focus_ids(base_ids: List[str], extra_ids: List[str], valid_ids: List[str], limit: int) -> List[str]: merged = list(base_ids) for chunk_id in extra_ids: if chunk_id in valid_ids and chunk_id not in merged: merged.append(chunk_id) return merged[:limit] def _chunk_order(chunk_id: str) -> int: match = re.search(r"(\d+)$", str(chunk_id)) return int(match.group(1)) if match else -1 def _neighbor_chunk_ids(chunk_ids: List[str], valid_ids: List[str], radius: int = 1) -> List[str]: order = {chunk_id: idx for idx, chunk_id in enumerate(valid_ids)} neighbors: List[str] = [] for chunk_id in chunk_ids: idx = order.get(chunk_id) if idx is None: continue for offset in range(1, radius + 1): for neighbor_idx in (idx - offset, idx + offset): if 0 <= neighbor_idx < len(valid_ids): neighbor_id = valid_ids[neighbor_idx] if neighbor_id not in neighbors: neighbors.append(neighbor_id) return neighbors def _build_second_round_focus( triage_parsed: Dict[str, Any], base_ids: List[str], extra_ids: List[str], valid_ids: List[str], limit: int, ) -> List[str]: primary = str(triage_parsed.get("primary_task_profile", "")).strip() extra_ids = _dedupe_chunk_ids(extra_ids, valid_ids) extra_neighbors = _neighbor_chunk_ids(extra_ids, valid_ids, radius=1) base_neighbors = _neighbor_chunk_ids(base_ids, valid_ids, radius=1) if primary in {"single-event-safety", "behavior-sequence", "composite-safety", "emergency-response"}: priority = extra_ids + extra_neighbors + base_ids + base_neighbors else: priority = base_ids + extra_ids + extra_neighbors + base_neighbors return _dedupe_chunk_ids(priority, valid_ids)[:limit] def _supervisor_decision(parsed: Dict[str, Any], round_index: int) -> str: action = str(parsed.get("recommended_action", "")).strip().lower() if action in {"allow_final_verdict", "refine_investigation", "abstain"}: return action if parsed.get("_parse_failed"): return "allow_final_verdict" if parsed.get("on_topic") is False: return "refine_investigation" if round_index == 1 else "abstain" if parsed.get("evidence_sufficient") is True: return "allow_final_verdict" risk = str(parsed.get("risk_of_false_alarm", "medium")).strip().lower() if round_index == 1: return "refine_investigation" if risk == "high" and parsed.get("on_topic") is False: return "abstain" return "allow_final_verdict" def _build_abstain_response(stage: Dict[str, Any], reason: str) -> Dict[str, Any]: triage = stage.get("triage_parsed", {}) supervisor = stage.get("supervisor_rounds", [{}])[-1].get("parsed", {}) if stage.get("supervisor_rounds") else {} notes = supervisor.get("supervisor_notes", [])[:2] missing = supervisor.get("missing_checks", [])[:2] reasoning = [ f"Supervisor gate enforced: {reason}.", f"Primary task profile: {triage.get('primary_task_profile') or triage.get('latent_task_profile', 'unknown')}.", ] reasoning.extend(str(note) for note in notes) actions = ["Gather the missing evidence before issuing a high-confidence anomaly verdict."] actions.extend(str(item) for item in missing) return { "is_anomaly": False, "confidence": "low", "threat_type": "none", "threat_description": "Evidence remains insufficient for a supported anomaly conclusion, so the workflow defaults to a conservative non-anomaly result.", "reasoning": reasoning, "key_evidence": [], "recommended_actions": actions[:3], "_supervisor_abstain": True, } def evaluate_episode_with_egpv2( episode: Dict[str, Any], episode_path: str, client, chunk_size: int = 80, max_focus_chunks: int = 5, preview_only: bool = False, ) -> Dict[str, Any]: meta = episode["metadata"] gt = episode["ground_truth"] material = build_case_material(episode, chunk_size=chunk_size) triage_prompt = build_triage_prompt(material) stage = { "triage_raw": "", "triage_parsed": {}, "investigator_raw": "", "investigator_parsed": {}, "supervisor_raw": "", "supervisor_parsed": {}, "investigator_rounds": [], "supervisor_rounds": [], "supervisor_gate": {"final_action": "", "abstained": False}, } raw_response = "" api_error = None start = time.time() if preview_only: focused_chunks = materialize_chunks(material["chunks"], [c["chunk_id"] for c in material["chunk_index"][:3]]) verifier_prompt = build_verifier_prompt(material, "", "", "", focused_chunks) parsed = {"is_anomaly": None, "_parse_failed": True} else: try: valid_chunk_ids = _valid_chunk_ids(material) first_round_limit = min(max_focus_chunks, 4) second_round_limit = max_focus_chunks + 2 stage["triage_raw"] = client.chat(triage_prompt["system"], triage_prompt["user"]) stage["triage_parsed"] = _apply_query_intent_guard( _normalize_triage(_extract_jsonish(stage["triage_raw"])), material.get("query", ""), ) triage_context = _dump_json(stage["triage_parsed"], fallback=stage["triage_raw"]) focus_ids = _select_focus_ids(material, stage["triage_parsed"], first_round_limit) focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=first_round_limit) investigator_prompt = build_investigator_prompt( material, triage_context, focused_chunks, round_index=1, ) investigator_raw = client.chat(investigator_prompt["system"], investigator_prompt["user"]) investigator_parsed = _extract_jsonish(investigator_raw) stage["investigator_rounds"].append({"round": 1, "focus_ids": focus_ids, "raw": investigator_raw, "parsed": investigator_parsed}) supervisor_prompt = build_supervisor_prompt( material, triage_context, investigator_raw, focused_chunks, round_index=1, ) supervisor_raw = client.chat(supervisor_prompt["system"], supervisor_prompt["user"]) supervisor_parsed = _extract_jsonish(supervisor_raw) stage["supervisor_rounds"].append({"round": 1, "raw": supervisor_raw, "parsed": supervisor_parsed}) decision = _supervisor_decision(supervisor_parsed, round_index=1) final_focus_ids = list(focus_ids) final_investigator_raw = investigator_raw final_investigator_parsed = investigator_parsed final_supervisor_raw = supervisor_raw final_supervisor_parsed = supervisor_parsed if decision == "refine_investigation": extra_focus = supervisor_parsed.get("needs_more_chunks", []) if not isinstance(extra_focus, list): extra_focus = [] final_focus_ids = _build_second_round_focus( stage["triage_parsed"], focus_ids, extra_focus, valid_chunk_ids, second_round_limit, ) if not final_focus_ids: final_focus_ids = _merge_focus_ids(focus_ids, extra_focus, valid_chunk_ids, second_round_limit) focused_chunks = materialize_chunks(material["chunks"], final_focus_ids, max_chunks=second_round_limit) investigator_prompt = build_investigator_prompt( material, triage_context, focused_chunks, supervisor_text=supervisor_raw, round_index=2, ) final_investigator_raw = client.chat(investigator_prompt["system"], investigator_prompt["user"]) final_investigator_parsed = _extract_jsonish(final_investigator_raw) stage["investigator_rounds"].append({"round": 2, "focus_ids": final_focus_ids, "raw": final_investigator_raw, "parsed": final_investigator_parsed}) supervisor_prompt = build_supervisor_prompt( material, triage_context, final_investigator_raw, focused_chunks, round_index=2, ) final_supervisor_raw = client.chat(supervisor_prompt["system"], supervisor_prompt["user"]) final_supervisor_parsed = _extract_jsonish(final_supervisor_raw) stage["supervisor_rounds"].append({"round": 2, "raw": final_supervisor_raw, "parsed": final_supervisor_parsed}) decision = _supervisor_decision(final_supervisor_parsed, round_index=2) stage["investigator_raw"] = final_investigator_raw stage["investigator_parsed"] = final_investigator_parsed stage["supervisor_raw"] = final_supervisor_raw stage["supervisor_parsed"] = final_supervisor_parsed stage["supervisor_gate"]["final_action"] = decision if decision == "abstain": stage["supervisor_gate"]["abstained"] = True parsed = _build_abstain_response(stage, reason=decision or "high false-alarm risk") verifier_prompt = build_verifier_prompt( material, triage_context, stage["investigator_raw"], stage["supervisor_raw"], focused_chunks, ) else: verifier_prompt = build_verifier_prompt( material, triage_context, stage["investigator_raw"], stage["supervisor_raw"], focused_chunks, ) raw_response = client.chat(verifier_prompt["system"], verifier_prompt["user"]) parsed = parse_model_response(raw_response, mode="baseline") except Exception as exc: api_error = str(exc) focused_chunks = "" verifier_prompt = build_verifier_prompt(material, stage["triage_raw"], stage["investigator_raw"], stage["supervisor_raw"], focused_chunks) parsed = {"is_anomaly": None, "_parse_failed": True} latency = time.time() - start scores = score_episode(gt, parsed, meta.get("variant", "")) return { "episode_id": episode.get("episode_id", ""), "episode_path": episode_path, "metadata": meta, "ground_truth": gt, "raw_response": raw_response, "model_response": parsed, "scores": scores, "latency": latency, "api_error": api_error, "egpv2_trace": { "case_summary": { "event_count": material["event_count"], "chunk_count": len(material["chunk_index"]), "signals": material["signals"], "protocol_notes": material["protocol_notes"], }, "triage_raw": stage["triage_raw"], "triage_parsed": stage["triage_parsed"], "investigator_raw": stage["investigator_raw"], "investigator_parsed": stage["investigator_parsed"], "supervisor_raw": stage["supervisor_raw"], "supervisor_parsed": stage["supervisor_parsed"], "investigator_rounds": stage["investigator_rounds"], "supervisor_rounds": stage["supervisor_rounds"], "supervisor_gate": stage["supervisor_gate"], "preview_triage_prompt": triage_prompt["user"][:2500], "preview_verifier_prompt": verifier_prompt["user"][:2500], }, }