282 lines
12 KiB
Python
282 lines
12 KiB
Python
import json
|
|
import re
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
from src.evaluation.scorer import parse_model_response, score_episode
|
|
|
|
from EGPv2.prompts import (
|
|
build_investigator_prompt,
|
|
build_supervisor_prompt,
|
|
build_triage_prompt,
|
|
build_verifier_prompt,
|
|
)
|
|
from EGPv2.signals import build_case_material, materialize_chunks
|
|
|
|
|
|
def _extract_jsonish(text: str) -> Dict[str, Any]:
|
|
raw = text.strip()
|
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", raw, re.DOTALL)
|
|
if match:
|
|
raw = match.group(1).strip()
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
start = raw.find("{")
|
|
end = raw.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
try:
|
|
return json.loads(raw[start : end + 1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"_parse_failed": True, "raw_preview": text[:1000]}
|
|
|
|
|
|
def _normalize_triage(parsed: Dict[str, Any]) -> Dict[str, Any]:
|
|
primary = str(parsed.get("primary_task_profile", "")).strip()
|
|
secondary = str(parsed.get("secondary_task_profile", "none")).strip()
|
|
if primary and "latent_task_profile" not in parsed:
|
|
parsed["latent_task_profile"] = primary if secondary in ("", "none") else f"{primary} | {secondary}"
|
|
return parsed
|
|
|
|
|
|
def _valid_chunk_ids(material: Dict) -> List[str]:
|
|
return [chunk["chunk_id"] for chunk in material["chunk_index"]]
|
|
|
|
|
|
def _dedupe_chunk_ids(chunk_ids: List[str], valid_ids: List[str]) -> List[str]:
|
|
valid_set = set(valid_ids)
|
|
deduped = []
|
|
for chunk_id in chunk_ids:
|
|
if chunk_id in valid_set and chunk_id not in deduped:
|
|
deduped.append(chunk_id)
|
|
return deduped
|
|
|
|
|
|
def _select_focus_ids(material: Dict, triage_parsed: Dict[str, Any], limit: int) -> List[str]:
|
|
valid_ids = _valid_chunk_ids(material)
|
|
raw_ids = triage_parsed.get("focus_chunk_ids", [])
|
|
if not isinstance(raw_ids, list):
|
|
raw_ids = []
|
|
focus_ids = _dedupe_chunk_ids(raw_ids, valid_ids)
|
|
if not focus_ids:
|
|
focus_ids = valid_ids[: min(2, len(valid_ids))]
|
|
return focus_ids[:limit]
|
|
|
|
|
|
def _merge_focus_ids(base_ids: List[str], extra_ids: List[str], valid_ids: List[str], limit: int) -> List[str]:
|
|
merged = list(base_ids)
|
|
for chunk_id in extra_ids:
|
|
if chunk_id in valid_ids and chunk_id not in merged:
|
|
merged.append(chunk_id)
|
|
return merged[:limit]
|
|
|
|
|
|
def _supervisor_decision(parsed: Dict[str, Any], round_index: int) -> str:
|
|
action = str(parsed.get("recommended_action", "")).strip().lower()
|
|
if action in {"allow_final_verdict", "refine_investigation", "abstain"}:
|
|
return action
|
|
if parsed.get("_parse_failed"):
|
|
return "allow_final_verdict"
|
|
if parsed.get("on_topic") is False:
|
|
return "refine_investigation" if round_index == 1 else "abstain"
|
|
if parsed.get("evidence_sufficient") is True:
|
|
return "allow_final_verdict"
|
|
risk = str(parsed.get("risk_of_false_alarm", "medium")).strip().lower()
|
|
if round_index == 1:
|
|
return "refine_investigation"
|
|
if risk == "high" and parsed.get("on_topic") is False:
|
|
return "abstain"
|
|
return "allow_final_verdict"
|
|
|
|
|
|
def _build_abstain_response(stage: Dict[str, Any], reason: str) -> Dict[str, Any]:
|
|
triage = stage.get("triage_parsed", {})
|
|
supervisor = stage.get("supervisor_rounds", [{}])[-1].get("parsed", {}) if stage.get("supervisor_rounds") else {}
|
|
notes = supervisor.get("supervisor_notes", [])[:2]
|
|
missing = supervisor.get("missing_checks", [])[:2]
|
|
reasoning = [
|
|
f"Supervisor gate enforced: {reason}.",
|
|
f"Primary task profile: {triage.get('primary_task_profile') or triage.get('latent_task_profile', 'unknown')}.",
|
|
]
|
|
reasoning.extend(str(note) for note in notes)
|
|
actions = ["Gather the missing evidence before issuing a high-confidence anomaly verdict."]
|
|
actions.extend(str(item) for item in missing)
|
|
return {
|
|
"is_anomaly": False,
|
|
"confidence": "low",
|
|
"threat_type": "none",
|
|
"threat_description": "Evidence remains insufficient for a supported anomaly conclusion, so the workflow defaults to a conservative non-anomaly result.",
|
|
"reasoning": reasoning,
|
|
"key_evidence": [],
|
|
"recommended_actions": actions[:3],
|
|
"_supervisor_abstain": True,
|
|
}
|
|
|
|
|
|
def evaluate_episode_with_egpv2(
|
|
episode: Dict[str, Any],
|
|
episode_path: str,
|
|
client,
|
|
chunk_size: int = 80,
|
|
max_focus_chunks: int = 5,
|
|
preview_only: bool = False,
|
|
) -> Dict[str, Any]:
|
|
meta = episode["metadata"]
|
|
gt = episode["ground_truth"]
|
|
material = build_case_material(episode, chunk_size=chunk_size)
|
|
|
|
triage_prompt = build_triage_prompt(material)
|
|
stage = {
|
|
"triage_raw": "",
|
|
"triage_parsed": {},
|
|
"investigator_raw": "",
|
|
"investigator_parsed": {},
|
|
"supervisor_raw": "",
|
|
"supervisor_parsed": {},
|
|
"investigator_rounds": [],
|
|
"supervisor_rounds": [],
|
|
"supervisor_gate": {"final_action": "", "abstained": False},
|
|
}
|
|
raw_response = ""
|
|
api_error = None
|
|
start = time.time()
|
|
|
|
if preview_only:
|
|
focused_chunks = materialize_chunks(material["chunks"], [c["chunk_id"] for c in material["chunk_index"][:3]])
|
|
verifier_prompt = build_verifier_prompt(material, "<preview>", "<preview>", "<preview>", focused_chunks)
|
|
parsed = {"is_anomaly": None, "_parse_failed": True}
|
|
else:
|
|
try:
|
|
valid_chunk_ids = _valid_chunk_ids(material)
|
|
first_round_limit = min(max_focus_chunks, 4)
|
|
second_round_limit = max_focus_chunks + 2
|
|
|
|
stage["triage_raw"] = client.chat(triage_prompt["system"], triage_prompt["user"])
|
|
stage["triage_parsed"] = _normalize_triage(_extract_jsonish(stage["triage_raw"]))
|
|
focus_ids = _select_focus_ids(material, stage["triage_parsed"], first_round_limit)
|
|
focused_chunks = materialize_chunks(material["chunks"], focus_ids, max_chunks=first_round_limit)
|
|
|
|
investigator_prompt = build_investigator_prompt(
|
|
material,
|
|
stage["triage_raw"],
|
|
focused_chunks,
|
|
round_index=1,
|
|
)
|
|
investigator_raw = client.chat(investigator_prompt["system"], investigator_prompt["user"])
|
|
investigator_parsed = _extract_jsonish(investigator_raw)
|
|
stage["investigator_rounds"].append({"round": 1, "focus_ids": focus_ids, "raw": investigator_raw, "parsed": investigator_parsed})
|
|
|
|
supervisor_prompt = build_supervisor_prompt(
|
|
material,
|
|
stage["triage_raw"],
|
|
investigator_raw,
|
|
focused_chunks,
|
|
round_index=1,
|
|
)
|
|
supervisor_raw = client.chat(supervisor_prompt["system"], supervisor_prompt["user"])
|
|
supervisor_parsed = _extract_jsonish(supervisor_raw)
|
|
stage["supervisor_rounds"].append({"round": 1, "raw": supervisor_raw, "parsed": supervisor_parsed})
|
|
|
|
decision = _supervisor_decision(supervisor_parsed, round_index=1)
|
|
final_focus_ids = list(focus_ids)
|
|
final_investigator_raw = investigator_raw
|
|
final_investigator_parsed = investigator_parsed
|
|
final_supervisor_raw = supervisor_raw
|
|
final_supervisor_parsed = supervisor_parsed
|
|
|
|
if decision == "refine_investigation":
|
|
extra_focus = supervisor_parsed.get("needs_more_chunks", [])
|
|
if not isinstance(extra_focus, list):
|
|
extra_focus = []
|
|
final_focus_ids = _merge_focus_ids(focus_ids, extra_focus, valid_chunk_ids, second_round_limit)
|
|
focused_chunks = materialize_chunks(material["chunks"], final_focus_ids, max_chunks=second_round_limit)
|
|
|
|
investigator_prompt = build_investigator_prompt(
|
|
material,
|
|
stage["triage_raw"],
|
|
focused_chunks,
|
|
supervisor_text=supervisor_raw,
|
|
round_index=2,
|
|
)
|
|
final_investigator_raw = client.chat(investigator_prompt["system"], investigator_prompt["user"])
|
|
final_investigator_parsed = _extract_jsonish(final_investigator_raw)
|
|
stage["investigator_rounds"].append({"round": 2, "focus_ids": final_focus_ids, "raw": final_investigator_raw, "parsed": final_investigator_parsed})
|
|
|
|
supervisor_prompt = build_supervisor_prompt(
|
|
material,
|
|
stage["triage_raw"],
|
|
final_investigator_raw,
|
|
focused_chunks,
|
|
round_index=2,
|
|
)
|
|
final_supervisor_raw = client.chat(supervisor_prompt["system"], supervisor_prompt["user"])
|
|
final_supervisor_parsed = _extract_jsonish(final_supervisor_raw)
|
|
stage["supervisor_rounds"].append({"round": 2, "raw": final_supervisor_raw, "parsed": final_supervisor_parsed})
|
|
decision = _supervisor_decision(final_supervisor_parsed, round_index=2)
|
|
|
|
stage["investigator_raw"] = final_investigator_raw
|
|
stage["investigator_parsed"] = final_investigator_parsed
|
|
stage["supervisor_raw"] = final_supervisor_raw
|
|
stage["supervisor_parsed"] = final_supervisor_parsed
|
|
stage["supervisor_gate"]["final_action"] = decision
|
|
|
|
if decision == "abstain":
|
|
stage["supervisor_gate"]["abstained"] = True
|
|
parsed = _build_abstain_response(stage, reason=decision or "high false-alarm risk")
|
|
verifier_prompt = build_verifier_prompt(
|
|
material,
|
|
stage["triage_raw"],
|
|
stage["investigator_raw"],
|
|
stage["supervisor_raw"],
|
|
focused_chunks,
|
|
)
|
|
else:
|
|
verifier_prompt = build_verifier_prompt(
|
|
material,
|
|
stage["triage_raw"],
|
|
stage["investigator_raw"],
|
|
stage["supervisor_raw"],
|
|
focused_chunks,
|
|
)
|
|
raw_response = client.chat(verifier_prompt["system"], verifier_prompt["user"])
|
|
parsed = parse_model_response(raw_response, mode="baseline")
|
|
except Exception as exc:
|
|
api_error = str(exc)
|
|
focused_chunks = ""
|
|
verifier_prompt = build_verifier_prompt(material, stage["triage_raw"], stage["investigator_raw"], stage["supervisor_raw"], focused_chunks)
|
|
parsed = {"is_anomaly": None, "_parse_failed": True}
|
|
|
|
latency = time.time() - start
|
|
scores = score_episode(gt, parsed, meta.get("variant", ""))
|
|
return {
|
|
"episode_id": episode.get("episode_id", ""),
|
|
"episode_path": episode_path,
|
|
"metadata": meta,
|
|
"ground_truth": gt,
|
|
"raw_response": raw_response,
|
|
"model_response": parsed,
|
|
"scores": scores,
|
|
"latency": latency,
|
|
"api_error": api_error,
|
|
"egpv2_trace": {
|
|
"case_summary": {
|
|
"event_count": material["event_count"],
|
|
"chunk_count": len(material["chunk_index"]),
|
|
"signals": material["signals"],
|
|
"protocol_notes": material["protocol_notes"],
|
|
},
|
|
"triage_raw": stage["triage_raw"],
|
|
"triage_parsed": stage["triage_parsed"],
|
|
"investigator_raw": stage["investigator_raw"],
|
|
"investigator_parsed": stage["investigator_parsed"],
|
|
"supervisor_raw": stage["supervisor_raw"],
|
|
"supervisor_parsed": stage["supervisor_parsed"],
|
|
"investigator_rounds": stage["investigator_rounds"],
|
|
"supervisor_rounds": stage["supervisor_rounds"],
|
|
"supervisor_gate": stage["supervisor_gate"],
|
|
"preview_triage_prompt": triage_prompt["user"][:2500],
|
|
"preview_verifier_prompt": verifier_prompt["user"][:2500],
|
|
},
|
|
}
|