Files
llmiotsafe/EGPv2_1/pipeline.py
2026-05-12 17:01:39 +08:00

375 lines
16 KiB
Python

import json
import re
import time
from typing import Any, Dict, List
from src.evaluation.scorer import parse_model_response, score_episode
from EGPv2_1.prompts import (
build_investigator_prompt,
build_supervisor_prompt,
build_triage_prompt,
build_verifier_prompt,
)
from EGPv2_1.signals import build_case_material, materialize_chunks
def _extract_jsonish(text: str) -> Dict[str, Any]:
raw = text.strip()
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL)
if match:
raw = match.group(1).strip()
try:
return json.loads(raw)
except json.JSONDecodeError:
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(raw[start : end + 1])
except json.JSONDecodeError:
pass
return {"_parse_failed": True, "raw_preview": text[:1000]}
def _normalize_triage(parsed: Dict[str, Any]) -> Dict[str, Any]:
primary = str(parsed.get("primary_task_profile", "")).strip()
secondary = str(parsed.get("secondary_task_profile", "none")).strip()
if primary and "latent_task_profile" not in parsed:
parsed["latent_task_profile"] = primary if secondary in ("", "none") else f"{primary} | {secondary}"
return parsed
def _dump_json(data: Dict[str, Any], fallback: str = "") -> str:
if not data or data.get("_parse_failed"):
return fallback
return json.dumps(data, ensure_ascii=False, indent=2)
def _infer_query_profile(query: str) -> str:
query = str(query or "")
if ("工作正常" in query) or ("故障类型" in query):
return "device-health"
if ("应急响应" in query) or ("严重程度" in query and "威胁类型" in query):
return "emergency-response"
if ("综合分析" in query and "安全状况" in query) or ("潜在风险" in query):
return "composite-safety"
if "异常行为模式" in query:
return "behavior-sequence"
if "安全威胁" in query:
return "single-event-safety"
return ""
def _apply_query_intent_guard(parsed: Dict[str, Any], query: str) -> Dict[str, Any]:
inferred = _infer_query_profile(query)
if not inferred:
return parsed
parsed["query_intent_profile"] = inferred
primary = str(parsed.get("primary_task_profile", "")).strip()
secondary = str(parsed.get("secondary_task_profile", "none")).strip() or "none"
if not primary:
parsed["primary_task_profile"] = inferred
return _normalize_triage(parsed)
if primary == "device-health" and inferred != "device-health":
parsed["primary_task_profile"] = inferred
parsed["secondary_task_profile"] = "device-health" if secondary in ("", "none") else secondary
notes = parsed.get("guardrail_notes", [])
if not isinstance(notes, list):
notes = [str(notes)]
notes.append("Primary task profile corrected from query intent to avoid telemetry-only drift.")
parsed["guardrail_notes"] = notes
return _normalize_triage(parsed)
def _valid_chunk_ids(material: Dict) -> List[str]:
return [chunk["chunk_id"] for chunk in material["chunk_index"]]
def _dedupe_chunk_ids(chunk_ids: List[str], valid_ids: List[str]) -> List[str]:
valid_set = set(valid_ids)
deduped = []
for chunk_id in chunk_ids:
if chunk_id in valid_set and chunk_id not in deduped:
deduped.append(chunk_id)
return deduped
def _select_focus_ids(material: Dict, triage_parsed: Dict[str, Any], limit: int) -> List[str]:
valid_ids = _valid_chunk_ids(material)
raw_ids = triage_parsed.get("focus_chunk_ids", [])
if not isinstance(raw_ids, list):
raw_ids = []
focus_ids = _dedupe_chunk_ids(raw_ids, valid_ids)
if not focus_ids:
focus_ids = valid_ids[: min(2, len(valid_ids))]
return focus_ids[:limit]
def _merge_focus_ids(base_ids: List[str], extra_ids: List[str], valid_ids: List[str], limit: int) -> List[str]:
merged = list(base_ids)
for chunk_id in extra_ids:
if chunk_id in valid_ids and chunk_id not in merged:
merged.append(chunk_id)
return merged[:limit]
def _chunk_order(chunk_id: str) -> int:
match = re.search(r"(\d+)$", str(chunk_id))
return int(match.group(1)) if match else -1
def _neighbor_chunk_ids(chunk_ids: List[str], valid_ids: List[str], radius: int = 1) -> List[str]:
order = {chunk_id: idx for idx, chunk_id in enumerate(valid_ids)}
neighbors: List[str] = []
for chunk_id in chunk_ids:
idx = order.get(chunk_id)
if idx is None:
continue
for offset in range(1, radius + 1):
for neighbor_idx in (idx - offset, idx + offset):
if 0 <= neighbor_idx < len(valid_ids):
neighbor_id = valid_ids[neighbor_idx]
if neighbor_id not in neighbors:
neighbors.append(neighbor_id)
return neighbors
def _build_second_round_focus(
triage_parsed: Dict[str, Any],
base_ids: List[str],
extra_ids: List[str],
valid_ids: List[str],
limit: int,
) -> List[str]:
primary = str(triage_parsed.get("primary_task_profile", "")).strip()
extra_ids = _dedupe_chunk_ids(extra_ids, valid_ids)
extra_neighbors = _neighbor_chunk_ids(extra_ids, valid_ids, radius=1)
base_neighbors = _neighbor_chunk_ids(base_ids, valid_ids, radius=1)
if primary in {"single-event-safety", "behavior-sequence", "composite-safety", "emergency-response"}:
priority = extra_ids + extra_neighbors + base_ids + base_neighbors
else:
priority = base_ids + extra_ids + extra_neighbors + base_neighbors
return _dedupe_chunk_ids(priority, valid_ids)[:limit]
def _supervisor_decision(parsed: Dict[str, Any], round_index: int) -> str:
action = str(parsed.get("recommended_action", "")).strip().lower()
if action in {"allow_final_verdict", "refine_investigation", "abstain"}:
return action
if parsed.get("_parse_failed"):
return "allow_final_verdict"
if parsed.get("on_topic") is False:
return "refine_investigation" if round_index == 1 else "abstain"
if parsed.get("evidence_sufficient") is True:
return "allow_final_verdict"
risk = str(parsed.get("risk_of_false_alarm", "medium")).strip().lower()
if round_index == 1:
return "refine_investigation"
if risk == "high" and parsed.get("on_topic") is False:
return "abstain"
return "allow_final_verdict"
def _build_abstain_response(stage: Dict[str, Any], reason: str) -> Dict[str, Any]:
triage = stage.get("triage_parsed", {})
supervisor = stage.get("supervisor_rounds", [{}])[-1].get("parsed", {}) if stage.get("supervisor_rounds") else {}
notes = supervisor.get("supervisor_notes", [])[:2]
missing = supervisor.get("missing_checks", [])[:2]
reasoning = [
f"Supervisor gate enforced: {reason}.",
f"Primary task profile: {triage.get('primary_task_profile') or triage.get('latent_task_profile', 'unknown')}.",
]
reasoning.extend(str(note) for note in notes)
actions = ["Gather the missing evidence before issuing a high-confidence anomaly verdict."]
actions.extend(str(item) for item in missing)
return {
"is_anomaly": False,
"confidence": "low",
"threat_type": "none",
"threat_description": "Evidence remains insufficient for a supported anomaly conclusion, so the workflow defaults to a conservative non-anomaly result.",
"reasoning": reasoning,
"key_evidence": [],
"recommended_actions": actions[:3],
"_supervisor_abstain": True,
}
def evaluate_episode_with_egpv2(
episode: Dict[str, Any],
episode_path: str,
client,
chunk_size: int = 80,
max_focus_chunks: int = 5,
preview_only: bool = False,
) -> Dict[str, Any]:
meta = episode["metadata"]
gt = episode["ground_truth"]
material = build_case_material(episode, chunk_size=chunk_size)
triage_prompt = build_triage_prompt(material)
stage = {
"triage_raw": "",
"triage_parsed": {},
"investigator_raw": "",
"investigator_parsed": {},
"supervisor_raw": "",
"supervisor_parsed": {},
"investigator_rounds": [],
"supervisor_rounds": [],
"supervisor_gate": {"final_action": "", "abstained": False},
}
raw_response = ""
api_error = None
start = time.time()
if preview_only:
focused_chunks = materialize_chunks(material["chunks"], [c["chunk_id"] for c in material["chunk_index"][:3]])
verifier_prompt = build_verifier_prompt(material, "<preview>", "<preview>", "<preview>", focused_chunks)
parsed = {"is_anomaly": None, "_parse_failed": True}
else:
try:
valid_chunk_ids = _valid_chunk_ids(material)
first_round_limit = min(max_focus_chunks, 4)
second_round_limit = max_focus_chunks + 2
stage["triage_raw"] = client.chat(triage_prompt["system"], triage_prompt["user"])
stage["triage_parsed"] = _apply_query_intent_guard(
_normalize_triage(_extract_jsonish(stage["triage_raw"])),
material.get("query", ""),
)
triage_context = _dump_json(stage["triage_parsed"], fallback=stage["triage_raw"])
focus_ids = _select_focus_ids(material, stage["triage_parsed"], first_round_limit)
focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=first_round_limit)
investigator_prompt = build_investigator_prompt(
material,
triage_context,
focused_chunks,
round_index=1,
)
investigator_raw = client.chat(investigator_prompt["system"], investigator_prompt["user"])
investigator_parsed = _extract_jsonish(investigator_raw)
stage["investigator_rounds"].append({"round": 1, "focus_ids": focus_ids, "raw": investigator_raw, "parsed": investigator_parsed})
supervisor_prompt = build_supervisor_prompt(
material,
triage_context,
investigator_raw,
focused_chunks,
round_index=1,
)
supervisor_raw = client.chat(supervisor_prompt["system"], supervisor_prompt["user"])
supervisor_parsed = _extract_jsonish(supervisor_raw)
stage["supervisor_rounds"].append({"round": 1, "raw": supervisor_raw, "parsed": supervisor_parsed})
decision = _supervisor_decision(supervisor_parsed, round_index=1)
final_focus_ids = list(focus_ids)
final_investigator_raw = investigator_raw
final_investigator_parsed = investigator_parsed
final_supervisor_raw = supervisor_raw
final_supervisor_parsed = supervisor_parsed
if decision == "refine_investigation":
extra_focus = supervisor_parsed.get("needs_more_chunks", [])
if not isinstance(extra_focus, list):
extra_focus = []
final_focus_ids = _build_second_round_focus(
stage["triage_parsed"],
focus_ids,
extra_focus,
valid_chunk_ids,
second_round_limit,
)
if not final_focus_ids:
final_focus_ids = _merge_focus_ids(focus_ids, extra_focus, valid_chunk_ids, second_round_limit)
focused_chunks = materialize_chunks(material["chunks"], final_focus_ids, max_chunks=second_round_limit)
investigator_prompt = build_investigator_prompt(
material,
triage_context,
focused_chunks,
supervisor_text=supervisor_raw,
round_index=2,
)
final_investigator_raw = client.chat(investigator_prompt["system"], investigator_prompt["user"])
final_investigator_parsed = _extract_jsonish(final_investigator_raw)
stage["investigator_rounds"].append({"round": 2, "focus_ids": final_focus_ids, "raw": final_investigator_raw, "parsed": final_investigator_parsed})
supervisor_prompt = build_supervisor_prompt(
material,
triage_context,
final_investigator_raw,
focused_chunks,
round_index=2,
)
final_supervisor_raw = client.chat(supervisor_prompt["system"], supervisor_prompt["user"])
final_supervisor_parsed = _extract_jsonish(final_supervisor_raw)
stage["supervisor_rounds"].append({"round": 2, "raw": final_supervisor_raw, "parsed": final_supervisor_parsed})
decision = _supervisor_decision(final_supervisor_parsed, round_index=2)
stage["investigator_raw"] = final_investigator_raw
stage["investigator_parsed"] = final_investigator_parsed
stage["supervisor_raw"] = final_supervisor_raw
stage["supervisor_parsed"] = final_supervisor_parsed
stage["supervisor_gate"]["final_action"] = decision
if decision == "abstain":
stage["supervisor_gate"]["abstained"] = True
parsed = _build_abstain_response(stage, reason=decision or "high false-alarm risk")
verifier_prompt = build_verifier_prompt(
material,
triage_context,
stage["investigator_raw"],
stage["supervisor_raw"],
focused_chunks,
)
else:
verifier_prompt = build_verifier_prompt(
material,
triage_context,
stage["investigator_raw"],
stage["supervisor_raw"],
focused_chunks,
)
raw_response = client.chat(verifier_prompt["system"], verifier_prompt["user"])
parsed = parse_model_response(raw_response, mode="baseline")
except Exception as exc:
api_error = str(exc)
focused_chunks = ""
verifier_prompt = build_verifier_prompt(material, stage["triage_raw"], stage["investigator_raw"], stage["supervisor_raw"], focused_chunks)
parsed = {"is_anomaly": None, "_parse_failed": True}
latency = time.time() - start
scores = score_episode(gt, parsed, meta.get("variant", ""))
return {
"episode_id": episode.get("episode_id", ""),
"episode_path": episode_path,
"metadata": meta,
"ground_truth": gt,
"raw_response": raw_response,
"model_response": parsed,
"scores": scores,
"latency": latency,
"api_error": api_error,
"egpv2_trace": {
"case_summary": {
"event_count": material["event_count"],
"chunk_count": len(material["chunk_index"]),
"signals": material["signals"],
"protocol_notes": material["protocol_notes"],
},
"triage_raw": stage["triage_raw"],
"triage_parsed": stage["triage_parsed"],
"investigator_raw": stage["investigator_raw"],
"investigator_parsed": stage["investigator_parsed"],
"supervisor_raw": stage["supervisor_raw"],
"supervisor_parsed": stage["supervisor_parsed"],
"investigator_rounds": stage["investigator_rounds"],
"supervisor_rounds": stage["supervisor_rounds"],
"supervisor_gate": stage["supervisor_gate"],
"preview_triage_prompt": triage_prompt["user"][:2500],
"preview_verifier_prompt": verifier_prompt["user"][:2500],
},
}