Files
llmiotsafe/EGPv3_2/pipeline.py
2026-05-12 17:01:39 +08:00

230 lines
9.1 KiB
Python

import json
import re
import time
from typing import Any, Dict, List
from src.evaluation.scorer import parse_model_response, score_episode
from EGPv3_2.prompts import (
build_defender_prompt,
build_extractor_prompt,
build_judge_prompt,
build_prosecutor_prompt,
)
from EGPv3_2.signals import build_case_material, materialize_chunks
def _extract_jsonish(text: str) -> Dict[str, Any]:
raw = text.strip()
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL)
if match:
raw = match.group(1).strip()
try:
return json.loads(raw)
except json.JSONDecodeError:
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(raw[start : end + 1])
except json.JSONDecodeError:
pass
return {"_parse_failed": True, "raw_preview": text[:1000]}
def _dump_json(data: Dict[str, Any], fallback: str = "") -> str:
if not data or data.get("_parse_failed"):
return fallback
return json.dumps(data, ensure_ascii=False, indent=2)
def _normalize_extractor(parsed: Dict[str, Any]) -> Dict[str, Any]:
primary = str(parsed.get("primary_task_profile", "")).strip()
secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none"
if primary and "latent_task_profile" not in parsed:
parsed["latent_task_profile"] = primary if secondary == "none" else f"{primary} | {secondary}"
return parsed
def _infer_query_profile(query: str) -> str:
query = str(query or "")
if ("工作正常" in query) or ("故障类型" in query):
return "device-health"
if ("应急响应" in query) or ("严重程度" in query and "威胁类型" in query):
return "emergency-response"
if ("综合分析" in query and "安全状况" in query) or ("潜在风险" in query):
return "composite-safety"
if "异常行为模式" in query:
return "behavior-sequence"
if "安全威胁" in query:
return "single-event-safety"
return ""
def _apply_query_intent_guard(parsed: Dict[str, Any], query: str) -> Dict[str, Any]:
inferred = _infer_query_profile(query)
if not inferred:
return parsed
parsed["query_intent_profile"] = inferred
primary = str(parsed.get("primary_task_profile", "")).strip()
secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none"
if not primary:
parsed["primary_task_profile"] = inferred
return _normalize_extractor(parsed)
if primary == "device-health" and inferred != "device-health":
parsed["primary_task_profile"] = inferred
parsed["secondary_task_profile"] = "device-health" if secondary in ("", "none") else secondary
notes = parsed.get("guardrail_notes", [])
if not isinstance(notes, list):
notes = [str(notes)]
notes.append("Primary task profile corrected from query intent to avoid device-health drift.")
parsed["guardrail_notes"] = notes
return _normalize_extractor(parsed)
def _valid_chunk_ids(material: Dict) -> List[str]:
return [chunk["chunk_id"] for chunk in material["chunk_index"]]
def _dedupe_chunk_ids(chunk_ids: List[str], valid_ids: List[str]) -> List[str]:
valid_set = set(valid_ids)
out: List[str] = []
for chunk_id in chunk_ids:
if chunk_id in valid_set and chunk_id not in out:
out.append(chunk_id)
return out
def _neighbor_chunk_ids(chunk_ids: List[str], valid_ids: List[str], radius: int = 1) -> List[str]:
order = {chunk_id: idx for idx, chunk_id in enumerate(valid_ids)}
neighbors: List[str] = []
for chunk_id in chunk_ids:
idx = order.get(chunk_id)
if idx is None:
continue
for offset in range(1, radius + 1):
for neighbor_idx in (idx - offset, idx + offset):
if 0 <= neighbor_idx < len(valid_ids):
neighbor_id = valid_ids[neighbor_idx]
if neighbor_id not in neighbors:
neighbors.append(neighbor_id)
return neighbors
def _select_focus_ids(material: Dict, extractor_parsed: Dict[str, Any], limit: int) -> List[str]:
valid_ids = _valid_chunk_ids(material)
raw_ids = extractor_parsed.get("focus_chunk_ids", [])
if not isinstance(raw_ids, list):
raw_ids = []
focus_ids = _dedupe_chunk_ids(raw_ids, valid_ids)
if not focus_ids:
focus_ids = valid_ids[: min(limit, len(valid_ids))]
primary = str(extractor_parsed.get("primary_task_profile", "")).strip()
if primary in {"single-event-safety", "behavior-sequence", "composite-safety", "emergency-response"}:
focus_ids = _dedupe_chunk_ids(focus_ids + _neighbor_chunk_ids(focus_ids, valid_ids, radius=1), valid_ids)
return focus_ids[:limit]
def evaluate_episode_with_egpv3(
episode: Dict[str, Any],
episode_path: str,
client,
chunk_size: int = 80,
max_focus_chunks: int = 6,
preview_only: bool = False,
) -> Dict[str, Any]:
meta = episode["metadata"]
gt = episode["ground_truth"]
material = build_case_material(episode, chunk_size=chunk_size)
extractor_prompt = build_extractor_prompt(material)
stage = {
"extractor_raw": "",
"extractor_parsed": {},
"prosecutor_raw": "",
"prosecutor_parsed": {},
"defender_raw": "",
"defender_parsed": {},
"judge_prompt_preview": "",
}
raw_response = ""
api_error = None
start = time.time()
if preview_only:
focus_ids = [c["chunk_id"] for c in material["chunk_index"][: min(4, len(material["chunk_index"]))]]
focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=max_focus_chunks)
judge_prompt = build_judge_prompt(material, "<preview>", "<preview>", "<preview>", focused_chunks)
parsed = {"is_anomaly": None, "_parse_failed": True}
else:
try:
stage["extractor_raw"] = client.chat(extractor_prompt["system"], extractor_prompt["user"])
stage["extractor_parsed"] = _apply_query_intent_guard(
_normalize_extractor(_extract_jsonish(stage["extractor_raw"])),
material.get("query", ""),
)
extractor_context = _dump_json(stage["extractor_parsed"], fallback=stage["extractor_raw"])
focus_ids = _select_focus_ids(material, stage["extractor_parsed"], max_focus_chunks)
focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=max_focus_chunks)
prosecutor_prompt = build_prosecutor_prompt(material, extractor_context, focused_chunks)
stage["prosecutor_raw"] = client.chat(prosecutor_prompt["system"], prosecutor_prompt["user"])
stage["prosecutor_parsed"] = _extract_jsonish(stage["prosecutor_raw"])
defender_prompt = build_defender_prompt(
material,
extractor_context,
focused_chunks,
stage["prosecutor_raw"],
)
stage["defender_raw"] = client.chat(defender_prompt["system"], defender_prompt["user"])
stage["defender_parsed"] = _extract_jsonish(stage["defender_raw"])
judge_prompt = build_judge_prompt(
material,
extractor_context,
stage["prosecutor_raw"],
stage["defender_raw"],
focused_chunks,
)
stage["judge_prompt_preview"] = judge_prompt["user"][:2500]
raw_response = client.chat(judge_prompt["system"], judge_prompt["user"])
parsed = parse_model_response(raw_response, mode="baseline")
except Exception as exc:
api_error = str(exc)
focused_chunks = ""
judge_prompt = build_judge_prompt(material, stage["extractor_raw"], stage["prosecutor_raw"], stage["defender_raw"], focused_chunks)
stage["judge_prompt_preview"] = judge_prompt["user"][:2500]
parsed = {"is_anomaly": None, "_parse_failed": True}
latency = time.time() - start
scores = score_episode(gt, parsed, meta.get("variant", ""))
return {
"episode_id": episode.get("episode_id", ""),
"episode_path": episode_path,
"metadata": meta,
"ground_truth": gt,
"raw_response": raw_response,
"model_response": parsed,
"scores": scores,
"latency": latency,
"api_error": api_error,
"egpv3_trace": {
"case_summary": {
"event_count": material["event_count"],
"chunk_count": len(material["chunk_index"]),
"signals": material["signals"],
"protocol_notes": material["protocol_notes"],
},
"extractor_raw": stage["extractor_raw"],
"extractor_parsed": stage["extractor_parsed"],
"prosecutor_raw": stage["prosecutor_raw"],
"prosecutor_parsed": stage["prosecutor_parsed"],
"defender_raw": stage["defender_raw"],
"defender_parsed": stage["defender_parsed"],
"preview_extractor_prompt": extractor_prompt["user"][:2500],
"preview_judge_prompt": stage["judge_prompt_preview"],
},
}