230 lines
9.1 KiB
Python
230 lines
9.1 KiB
Python
import json
|
|
import re
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
from src.evaluation.scorer import parse_model_response, score_episode
|
|
|
|
from EGPv3_1.prompts import (
|
|
build_defender_prompt,
|
|
build_extractor_prompt,
|
|
build_judge_prompt,
|
|
build_prosecutor_prompt,
|
|
)
|
|
from EGPv3_1.signals import build_case_material, materialize_chunks
|
|
|
|
|
|
def _extract_jsonish(text: str) -> Dict[str, Any]:
|
|
raw = text.strip()
|
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL)
|
|
if match:
|
|
raw = match.group(1).strip()
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
start = raw.find("{")
|
|
end = raw.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
try:
|
|
return json.loads(raw[start : end + 1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"_parse_failed": True, "raw_preview": text[:1000]}
|
|
|
|
|
|
def _dump_json(data: Dict[str, Any], fallback: str = "") -> str:
|
|
if not data or data.get("_parse_failed"):
|
|
return fallback
|
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def _normalize_extractor(parsed: Dict[str, Any]) -> Dict[str, Any]:
|
|
primary = str(parsed.get("primary_task_profile", "")).strip()
|
|
secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none"
|
|
if primary and "latent_task_profile" not in parsed:
|
|
parsed["latent_task_profile"] = primary if secondary == "none" else f"{primary} | {secondary}"
|
|
return parsed
|
|
|
|
|
|
def _infer_query_profile(query: str) -> str:
|
|
query = str(query or "")
|
|
if ("工作正常" in query) or ("故障类型" in query):
|
|
return "device-health"
|
|
if ("应急响应" in query) or ("严重程度" in query and "威胁类型" in query):
|
|
return "emergency-response"
|
|
if ("综合分析" in query and "安全状况" in query) or ("潜在风险" in query):
|
|
return "composite-safety"
|
|
if "异常行为模式" in query:
|
|
return "behavior-sequence"
|
|
if "安全威胁" in query:
|
|
return "single-event-safety"
|
|
return ""
|
|
|
|
|
|
def _apply_query_intent_guard(parsed: Dict[str, Any], query: str) -> Dict[str, Any]:
|
|
inferred = _infer_query_profile(query)
|
|
if not inferred:
|
|
return parsed
|
|
parsed["query_intent_profile"] = inferred
|
|
primary = str(parsed.get("primary_task_profile", "")).strip()
|
|
secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none"
|
|
if not primary:
|
|
parsed["primary_task_profile"] = inferred
|
|
return _normalize_extractor(parsed)
|
|
if primary == "device-health" and inferred != "device-health":
|
|
parsed["primary_task_profile"] = inferred
|
|
parsed["secondary_task_profile"] = "device-health" if secondary in ("", "none") else secondary
|
|
notes = parsed.get("guardrail_notes", [])
|
|
if not isinstance(notes, list):
|
|
notes = [str(notes)]
|
|
notes.append("Primary task profile corrected from query intent to avoid device-health drift.")
|
|
parsed["guardrail_notes"] = notes
|
|
return _normalize_extractor(parsed)
|
|
|
|
|
|
def _valid_chunk_ids(material: Dict) -> List[str]:
|
|
return [chunk["chunk_id"] for chunk in material["chunk_index"]]
|
|
|
|
|
|
def _dedupe_chunk_ids(chunk_ids: List[str], valid_ids: List[str]) -> List[str]:
|
|
valid_set = set(valid_ids)
|
|
out: List[str] = []
|
|
for chunk_id in chunk_ids:
|
|
if chunk_id in valid_set and chunk_id not in out:
|
|
out.append(chunk_id)
|
|
return out
|
|
|
|
|
|
def _neighbor_chunk_ids(chunk_ids: List[str], valid_ids: List[str], radius: int = 1) -> List[str]:
|
|
order = {chunk_id: idx for idx, chunk_id in enumerate(valid_ids)}
|
|
neighbors: List[str] = []
|
|
for chunk_id in chunk_ids:
|
|
idx = order.get(chunk_id)
|
|
if idx is None:
|
|
continue
|
|
for offset in range(1, radius + 1):
|
|
for neighbor_idx in (idx - offset, idx + offset):
|
|
if 0 <= neighbor_idx < len(valid_ids):
|
|
neighbor_id = valid_ids[neighbor_idx]
|
|
if neighbor_id not in neighbors:
|
|
neighbors.append(neighbor_id)
|
|
return neighbors
|
|
|
|
|
|
def _select_focus_ids(material: Dict, extractor_parsed: Dict[str, Any], limit: int) -> List[str]:
|
|
valid_ids = _valid_chunk_ids(material)
|
|
raw_ids = extractor_parsed.get("focus_chunk_ids", [])
|
|
if not isinstance(raw_ids, list):
|
|
raw_ids = []
|
|
focus_ids = _dedupe_chunk_ids(raw_ids, valid_ids)
|
|
if not focus_ids:
|
|
focus_ids = valid_ids[: min(limit, len(valid_ids))]
|
|
primary = str(extractor_parsed.get("primary_task_profile", "")).strip()
|
|
if primary in {"single-event-safety", "behavior-sequence", "composite-safety", "emergency-response"}:
|
|
focus_ids = _dedupe_chunk_ids(focus_ids + _neighbor_chunk_ids(focus_ids, valid_ids, radius=1), valid_ids)
|
|
return focus_ids[:limit]
|
|
|
|
|
|
def evaluate_episode_with_egpv3(
|
|
episode: Dict[str, Any],
|
|
episode_path: str,
|
|
client,
|
|
chunk_size: int = 80,
|
|
max_focus_chunks: int = 6,
|
|
preview_only: bool = False,
|
|
) -> Dict[str, Any]:
|
|
meta = episode["metadata"]
|
|
gt = episode["ground_truth"]
|
|
material = build_case_material(episode, chunk_size=chunk_size)
|
|
|
|
extractor_prompt = build_extractor_prompt(material)
|
|
stage = {
|
|
"extractor_raw": "",
|
|
"extractor_parsed": {},
|
|
"prosecutor_raw": "",
|
|
"prosecutor_parsed": {},
|
|
"defender_raw": "",
|
|
"defender_parsed": {},
|
|
"judge_prompt_preview": "",
|
|
}
|
|
raw_response = ""
|
|
api_error = None
|
|
start = time.time()
|
|
|
|
if preview_only:
|
|
focus_ids = [c["chunk_id"] for c in material["chunk_index"][: min(4, len(material["chunk_index"]))]]
|
|
focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=max_focus_chunks)
|
|
judge_prompt = build_judge_prompt(material, "<preview>", "<preview>", "<preview>", focused_chunks)
|
|
parsed = {"is_anomaly": None, "_parse_failed": True}
|
|
else:
|
|
try:
|
|
stage["extractor_raw"] = client.chat(extractor_prompt["system"], extractor_prompt["user"])
|
|
stage["extractor_parsed"] = _apply_query_intent_guard(
|
|
_normalize_extractor(_extract_jsonish(stage["extractor_raw"])),
|
|
material.get("query", ""),
|
|
)
|
|
extractor_context = _dump_json(stage["extractor_parsed"], fallback=stage["extractor_raw"])
|
|
|
|
focus_ids = _select_focus_ids(material, stage["extractor_parsed"], max_focus_chunks)
|
|
focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=max_focus_chunks)
|
|
|
|
prosecutor_prompt = build_prosecutor_prompt(material, extractor_context, focused_chunks)
|
|
stage["prosecutor_raw"] = client.chat(prosecutor_prompt["system"], prosecutor_prompt["user"])
|
|
stage["prosecutor_parsed"] = _extract_jsonish(stage["prosecutor_raw"])
|
|
|
|
defender_prompt = build_defender_prompt(
|
|
material,
|
|
extractor_context,
|
|
focused_chunks,
|
|
stage["prosecutor_raw"],
|
|
)
|
|
stage["defender_raw"] = client.chat(defender_prompt["system"], defender_prompt["user"])
|
|
stage["defender_parsed"] = _extract_jsonish(stage["defender_raw"])
|
|
|
|
judge_prompt = build_judge_prompt(
|
|
material,
|
|
extractor_context,
|
|
stage["prosecutor_raw"],
|
|
stage["defender_raw"],
|
|
focused_chunks,
|
|
)
|
|
stage["judge_prompt_preview"] = judge_prompt["user"][:2500]
|
|
raw_response = client.chat(judge_prompt["system"], judge_prompt["user"])
|
|
parsed = parse_model_response(raw_response, mode="baseline")
|
|
except Exception as exc:
|
|
api_error = str(exc)
|
|
focused_chunks = ""
|
|
judge_prompt = build_judge_prompt(material, stage["extractor_raw"], stage["prosecutor_raw"], stage["defender_raw"], focused_chunks)
|
|
stage["judge_prompt_preview"] = judge_prompt["user"][:2500]
|
|
parsed = {"is_anomaly": None, "_parse_failed": True}
|
|
|
|
latency = time.time() - start
|
|
scores = score_episode(gt, parsed, meta.get("variant", ""))
|
|
return {
|
|
"episode_id": episode.get("episode_id", ""),
|
|
"episode_path": episode_path,
|
|
"metadata": meta,
|
|
"ground_truth": gt,
|
|
"raw_response": raw_response,
|
|
"model_response": parsed,
|
|
"scores": scores,
|
|
"latency": latency,
|
|
"api_error": api_error,
|
|
"egpv3_trace": {
|
|
"case_summary": {
|
|
"event_count": material["event_count"],
|
|
"chunk_count": len(material["chunk_index"]),
|
|
"signals": material["signals"],
|
|
"protocol_notes": material["protocol_notes"],
|
|
},
|
|
"extractor_raw": stage["extractor_raw"],
|
|
"extractor_parsed": stage["extractor_parsed"],
|
|
"prosecutor_raw": stage["prosecutor_raw"],
|
|
"prosecutor_parsed": stage["prosecutor_parsed"],
|
|
"defender_raw": stage["defender_raw"],
|
|
"defender_parsed": stage["defender_parsed"],
|
|
"preview_extractor_prompt": extractor_prompt["user"][:2500],
|
|
"preview_judge_prompt": stage["judge_prompt_preview"],
|
|
},
|
|
}
|