Files
llmiotsafe/EGPv4/pipeline.py
2026-05-12 17:01:39 +08:00

116 lines
4.3 KiB
Python

import json
import re
import time
from typing import Any, Dict
from src.evaluation.scorer import score_episode
from EGPv4.extractor import build_evidence_packet, materialize_evidence_packet
from EGPv4.prompts import build_defender_prompt, build_judge_prompt, build_prosecutor_prompt
def _extract_jsonish(text: str) -> Dict[str, Any]:
raw = text.strip()
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL)
if match:
raw = match.group(1).strip()
try:
return json.loads(raw)
except json.JSONDecodeError:
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(raw[start : end + 1])
except json.JSONDecodeError:
pass
return {"_parse_failed": True, "raw_preview": text[:1000]}
def _normalize_judge_output(parsed: Dict[str, Any]) -> Dict[str, Any]:
if parsed.get("_parse_failed"):
return {"is_anomaly": None, "_parse_failed": True}
verdict = parsed.get("final_verdict", {}) if isinstance(parsed.get("final_verdict"), dict) else {}
return {
"is_anomaly": verdict.get("is_anomaly"),
"confidence": verdict.get("confidence", "low"),
"threat_type": verdict.get("threat_type", "none"),
"threat_description": parsed.get("threat_description", ""),
"reasoning": [json.dumps(parsed.get("reasoning", {}), ensure_ascii=False)],
"key_evidence": [],
"recommended_actions": parsed.get("recommended_actions", []),
"_judge_raw": parsed,
}
def evaluate_episode_with_egpv4(
episode: Dict[str, Any],
episode_path: str,
client,
chunk_size: int = 80,
max_focus_chunks: int = 6,
preview_only: bool = False,
) -> Dict[str, Any]:
meta = episode["metadata"]
gt = episode["ground_truth"]
evidence_packet = build_evidence_packet(episode, chunk_size=chunk_size, max_focus_chunks=max_focus_chunks)
evidence_text = materialize_evidence_packet(evidence_packet)
stage = {
"evidence_packet": evidence_packet,
"prosecutor_raw": "",
"prosecutor_parsed": {},
"defender_raw": "",
"defender_parsed": {},
"judge_raw": "",
"judge_parsed": {},
}
raw_response = ""
api_error = None
start = time.time()
prosecutor_prompt = build_prosecutor_prompt(evidence_text)
if preview_only:
parsed = {"is_anomaly": None, "_parse_failed": True}
else:
try:
stage["prosecutor_raw"] = client.chat(prosecutor_prompt["system"], prosecutor_prompt["user"])
stage["prosecutor_parsed"] = _extract_jsonish(stage["prosecutor_raw"])
defender_prompt = build_defender_prompt(evidence_text, stage["prosecutor_raw"])
stage["defender_raw"] = client.chat(defender_prompt["system"], defender_prompt["user"])
stage["defender_parsed"] = _extract_jsonish(stage["defender_raw"])
judge_prompt = build_judge_prompt(stage["prosecutor_raw"], stage["defender_raw"])
raw_response = client.chat(judge_prompt["system"], judge_prompt["user"])
stage["judge_raw"] = raw_response
stage["judge_parsed"] = _extract_jsonish(raw_response)
parsed = _normalize_judge_output(stage["judge_parsed"])
except Exception as exc:
api_error = str(exc)
parsed = {"is_anomaly": None, "_parse_failed": True}
latency = time.time() - start
scores = score_episode(gt, parsed, meta.get("variant", ""))
return {
"episode_id": episode.get("episode_id", ""),
"episode_path": episode_path,
"metadata": meta,
"ground_truth": gt,
"raw_response": raw_response,
"model_response": parsed,
"scores": scores,
"latency": latency,
"api_error": api_error,
"egpv4_trace": {
"evidence_packet": evidence_packet,
"prosecutor_raw": stage["prosecutor_raw"],
"prosecutor_parsed": stage["prosecutor_parsed"],
"defender_raw": stage["defender_raw"],
"defender_parsed": stage["defender_parsed"],
"judge_raw": stage["judge_raw"],
"judge_parsed": stage["judge_parsed"],
"preview_evidence_text": evidence_text[:3000],
},
}