Compare commits
2 Commits
76df34ed79
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
444d58726a | ||
|
|
0a2b344c84 |
25
README.md
25
README.md
@@ -67,14 +67,18 @@ Phenomenon → Hypothesis 的边类型与权重写死在 `HYPOTHESIS_EDGE_WEIGHT
|
||||
| **Phase 4** | TimelineAgent 用 `build_filesystem_timeline` 生成 MAC 时间线,与 Phenomenon 时间戳关联 |
|
||||
| **Phase 5** | ReportAgent 综合假设、证据、实体,生成 Markdown 报告 |
|
||||
|
||||
### Gap Analysis(Phase 3 末)
|
||||
### Investigation Areas(hypothesis-derived)
|
||||
|
||||
`config.yaml:investigation_areas` 列出必须覆盖的调查领域(系统信息、用户账户、网络配置、邮件配置、IRC 日志、PCAP、删除文件、Prefetch 等)。Orchestrator 两层判定覆盖情况:
|
||||
Phase 2 末尾 orchestrator 调一次 LLM 从所有 active hypothesis 派生 5-12 个 **InvestigationArea**(snake_case slug、description、suggested_agent、expected_keywords、expected_tools、priority、motivating_hypothesis_ids)。Areas 存进 `graph.investigation_areas`,序列化到 `runs/<ts>/investigation_areas.json`。两个用途:
|
||||
|
||||
1. **关键词匹配**(`_AREA_KEYWORDS`)— 扫现有 Phenomenon 标题/描述
|
||||
2. **工具命中**(`_AREA_TOOLS`)— 检查是否调用过该领域的关键工具(如 `enumerate_users`、`parse_pcap_strings`)
|
||||
1. **Phase 3 主循环提示** — 每个 hypothesis 块附 `Expected areas: a, b, c`,LLM 仍自由选 lead 但有软引导
|
||||
2. **Phase 3 末尾 Gap Analysis** — 两层判定覆盖情况:
|
||||
- **关键词匹配**:扫 Phenomenon 标题/描述对照 area.expected_keywords
|
||||
- **工具命中**:检查 area.expected_tools 是否实际调用过
|
||||
|
||||
未覆盖的领域自动派发 lead,最多 3 轮补漏。
|
||||
未覆盖的 area 自动派 lead(`suggested_agent` + `priority` + `motivating_hypothesis_ids[0]` 透传给 `Lead.hypothesis_id` 保留 provenance),最多 3 轮补漏。
|
||||
|
||||
**手动 override**:`config.yaml:investigation_areas` 默认注释掉,纯 LLM 派生。取消注释可添加强制必查的领域,会先于 LLM 写入并通过 slug-based dedupe 保护不被覆盖(LLM 只会 augment keyword/tool 列表)。这是跨案件/跨平台适配的关键 —— 不再 hardcode Windows-specific 领域。
|
||||
|
||||
## Agent 体系
|
||||
|
||||
@@ -183,11 +187,12 @@ max_investigation_rounds: 5 # Phase 3 最大迭代轮数
|
||||
# - title: "嫌疑人主动实施网络嗅探"
|
||||
# description: "..."
|
||||
|
||||
investigation_areas: # Gap Analysis 必须覆盖的领域
|
||||
- area: system_info
|
||||
agent: registry
|
||||
task: "..."
|
||||
# ...
|
||||
# investigation_areas: # 可选:手动 override(默认全 LLM 派生)
|
||||
# - area: shutdown_time # LLM 通过 slug dedupe 只 augment
|
||||
# agent: registry # keyword/tool 列表,不覆盖 manual
|
||||
# priority: 3
|
||||
# keywords: [shutdown]
|
||||
# tools: [get_shutdown_time]
|
||||
```
|
||||
|
||||
未配置 `hypotheses` 时由 HypothesisAgent 自动生成。
|
||||
|
||||
@@ -24,6 +24,7 @@ class HypothesisAgent(BaseAgent):
|
||||
"and formulate investigative hypotheses about what happened on this system. "
|
||||
"Your ultimate goal: build the most complete picture of events that occurred."
|
||||
)
|
||||
mandatory_record_tools = ("add_hypothesis",)
|
||||
|
||||
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
|
||||
super().__init__(llm, graph)
|
||||
@@ -68,7 +69,7 @@ class HypothesisAgent(BaseAgent):
|
||||
f"WORKFLOW:\n"
|
||||
f"1. Call list_phenomena and search_graph to review existing findings.\n"
|
||||
f"2. For each hypothesis you want to record, call add_hypothesis (title + description).\n"
|
||||
f"3. Wrap a short summary in <answer> when you have generated 3-7 hypotheses.\n\n"
|
||||
f"3. STOP after you have generated 3-7 hypotheses. Do not call any more tools.\n\n"
|
||||
f"STRICT BOUNDARIES:\n"
|
||||
f"- Your only mutation tool is add_hypothesis. Do NOT attempt list_directory, "
|
||||
f"parse_registry_key, extract_file, or any disk-image investigation tools — "
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from base_agent import BaseAgent
|
||||
from evidence_graph import EvidenceGraph
|
||||
from llm_client import LLMClient
|
||||
@@ -15,11 +12,16 @@ class ReportAgent(BaseAgent):
|
||||
role = (
|
||||
"Forensic report writer. You synthesize all findings from the investigation "
|
||||
"into a structured, professional forensic analysis report organized by hypotheses.\n\n"
|
||||
"IMPORTANT: Only include findings that have a source_tool attribution (marked VERIFIED). "
|
||||
"Only include findings that have a source_tool attribution (marked VERIFIED). "
|
||||
"If evidence lacks source attribution, mark it as UNVERIFIED. "
|
||||
"Do NOT invent or fabricate any data, timestamps, or findings not present in the evidence.\n\n"
|
||||
"CRITICAL: You MUST call save_report to write the final report."
|
||||
"Do NOT invent or fabricate any data, timestamps, or findings not present in the evidence."
|
||||
)
|
||||
# Calling save_report is BOTH the recording action and the completion
|
||||
# signal. tool_call_loop returns the moment save_report executes; the
|
||||
# tool's return value becomes the agent's final_text. The forced-retry
|
||||
# mechanism fires if save_report is never called.
|
||||
mandatory_record_tools = ("save_report",)
|
||||
terminal_tools = ("save_report",)
|
||||
|
||||
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
|
||||
super().__init__(llm, graph)
|
||||
@@ -30,23 +32,26 @@ class ReportAgent(BaseAgent):
|
||||
self._register_graph_read_tools()
|
||||
|
||||
def _build_system_prompt(self, task: str) -> str:
|
||||
"""Report agent gets a clean prompt — no Phase A/B/C/D workflow."""
|
||||
return (
|
||||
f"You are a forensic report writer.\n"
|
||||
f"Role: {self.role}\n\n"
|
||||
f"Investigation state:\n{self.graph.stats_summary()}\n\n"
|
||||
f"Your task: {task}\n\n"
|
||||
f"WORKFLOW:\n"
|
||||
f"1. Call get_hypotheses_with_evidence to get all hypotheses and their linked evidence\n"
|
||||
f"2. Call get_all_phenomena to get detailed findings by category\n"
|
||||
f"3. Call get_entities to get people, programs, and hosts\n"
|
||||
f"4. Call get_case_info for case metadata\n"
|
||||
f"5. Write the complete report directly in your <answer> block\n\n"
|
||||
f"1. Call get_hypotheses_with_evidence, get_all_phenomena, get_entities, get_case_info "
|
||||
f" to gather all the data needed for the report. Make these calls in parallel.\n"
|
||||
f"2. Assemble the complete markdown forensic report.\n"
|
||||
f"3. Call save_report(content=<full markdown>, output_path=\"report.md\").\n"
|
||||
f" This single call is the completion signal — the run ENDS the moment it executes.\n"
|
||||
f" Do NOT call any read tools after this point; they will not run.\n"
|
||||
f" Do NOT write the report as free text outside of save_report; only the\n"
|
||||
f" `content` argument of save_report is persisted.\n\n"
|
||||
f"RULES:\n"
|
||||
f"- Write the report DIRECTLY in <answer> — do NOT use save_report tool\n"
|
||||
f"- Only include findings present in the evidence graph\n"
|
||||
f"- Do NOT invent timestamps, file paths, or data not in the phenomena\n"
|
||||
f"- The report must be complete — do not cut off mid-section\n"
|
||||
f"- The report must be the complete markdown — do not cut off mid-section.\n"
|
||||
f"- Only include findings present in the evidence graph.\n"
|
||||
f"- Do NOT invent timestamps, file paths, or data not in the phenomena.\n"
|
||||
f"- The `content` argument can be 10K+ chars. JSON-escape inner quotes (\\\") and\n"
|
||||
f" backslashes (\\\\) and newlines (\\n) correctly.\n"
|
||||
)
|
||||
|
||||
def _register_tools(self) -> None:
|
||||
@@ -186,10 +191,16 @@ class ReportAgent(BaseAgent):
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _save_report(self, content: str, output_path: str) -> str:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(content)
|
||||
return f"Report saved to {output_path} ({len(content)} chars)"
|
||||
except Exception as e:
|
||||
return f"Error saving report: {e}"
|
||||
"""Save the report and return the content itself.
|
||||
|
||||
The content is returned (rather than a "saved to ..." status string)
|
||||
so that when tool_call_loop short-circuits on this terminal tool,
|
||||
`final_text` is the full markdown — orchestrator writes it to the
|
||||
canonical report.md path under runs/<ts>/.
|
||||
|
||||
The output_path argument is kept for backward compat but the model's
|
||||
chosen path is ignored — the orchestrator owns the persistence path.
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
return content
|
||||
|
||||
@@ -24,6 +24,7 @@ class TimelineAgent(BaseAgent):
|
||||
"MAC timestamps and correlate events across all phenomena categories in the "
|
||||
"evidence graph to reconstruct the sequence of activities on the system."
|
||||
)
|
||||
mandatory_record_tools = ("add_temporal_edge",)
|
||||
|
||||
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
|
||||
super().__init__(llm, graph)
|
||||
@@ -95,7 +96,7 @@ class TimelineAgent(BaseAgent):
|
||||
f" - 'Tool installation' (before) 'Tool execution'\n"
|
||||
f"4. Aim for 15-40 temporal edges that connect the major events into a "
|
||||
f"forensic story.\n"
|
||||
f"5. Wrap a short summary in <answer> when done.\n\n"
|
||||
f"5. STOP after recording all meaningful temporal edges. Do not call any more tools.\n\n"
|
||||
f"STRICT BOUNDARIES:\n"
|
||||
f"- Your job is to CONNECT existing phenomena, NOT to discover new ones. "
|
||||
f"You CANNOT call add_phenomenon — the tool isn't yours.\n"
|
||||
|
||||
@@ -31,11 +31,26 @@ class BaseAgent:
|
||||
name: str = "base"
|
||||
role: str = "A forensic analysis agent."
|
||||
|
||||
# Tools the agent MUST invoke at least once for the run to count as productive.
|
||||
# If none of these were called when tool_call_loop returns, run() fires a
|
||||
# forced retry with an explicit "you forgot to record" instruction.
|
||||
# Subclasses override to declare their own recording responsibility
|
||||
# (timeline → add_temporal_edge, hypothesis → add_hypothesis, report → save_report).
|
||||
mandatory_record_tools: tuple[str, ...] = ("add_phenomenon",)
|
||||
|
||||
# Tools whose invocation ends the run immediately. After any terminal tool
|
||||
# is called, tool_call_loop returns with that tool's result text as
|
||||
# final_text. Used by agents whose "completion" is a single explicit
|
||||
# action rather than "model decides to stop calling tools". For multi-call
|
||||
# agents (filesystem records many phenomena) leave empty.
|
||||
terminal_tools: tuple[str, ...] = ()
|
||||
|
||||
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
|
||||
self.llm = llm
|
||||
self.graph = graph
|
||||
self._tools: dict[str, dict] = {} # name -> schema
|
||||
self._executors: dict[str, Any] = {} # name -> async callable
|
||||
self._record_call_counts: dict[str, int] = {}
|
||||
self._work_log: list[str] = []
|
||||
self._current_lead_id: str | None = None
|
||||
|
||||
@@ -52,7 +67,18 @@ class BaseAgent:
|
||||
"description": description,
|
||||
"input_schema": input_schema,
|
||||
}
|
||||
self._executors[name] = executor
|
||||
if name in self.mandatory_record_tools:
|
||||
self._executors[name] = self._wrap_record_executor(name, executor)
|
||||
else:
|
||||
self._executors[name] = executor
|
||||
|
||||
def _wrap_record_executor(self, name: str, executor: Any) -> Any:
|
||||
"""Wrap a mandatory-record executor to count successful invocations."""
|
||||
async def wrapped(*args, **kwargs):
|
||||
result = await executor(*args, **kwargs)
|
||||
self._record_call_counts[name] = self._record_call_counts.get(name, 0) + 1
|
||||
return result
|
||||
return wrapped
|
||||
|
||||
def get_tool_definitions(self) -> list[dict]:
|
||||
"""Get tool definitions in Claude API format."""
|
||||
@@ -91,20 +117,19 @@ class BaseAgent:
|
||||
f" FIRST call list_phenomena to get the current IDs — do NOT rely on memory.\n"
|
||||
f" Then call link_to_entity for each relevant phenomenon.\n"
|
||||
f" NEVER guess or fabricate a phenomenon ID. If an ID is not in list_phenomena output, it does not exist.\n\n"
|
||||
f"Phase D — ANSWER:\n"
|
||||
f" Only give your <answer> AFTER completing Phases B and C.\n\n"
|
||||
f"Phase D — STOP:\n"
|
||||
f" Once all phenomena are recorded and entities linked, you are DONE.\n"
|
||||
f" Do not call any more tools. The orchestrator picks up automatically.\n\n"
|
||||
f"CRITICAL — RECORDING REQUIREMENT:\n"
|
||||
f"- Your <answer> block is DISCARDED by the orchestrator. Only graph mutations propagate.\n"
|
||||
f"- Other agents and the final report read ONLY the evidence graph "
|
||||
f"(phenomena, entities, edges).\n"
|
||||
f"- You MUST call add_phenomenon for EVERY significant finding BEFORE you end.\n"
|
||||
f"- Only graph mutations propagate to other agents and the final report.\n"
|
||||
f"- You MUST call add_phenomenon for EVERY significant finding BEFORE you stop.\n"
|
||||
f"- NEGATIVE findings count too. If you searched X (a directory, a pattern, "
|
||||
f"a registry key) and found NOTHING, that absence IS evidence — call "
|
||||
f"add_phenomenon with a 'No matches for X' title and the search scope in "
|
||||
f"raw_data. Negative findings constrain the hypothesis space and prevent "
|
||||
f"the next agent from wasting time re-searching.\n"
|
||||
f"- If you produce <answer> without having called add_phenomenon at least once, "
|
||||
f"the task is FAILED regardless of what you wrote in <answer>.\n"
|
||||
f"- If you stop without having called add_phenomenon at least once, the task "
|
||||
f"is FAILED and a forced retry will fire.\n"
|
||||
f"- Include exact file paths, inode numbers, timestamps, and the source_tool "
|
||||
f"that produced each finding.\n\n"
|
||||
f"ANTI-HALLUCINATION RULES — STRICTLY ENFORCED:\n"
|
||||
@@ -124,6 +149,7 @@ class BaseAgent:
|
||||
self._current_lead_id = lead_id
|
||||
|
||||
self._register_graph_tools()
|
||||
self._record_call_counts.clear()
|
||||
|
||||
system = self._build_system_prompt(task)
|
||||
messages = [{"role": "user", "content": task}]
|
||||
@@ -132,12 +158,60 @@ class BaseAgent:
|
||||
ph_before = len(self.graph.phenomena)
|
||||
|
||||
try:
|
||||
final_text, _ = await self.llm.tool_call_loop(
|
||||
final_text, conversation = await self.llm.tool_call_loop(
|
||||
messages=messages,
|
||||
tools=self.get_tool_definitions(),
|
||||
tool_executor=self._executors,
|
||||
system=system,
|
||||
terminal_tools=self.terminal_tools,
|
||||
)
|
||||
|
||||
# Forced-record retry: if the agent has any mandatory recording
|
||||
# tools but never invoked any of them, force one more round with
|
||||
# an explicit "you forgot to record" instruction. The mandatory
|
||||
# set is declared on the class — Timeline → add_temporal_edge,
|
||||
# Hypothesis → add_hypothesis, ReportAgent → (). For agents with
|
||||
# empty mandatory_record_tools this branch is a no-op.
|
||||
registered_mandatory = [
|
||||
t for t in self.mandatory_record_tools if t in self._executors
|
||||
]
|
||||
recorded_any = any(
|
||||
self._record_call_counts.get(t, 0) > 0
|
||||
for t in registered_mandatory
|
||||
)
|
||||
if registered_mandatory and not recorded_any:
|
||||
missing = "/".join(registered_mandatory)
|
||||
logger.warning(
|
||||
"[%s] finished without calling any of [%s] — forcing RECORD retry",
|
||||
self.name, missing,
|
||||
)
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"STOP. You produced an answer without ever calling "
|
||||
f"{missing}. Your answer is DISCARDED — only graph "
|
||||
f"mutations propagate to other agents and the final "
|
||||
f"report.\n\n"
|
||||
f"You MUST now call {missing} for every significant "
|
||||
f"finding from your prior investigation, including "
|
||||
f"exact identifiers, timestamps, and the source_tool "
|
||||
f"that produced each finding. If you genuinely found "
|
||||
f"NOTHING noteworthy, call the recording tool ONCE "
|
||||
f"with a 'No significant findings' style entry "
|
||||
f"summarizing what you searched.\n\n"
|
||||
f"Do not run more investigation tools. Just record "
|
||||
f"what you already found. Then end."
|
||||
),
|
||||
})
|
||||
final_text, _ = await self.llm.tool_call_loop(
|
||||
messages=conversation,
|
||||
tools=self.get_tool_definitions(),
|
||||
tool_executor=self._executors,
|
||||
system=system,
|
||||
max_iterations=10,
|
||||
terminal_tools=self.terminal_tools,
|
||||
)
|
||||
|
||||
self._work_log.append(f"[Task: {task[:80]}] -> {final_text[:150]}")
|
||||
except Exception:
|
||||
self.graph.agent_status[self.name] = "failed"
|
||||
|
||||
@@ -197,6 +197,41 @@ class Lead:
|
||||
return cls(**d)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvestigationArea:
|
||||
"""An area to investigate to confirm/refute one or more hypotheses.
|
||||
|
||||
Derived by the orchestrator from active hypotheses after Phase 2; also
|
||||
seeded from config.yaml:investigation_areas as an optional manual
|
||||
override. Each area carries its own keywords + expected tools so the
|
||||
gap-analysis coverage check is generic, not tied to hard-coded constants.
|
||||
"""
|
||||
|
||||
id: str # "area-{slug}"
|
||||
area: str # snake_case slug (dedupe key)
|
||||
description: str
|
||||
suggested_agent: str # filesystem / registry / communication / network / timeline
|
||||
expected_keywords: list[str] = field(default_factory=list)
|
||||
expected_tools: list[str] = field(default_factory=list)
|
||||
priority: int = 5 # 1 (highest) - 10 (lowest)
|
||||
motivating_hypothesis_ids: list[str] = field(default_factory=list)
|
||||
created_by: str = "" # "manual" | "llm_derive" | "fallback"
|
||||
created_at: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> InvestigationArea:
|
||||
return cls(**d)
|
||||
|
||||
def summary(self) -> str:
|
||||
return (
|
||||
f"[{self.area}] P{self.priority} agent={self.suggested_agent} "
|
||||
f"(motivating: {len(self.motivating_hypothesis_ids)})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedAsset:
|
||||
"""A file extracted from the disk image and tracked in the asset library."""
|
||||
@@ -270,6 +305,11 @@ class EvidenceGraph:
|
||||
self.asset_library: dict[str, ExtractedAsset] = {}
|
||||
self._inode_index: dict[str, str] = {} # inode → asset_id
|
||||
|
||||
# Investigation areas — derived from hypotheses (LLM) and/or seeded
|
||||
# from config.yaml:investigation_areas (manual override). Drives the
|
||||
# gap-analysis coverage check.
|
||||
self.investigation_areas: dict[str, InvestigationArea] = {}
|
||||
|
||||
# Set by BaseAgent.run() before each agent execution
|
||||
self._current_agent: str = ""
|
||||
|
||||
@@ -295,6 +335,9 @@ class EvidenceGraph:
|
||||
"leads": [l.to_dict() for l in self.leads],
|
||||
"agent_status": dict(self.agent_status),
|
||||
"asset_library": {aid: a.to_dict() for aid, a in self.asset_library.items()},
|
||||
"investigation_areas": {
|
||||
aid: a.to_dict() for aid, a in self.investigation_areas.items()
|
||||
},
|
||||
"saved_at": datetime.now().isoformat(),
|
||||
}
|
||||
tmp = self._persist_path.with_suffix(".tmp")
|
||||
@@ -345,6 +388,10 @@ class EvidenceGraph:
|
||||
asset = ExtractedAsset.from_dict(a_data)
|
||||
graph.asset_library[aid] = asset
|
||||
graph._inode_index[asset.inode] = aid
|
||||
graph.investigation_areas = {
|
||||
aid: InvestigationArea.from_dict(a)
|
||||
for aid, a in data.get("investigation_areas", {}).items()
|
||||
}
|
||||
graph._rebuild_adjacency()
|
||||
logger.info(
|
||||
"EvidenceGraph restored: %d phenomena, %d hypotheses, %d entities, "
|
||||
@@ -656,6 +703,57 @@ class EvidenceGraph:
|
||||
break
|
||||
self._auto_save()
|
||||
|
||||
# ---- Investigation areas -------------------------------------------------
|
||||
|
||||
async def add_investigation_area(
|
||||
self,
|
||||
area: str,
|
||||
description: str,
|
||||
suggested_agent: str,
|
||||
expected_keywords: list[str] | None = None,
|
||||
expected_tools: list[str] | None = None,
|
||||
priority: int = 5,
|
||||
motivating_hypothesis_ids: list[str] | None = None,
|
||||
created_by: str = "",
|
||||
) -> tuple[str, bool]:
|
||||
"""Add or merge an investigation area. Dedupe key is the `area` slug.
|
||||
|
||||
On collision, union the three list fields (keywords / tools /
|
||||
motivating_hypothesis_ids); description / suggested_agent / priority
|
||||
are preserved from the first writer (manual seed wins over LLM derive).
|
||||
Returns (id, was_existing).
|
||||
"""
|
||||
async with self._lock:
|
||||
for existing in self.investigation_areas.values():
|
||||
if existing.area == area:
|
||||
for kw in (expected_keywords or []):
|
||||
if kw not in existing.expected_keywords:
|
||||
existing.expected_keywords.append(kw)
|
||||
for t in (expected_tools or []):
|
||||
if t not in existing.expected_tools:
|
||||
existing.expected_tools.append(t)
|
||||
for hid in (motivating_hypothesis_ids or []):
|
||||
if hid not in existing.motivating_hypothesis_ids:
|
||||
existing.motivating_hypothesis_ids.append(hid)
|
||||
self._auto_save()
|
||||
return existing.id, True
|
||||
|
||||
aid = f"area-{area}"
|
||||
self.investigation_areas[aid] = InvestigationArea(
|
||||
id=aid,
|
||||
area=area,
|
||||
description=description,
|
||||
suggested_agent=suggested_agent,
|
||||
expected_keywords=list(expected_keywords or []),
|
||||
expected_tools=list(expected_tools or []),
|
||||
priority=priority,
|
||||
motivating_hypothesis_ids=list(motivating_hypothesis_ids or []),
|
||||
created_by=created_by,
|
||||
created_at=datetime.now().isoformat(),
|
||||
)
|
||||
self._auto_save()
|
||||
return aid, False
|
||||
|
||||
# ---- Asset library -------------------------------------------------------
|
||||
|
||||
async def register_asset(
|
||||
|
||||
532
llm_client.py
532
llm_client.py
@@ -1,9 +1,10 @@
|
||||
"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
|
||||
|
||||
Tool calling is text-based (ReAct pattern): tool schemas are embedded in
|
||||
the system prompt and tool calls are parsed as <tool_call> JSON blocks
|
||||
from model output. This keeps the protocol independent of whether the
|
||||
underlying API supports native function/tool calling.
|
||||
Tool calling uses the OpenAI-native `tools=[...]` parameter. The model
|
||||
returns structured tool_calls via the streaming protocol; we accumulate
|
||||
them, dispatch to our executors, and feed results back as `role: "tool"`
|
||||
messages. This eliminates the fragile "model writes JSON inside free
|
||||
text" problem of the previous ReAct text mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -32,69 +33,81 @@ class LLMAPIError(Exception):
|
||||
self.attempts = attempts
|
||||
|
||||
|
||||
# Markers the model uses to signal tool calls and final answers
|
||||
TOOL_CALL_TAG = "<tool_call>"
|
||||
TOOL_CALL_END = "</tool_call>"
|
||||
TOOL_RESULT_TAG = "<tool_result>"
|
||||
TOOL_RESULT_END = "</tool_result>"
|
||||
# Optional answer tags — kept for backward compat with prompts that wrap
|
||||
# their final response in <answer>...</answer>. Native tool calling does
|
||||
# not need these (no tool_calls = final), but if the model continues to
|
||||
# emit them, we strip the tags so callers see clean text.
|
||||
ANSWER_TAG = "<answer>"
|
||||
ANSWER_END = "</answer>"
|
||||
|
||||
|
||||
def _build_tools_prompt(tools: list[dict]) -> str:
|
||||
"""Format tool definitions for inclusion in the system prompt."""
|
||||
lines = ["You have access to the following tools:\n"]
|
||||
for t in tools:
|
||||
schema = t.get("input_schema", {})
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
def _to_openai_tools(tools: list[dict]) -> list[dict]:
|
||||
"""Convert internal tool definitions to OpenAI native function-tools format."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
|
||||
},
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
|
||||
params = []
|
||||
for pname, pdef in props.items():
|
||||
req = " (required)" if pname in required else ""
|
||||
desc = pdef.get("description", "")
|
||||
ptype = pdef.get("type", "string")
|
||||
enum_vals = pdef.get("enum")
|
||||
if enum_vals:
|
||||
allowed = ", ".join(f'"{v}"' for v in enum_vals)
|
||||
params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]")
|
||||
else:
|
||||
params.append(f" - {pname}: {ptype}{req} — {desc}")
|
||||
|
||||
param_block = "\n".join(params) if params else " (no parameters)"
|
||||
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
|
||||
def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None:
|
||||
"""Return the first balanced [...] or {...} substring, or None if no balanced pair.
|
||||
|
||||
lines.append(
|
||||
"## How to use tools\n"
|
||||
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
|
||||
f"{TOOL_CALL_TAG}\n"
|
||||
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
|
||||
f"{TOOL_CALL_END}\n\n"
|
||||
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
|
||||
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
|
||||
"When you have finished your analysis and have a final answer, wrap it in:\n"
|
||||
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
|
||||
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
|
||||
"You MUST call at least one tool before giving your final answer."
|
||||
Stack-based — handles nested brackets correctly (regex with .*? would
|
||||
truncate at the first inner closing bracket, regex with .* would over-eat
|
||||
trailing text). Brackets inside JSON string literals are ignored by
|
||||
callers because the caller passes the result through json.loads which
|
||||
re-parses with proper string handling.
|
||||
"""
|
||||
start = text.find(open_char)
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if c == open_char:
|
||||
depth += 1
|
||||
elif c == close_char:
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _safe_json_loads(text: str):
|
||||
"""Parse JSON with progressive sanitization for LLM-produced output.
|
||||
|
||||
Tries (0) as-is, (1) escape stray backslashes outside valid JSON escapes
|
||||
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
|
||||
input (first 600 chars) so we can diagnose what the model emitted.
|
||||
|
||||
Used by orchestrator JSON callsites (_call_llm_for_json) and by
|
||||
tool_call_loop when parsing tool_call arguments returned by the API.
|
||||
"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
stage1 = re.sub(
|
||||
r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})',
|
||||
r'\\\\',
|
||||
text,
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_tool_calls(text: str) -> list[dict]:
|
||||
"""Extract tool call JSON blocks from model output."""
|
||||
pattern = re.compile(
|
||||
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
|
||||
re.DOTALL,
|
||||
)
|
||||
calls = []
|
||||
for match in pattern.finditer(text):
|
||||
raw = match.group(1).strip()
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
calls.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
|
||||
return calls
|
||||
try:
|
||||
return json.loads(stage1)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
"_safe_json_loads failed after sanitize (%s); raw head[:600]=%r",
|
||||
e, text[:600],
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _extract_answer(text: str) -> str | None:
|
||||
@@ -236,50 +249,41 @@ _DECAY_TIERS: list[tuple[int, int]] = [
|
||||
|
||||
|
||||
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
|
||||
"""Truncate tool results in older messages to save context space.
|
||||
"""Truncate the `content` of older `role: "tool"` messages to save context.
|
||||
|
||||
Operates in-place-style on a copy. Only touches user messages that
|
||||
contain <tool_result> blocks (these are the tool-result messages
|
||||
generated by tool_call_loop).
|
||||
Each `role: "tool"` message in the conversation corresponds to one tool
|
||||
call's result. We rank these messages by recency and progressively
|
||||
truncate older ones according to `_DECAY_TIERS`.
|
||||
"""
|
||||
# Count rounds from the end. A "round" is a (assistant, user) pair.
|
||||
# messages alternate: [user, assistant, user, assistant, user, ...]
|
||||
# The initial user message is index 0, then pairs start at index 1.
|
||||
total = len(messages)
|
||||
if total <= 10: # not enough messages to bother
|
||||
if total <= 10:
|
||||
return messages
|
||||
|
||||
result = []
|
||||
# Count tool-result user messages from the end
|
||||
tool_result_indices = [
|
||||
i for i, m in enumerate(messages)
|
||||
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
|
||||
tool_msg_indices = [
|
||||
i for i, m in enumerate(messages) if m.get("role") == "tool"
|
||||
]
|
||||
|
||||
# Build a set of indices that need decay, mapped to their max_chars
|
||||
decay_map: dict[int, int] = {}
|
||||
n_tool_msgs = len(tool_result_indices)
|
||||
for rank, idx in enumerate(reversed(tool_result_indices)):
|
||||
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
|
||||
for rank, idx in enumerate(reversed(tool_msg_indices)):
|
||||
rounds_ago = rank
|
||||
for threshold, max_chars in _DECAY_TIERS:
|
||||
if rounds_ago < threshold:
|
||||
decay_map[idx] = max_chars
|
||||
break
|
||||
|
||||
result = []
|
||||
for i, msg in enumerate(messages):
|
||||
if i in decay_map:
|
||||
max_chars = decay_map[i]
|
||||
content = msg["content"]
|
||||
content = msg.get("content", "") or ""
|
||||
if len(content) > max_chars + 200:
|
||||
# Truncate but preserve the tool_result tags structure
|
||||
truncated = content[:max_chars]
|
||||
# Count how many tool results are in this message
|
||||
n_results = content.count(TOOL_RESULT_TAG)
|
||||
truncated += (
|
||||
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
|
||||
f"{n_results} tool result(s)]"
|
||||
truncated = (
|
||||
content[:max_chars]
|
||||
+ f"\n... [context compressed: {len(content)} -> {max_chars} chars]"
|
||||
)
|
||||
result.append({"role": msg["role"], "content": truncated})
|
||||
new_msg = dict(msg)
|
||||
new_msg["content"] = truncated
|
||||
result.append(new_msg)
|
||||
else:
|
||||
result.append(msg)
|
||||
else:
|
||||
@@ -404,6 +408,95 @@ class LLMClient:
|
||||
|
||||
return ""
|
||||
|
||||
async def _chat_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
openai_tools: list[dict],
|
||||
max_retries: int = 5,
|
||||
) -> tuple[str, str | None, list[dict]]:
|
||||
"""Stream a chat completion with native tool calling enabled.
|
||||
|
||||
Returns:
|
||||
(text_content, reasoning_content, raw_tool_calls).
|
||||
- reasoning_content is non-None when DeepSeek thinking mode is
|
||||
active; the caller MUST echo it back in the assistant message
|
||||
on subsequent requests, or the API returns HTTP 400.
|
||||
- raw_tool_calls is a list of {"id","name","arguments"} dicts;
|
||||
arguments is the raw JSON string returned by the API.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": True,
|
||||
"tools": openai_tools,
|
||||
}
|
||||
if self.reasoning_effort:
|
||||
kwargs["reasoning_effort"] = self.reasoning_effort
|
||||
if self.thinking_enabled:
|
||||
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
logger.debug(
|
||||
"LLM request (stream+tools): %d messages, %d tools (attempt %d)",
|
||||
len(messages), len(openai_tools), attempt + 1,
|
||||
)
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments}
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
text_parts.append(delta.content)
|
||||
# DeepSeek thinking-mode: reasoning_content is returned
|
||||
# alongside content and MUST be echoed back on subsequent
|
||||
# requests, otherwise the API rejects with HTTP 400.
|
||||
rc = getattr(delta, "reasoning_content", None)
|
||||
if rc:
|
||||
reasoning_parts.append(rc)
|
||||
if delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index
|
||||
entry = tool_calls_acc.setdefault(
|
||||
idx, {"id": None, "name": None, "arguments": ""},
|
||||
)
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
fn = tc_delta.function
|
||||
if fn:
|
||||
if fn.name:
|
||||
entry["name"] = fn.name
|
||||
if fn.arguments:
|
||||
entry["arguments"] += fn.arguments
|
||||
|
||||
text = "".join(text_parts)
|
||||
reasoning = "".join(reasoning_parts) or None
|
||||
ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
|
||||
logger.debug(
|
||||
"LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls",
|
||||
len(text), len(reasoning or ""), len(ordered),
|
||||
)
|
||||
return text, reasoning, ordered
|
||||
|
||||
except (APIConnectionError, APITimeoutError, APIError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait = 2 ** attempt * 10
|
||||
logger.warning(
|
||||
"Tool-call request failed (%s), retrying in %ds...", e, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
raise LLMAPIError(
|
||||
f"LLM API unreachable after {max_retries} attempts: {e}",
|
||||
attempts=max_retries,
|
||||
) from e
|
||||
|
||||
return "", None, []
|
||||
|
||||
async def tool_call_loop(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -411,87 +504,159 @@ class LLMClient:
|
||||
tool_executor: dict[str, Any],
|
||||
system: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
terminal_tools: tuple[str, ...] = (),
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Run a ReAct-style tool-calling loop.
|
||||
"""Run a tool-calling loop using OpenAI-native tool calls.
|
||||
|
||||
The model outputs <tool_call> blocks which we parse and execute,
|
||||
feeding results back as <tool_result> blocks until the model
|
||||
outputs an <answer> block.
|
||||
The model returns structured `tool_calls` in its message; we
|
||||
dispatch them through our executor dict and feed each result back
|
||||
as a `role: "tool"` message with the matching `tool_call_id`. The
|
||||
loop ends when:
|
||||
- the model returns a message with no tool_calls (normal exit), or
|
||||
- any tool in `terminal_tools` is called — in that case, the loop
|
||||
short-circuits with that tool's result text as final_text. This
|
||||
gives agents (notably ReportAgent) an explicit completion signal
|
||||
that the old `<answer>` text tag used to provide.
|
||||
|
||||
Returns:
|
||||
(final_text, all_messages)
|
||||
(final_text, full_message_history)
|
||||
"""
|
||||
# Build system prompt with tool definitions
|
||||
tools_prompt = _build_tools_prompt(tools)
|
||||
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
|
||||
terminal_set = set(terminal_tools)
|
||||
openai_tools = _to_openai_tools(tools)
|
||||
|
||||
messages = list(messages) # don't mutate caller's list
|
||||
_folded = False # Track whether we've already folded once this loop
|
||||
# The caller may pass `messages` either as raw conversation (no system)
|
||||
# together with `system=...`, OR as a complete history that already
|
||||
# starts with the system message (retry path). Accept both shapes.
|
||||
if messages and messages[0].get("role") == "system":
|
||||
full_messages: list[dict] = list(messages)
|
||||
else:
|
||||
full_messages = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
_folded = False
|
||||
|
||||
for i in range(max_iterations):
|
||||
for _i in range(max_iterations):
|
||||
# ── Context compression before each API call ──────────────
|
||||
# Stage A: progressively decay old tool results
|
||||
messages = _apply_progressive_decay(messages)
|
||||
|
||||
# Stage B: fold oldest messages into LLM summary if too long
|
||||
if not _folded and len(messages) > _FOLD_THRESHOLD:
|
||||
messages = await self._fold_old_messages(messages, full_system)
|
||||
full_messages = _apply_progressive_decay(full_messages)
|
||||
if not _folded and len(full_messages) > _FOLD_THRESHOLD:
|
||||
full_messages = await self._fold_old_messages(full_messages)
|
||||
_folded = True
|
||||
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
||||
# Allow a second fold if messages grew back significantly
|
||||
messages = await self._fold_old_messages(messages, full_system)
|
||||
elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
||||
full_messages = await self._fold_old_messages(full_messages)
|
||||
|
||||
text = await self.chat(messages, system=full_system)
|
||||
text, reasoning, raw_tool_calls = await self._chat_with_tools(
|
||||
full_messages, openai_tools,
|
||||
)
|
||||
|
||||
# Check for final answer
|
||||
answer = _extract_answer(text)
|
||||
if answer is not None:
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
return answer, messages
|
||||
if not raw_tool_calls:
|
||||
# Model produced a final response. Strip optional <answer>
|
||||
# tags for backward compatibility with old prompts.
|
||||
final_msg: dict[str, Any] = {"role": "assistant", "content": text}
|
||||
if reasoning:
|
||||
final_msg["reasoning_content"] = reasoning
|
||||
full_messages.append(final_msg)
|
||||
answer = _extract_answer(text)
|
||||
return (answer if answer is not None else text), full_messages
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = _extract_tool_calls(text)
|
||||
# Parse arguments + build internal call dicts
|
||||
parsed_calls: list[dict] = []
|
||||
for rc in raw_tool_calls:
|
||||
args_str = rc.get("arguments", "") or ""
|
||||
try:
|
||||
args = _safe_json_loads(args_str) if args_str.strip() else {}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"Failed to parse arguments for tool %s: %s",
|
||||
rc.get("name"), e,
|
||||
)
|
||||
args = {}
|
||||
parsed_calls.append({
|
||||
"id": rc.get("id"),
|
||||
"name": rc.get("name", ""),
|
||||
"arguments": args,
|
||||
})
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls and no answer tag — treat entire text as answer
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
return text, messages
|
||||
# Append the assistant turn with the raw tool_calls (and the
|
||||
# DeepSeek-mandated reasoning_content echo-back), then execute.
|
||||
asst_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": text or None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": rc.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": rc.get("name", ""),
|
||||
"arguments": rc.get("arguments", "") or "",
|
||||
},
|
||||
}
|
||||
for rc in raw_tool_calls
|
||||
],
|
||||
}
|
||||
if reasoning:
|
||||
asst_msg["reasoning_content"] = reasoning
|
||||
full_messages.append(asst_msg)
|
||||
|
||||
# Execute tool calls — read-only tools run in parallel
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
|
||||
result_parts = []
|
||||
batches = _partition_tool_calls(tool_calls)
|
||||
batches = _partition_tool_calls(parsed_calls)
|
||||
t_batch_start = time.monotonic()
|
||||
|
||||
# Each entry: (tool_call_dict, raw_result, formatted_for_llm)
|
||||
executed: list[tuple[dict, str, str]] = []
|
||||
for batch in batches:
|
||||
if batch.is_read_only and len(batch.calls) > 1:
|
||||
batch_results = await self._execute_tool_batch_parallel(
|
||||
results = await self._execute_tool_batch_parallel(
|
||||
batch.calls, tool_executor, tools,
|
||||
)
|
||||
result_parts.extend(batch_results)
|
||||
for tc, (raw, formatted) in zip(batch.calls, results):
|
||||
executed.append((tc, raw, formatted))
|
||||
else:
|
||||
for tc in batch.calls:
|
||||
result_parts.append(
|
||||
await self._execute_single_tool(tc, tool_executor, tools)
|
||||
raw, formatted = await self._execute_single_tool(
|
||||
tc, tool_executor, tools,
|
||||
)
|
||||
executed.append((tc, raw, formatted))
|
||||
|
||||
# Emit folded tool-call summary for the terminal
|
||||
t_batch_elapsed = time.monotonic() - t_batch_start
|
||||
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
|
||||
_emit_tool_call_summary(parsed_calls, t_batch_elapsed)
|
||||
|
||||
# Feed results back as a user message
|
||||
result_message = "\n\n".join(result_parts)
|
||||
messages.append({"role": "user", "content": result_message})
|
||||
# Append formatted tool results to the conversation (this is
|
||||
# what the LLM sees on subsequent rounds — truncated for context
|
||||
# economy).
|
||||
for tc, _raw, formatted in executed:
|
||||
full_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"content": formatted,
|
||||
})
|
||||
|
||||
# Terminal-tool short-circuit: if the model called any tool in
|
||||
# `terminal_tools`, end the loop immediately. The terminal tool's
|
||||
# RAW result (untruncated) becomes final_text — the LLM may have
|
||||
# produced a 20K-char report via save_report and we must not
|
||||
# truncate it just because the LLM-facing copy is truncated.
|
||||
if terminal_set:
|
||||
for tc, raw, _formatted in executed:
|
||||
name = tc.get("name", "")
|
||||
if name in terminal_set:
|
||||
logger.info(
|
||||
"Terminal tool %s called — exiting tool_call_loop", name,
|
||||
)
|
||||
return raw, full_messages
|
||||
|
||||
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
|
||||
return "[Max tool call iterations reached]", messages
|
||||
return "[Max tool call iterations reached]", full_messages
|
||||
|
||||
async def _execute_single_tool(
|
||||
self, tc: dict, tool_executor: dict[str, Any],
|
||||
tools: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Execute a single tool call and return the formatted result."""
|
||||
) -> tuple[str, str]:
|
||||
"""Execute a single tool call.
|
||||
|
||||
Returns (raw_result, formatted_for_llm). `raw_result` is the
|
||||
unmodified executor return (used by terminal-tool short-circuit as
|
||||
final_text). `formatted_for_llm` is `[tool_name] {truncated}` and
|
||||
is what gets fed back to the model as the tool message content.
|
||||
"""
|
||||
tool_name = tc.get("name", "")
|
||||
tool_args = tc.get("arguments", {})
|
||||
|
||||
@@ -502,72 +667,106 @@ class LLMClient:
|
||||
|
||||
executor = tool_executor.get(tool_name)
|
||||
if executor is None:
|
||||
result_text = f"Error: unknown tool '{tool_name}'"
|
||||
raw = f"Error: unknown tool '{tool_name}'"
|
||||
else:
|
||||
try:
|
||||
result_text = await executor(**tool_args)
|
||||
raw = await executor(**tool_args)
|
||||
except Exception as e:
|
||||
logger.error("Tool %s failed: %s", tool_name, e)
|
||||
result_text = f"Error executing {tool_name}: {e}"
|
||||
raw = f"Error executing {tool_name}: {e}"
|
||||
|
||||
return (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
)
|
||||
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
|
||||
return raw, formatted
|
||||
|
||||
async def _execute_tool_batch_parallel(
|
||||
self, calls: list[dict], tool_executor: dict[str, Any],
|
||||
tools: list[dict] | None = None,
|
||||
) -> list[str]:
|
||||
"""Execute multiple read-only tool calls concurrently."""
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Execute multiple read-only tool calls concurrently.
|
||||
|
||||
Returns a list of (raw_result, formatted_for_llm) tuples in the
|
||||
same order as `calls`.
|
||||
"""
|
||||
logger.info("Executing %d read-only tools in parallel", len(calls))
|
||||
|
||||
async def _run_one(tc: dict) -> str:
|
||||
async def _run_one(tc: dict) -> tuple[str, str]:
|
||||
tool_name = tc.get("name", "")
|
||||
tool_args = tc.get("arguments", {})
|
||||
if tools:
|
||||
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
||||
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
||||
logger.info(
|
||||
"Calling tool (parallel): %s(%s)",
|
||||
tool_name, json.dumps(tool_args, ensure_ascii=False),
|
||||
)
|
||||
executor = tool_executor.get(tool_name)
|
||||
if executor is None:
|
||||
result_text = f"Error: unknown tool '{tool_name}'"
|
||||
raw = f"Error: unknown tool '{tool_name}'"
|
||||
else:
|
||||
try:
|
||||
result_text = await executor(**tool_args)
|
||||
raw = await executor(**tool_args)
|
||||
except Exception as e:
|
||||
logger.error("Tool %s failed: %s", tool_name, e)
|
||||
result_text = f"Error executing {tool_name}: {e}"
|
||||
return (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
)
|
||||
raw = f"Error executing {tool_name}: {e}"
|
||||
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
|
||||
return raw, formatted
|
||||
|
||||
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
|
||||
return list(results)
|
||||
|
||||
async def _fold_old_messages(
|
||||
self, messages: list[dict], system: str,
|
||||
self, messages: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Fold old messages into an LLM-generated summary (Stage B).
|
||||
|
||||
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
|
||||
replaces earlier ones with a single summary message.
|
||||
Preserves the leading system message (if any), keeps the most
|
||||
recent _FOLD_KEEP_RECENT messages intact, and replaces the older
|
||||
middle slice with a single summary user message.
|
||||
"""
|
||||
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
|
||||
# Pin the system message — it must NEVER be summarized away.
|
||||
system_msgs: list[dict] = []
|
||||
body = messages
|
||||
if messages and messages[0].get("role") == "system":
|
||||
system_msgs = [messages[0]]
|
||||
body = messages[1:]
|
||||
|
||||
n_to_fold = len(body) - _FOLD_KEEP_RECENT
|
||||
if n_to_fold <= 2:
|
||||
return messages
|
||||
|
||||
old_messages = messages[:n_to_fold]
|
||||
recent_messages = messages[n_to_fold:]
|
||||
# Pull the fold boundary forward so we never split an assistant turn
|
||||
# from its matching tool results. The API rejects (HTTP 400) any
|
||||
# `role: "tool"` message that does not immediately follow an
|
||||
# `assistant` message with `tool_calls`. We walk the boundary into
|
||||
# `recent_messages` while its head is a `role: "tool"` message, or
|
||||
# while the prior `recent` message is `assistant{tool_calls}` whose
|
||||
# paired tools span the boundary.
|
||||
while n_to_fold < len(body):
|
||||
head = body[n_to_fold]
|
||||
if head.get("role") == "tool":
|
||||
n_to_fold += 1
|
||||
continue
|
||||
break
|
||||
|
||||
if n_to_fold >= len(body):
|
||||
# Everything got folded — nothing recent to keep.
|
||||
return system_msgs + [body[0]] if system_msgs else messages
|
||||
|
||||
old_messages = body[:n_to_fold]
|
||||
recent_messages = body[n_to_fold:]
|
||||
|
||||
# Build a text dump of old messages for summarization
|
||||
old_text_parts = []
|
||||
for msg in old_messages:
|
||||
role = msg["role"]
|
||||
content = msg.get("content", "")
|
||||
# Truncate each message for the summary prompt to avoid overload
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content") or ""
|
||||
# Render tool_calls (assistant turn) compactly.
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
tc_names = [
|
||||
tc.get("function", {}).get("name", "?")
|
||||
for tc in msg["tool_calls"]
|
||||
]
|
||||
content = (content + " " if content else "") + (
|
||||
"called: " + ", ".join(tc_names)
|
||||
)
|
||||
if len(content) > 1000:
|
||||
content = content[:1000] + "..."
|
||||
old_text_parts.append(f"[{role}]: {content}")
|
||||
@@ -591,7 +790,6 @@ class LLMClient:
|
||||
logger.warning("Context folding failed: %s — keeping original messages", e)
|
||||
return messages
|
||||
|
||||
# Replace old messages with a single summary
|
||||
summary_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
@@ -599,4 +797,4 @@ class LLMClient:
|
||||
f"messages in this conversation]\n\n{summary}"
|
||||
),
|
||||
}
|
||||
return [summary_message] + recent_messages
|
||||
return system_msgs + [summary_message] + recent_messages
|
||||
|
||||
295
orchestrator.py
295
orchestrator.py
@@ -12,41 +12,12 @@ from pathlib import Path
|
||||
|
||||
from agent_factory import AgentFactory
|
||||
from evidence_graph import EvidenceGraph
|
||||
from llm_client import LLMClient
|
||||
from llm_client import LLMClient, _extract_first_balanced, _safe_json_loads
|
||||
from tool_registry import TOOL_CATALOG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_json_loads(text: str):
|
||||
"""Parse JSON with progressive sanitization for LLM-produced output.
|
||||
|
||||
Tries: (0) as-is, (1) escape stray backslashes outside valid JSON
|
||||
escapes (\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure,
|
||||
logs raw input (first 600 chars) so we can diagnose what the model
|
||||
actually emitted.
|
||||
"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Escape backslashes not followed by a valid JSON escape character.
|
||||
# NOTE: \\u must be followed by exactly 4 hex digits to be valid.
|
||||
stage1 = re.sub(
|
||||
r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})',
|
||||
r'\\\\',
|
||||
text,
|
||||
)
|
||||
try:
|
||||
return json.loads(stage1)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
"_safe_json_loads failed after sanitize (%s); raw head[:600]=%r",
|
||||
e, text[:600],
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _log(msg: str, **extra) -> None:
|
||||
"""Emit a structured log message with extra fields for the terminal formatter."""
|
||||
logger.info(msg, extra=extra)
|
||||
@@ -123,6 +94,14 @@ class Orchestrator:
|
||||
"Omit phenomena that are unrelated. Be conservative — only link genuinely relevant evidence."
|
||||
)
|
||||
|
||||
_AREA_DERIVE_SYSTEM = (
|
||||
"You are a forensic investigation strategist. Given a set of hypotheses, "
|
||||
"decompose them into a minimal aggregate set of investigation areas. An "
|
||||
"area is a focused, concrete question with the keywords and tool names an "
|
||||
"answering phenomenon would mention. Aggregate aggressively — when two "
|
||||
"hypotheses share an area, emit it once and list both hypothesis_ids."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLMClient,
|
||||
@@ -246,6 +225,125 @@ class Orchestrator:
|
||||
"The ultimate goal is to reconstruct a detailed timeline of what happened on this host."
|
||||
)
|
||||
|
||||
# ---- Investigation areas (manual seed + LLM derive) ----------------------
|
||||
|
||||
_VALID_AGENT_TYPES = {"filesystem", "registry", "communication", "network", "timeline"}
|
||||
|
||||
async def _seed_manual_investigation_areas(self) -> None:
|
||||
"""Import config.yaml:investigation_areas entries (manual override).
|
||||
|
||||
Run early in Phase 2 so manual entries are in the graph before LLM
|
||||
derivation; LLM derive then augments via slug-based dedupe.
|
||||
"""
|
||||
for entry in self.config.get("investigation_areas", []):
|
||||
area = entry.get("area")
|
||||
if not area:
|
||||
continue
|
||||
await self.graph.add_investigation_area(
|
||||
area=area,
|
||||
description=entry.get("description", entry.get("task", "")),
|
||||
suggested_agent=entry.get("agent", "filesystem"),
|
||||
expected_keywords=entry.get("keywords", []),
|
||||
expected_tools=entry.get("tools", []),
|
||||
priority=entry.get("priority", 3),
|
||||
created_by="manual",
|
||||
)
|
||||
|
||||
async def _derive_investigation_areas(self) -> None:
|
||||
"""Ask LLM to derive investigation areas from active hypotheses.
|
||||
|
||||
Manual-seeded or already-populated graph (resume) → no-op. On LLM
|
||||
failure or empty output, falls back to one-area-per-hypothesis.
|
||||
"""
|
||||
if self.graph.investigation_areas:
|
||||
return
|
||||
active = [h for h in self.graph.hypotheses.values() if h.status == "active"]
|
||||
if not active:
|
||||
return
|
||||
|
||||
available_tools = sorted(TOOL_CATALOG.keys())
|
||||
hyp_lines = "\n".join(
|
||||
f" [{h.id}] {h.title}: {h.description}" for h in active
|
||||
)
|
||||
prompt = (
|
||||
f"Active hypotheses:\n{hyp_lines}\n\n"
|
||||
f"Available agents: {sorted(self._VALID_AGENT_TYPES)}\n"
|
||||
f"Available tool names (pick 1-3 per area for expected_tools): {available_tools}\n\n"
|
||||
f"Emit 5-12 distinct investigation areas covering the FULL hypothesis set.\n"
|
||||
f"Each area must include:\n"
|
||||
f" - area: snake_case slug (dedupe key)\n"
|
||||
f" - description: one sentence on what to find\n"
|
||||
f" - suggested_agent: one of the agents above\n"
|
||||
f" - expected_keywords: 3-8 lowercase tokens that an answering phenomenon would mention\n"
|
||||
f" - expected_tools: 1-3 tool names from the list above\n"
|
||||
f" - priority: 1 (highest) to 10\n"
|
||||
f" - motivating_hypothesis_ids: at least one [hyp-xxx] from above\n\n"
|
||||
f"Aggregate aggressively — when two hypotheses share an area, emit it ONCE "
|
||||
f"and list both ids in motivating_hypothesis_ids.\n\n"
|
||||
f"Respond ONLY with JSON:\n"
|
||||
f'[{{"area":"...","description":"...","suggested_agent":"...",'
|
||||
f'"expected_keywords":[...],"expected_tools":[...],"priority":1-10,'
|
||||
f'"motivating_hypothesis_ids":["hyp-xxx"]}}]'
|
||||
)
|
||||
|
||||
try:
|
||||
items = await self._call_llm_for_json(
|
||||
system=self._AREA_DERIVE_SYSTEM,
|
||||
user_prompt=prompt,
|
||||
schema="array",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Area derivation LLM failed: %s — falling back", e)
|
||||
await self._derive_areas_fallback(active)
|
||||
return
|
||||
|
||||
valid_hyp_ids = set(self.graph.hypotheses.keys())
|
||||
for it in items:
|
||||
area = it.get("area", "").strip()
|
||||
if not area:
|
||||
continue
|
||||
agent = it.get("suggested_agent", "filesystem")
|
||||
if agent not in self._VALID_AGENT_TYPES:
|
||||
agent = AGENT_ALIASES.get(agent, "filesystem")
|
||||
tools = [t for t in it.get("expected_tools", []) if t in TOOL_CATALOG]
|
||||
motivating = [
|
||||
h for h in it.get("motivating_hypothesis_ids", [])
|
||||
if h in valid_hyp_ids
|
||||
]
|
||||
priority = max(1, min(10, int(it.get("priority", 5))))
|
||||
await self.graph.add_investigation_area(
|
||||
area=area,
|
||||
description=it.get("description", ""),
|
||||
suggested_agent=agent,
|
||||
expected_keywords=[
|
||||
str(kw).lower() for kw in it.get("expected_keywords", [])
|
||||
],
|
||||
expected_tools=tools,
|
||||
priority=priority,
|
||||
motivating_hypothesis_ids=motivating,
|
||||
created_by="llm_derive",
|
||||
)
|
||||
|
||||
if not self.graph.investigation_areas:
|
||||
await self._derive_areas_fallback(active)
|
||||
|
||||
async def _derive_areas_fallback(self, active: list) -> None:
|
||||
"""One area per active hypothesis as a minimal safety net."""
|
||||
for h in active:
|
||||
slug = re.sub(r"[^a-z0-9_]+", "_", h.title.lower())[:40].strip("_")
|
||||
if not slug:
|
||||
slug = h.id.replace("-", "_")
|
||||
await self.graph.add_investigation_area(
|
||||
area=slug,
|
||||
description=h.title,
|
||||
suggested_agent="filesystem",
|
||||
expected_keywords=h.title.lower().split()[:6],
|
||||
expected_tools=[],
|
||||
priority=5,
|
||||
motivating_hypothesis_ids=[h.id],
|
||||
created_by="fallback",
|
||||
)
|
||||
|
||||
# ---- LLM JSON helper -----------------------------------------------------
|
||||
|
||||
async def _call_llm_for_json(
|
||||
@@ -264,13 +362,12 @@ class Orchestrator:
|
||||
"""
|
||||
error_hint = ""
|
||||
last_err: Exception | None = None
|
||||
pattern = r'\[.*?\]' if schema == "array" else r'\{.*?\}'
|
||||
open_c, close_c = ('[', ']') if schema == "array" else ('{', '}')
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
messages = [{"role": "user", "content": user_prompt + error_hint}]
|
||||
response = await self.llm.chat(messages=messages, system=system)
|
||||
match = re.search(pattern, response, re.DOTALL)
|
||||
candidate = match.group() if match else response
|
||||
candidate = _extract_first_balanced(response, open_c, close_c) or response
|
||||
try:
|
||||
return _safe_json_loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
@@ -308,10 +405,19 @@ class Orchestrator:
|
||||
existing = "\n".join(
|
||||
f" - {r['node']} [{r['edge_type']}]" for r in related
|
||||
) or " (none yet)"
|
||||
related_areas = [
|
||||
a.area for a in self.graph.investigation_areas.values()
|
||||
if hyp.id in a.motivating_hypothesis_ids
|
||||
]
|
||||
expected_line = (
|
||||
f" Expected areas to investigate: {', '.join(related_areas)}\n"
|
||||
if related_areas else ""
|
||||
)
|
||||
hyp_blocks.append(
|
||||
f"Hypothesis [{hyp.id}]: {hyp.title}\n"
|
||||
f" Description: {hyp.description}\n"
|
||||
f" Current confidence: {hyp.confidence:.2f}\n"
|
||||
f"{expected_line}"
|
||||
f" Existing evidence:\n{existing}"
|
||||
)
|
||||
|
||||
@@ -358,12 +464,21 @@ class Orchestrator:
|
||||
existing_evidence = "\n".join(
|
||||
f" - {r['node']} [{r['edge_type']}]" for r in related
|
||||
) or " (none yet)"
|
||||
related_areas = [
|
||||
a.area for a in self.graph.investigation_areas.values()
|
||||
if hyp.id in a.motivating_hypothesis_ids
|
||||
]
|
||||
expected_line = (
|
||||
f"Expected areas to investigate: {', '.join(related_areas)}\n\n"
|
||||
if related_areas else ""
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"Hypothesis: {hyp.title}\n"
|
||||
f"Description: {hyp.description}\n"
|
||||
f"Current confidence: {hyp.confidence:.2f}\n\n"
|
||||
f"Existing evidence linked to this hypothesis:\n{existing_evidence}\n\n"
|
||||
f"{expected_line}"
|
||||
f"What additional evidence should we look for to CONFIRM or DENY this hypothesis?\n"
|
||||
f"List 1-3 specific, actionable investigation tasks.\n"
|
||||
f"For each, specify which agent type should handle it: "
|
||||
@@ -494,74 +609,57 @@ class Orchestrator:
|
||||
|
||||
# ---- Gap analysis (coverage check) ---------------------------------------
|
||||
|
||||
_AREA_KEYWORDS: dict[str, list[str]] = {
|
||||
"system_info": ["install date", "registered owner", "product name", "windows xp", "system information"],
|
||||
"user_accounts": ["user account", "enumerate", "sam hive", "administrator", "mr. evil"],
|
||||
"shutdown_time": ["shutdown"],
|
||||
"network_config": ["network interface", "network adapter", "ip address", "dhcp", "mac address", "network config"],
|
||||
"installed_software": ["installed software", "program files", "installed program"],
|
||||
"email_config": ["smtp", "pop3", "nntp", "email account", "email config"],
|
||||
"chat_logs": ["irc", "mirc", "chat log", "channel"],
|
||||
"network_activity": ["packet capture", "pcap", "interception", "http request", "user-agent"],
|
||||
"deleted_files": ["deleted file", "recycle", "recycler"],
|
||||
"execution_evidence": ["prefetch", "execution", "run count", "last execution"],
|
||||
}
|
||||
def _check_coverage(self) -> set[str]:
|
||||
"""Return slugs of investigation_areas already covered by phenomena.
|
||||
|
||||
# Deterministic coverage: if the canonical tool was called, the area is covered.
|
||||
_AREA_TOOLS: dict[str, list[str]] = {
|
||||
"system_info": ["get_system_info"],
|
||||
"user_accounts": ["enumerate_users"],
|
||||
"shutdown_time": ["get_shutdown_time"],
|
||||
"network_config": ["get_network_interfaces"],
|
||||
"installed_software": ["list_installed_software"],
|
||||
"email_config": ["get_email_config"],
|
||||
"network_activity": ["parse_pcap_strings"],
|
||||
"deleted_files": ["count_deleted_files"],
|
||||
"execution_evidence": ["parse_prefetch"],
|
||||
}
|
||||
Layer A: any expected_keyword found in evidence text (category +
|
||||
title + description, lowercased).
|
||||
Layer B: any expected_tool present in the source_tool set of recorded
|
||||
phenomena (deterministic — the canonical tool was actually called).
|
||||
"""
|
||||
evidence_text = " ".join(
|
||||
f"{ph.category} {ph.title} {ph.description}".lower()
|
||||
for ph in self.graph.phenomena.values()
|
||||
)
|
||||
used_tools: set[str] = {
|
||||
ph.source_tool for ph in self.graph.phenomena.values() if ph.source_tool
|
||||
}
|
||||
|
||||
def _check_coverage(self, areas: list[dict]) -> set[str]:
|
||||
# Layer 1: keyword matching on category + title + description
|
||||
evidence_text = ""
|
||||
for ph in self.graph.phenomena.values():
|
||||
evidence_text += f" {ph.category} {ph.title} {ph.description} ".lower()
|
||||
|
||||
# Layer 2: collect all source_tools that produced phenomena
|
||||
used_tools: set[str] = {ph.source_tool for ph in self.graph.phenomena.values() if ph.source_tool}
|
||||
|
||||
covered = set()
|
||||
for area in areas:
|
||||
area_name = area["area"]
|
||||
# Check keywords
|
||||
keywords = self._AREA_KEYWORDS.get(area_name, [])
|
||||
if any(kw in evidence_text for kw in keywords):
|
||||
covered.add(area_name)
|
||||
covered: set[str] = set()
|
||||
for a in self.graph.investigation_areas.values():
|
||||
if any(kw.lower() in evidence_text for kw in a.expected_keywords):
|
||||
covered.add(a.area)
|
||||
continue
|
||||
# Check source_tool
|
||||
area_tools = self._AREA_TOOLS.get(area_name, [])
|
||||
if any(tool in used_tools for tool in area_tools):
|
||||
covered.add(area_name)
|
||||
if any(t in used_tools for t in a.expected_tools):
|
||||
covered.add(a.area)
|
||||
return covered
|
||||
|
||||
async def _run_gap_analysis(self) -> None:
|
||||
areas = self.config.get("investigation_areas", [])
|
||||
areas = list(self.graph.investigation_areas.values())
|
||||
if not areas:
|
||||
return
|
||||
|
||||
covered = self._check_coverage(areas)
|
||||
uncovered = [a for a in areas if a["area"] not in covered]
|
||||
covered = self._check_coverage()
|
||||
uncovered = [a for a in areas if a.area not in covered]
|
||||
|
||||
if not uncovered:
|
||||
_log(f"All {len(areas)} investigation areas covered", event="progress")
|
||||
return
|
||||
|
||||
uncovered_names = ", ".join(a["area"] for a in uncovered)
|
||||
_log(f"{len(uncovered)}/{len(areas)} areas uncovered: {uncovered_names}", event="dispatch")
|
||||
for area in uncovered:
|
||||
uncovered_names = ", ".join(a.area for a in uncovered)
|
||||
_log(
|
||||
f"{len(uncovered)}/{len(areas)} areas uncovered: {uncovered_names}",
|
||||
event="dispatch",
|
||||
)
|
||||
for a in uncovered:
|
||||
await self.graph.add_lead(
|
||||
target_agent=area["agent"],
|
||||
description=area["task"],
|
||||
priority=3,
|
||||
target_agent=a.suggested_agent,
|
||||
description=a.description,
|
||||
priority=a.priority,
|
||||
hypothesis_id=(
|
||||
a.motivating_hypothesis_ids[0]
|
||||
if a.motivating_hypothesis_ids else None
|
||||
),
|
||||
)
|
||||
|
||||
for round_num in range(3):
|
||||
@@ -608,6 +706,15 @@ class Orchestrator:
|
||||
json.dumps(leads_data, ensure_ascii=False, indent=2)
|
||||
)
|
||||
|
||||
# Investigation areas export
|
||||
areas_data = {
|
||||
aid: a.to_dict()
|
||||
for aid, a in self.graph.investigation_areas.items()
|
||||
}
|
||||
(self.run_dir / "investigation_areas.json").write_text(
|
||||
json.dumps(areas_data, ensure_ascii=False, indent=2)
|
||||
)
|
||||
|
||||
# Run metadata
|
||||
end_time = datetime.now()
|
||||
metadata = {
|
||||
@@ -667,6 +774,11 @@ class Orchestrator:
|
||||
if resume_phase <= 2:
|
||||
_log("Phase 2: Hypothesis Generation", event="phase")
|
||||
t0 = time.monotonic()
|
||||
|
||||
# Seed manual investigation areas (if any) BEFORE LLM derive,
|
||||
# so manual entries win the dedupe and LLM only augments.
|
||||
await self._seed_manual_investigation_areas()
|
||||
|
||||
manual_hypotheses = self.config.get("hypotheses", [])
|
||||
if manual_hypotheses:
|
||||
await self._generate_hypotheses_manual(manual_hypotheses)
|
||||
@@ -677,10 +789,17 @@ class Orchestrator:
|
||||
if self.graph.phenomena and self.graph.hypotheses:
|
||||
await self._judge_new_phenomena()
|
||||
|
||||
# Derive investigation areas from active hypotheses.
|
||||
# No-op if manual seed already populated or resume restored areas.
|
||||
await self._derive_investigation_areas()
|
||||
|
||||
for h in self.graph.hypotheses.values():
|
||||
_log(f" {h.summary()}", event="hypothesis")
|
||||
for a in self.graph.investigation_areas.values():
|
||||
_log(f" {a.summary()}", event="area")
|
||||
_log(
|
||||
f"+{len(self.graph.hypotheses)} hypotheses generated",
|
||||
f"+{len(self.graph.hypotheses)} hypotheses, "
|
||||
f"{len(self.graph.investigation_areas)} areas",
|
||||
event="progress", elapsed=time.monotonic() - t0,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from evidence_graph import (
|
||||
from llm_client import (
|
||||
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
|
||||
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
|
||||
TOOL_RESULT_TAG, TOOL_RESULT_END,
|
||||
)
|
||||
from tool_registry import (
|
||||
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
|
||||
@@ -598,7 +597,8 @@ class TestParallelToolExecution:
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert len(results) == 3
|
||||
assert all("ok" in r for r in results)
|
||||
# results are (raw, formatted) tuples; both contain "ok"
|
||||
assert all("ok" in raw and "ok" in formatted for raw, formatted in results)
|
||||
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
|
||||
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
|
||||
|
||||
@@ -627,8 +627,10 @@ class TestParallelToolExecution:
|
||||
|
||||
results = await client._execute_tool_batch_parallel(calls, tool_executor)
|
||||
assert len(results) == 2
|
||||
assert "success" in results[0]
|
||||
assert "Error" in results[1]
|
||||
raw0, formatted0 = results[0]
|
||||
raw1, formatted1 = results[1]
|
||||
assert "success" in raw0 and "success" in formatted0
|
||||
assert "Error" in raw1 and "Error" in formatted1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -843,21 +845,27 @@ class TestToolResultCache:
|
||||
|
||||
class TestProgressiveDecay:
|
||||
def _make_messages(self, n_rounds: int) -> list[dict]:
|
||||
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
|
||||
"""Build a synthetic message list shaped like native tool calling:
|
||||
user → (assistant w/ tool_calls → tool result)+
|
||||
"""
|
||||
messages = [{"role": "user", "content": "Start task"}]
|
||||
for i in range(n_rounds):
|
||||
tc_id = f"call_{i}"
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc_id,
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
})
|
||||
# Tool result message with substantial content
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[tool_{i}] {'x' * 2500}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
),
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": f"[tool_{i}] {'x' * 2500}",
|
||||
})
|
||||
return messages
|
||||
|
||||
@@ -865,21 +873,16 @@ class TestProgressiveDecay:
|
||||
msgs = self._make_messages(3)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert len(result) == len(msgs)
|
||||
# Content should be identical for short conversations
|
||||
for orig, decayed in zip(msgs, result):
|
||||
assert orig["content"] == decayed["content"]
|
||||
assert orig.get("content") == decayed.get("content")
|
||||
|
||||
def test_old_messages_truncated(self):
|
||||
msgs = self._make_messages(20)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
|
||||
# Recent tool results should be full length
|
||||
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
|
||||
assert len(last_tool_result["content"]) > 2000
|
||||
|
||||
# Oldest tool results should be truncated
|
||||
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
|
||||
assert len(first_tool_result["content"]) < 500
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
assert len(tool_msgs[-1]["content"]) > 2000
|
||||
assert len(tool_msgs[0]["content"]) < 500
|
||||
|
||||
def test_message_count_preserved(self):
|
||||
msgs = self._make_messages(20)
|
||||
@@ -915,7 +918,7 @@ class TestMessageFolding:
|
||||
messages.append({"role": "assistant", "content": f"thinking step {i}"})
|
||||
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
|
||||
|
||||
result = await client._fold_old_messages(messages, "system prompt")
|
||||
result = await client._fold_old_messages(messages)
|
||||
|
||||
# Should be significantly shorter
|
||||
assert len(result) < len(messages)
|
||||
@@ -937,7 +940,7 @@ class TestMessageFolding:
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
result = await client._fold_old_messages(messages)
|
||||
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
|
||||
assert result == messages
|
||||
client.chat.assert_not_called()
|
||||
@@ -951,7 +954,555 @@ class TestMessageFolding:
|
||||
client.chat = AsyncMock(side_effect=Exception("API error"))
|
||||
|
||||
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
result = await client._fold_old_messages(messages)
|
||||
|
||||
# On failure, should return original messages unchanged
|
||||
assert len(result) == 40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_boundary_never_orphans_tool_message(self):
|
||||
"""If the natural fold boundary would leave `role: "tool"` at the
|
||||
head of `recent_messages`, fold must walk the boundary forward
|
||||
until the head is non-tool. The API rejects orphan tool messages
|
||||
with HTTP 400."""
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock(return_value="summary")
|
||||
|
||||
# Build a long conversation of (assistant{tool_calls}, tool) pairs.
|
||||
# Place the assistant at the exact n_to_fold boundary so its paired
|
||||
# tool would otherwise be orphaned at the head of recent_messages.
|
||||
msgs: list[dict] = [{"role": "user", "content": "task"}]
|
||||
for i in range(30):
|
||||
tc_id = f"call_{i}"
|
||||
msgs.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{
|
||||
"id": tc_id, "type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}],
|
||||
})
|
||||
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": "ok"})
|
||||
|
||||
result = await client._fold_old_messages(msgs)
|
||||
# No `role: "tool"` may appear without an `assistant{tool_calls}`
|
||||
# immediately preceding it.
|
||||
for i, m in enumerate(result):
|
||||
if m.get("role") == "tool":
|
||||
assert i > 0, "tool message cannot be first"
|
||||
prev = result[i - 1]
|
||||
assert prev.get("role") == "assistant" and prev.get("tool_calls"), (
|
||||
f"tool at index {i} preceded by {prev.get('role')} "
|
||||
f"(tool_calls={bool(prev.get('tool_calls'))})"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Investigation areas: dataclass + derivation + coverage + dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInvestigationAreaDerivation:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_investigation_area_dedupes_and_merges_lists(self, graph):
|
||||
aid1, existed1 = await graph.add_investigation_area(
|
||||
area="password_hashes", description="SAM dump",
|
||||
suggested_agent="filesystem",
|
||||
expected_keywords=["sam", "pwdump"],
|
||||
expected_tools=["search_strings"],
|
||||
motivating_hypothesis_ids=["hyp-a"],
|
||||
created_by="llm_derive",
|
||||
)
|
||||
aid2, existed2 = await graph.add_investigation_area(
|
||||
area="password_hashes", description="Other description",
|
||||
suggested_agent="registry",
|
||||
expected_keywords=["sam", "hashdump"], # one new
|
||||
expected_tools=["search_strings", "parse_registry_key"], # one new
|
||||
motivating_hypothesis_ids=["hyp-b"], # new
|
||||
created_by="manual",
|
||||
)
|
||||
assert not existed1
|
||||
assert existed2
|
||||
assert aid1 == aid2
|
||||
a = graph.investigation_areas[aid1]
|
||||
# Description and suggested_agent NOT overwritten (first-write wins)
|
||||
assert a.description == "SAM dump"
|
||||
assert a.suggested_agent == "filesystem"
|
||||
# Three list fields are unioned
|
||||
assert set(a.expected_keywords) == {"sam", "pwdump", "hashdump"}
|
||||
assert set(a.expected_tools) == {"search_strings", "parse_registry_key"}
|
||||
assert set(a.motivating_hypothesis_ids) == {"hyp-a", "hyp-b"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_coverage_keyword_layer(self, graph):
|
||||
await graph.add_phenomenon(
|
||||
"fs", "filesystem", "Cain SAM dump artifact",
|
||||
"Found sam.lst in the Cain folder",
|
||||
source_tool="list_directory",
|
||||
)
|
||||
await graph.add_investigation_area(
|
||||
area="password_hashes", description="SAM dump",
|
||||
suggested_agent="filesystem",
|
||||
expected_keywords=["sam.lst", "pwdump"],
|
||||
expected_tools=["nonexistent_tool"],
|
||||
)
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
from unittest.mock import AsyncMock
|
||||
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
|
||||
covered = orch._check_coverage()
|
||||
assert "password_hashes" in covered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_coverage_tool_layer(self, graph):
|
||||
await graph.add_phenomenon(
|
||||
"reg", "registry", "User accounts",
|
||||
"Found accounts",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
await graph.add_investigation_area(
|
||||
area="user_accounts", description="Enum users",
|
||||
suggested_agent="registry",
|
||||
expected_keywords=["irrelevant"],
|
||||
expected_tools=["enumerate_users"],
|
||||
)
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
from unittest.mock import AsyncMock
|
||||
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
|
||||
covered = orch._check_coverage()
|
||||
assert "user_accounts" in covered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state_round_trip_areas(self, graph, tmp_path):
|
||||
await graph.add_investigation_area(
|
||||
area="x", description="d", suggested_agent="filesystem",
|
||||
expected_keywords=["k1"], expected_tools=["t1"],
|
||||
priority=2, motivating_hypothesis_ids=["hyp-a"],
|
||||
created_by="manual",
|
||||
)
|
||||
path = tmp_path / "state.json"
|
||||
graph.save_state(path)
|
||||
g2 = EvidenceGraph.load_state(path)
|
||||
assert len(g2.investigation_areas) == 1
|
||||
a = list(g2.investigation_areas.values())[0]
|
||||
assert a.area == "x"
|
||||
assert a.expected_keywords == ["k1"]
|
||||
assert a.priority == 2
|
||||
assert a.motivating_hypothesis_ids == ["hyp-a"]
|
||||
assert a.created_by == "manual"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derive_no_op_when_areas_already_populated(self, graph):
|
||||
"""Resume safety: if areas are already in the graph (manual seed or
|
||||
restored from disk), _derive_investigation_areas does nothing."""
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
await graph.add_hypothesis("test", "desc", created_by="t")
|
||||
await graph.add_investigation_area(
|
||||
area="pre_existing", description="d", suggested_agent="filesystem",
|
||||
created_by="manual",
|
||||
)
|
||||
|
||||
llm = AsyncMock()
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
# LLM should not have been called
|
||||
assert llm.chat.call_count == 0
|
||||
# Area count unchanged
|
||||
assert len(graph.investigation_areas) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_when_llm_returns_empty_list(self, graph):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
await graph.add_hypothesis("Some compromise", "desc", created_by="t")
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = "[]"
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
# Fallback creates one area per hypothesis
|
||||
assert len(graph.investigation_areas) == 1
|
||||
a = list(graph.investigation_areas.values())[0]
|
||||
assert a.created_by == "fallback"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_filtered_kept_keywords(self, graph):
|
||||
"""LLM emits a tool name not in TOOL_CATALOG; tool is filtered,
|
||||
but the area itself with its keywords is kept."""
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
h = await graph.add_hypothesis("h", "desc", created_by="t")
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = (
|
||||
'[{"area":"foo","description":"desc","suggested_agent":"filesystem",'
|
||||
'"expected_keywords":["kw1","kw2"],"expected_tools":["nonexistent_tool"],'
|
||||
f'"priority":2,"motivating_hypothesis_ids":["{h}"]}}]'
|
||||
)
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
assert "area-foo" in graph.investigation_areas
|
||||
a = graph.investigation_areas["area-foo"]
|
||||
assert a.expected_keywords == ["kw1", "kw2"]
|
||||
assert a.expected_tools == [] # filtered out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_agent_resolved_via_AGENT_ALIASES(self, graph):
|
||||
"""LLM emits 'chat' (which is in AGENT_ALIASES → 'communication').
|
||||
The area should land with the resolved agent name."""
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
h = await graph.add_hypothesis("h", "desc", created_by="t")
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = (
|
||||
'[{"area":"chat_stuff","description":"d","suggested_agent":"chat",'
|
||||
'"expected_keywords":["irc"],"expected_tools":[],'
|
||||
f'"priority":3,"motivating_hypothesis_ids":["{h}"]}}]'
|
||||
)
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
a = graph.investigation_areas["area-chat_stuff"]
|
||||
assert a.suggested_agent == "communication"
|
||||
|
||||
@staticmethod
|
||||
def _agent_with_executor(graph, llm, tool_name: str, real_executor):
|
||||
"""Build a BaseAgent that registers tool_name via the real register_tool
|
||||
path so the mandatory-record wrapper is engaged."""
|
||||
from base_agent import BaseAgent
|
||||
agent = BaseAgent(llm, graph)
|
||||
agent.name = "test_agent"
|
||||
# Bypass _register_graph_tools side-effects in run() — we register
|
||||
# only what the test needs.
|
||||
agent._register_graph_tools = lambda: None
|
||||
agent.register_tool(
|
||||
name=tool_name, description="", input_schema={}, executor=real_executor,
|
||||
)
|
||||
return agent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forced_record_retry_fires_when_zero_phenomena(self):
|
||||
"""BaseAgent.run should automatically retry one more LLM round if
|
||||
the agent finished without calling any mandatory recording tool."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
|
||||
async def real_add(**kw):
|
||||
await graph.add_phenomenon(
|
||||
source_agent="test", category="filesystem",
|
||||
title="Forced retry record",
|
||||
description="Recorded after STOP prompt",
|
||||
source_tool="forced_retry",
|
||||
)
|
||||
|
||||
agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add)
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
already_retrying = any(
|
||||
"STOP." in (m.get("content", "") if isinstance(m, dict) else "")
|
||||
for m in messages
|
||||
)
|
||||
if not already_retrying:
|
||||
return "Final answer without recording.", list(messages) + [
|
||||
{"role": "assistant", "content": "Final answer without recording."}
|
||||
]
|
||||
await tool_executor["add_phenomenon"]() # goes through wrapper
|
||||
return "Recorded.", []
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
await agent.run("test task")
|
||||
assert len(graph.phenomena) == 1
|
||||
assert agent._record_call_counts["add_phenomenon"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_retry_when_mandatory_tool_was_called(self):
|
||||
"""Retry should NOT fire if a mandatory recording tool was invoked."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def real_add(**kw):
|
||||
await graph.add_phenomenon(
|
||||
source_agent="test", category="filesystem", title="x",
|
||||
description="y", source_tool="t",
|
||||
)
|
||||
|
||||
agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add)
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
call_count["n"] += 1
|
||||
await tool_executor["add_phenomenon"]() # wrapper increments count
|
||||
return "done.", list(messages)
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
await agent.run("test task")
|
||||
assert call_count["n"] == 1 # no retry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_retry_when_mandatory_tools_empty(self):
|
||||
"""ReportAgent declares mandatory_record_tools=() — retry should
|
||||
not fire even with zero graph mutations (final text IS the output)."""
|
||||
from unittest.mock import AsyncMock
|
||||
from base_agent import BaseAgent
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
call_count["n"] += 1
|
||||
return "report body here", list(messages)
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
|
||||
class ReportLike(BaseAgent):
|
||||
mandatory_record_tools = ()
|
||||
|
||||
agent = ReportLike(llm, graph)
|
||||
agent.name = "report_like"
|
||||
agent._register_graph_tools = lambda: None
|
||||
await agent.run("test task")
|
||||
assert call_count["n"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forced_retry_fires_for_timeline_agent(self):
|
||||
"""TimelineAgent.mandatory_record_tools=('add_temporal_edge',) — retry
|
||||
should fire when timeline finishes without creating any temporal edges,
|
||||
even though the agent does not have add_phenomenon."""
|
||||
from unittest.mock import AsyncMock
|
||||
from base_agent import BaseAgent
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
edge_added = {"n": 0}
|
||||
async def real_add_edge(**kw):
|
||||
edge_added["n"] += 1
|
||||
|
||||
class TimelineLike(BaseAgent):
|
||||
mandatory_record_tools = ("add_temporal_edge",)
|
||||
|
||||
agent = TimelineLike(llm, graph)
|
||||
agent.name = "timeline_like"
|
||||
agent._register_graph_tools = lambda: None
|
||||
agent.register_tool("add_temporal_edge", "", {}, real_add_edge)
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
call_count["n"] += 1
|
||||
already_retrying = any(
|
||||
"STOP." in (m.get("content", "") if isinstance(m, dict) else "")
|
||||
for m in messages
|
||||
)
|
||||
if not already_retrying:
|
||||
return "answer", list(messages)
|
||||
await tool_executor["add_temporal_edge"]()
|
||||
return "recorded.", []
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
await agent.run("build timeline")
|
||||
assert call_count["n"] == 2 # first + retry
|
||||
assert edge_added["n"] == 1
|
||||
assert agent._record_call_counts["add_temporal_edge"] == 1
|
||||
|
||||
# ---- terminal_tools: real LLMClient.tool_call_loop short-circuit -----
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminal_tool_exits_loop_immediately(self):
|
||||
"""When a terminal tool is called, tool_call_loop must return
|
||||
with that tool's result text as final_text — no further LLM calls."""
|
||||
from unittest.mock import AsyncMock
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.max_tokens = 4096
|
||||
client.reasoning_effort = None
|
||||
client.thinking_enabled = False
|
||||
client.model = "test"
|
||||
client._client = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_chat_with_tools(messages, openai_tools):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
# First turn: model calls a read tool then the terminal tool.
|
||||
return "thinking aloud", None, [
|
||||
{"id": "tc1", "name": "read_tool", "arguments": "{}"},
|
||||
{"id": "tc2", "name": "save_report",
|
||||
"arguments": '{"content":"FINAL REPORT BODY","output_path":"r.md"}'},
|
||||
]
|
||||
raise AssertionError("loop should have exited after terminal tool")
|
||||
|
||||
client._chat_with_tools = fake_chat_with_tools
|
||||
|
||||
async def read_tool():
|
||||
return "some data"
|
||||
|
||||
async def save_report(content, output_path):
|
||||
return content # terminal tool returns content as final_text
|
||||
|
||||
tools = [
|
||||
{"name": "read_tool", "description": "", "input_schema": {"type": "object", "properties": {}}},
|
||||
{"name": "save_report", "description": "",
|
||||
"input_schema": {"type": "object", "properties": {
|
||||
"content": {"type": "string"}, "output_path": {"type": "string"}}}},
|
||||
]
|
||||
executors = {"read_tool": read_tool, "save_report": save_report}
|
||||
|
||||
final_text, _ = await client.tool_call_loop(
|
||||
messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools, tool_executor=executors,
|
||||
system="sys", terminal_tools=("save_report",),
|
||||
)
|
||||
assert final_text == "FINAL REPORT BODY"
|
||||
assert call_count["n"] == 1 # never called a 2nd round
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_terminal_short_circuit_when_not_declared(self):
|
||||
"""When terminal_tools is empty, the same call sequence should
|
||||
run the read tool, run save_report-like tool, AND continue the loop
|
||||
(i.e. another LLM round) until the model stops calling tools."""
|
||||
from unittest.mock import AsyncMock
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.max_tokens = 4096
|
||||
client.reasoning_effort = None
|
||||
client.thinking_enabled = False
|
||||
client.model = "test"
|
||||
client._client = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_chat_with_tools(messages, openai_tools):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return "", None, [
|
||||
{"id": "tc1", "name": "add_phenomenon",
|
||||
"arguments": '{"title":"x","description":"y"}'},
|
||||
]
|
||||
return "all done", None, [] # model stops calling tools
|
||||
|
||||
client._chat_with_tools = fake_chat_with_tools
|
||||
|
||||
async def add_phenomenon(title, description):
|
||||
return f"recorded {title}"
|
||||
|
||||
tools = [
|
||||
{"name": "add_phenomenon", "description": "",
|
||||
"input_schema": {"type": "object", "properties": {
|
||||
"title": {"type": "string"}, "description": {"type": "string"}}}},
|
||||
]
|
||||
executors = {"add_phenomenon": add_phenomenon}
|
||||
|
||||
final_text, _ = await client.tool_call_loop(
|
||||
messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools, tool_executor=executors,
|
||||
system="sys", terminal_tools=(), # NOT terminal
|
||||
)
|
||||
assert final_text == "all done"
|
||||
assert call_count["n"] == 2 # 2 rounds — terminal_tools empty, loop continues
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_agent_terminal_tool_declared(self):
|
||||
"""ReportAgent should declare save_report as both mandatory and terminal."""
|
||||
from agents.report import ReportAgent
|
||||
assert ReportAgent.terminal_tools == ("save_report",)
|
||||
assert ReportAgent.mandatory_record_tools == ("save_report",)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminal_tool_result_not_truncated(self):
|
||||
"""Terminal tool's raw return is used as final_text and must NOT
|
||||
be truncated to 3000 chars (the truncation cap applies only to
|
||||
LLM-context tool result messages). A 20K-char markdown report
|
||||
passed through save_report should reach the caller intact."""
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.max_tokens = 4096
|
||||
client.reasoning_effort = None
|
||||
client.thinking_enabled = False
|
||||
client.model = "test"
|
||||
client._client = None
|
||||
|
||||
long_report = "# Report\n" + ("- finding " + "x" * 100 + "\n") * 200
|
||||
assert len(long_report) > 10000
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_chat_with_tools(messages, openai_tools):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return "", None, [
|
||||
{"id": "tc1", "name": "save_report",
|
||||
"arguments": '{"content":"placeholder","output_path":"r.md"}'},
|
||||
]
|
||||
raise AssertionError("loop should have exited")
|
||||
|
||||
client._chat_with_tools = fake_chat_with_tools
|
||||
|
||||
async def save_report(content, output_path):
|
||||
return long_report # ignore content arg; return long content
|
||||
|
||||
tools = [{"name": "save_report", "description": "",
|
||||
"input_schema": {"type": "object", "properties": {
|
||||
"content": {"type": "string"}, "output_path": {"type": "string"}}}}]
|
||||
executors = {"save_report": save_report}
|
||||
|
||||
final_text, _ = await client.tool_call_loop(
|
||||
messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools, tool_executor=executors,
|
||||
system="sys", terminal_tools=("save_report",),
|
||||
)
|
||||
assert final_text == long_report
|
||||
assert len(final_text) > 10000 # not truncated to 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_uses_hypothesis_id_when_motivating_ids_present(self, graph):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
h = await graph.add_hypothesis("h", "desc", created_by="t")
|
||||
await graph.add_investigation_area(
|
||||
area="uncovered_area", description="d", suggested_agent="registry",
|
||||
expected_keywords=["xyz_no_match"], expected_tools=[],
|
||||
priority=2, motivating_hypothesis_ids=[h],
|
||||
created_by="llm_derive",
|
||||
)
|
||||
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
|
||||
# Don't actually dispatch (would call agents) — just hit the lead-add path
|
||||
# by manually replicating what _run_gap_analysis does.
|
||||
covered = orch._check_coverage()
|
||||
assert "uncovered_area" not in covered
|
||||
# Simulate dispatch
|
||||
for a in graph.investigation_areas.values():
|
||||
if a.area not in covered:
|
||||
await graph.add_lead(
|
||||
target_agent=a.suggested_agent,
|
||||
description=a.description,
|
||||
priority=a.priority,
|
||||
hypothesis_id=(a.motivating_hypothesis_ids[0]
|
||||
if a.motivating_hypothesis_ids else None),
|
||||
)
|
||||
assert len(graph.leads) == 1
|
||||
assert graph.leads[0].hypothesis_id == h
|
||||
assert graph.leads[0].target_agent == "registry"
|
||||
assert graph.leads[0].priority == 2
|
||||
|
||||
Reference in New Issue
Block a user