refactor: native tool calling + generic forced-retry + terminal exit

- llm_client: switch tool_call_loop from text-based <tool_call> regex
  to OpenAI-native tools=[...] / structured tool_calls field; accumulate
  delta.reasoning_content for DeepSeek thinking-mode echo-back; fold
  preserves system msg and aligns boundary to never orphan role:tool
- base_agent: generic forced-retry via mandatory_record_tools class attr
  (filesystem -> add_phenomenon, timeline -> add_temporal_edge,
  hypothesis -> add_hypothesis, report -> save_report); count via
  executor wrapper
- terminal_tools class attr + loop short-circuit: when a terminal tool
  is called, loop exits with its raw return as final_text. ReportAgent
  declares save_report as terminal - replaces the <answer>-tag stop
  signal that native tool calling broke
- _execute_*: return (raw, formatted) - terminal exit uses untruncated
  raw, conversation history uses 3000-char-capped formatted
- evidence_graph + orchestrator: LLM-derived InvestigationArea support
  (hypothesis-driven coverage check, replaces hardcoded _AREA_KEYWORDS /
  _AREA_TOOLS); manual yaml block kept as optional seed
- strip <answer> references from agent prompts (no longer load-bearing)

Verified on CFReDS image across 4 smoke runs: 0 JSON parse failures
(was 3); 22 temporal edges from Phase 4 (was 0); ReportAgent exits via
save_report (was max_iterations regression). 78/78 unit tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
BattleTag
2026-05-13 13:51:19 +08:00
parent 0a2b344c84
commit 444d58726a
9 changed files with 1356 additions and 298 deletions

View File

@@ -67,14 +67,18 @@ Phenomenon → Hypothesis 的边类型与权重写死在 `HYPOTHESIS_EDGE_WEIGHT
| **Phase 4** | TimelineAgent 用 `build_filesystem_timeline` 生成 MAC 时间线,与 Phenomenon 时间戳关联 |
| **Phase 5** | ReportAgent 综合假设、证据、实体,生成 Markdown 报告 |
### Gap AnalysisPhase 3 末
### Investigation Areashypothesis-derived
`config.yaml:investigation_areas` 列出必须覆盖的调查领域系统信息、用户账户、网络配置、邮件配置、IRC 日志、PCAP、删除文件、Prefetch 等。Orchestrator 两层判定覆盖情况
Phase 2 末尾 orchestrator 调一次 LLM 从所有 active hypothesis 派生 5-12 个 **InvestigationArea**snake_case slug、description、suggested_agent、expected_keywords、expected_tools、priority、motivating_hypothesis_ids。Areas 存进 `graph.investigation_areas`,序列化到 `runs/<ts>/investigation_areas.json`。两个用途
1. **关键词匹配**`_AREA_KEYWORDS`)— 扫现有 Phenomenon 标题/描述
2. **工具命中**`_AREA_TOOLS`)— 检查是否调用过该领域的关键工具(如 `enumerate_users``parse_pcap_strings`
1. **Phase 3 主循环提示** — 每个 hypothesis 块附 `Expected areas: a, b, c`LLM 仍自由选 lead 但有软引导
2. **Phase 3 末尾 Gap Analysis** — 两层判定覆盖情况:
- **关键词匹配**:扫 Phenomenon 标题/描述对照 area.expected_keywords
- **工具命中**:检查 area.expected_tools 是否实际调用过
未覆盖的领域自动派 lead最多 3 轮补漏。
未覆盖的 area 自动派 lead`suggested_agent` + `priority` + `motivating_hypothesis_ids[0]` 透传给 `Lead.hypothesis_id` 保留 provenance,最多 3 轮补漏。
**手动 override**`config.yaml:investigation_areas` 默认注释掉,纯 LLM 派生。取消注释可添加强制必查的领域,会先于 LLM 写入并通过 slug-based dedupe 保护不被覆盖LLM 只会 augment keyword/tool 列表)。这是跨案件/跨平台适配的关键 —— 不再 hardcode Windows-specific 领域。
## Agent 体系
@@ -183,11 +187,12 @@ max_investigation_rounds: 5 # Phase 3 最大迭代轮数
# - title: "嫌疑人主动实施网络嗅探"
# description: "..."
investigation_areas: # Gap Analysis 必须覆盖的领域
- area: system_info
agent: registry
task: "..."
# ...
# investigation_areas: # 可选:手动 override默认全 LLM 派生)
# - area: shutdown_time # LLM 通过 slug dedupe 只 augment
# agent: registry # keyword/tool 列表,不覆盖 manual
# priority: 3
# keywords: [shutdown]
# tools: [get_shutdown_time]
```
未配置 `hypotheses` 时由 HypothesisAgent 自动生成。

View File

@@ -24,6 +24,7 @@ class HypothesisAgent(BaseAgent):
"and formulate investigative hypotheses about what happened on this system. "
"Your ultimate goal: build the most complete picture of events that occurred."
)
mandatory_record_tools = ("add_hypothesis",)
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
super().__init__(llm, graph)
@@ -68,7 +69,7 @@ class HypothesisAgent(BaseAgent):
f"WORKFLOW:\n"
f"1. Call list_phenomena and search_graph to review existing findings.\n"
f"2. For each hypothesis you want to record, call add_hypothesis (title + description).\n"
f"3. Wrap a short summary in <answer> when you have generated 3-7 hypotheses.\n\n"
f"3. STOP after you have generated 3-7 hypotheses. Do not call any more tools.\n\n"
f"STRICT BOUNDARIES:\n"
f"- Your only mutation tool is add_hypothesis. Do NOT attempt list_directory, "
f"parse_registry_key, extract_file, or any disk-image investigation tools — "

View File

@@ -2,9 +2,6 @@
from __future__ import annotations
import json
import os
from base_agent import BaseAgent
from evidence_graph import EvidenceGraph
from llm_client import LLMClient
@@ -15,11 +12,16 @@ class ReportAgent(BaseAgent):
role = (
"Forensic report writer. You synthesize all findings from the investigation "
"into a structured, professional forensic analysis report organized by hypotheses.\n\n"
"IMPORTANT: Only include findings that have a source_tool attribution (marked VERIFIED). "
"Only include findings that have a source_tool attribution (marked VERIFIED). "
"If evidence lacks source attribution, mark it as UNVERIFIED. "
"Do NOT invent or fabricate any data, timestamps, or findings not present in the evidence.\n\n"
"CRITICAL: You MUST call save_report to write the final report."
"Do NOT invent or fabricate any data, timestamps, or findings not present in the evidence."
)
# Calling save_report is BOTH the recording action and the completion
# signal. tool_call_loop returns the moment save_report executes; the
# tool's return value becomes the agent's final_text. The forced-retry
# mechanism fires if save_report is never called.
mandatory_record_tools = ("save_report",)
terminal_tools = ("save_report",)
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
super().__init__(llm, graph)
@@ -30,23 +32,26 @@ class ReportAgent(BaseAgent):
self._register_graph_read_tools()
def _build_system_prompt(self, task: str) -> str:
"""Report agent gets a clean prompt — no Phase A/B/C/D workflow."""
return (
f"You are a forensic report writer.\n"
f"Role: {self.role}\n\n"
f"Investigation state:\n{self.graph.stats_summary()}\n\n"
f"Your task: {task}\n\n"
f"WORKFLOW:\n"
f"1. Call get_hypotheses_with_evidence to get all hypotheses and their linked evidence\n"
f"2. Call get_all_phenomena to get detailed findings by category\n"
f"3. Call get_entities to get people, programs, and hosts\n"
f"4. Call get_case_info for case metadata\n"
f"5. Write the complete report directly in your <answer> block\n\n"
f"1. Call get_hypotheses_with_evidence, get_all_phenomena, get_entities, get_case_info "
f" to gather all the data needed for the report. Make these calls in parallel.\n"
f"2. Assemble the complete markdown forensic report.\n"
f"3. Call save_report(content=<full markdown>, output_path=\"report.md\").\n"
f" This single call is the completion signal — the run ENDS the moment it executes.\n"
f" Do NOT call any read tools after this point; they will not run.\n"
f" Do NOT write the report as free text outside of save_report; only the\n"
f" `content` argument of save_report is persisted.\n\n"
f"RULES:\n"
f"- Write the report DIRECTLY in <answer> — do NOT use save_report tool\n"
f"- Only include findings present in the evidence graph\n"
f"- Do NOT invent timestamps, file paths, or data not in the phenomena\n"
f"- The report must be complete — do not cut off mid-section\n"
f"- The report must be the complete markdown — do not cut off mid-section.\n"
f"- Only include findings present in the evidence graph.\n"
f"- Do NOT invent timestamps, file paths, or data not in the phenomena.\n"
f"- The `content` argument can be 10K+ chars. JSON-escape inner quotes (\\\") and\n"
f" backslashes (\\\\) and newlines (\\n) correctly.\n"
)
def _register_tools(self) -> None:
@@ -186,10 +191,16 @@ class ReportAgent(BaseAgent):
return "\n".join(lines)
async def _save_report(self, content: str, output_path: str) -> str:
try:
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w") as f:
f.write(content)
return f"Report saved to {output_path} ({len(content)} chars)"
except Exception as e:
return f"Error saving report: {e}"
"""Save the report and return the content itself.
The content is returned (rather than a "saved to ..." status string)
so that when tool_call_loop short-circuits on this terminal tool,
`final_text` is the full markdown — orchestrator writes it to the
canonical report.md path under runs/<ts>/.
The output_path argument is kept for backward compat but the model's
chosen path is ignored — the orchestrator owns the persistence path.
"""
if not content:
return ""
return content

View File

@@ -24,6 +24,7 @@ class TimelineAgent(BaseAgent):
"MAC timestamps and correlate events across all phenomena categories in the "
"evidence graph to reconstruct the sequence of activities on the system."
)
mandatory_record_tools = ("add_temporal_edge",)
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
super().__init__(llm, graph)
@@ -95,7 +96,7 @@ class TimelineAgent(BaseAgent):
f" - 'Tool installation' (before) 'Tool execution'\n"
f"4. Aim for 15-40 temporal edges that connect the major events into a "
f"forensic story.\n"
f"5. Wrap a short summary in <answer> when done.\n\n"
f"5. STOP after recording all meaningful temporal edges. Do not call any more tools.\n\n"
f"STRICT BOUNDARIES:\n"
f"- Your job is to CONNECT existing phenomena, NOT to discover new ones. "
f"You CANNOT call add_phenomenon — the tool isn't yours.\n"

View File

@@ -31,11 +31,26 @@ class BaseAgent:
name: str = "base"
role: str = "A forensic analysis agent."
# Tools the agent MUST invoke at least once for the run to count as productive.
# If none of these were called when tool_call_loop returns, run() fires a
# forced retry with an explicit "you forgot to record" instruction.
# Subclasses override to declare their own recording responsibility
# (timeline → add_temporal_edge, hypothesis → add_hypothesis, report → save_report).
mandatory_record_tools: tuple[str, ...] = ("add_phenomenon",)
# Tools whose invocation ends the run immediately. After any terminal tool
# is called, tool_call_loop returns with that tool's result text as
# final_text. Used by agents whose "completion" is a single explicit
# action rather than "model decides to stop calling tools". For multi-call
# agents (filesystem records many phenomena) leave empty.
terminal_tools: tuple[str, ...] = ()
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
self.llm = llm
self.graph = graph
self._tools: dict[str, dict] = {} # name -> schema
self._executors: dict[str, Any] = {} # name -> async callable
self._record_call_counts: dict[str, int] = {}
self._work_log: list[str] = []
self._current_lead_id: str | None = None
@@ -52,7 +67,18 @@ class BaseAgent:
"description": description,
"input_schema": input_schema,
}
self._executors[name] = executor
if name in self.mandatory_record_tools:
self._executors[name] = self._wrap_record_executor(name, executor)
else:
self._executors[name] = executor
def _wrap_record_executor(self, name: str, executor: Any) -> Any:
"""Wrap a mandatory-record executor to count successful invocations."""
async def wrapped(*args, **kwargs):
result = await executor(*args, **kwargs)
self._record_call_counts[name] = self._record_call_counts.get(name, 0) + 1
return result
return wrapped
def get_tool_definitions(self) -> list[dict]:
"""Get tool definitions in Claude API format."""
@@ -91,20 +117,19 @@ class BaseAgent:
f" FIRST call list_phenomena to get the current IDs — do NOT rely on memory.\n"
f" Then call link_to_entity for each relevant phenomenon.\n"
f" NEVER guess or fabricate a phenomenon ID. If an ID is not in list_phenomena output, it does not exist.\n\n"
f"Phase D — ANSWER:\n"
f" Only give your <answer> AFTER completing Phases B and C.\n\n"
f"Phase D — STOP:\n"
f" Once all phenomena are recorded and entities linked, you are DONE.\n"
f" Do not call any more tools. The orchestrator picks up automatically.\n\n"
f"CRITICAL — RECORDING REQUIREMENT:\n"
f"- Your <answer> block is DISCARDED by the orchestrator. Only graph mutations propagate.\n"
f"- Other agents and the final report read ONLY the evidence graph "
f"(phenomena, entities, edges).\n"
f"- You MUST call add_phenomenon for EVERY significant finding BEFORE you end.\n"
f"- Only graph mutations propagate to other agents and the final report.\n"
f"- You MUST call add_phenomenon for EVERY significant finding BEFORE you stop.\n"
f"- NEGATIVE findings count too. If you searched X (a directory, a pattern, "
f"a registry key) and found NOTHING, that absence IS evidence — call "
f"add_phenomenon with a 'No matches for X' title and the search scope in "
f"raw_data. Negative findings constrain the hypothesis space and prevent "
f"the next agent from wasting time re-searching.\n"
f"- If you produce <answer> without having called add_phenomenon at least once, "
f"the task is FAILED regardless of what you wrote in <answer>.\n"
f"- If you stop without having called add_phenomenon at least once, the task "
f"is FAILED and a forced retry will fire.\n"
f"- Include exact file paths, inode numbers, timestamps, and the source_tool "
f"that produced each finding.\n\n"
f"ANTI-HALLUCINATION RULES — STRICTLY ENFORCED:\n"
@@ -124,6 +149,7 @@ class BaseAgent:
self._current_lead_id = lead_id
self._register_graph_tools()
self._record_call_counts.clear()
system = self._build_system_prompt(task)
messages = [{"role": "user", "content": task}]
@@ -132,12 +158,60 @@ class BaseAgent:
ph_before = len(self.graph.phenomena)
try:
final_text, _ = await self.llm.tool_call_loop(
final_text, conversation = await self.llm.tool_call_loop(
messages=messages,
tools=self.get_tool_definitions(),
tool_executor=self._executors,
system=system,
terminal_tools=self.terminal_tools,
)
# Forced-record retry: if the agent has any mandatory recording
# tools but never invoked any of them, force one more round with
# an explicit "you forgot to record" instruction. The mandatory
# set is declared on the class — Timeline → add_temporal_edge,
# Hypothesis → add_hypothesis, ReportAgent → (). For agents with
# empty mandatory_record_tools this branch is a no-op.
registered_mandatory = [
t for t in self.mandatory_record_tools if t in self._executors
]
recorded_any = any(
self._record_call_counts.get(t, 0) > 0
for t in registered_mandatory
)
if registered_mandatory and not recorded_any:
missing = "/".join(registered_mandatory)
logger.warning(
"[%s] finished without calling any of [%s] — forcing RECORD retry",
self.name, missing,
)
conversation.append({
"role": "user",
"content": (
f"STOP. You produced an answer without ever calling "
f"{missing}. Your answer is DISCARDED — only graph "
f"mutations propagate to other agents and the final "
f"report.\n\n"
f"You MUST now call {missing} for every significant "
f"finding from your prior investigation, including "
f"exact identifiers, timestamps, and the source_tool "
f"that produced each finding. If you genuinely found "
f"NOTHING noteworthy, call the recording tool ONCE "
f"with a 'No significant findings' style entry "
f"summarizing what you searched.\n\n"
f"Do not run more investigation tools. Just record "
f"what you already found. Then end."
),
})
final_text, _ = await self.llm.tool_call_loop(
messages=conversation,
tools=self.get_tool_definitions(),
tool_executor=self._executors,
system=system,
max_iterations=10,
terminal_tools=self.terminal_tools,
)
self._work_log.append(f"[Task: {task[:80]}] -> {final_text[:150]}")
except Exception:
self.graph.agent_status[self.name] = "failed"

View File

@@ -197,6 +197,41 @@ class Lead:
return cls(**d)
@dataclass
class InvestigationArea:
"""An area to investigate to confirm/refute one or more hypotheses.
Derived by the orchestrator from active hypotheses after Phase 2; also
seeded from config.yaml:investigation_areas as an optional manual
override. Each area carries its own keywords + expected tools so the
gap-analysis coverage check is generic, not tied to hard-coded constants.
"""
id: str # "area-{slug}"
area: str # snake_case slug (dedupe key)
description: str
suggested_agent: str # filesystem / registry / communication / network / timeline
expected_keywords: list[str] = field(default_factory=list)
expected_tools: list[str] = field(default_factory=list)
priority: int = 5 # 1 (highest) - 10 (lowest)
motivating_hypothesis_ids: list[str] = field(default_factory=list)
created_by: str = "" # "manual" | "llm_derive" | "fallback"
created_at: str = ""
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, d: dict) -> InvestigationArea:
return cls(**d)
def summary(self) -> str:
return (
f"[{self.area}] P{self.priority} agent={self.suggested_agent} "
f"(motivating: {len(self.motivating_hypothesis_ids)})"
)
@dataclass
class ExtractedAsset:
"""A file extracted from the disk image and tracked in the asset library."""
@@ -270,6 +305,11 @@ class EvidenceGraph:
self.asset_library: dict[str, ExtractedAsset] = {}
self._inode_index: dict[str, str] = {} # inode → asset_id
# Investigation areas — derived from hypotheses (LLM) and/or seeded
# from config.yaml:investigation_areas (manual override). Drives the
# gap-analysis coverage check.
self.investigation_areas: dict[str, InvestigationArea] = {}
# Set by BaseAgent.run() before each agent execution
self._current_agent: str = ""
@@ -295,6 +335,9 @@ class EvidenceGraph:
"leads": [l.to_dict() for l in self.leads],
"agent_status": dict(self.agent_status),
"asset_library": {aid: a.to_dict() for aid, a in self.asset_library.items()},
"investigation_areas": {
aid: a.to_dict() for aid, a in self.investigation_areas.items()
},
"saved_at": datetime.now().isoformat(),
}
tmp = self._persist_path.with_suffix(".tmp")
@@ -345,6 +388,10 @@ class EvidenceGraph:
asset = ExtractedAsset.from_dict(a_data)
graph.asset_library[aid] = asset
graph._inode_index[asset.inode] = aid
graph.investigation_areas = {
aid: InvestigationArea.from_dict(a)
for aid, a in data.get("investigation_areas", {}).items()
}
graph._rebuild_adjacency()
logger.info(
"EvidenceGraph restored: %d phenomena, %d hypotheses, %d entities, "
@@ -656,6 +703,57 @@ class EvidenceGraph:
break
self._auto_save()
# ---- Investigation areas -------------------------------------------------
async def add_investigation_area(
self,
area: str,
description: str,
suggested_agent: str,
expected_keywords: list[str] | None = None,
expected_tools: list[str] | None = None,
priority: int = 5,
motivating_hypothesis_ids: list[str] | None = None,
created_by: str = "",
) -> tuple[str, bool]:
"""Add or merge an investigation area. Dedupe key is the `area` slug.
On collision, union the three list fields (keywords / tools /
motivating_hypothesis_ids); description / suggested_agent / priority
are preserved from the first writer (manual seed wins over LLM derive).
Returns (id, was_existing).
"""
async with self._lock:
for existing in self.investigation_areas.values():
if existing.area == area:
for kw in (expected_keywords or []):
if kw not in existing.expected_keywords:
existing.expected_keywords.append(kw)
for t in (expected_tools or []):
if t not in existing.expected_tools:
existing.expected_tools.append(t)
for hid in (motivating_hypothesis_ids or []):
if hid not in existing.motivating_hypothesis_ids:
existing.motivating_hypothesis_ids.append(hid)
self._auto_save()
return existing.id, True
aid = f"area-{area}"
self.investigation_areas[aid] = InvestigationArea(
id=aid,
area=area,
description=description,
suggested_agent=suggested_agent,
expected_keywords=list(expected_keywords or []),
expected_tools=list(expected_tools or []),
priority=priority,
motivating_hypothesis_ids=list(motivating_hypothesis_ids or []),
created_by=created_by,
created_at=datetime.now().isoformat(),
)
self._auto_save()
return aid, False
# ---- Asset library -------------------------------------------------------
async def register_asset(

View File

@@ -1,9 +1,10 @@
"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
Tool calling is text-based (ReAct pattern): tool schemas are embedded in
the system prompt and tool calls are parsed as <tool_call> JSON blocks
from model output. This keeps the protocol independent of whether the
underlying API supports native function/tool calling.
Tool calling uses the OpenAI-native `tools=[...]` parameter. The model
returns structured tool_calls via the streaming protocol; we accumulate
them, dispatch to our executors, and feed results back as `role: "tool"`
messages. This eliminates the fragile "model writes JSON inside free
text" problem of the previous ReAct text mode.
"""
from __future__ import annotations
@@ -32,52 +33,51 @@ class LLMAPIError(Exception):
self.attempts = attempts
# Markers the model uses to signal tool calls and final answers
TOOL_CALL_TAG = "<tool_call>"
TOOL_CALL_END = "</tool_call>"
TOOL_RESULT_TAG = "<tool_result>"
TOOL_RESULT_END = "</tool_result>"
# Optional answer tags — kept for backward compat with prompts that wrap
# their final response in <answer>...</answer>. Native tool calling does
# not need these (no tool_calls = final), but if the model continues to
# emit them, we strip the tags so callers see clean text.
ANSWER_TAG = "<answer>"
ANSWER_END = "</answer>"
def _build_tools_prompt(tools: list[dict]) -> str:
"""Format tool definitions for inclusion in the system prompt."""
lines = ["You have access to the following tools:\n"]
for t in tools:
schema = t.get("input_schema", {})
props = schema.get("properties", {})
required = schema.get("required", [])
def _to_openai_tools(tools: list[dict]) -> list[dict]:
"""Convert internal tool definitions to OpenAI native function-tools format."""
return [
{
"type": "function",
"function": {
"name": t["name"],
"description": t["description"],
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
},
}
for t in tools
]
params = []
for pname, pdef in props.items():
req = " (required)" if pname in required else ""
desc = pdef.get("description", "")
ptype = pdef.get("type", "string")
enum_vals = pdef.get("enum")
if enum_vals:
allowed = ", ".join(f'"{v}"' for v in enum_vals)
params.append(f" - {pname}: {ptype}{req}{desc} Allowed values: [{allowed}]")
else:
params.append(f" - {pname}: {ptype}{req}{desc}")
param_block = "\n".join(params) if params else " (no parameters)"
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None:
"""Return the first balanced [...] or {...} substring, or None if no balanced pair.
lines.append(
"## How to use tools\n"
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
f"{TOOL_CALL_TAG}\n"
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
f"{TOOL_CALL_END}\n\n"
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
"When you have finished your analysis and have a final answer, wrap it in:\n"
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
"You MUST call at least one tool before giving your final answer."
)
return "\n".join(lines)
Stack-based — handles nested brackets correctly (regex with .*? would
truncate at the first inner closing bracket, regex with .* would over-eat
trailing text). Brackets inside JSON string literals are ignored by
callers because the caller passes the result through json.loads which
re-parses with proper string handling.
"""
start = text.find(open_char)
if start < 0:
return None
depth = 0
for i in range(start, len(text)):
c = text[i]
if c == open_char:
depth += 1
elif c == close_char:
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _safe_json_loads(text: str):
@@ -87,8 +87,8 @@ def _safe_json_loads(text: str):
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
input (first 600 chars) so we can diagnose what the model emitted.
Used both by orchestrator JSON callsites and by _extract_tool_calls
when parsing <tool_call> blocks from model output.
Used by orchestrator JSON callsites (_call_llm_for_json) and by
tool_call_loop when parsing tool_call arguments returned by the API.
"""
try:
return json.loads(text)
@@ -110,23 +110,6 @@ def _safe_json_loads(text: str):
raise
def _extract_tool_calls(text: str) -> list[dict]:
"""Extract tool call JSON blocks from model output."""
pattern = re.compile(
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
re.DOTALL,
)
calls = []
for match in pattern.finditer(text):
raw = match.group(1).strip()
try:
parsed = _safe_json_loads(raw)
calls.append(parsed)
except json.JSONDecodeError:
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
return calls
def _extract_answer(text: str) -> str | None:
"""Extract the final answer from model output."""
pattern = re.compile(
@@ -266,50 +249,41 @@ _DECAY_TIERS: list[tuple[int, int]] = [
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
"""Truncate tool results in older messages to save context space.
"""Truncate the `content` of older `role: "tool"` messages to save context.
Operates in-place-style on a copy. Only touches user messages that
contain <tool_result> blocks (these are the tool-result messages
generated by tool_call_loop).
Each `role: "tool"` message in the conversation corresponds to one tool
call's result. We rank these messages by recency and progressively
truncate older ones according to `_DECAY_TIERS`.
"""
# Count rounds from the end. A "round" is a (assistant, user) pair.
# messages alternate: [user, assistant, user, assistant, user, ...]
# The initial user message is index 0, then pairs start at index 1.
total = len(messages)
if total <= 10: # not enough messages to bother
if total <= 10:
return messages
result = []
# Count tool-result user messages from the end
tool_result_indices = [
i for i, m in enumerate(messages)
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
tool_msg_indices = [
i for i, m in enumerate(messages) if m.get("role") == "tool"
]
# Build a set of indices that need decay, mapped to their max_chars
decay_map: dict[int, int] = {}
n_tool_msgs = len(tool_result_indices)
for rank, idx in enumerate(reversed(tool_result_indices)):
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
for rank, idx in enumerate(reversed(tool_msg_indices)):
rounds_ago = rank
for threshold, max_chars in _DECAY_TIERS:
if rounds_ago < threshold:
decay_map[idx] = max_chars
break
result = []
for i, msg in enumerate(messages):
if i in decay_map:
max_chars = decay_map[i]
content = msg["content"]
content = msg.get("content", "") or ""
if len(content) > max_chars + 200:
# Truncate but preserve the tool_result tags structure
truncated = content[:max_chars]
# Count how many tool results are in this message
n_results = content.count(TOOL_RESULT_TAG)
truncated += (
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
f"{n_results} tool result(s)]"
truncated = (
content[:max_chars]
+ f"\n... [context compressed: {len(content)} -> {max_chars} chars]"
)
result.append({"role": msg["role"], "content": truncated})
new_msg = dict(msg)
new_msg["content"] = truncated
result.append(new_msg)
else:
result.append(msg)
else:
@@ -434,6 +408,95 @@ class LLMClient:
return ""
async def _chat_with_tools(
self,
messages: list[dict],
openai_tools: list[dict],
max_retries: int = 5,
) -> tuple[str, str | None, list[dict]]:
"""Stream a chat completion with native tool calling enabled.
Returns:
(text_content, reasoning_content, raw_tool_calls).
- reasoning_content is non-None when DeepSeek thinking mode is
active; the caller MUST echo it back in the assistant message
on subsequent requests, or the API returns HTTP 400.
- raw_tool_calls is a list of {"id","name","arguments"} dicts;
arguments is the raw JSON string returned by the API.
"""
kwargs: dict[str, Any] = {
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"stream": True,
"tools": openai_tools,
}
if self.reasoning_effort:
kwargs["reasoning_effort"] = self.reasoning_effort
if self.thinking_enabled:
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
for attempt in range(max_retries):
logger.debug(
"LLM request (stream+tools): %d messages, %d tools (attempt %d)",
len(messages), len(openai_tools), attempt + 1,
)
text_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments}
try:
stream = await self._client.chat.completions.create(**kwargs)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta.content:
text_parts.append(delta.content)
# DeepSeek thinking-mode: reasoning_content is returned
# alongside content and MUST be echoed back on subsequent
# requests, otherwise the API rejects with HTTP 400.
rc = getattr(delta, "reasoning_content", None)
if rc:
reasoning_parts.append(rc)
if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
entry = tool_calls_acc.setdefault(
idx, {"id": None, "name": None, "arguments": ""},
)
if tc_delta.id:
entry["id"] = tc_delta.id
fn = tc_delta.function
if fn:
if fn.name:
entry["name"] = fn.name
if fn.arguments:
entry["arguments"] += fn.arguments
text = "".join(text_parts)
reasoning = "".join(reasoning_parts) or None
ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
logger.debug(
"LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls",
len(text), len(reasoning or ""), len(ordered),
)
return text, reasoning, ordered
except (APIConnectionError, APITimeoutError, APIError) as e:
if attempt < max_retries - 1:
wait = 2 ** attempt * 10
logger.warning(
"Tool-call request failed (%s), retrying in %ds...", e, wait,
)
await asyncio.sleep(wait)
else:
raise LLMAPIError(
f"LLM API unreachable after {max_retries} attempts: {e}",
attempts=max_retries,
) from e
return "", None, []
async def tool_call_loop(
self,
messages: list[dict],
@@ -441,87 +504,159 @@ class LLMClient:
tool_executor: dict[str, Any],
system: str | None = None,
max_iterations: int = 40,
terminal_tools: tuple[str, ...] = (),
) -> tuple[str, list[dict]]:
"""Run a ReAct-style tool-calling loop.
"""Run a tool-calling loop using OpenAI-native tool calls.
The model outputs <tool_call> blocks which we parse and execute,
feeding results back as <tool_result> blocks until the model
outputs an <answer> block.
The model returns structured `tool_calls` in its message; we
dispatch them through our executor dict and feed each result back
as a `role: "tool"` message with the matching `tool_call_id`. The
loop ends when:
- the model returns a message with no tool_calls (normal exit), or
- any tool in `terminal_tools` is called — in that case, the loop
short-circuits with that tool's result text as final_text. This
gives agents (notably ReportAgent) an explicit completion signal
that the old `<answer>` text tag used to provide.
Returns:
(final_text, all_messages)
(final_text, full_message_history)
"""
# Build system prompt with tool definitions
tools_prompt = _build_tools_prompt(tools)
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
terminal_set = set(terminal_tools)
openai_tools = _to_openai_tools(tools)
messages = list(messages) # don't mutate caller's list
_folded = False # Track whether we've already folded once this loop
# The caller may pass `messages` either as raw conversation (no system)
# together with `system=...`, OR as a complete history that already
# starts with the system message (retry path). Accept both shapes.
if messages and messages[0].get("role") == "system":
full_messages: list[dict] = list(messages)
else:
full_messages = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
_folded = False
for i in range(max_iterations):
for _i in range(max_iterations):
# ── Context compression before each API call ──────────────
# Stage A: progressively decay old tool results
messages = _apply_progressive_decay(messages)
# Stage B: fold oldest messages into LLM summary if too long
if not _folded and len(messages) > _FOLD_THRESHOLD:
messages = await self._fold_old_messages(messages, full_system)
full_messages = _apply_progressive_decay(full_messages)
if not _folded and len(full_messages) > _FOLD_THRESHOLD:
full_messages = await self._fold_old_messages(full_messages)
_folded = True
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
# Allow a second fold if messages grew back significantly
messages = await self._fold_old_messages(messages, full_system)
elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
full_messages = await self._fold_old_messages(full_messages)
text = await self.chat(messages, system=full_system)
text, reasoning, raw_tool_calls = await self._chat_with_tools(
full_messages, openai_tools,
)
# Check for final answer
answer = _extract_answer(text)
if answer is not None:
messages.append({"role": "assistant", "content": text})
return answer, messages
if not raw_tool_calls:
# Model produced a final response. Strip optional <answer>
# tags for backward compatibility with old prompts.
final_msg: dict[str, Any] = {"role": "assistant", "content": text}
if reasoning:
final_msg["reasoning_content"] = reasoning
full_messages.append(final_msg)
answer = _extract_answer(text)
return (answer if answer is not None else text), full_messages
# Check for tool calls
tool_calls = _extract_tool_calls(text)
# Parse arguments + build internal call dicts
parsed_calls: list[dict] = []
for rc in raw_tool_calls:
args_str = rc.get("arguments", "") or ""
try:
args = _safe_json_loads(args_str) if args_str.strip() else {}
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to parse arguments for tool %s: %s",
rc.get("name"), e,
)
args = {}
parsed_calls.append({
"id": rc.get("id"),
"name": rc.get("name", ""),
"arguments": args,
})
if not tool_calls:
# No tool calls and no answer tag — treat entire text as answer
messages.append({"role": "assistant", "content": text})
return text, messages
# Append the assistant turn with the raw tool_calls (and the
# DeepSeek-mandated reasoning_content echo-back), then execute.
asst_msg: dict[str, Any] = {
"role": "assistant",
"content": text or None,
"tool_calls": [
{
"id": rc.get("id"),
"type": "function",
"function": {
"name": rc.get("name", ""),
"arguments": rc.get("arguments", "") or "",
},
}
for rc in raw_tool_calls
],
}
if reasoning:
asst_msg["reasoning_content"] = reasoning
full_messages.append(asst_msg)
# Execute tool calls — read-only tools run in parallel
messages.append({"role": "assistant", "content": text})
result_parts = []
batches = _partition_tool_calls(tool_calls)
batches = _partition_tool_calls(parsed_calls)
t_batch_start = time.monotonic()
# Each entry: (tool_call_dict, raw_result, formatted_for_llm)
executed: list[tuple[dict, str, str]] = []
for batch in batches:
if batch.is_read_only and len(batch.calls) > 1:
batch_results = await self._execute_tool_batch_parallel(
results = await self._execute_tool_batch_parallel(
batch.calls, tool_executor, tools,
)
result_parts.extend(batch_results)
for tc, (raw, formatted) in zip(batch.calls, results):
executed.append((tc, raw, formatted))
else:
for tc in batch.calls:
result_parts.append(
await self._execute_single_tool(tc, tool_executor, tools)
raw, formatted = await self._execute_single_tool(
tc, tool_executor, tools,
)
executed.append((tc, raw, formatted))
# Emit folded tool-call summary for the terminal
t_batch_elapsed = time.monotonic() - t_batch_start
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
_emit_tool_call_summary(parsed_calls, t_batch_elapsed)
# Feed results back as a user message
result_message = "\n\n".join(result_parts)
messages.append({"role": "user", "content": result_message})
# Append formatted tool results to the conversation (this is
# what the LLM sees on subsequent rounds — truncated for context
# economy).
for tc, _raw, formatted in executed:
full_messages.append({
"role": "tool",
"tool_call_id": tc["id"],
"content": formatted,
})
# Terminal-tool short-circuit: if the model called any tool in
# `terminal_tools`, end the loop immediately. The terminal tool's
# RAW result (untruncated) becomes final_text — the LLM may have
# produced a 20K-char report via save_report and we must not
# truncate it just because the LLM-facing copy is truncated.
if terminal_set:
for tc, raw, _formatted in executed:
name = tc.get("name", "")
if name in terminal_set:
logger.info(
"Terminal tool %s called — exiting tool_call_loop", name,
)
return raw, full_messages
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
return "[Max tool call iterations reached]", messages
return "[Max tool call iterations reached]", full_messages
async def _execute_single_tool(
self, tc: dict, tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> str:
"""Execute a single tool call and return the formatted result."""
) -> tuple[str, str]:
"""Execute a single tool call.
Returns (raw_result, formatted_for_llm). `raw_result` is the
unmodified executor return (used by terminal-tool short-circuit as
final_text). `formatted_for_llm` is `[tool_name] {truncated}` and
is what gets fed back to the model as the tool message content.
"""
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
@@ -532,72 +667,106 @@ class LLMClient:
executor = tool_executor.get(tool_name)
if executor is None:
result_text = f"Error: unknown tool '{tool_name}'"
raw = f"Error: unknown tool '{tool_name}'"
else:
try:
result_text = await executor(**tool_args)
raw = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
result_text = f"Error executing {tool_name}: {e}"
raw = f"Error executing {tool_name}: {e}"
return (
f"{TOOL_RESULT_TAG}\n"
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
f"{TOOL_RESULT_END}"
)
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
return raw, formatted
async def _execute_tool_batch_parallel(
self, calls: list[dict], tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> list[str]:
"""Execute multiple read-only tool calls concurrently."""
) -> list[tuple[str, str]]:
"""Execute multiple read-only tool calls concurrently.
Returns a list of (raw_result, formatted_for_llm) tuples in the
same order as `calls`.
"""
logger.info("Executing %d read-only tools in parallel", len(calls))
async def _run_one(tc: dict) -> str:
async def _run_one(tc: dict) -> tuple[str, str]:
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
if tools:
tool_args = _fix_tool_args(tool_name, tool_args, tools)
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
logger.info(
"Calling tool (parallel): %s(%s)",
tool_name, json.dumps(tool_args, ensure_ascii=False),
)
executor = tool_executor.get(tool_name)
if executor is None:
result_text = f"Error: unknown tool '{tool_name}'"
raw = f"Error: unknown tool '{tool_name}'"
else:
try:
result_text = await executor(**tool_args)
raw = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
result_text = f"Error executing {tool_name}: {e}"
return (
f"{TOOL_RESULT_TAG}\n"
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
f"{TOOL_RESULT_END}"
)
raw = f"Error executing {tool_name}: {e}"
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
return raw, formatted
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
return list(results)
async def _fold_old_messages(
self, messages: list[dict], system: str,
self, messages: list[dict],
) -> list[dict]:
"""Fold old messages into an LLM-generated summary (Stage B).
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
replaces earlier ones with a single summary message.
Preserves the leading system message (if any), keeps the most
recent _FOLD_KEEP_RECENT messages intact, and replaces the older
middle slice with a single summary user message.
"""
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
# Pin the system message — it must NEVER be summarized away.
system_msgs: list[dict] = []
body = messages
if messages and messages[0].get("role") == "system":
system_msgs = [messages[0]]
body = messages[1:]
n_to_fold = len(body) - _FOLD_KEEP_RECENT
if n_to_fold <= 2:
return messages
old_messages = messages[:n_to_fold]
recent_messages = messages[n_to_fold:]
# Pull the fold boundary forward so we never split an assistant turn
# from its matching tool results. The API rejects (HTTP 400) any
# `role: "tool"` message that does not immediately follow an
# `assistant` message with `tool_calls`. We walk the boundary into
# `recent_messages` while its head is a `role: "tool"` message, or
# while the prior `recent` message is `assistant{tool_calls}` whose
# paired tools span the boundary.
while n_to_fold < len(body):
head = body[n_to_fold]
if head.get("role") == "tool":
n_to_fold += 1
continue
break
if n_to_fold >= len(body):
# Everything got folded — nothing recent to keep.
return system_msgs + [body[0]] if system_msgs else messages
old_messages = body[:n_to_fold]
recent_messages = body[n_to_fold:]
# Build a text dump of old messages for summarization
old_text_parts = []
for msg in old_messages:
role = msg["role"]
content = msg.get("content", "")
# Truncate each message for the summary prompt to avoid overload
role = msg.get("role", "?")
content = msg.get("content") or ""
# Render tool_calls (assistant turn) compactly.
if role == "assistant" and msg.get("tool_calls"):
tc_names = [
tc.get("function", {}).get("name", "?")
for tc in msg["tool_calls"]
]
content = (content + " " if content else "") + (
"called: " + ", ".join(tc_names)
)
if len(content) > 1000:
content = content[:1000] + "..."
old_text_parts.append(f"[{role}]: {content}")
@@ -621,7 +790,6 @@ class LLMClient:
logger.warning("Context folding failed: %s — keeping original messages", e)
return messages
# Replace old messages with a single summary
summary_message = {
"role": "user",
"content": (
@@ -629,4 +797,4 @@ class LLMClient:
f"messages in this conversation]\n\n{summary}"
),
}
return [summary_message] + recent_messages
return system_msgs + [summary_message] + recent_messages

View File

@@ -12,7 +12,8 @@ from pathlib import Path
from agent_factory import AgentFactory
from evidence_graph import EvidenceGraph
from llm_client import LLMClient, _safe_json_loads
from llm_client import LLMClient, _extract_first_balanced, _safe_json_loads
from tool_registry import TOOL_CATALOG
logger = logging.getLogger(__name__)
@@ -93,6 +94,14 @@ class Orchestrator:
"Omit phenomena that are unrelated. Be conservative — only link genuinely relevant evidence."
)
_AREA_DERIVE_SYSTEM = (
"You are a forensic investigation strategist. Given a set of hypotheses, "
"decompose them into a minimal aggregate set of investigation areas. An "
"area is a focused, concrete question with the keywords and tool names an "
"answering phenomenon would mention. Aggregate aggressively — when two "
"hypotheses share an area, emit it once and list both hypothesis_ids."
)
def __init__(
self,
llm: LLMClient,
@@ -216,6 +225,125 @@ class Orchestrator:
"The ultimate goal is to reconstruct a detailed timeline of what happened on this host."
)
# ---- Investigation areas (manual seed + LLM derive) ----------------------
_VALID_AGENT_TYPES = {"filesystem", "registry", "communication", "network", "timeline"}
async def _seed_manual_investigation_areas(self) -> None:
"""Import config.yaml:investigation_areas entries (manual override).
Run early in Phase 2 so manual entries are in the graph before LLM
derivation; LLM derive then augments via slug-based dedupe.
"""
for entry in self.config.get("investigation_areas", []):
area = entry.get("area")
if not area:
continue
await self.graph.add_investigation_area(
area=area,
description=entry.get("description", entry.get("task", "")),
suggested_agent=entry.get("agent", "filesystem"),
expected_keywords=entry.get("keywords", []),
expected_tools=entry.get("tools", []),
priority=entry.get("priority", 3),
created_by="manual",
)
async def _derive_investigation_areas(self) -> None:
"""Ask LLM to derive investigation areas from active hypotheses.
Manual-seeded or already-populated graph (resume) → no-op. On LLM
failure or empty output, falls back to one-area-per-hypothesis.
"""
if self.graph.investigation_areas:
return
active = [h for h in self.graph.hypotheses.values() if h.status == "active"]
if not active:
return
available_tools = sorted(TOOL_CATALOG.keys())
hyp_lines = "\n".join(
f" [{h.id}] {h.title}: {h.description}" for h in active
)
prompt = (
f"Active hypotheses:\n{hyp_lines}\n\n"
f"Available agents: {sorted(self._VALID_AGENT_TYPES)}\n"
f"Available tool names (pick 1-3 per area for expected_tools): {available_tools}\n\n"
f"Emit 5-12 distinct investigation areas covering the FULL hypothesis set.\n"
f"Each area must include:\n"
f" - area: snake_case slug (dedupe key)\n"
f" - description: one sentence on what to find\n"
f" - suggested_agent: one of the agents above\n"
f" - expected_keywords: 3-8 lowercase tokens that an answering phenomenon would mention\n"
f" - expected_tools: 1-3 tool names from the list above\n"
f" - priority: 1 (highest) to 10\n"
f" - motivating_hypothesis_ids: at least one [hyp-xxx] from above\n\n"
f"Aggregate aggressively — when two hypotheses share an area, emit it ONCE "
f"and list both ids in motivating_hypothesis_ids.\n\n"
f"Respond ONLY with JSON:\n"
f'[{{"area":"...","description":"...","suggested_agent":"...",'
f'"expected_keywords":[...],"expected_tools":[...],"priority":1-10,'
f'"motivating_hypothesis_ids":["hyp-xxx"]}}]'
)
try:
items = await self._call_llm_for_json(
system=self._AREA_DERIVE_SYSTEM,
user_prompt=prompt,
schema="array",
)
except Exception as e:
logger.warning("Area derivation LLM failed: %s — falling back", e)
await self._derive_areas_fallback(active)
return
valid_hyp_ids = set(self.graph.hypotheses.keys())
for it in items:
area = it.get("area", "").strip()
if not area:
continue
agent = it.get("suggested_agent", "filesystem")
if agent not in self._VALID_AGENT_TYPES:
agent = AGENT_ALIASES.get(agent, "filesystem")
tools = [t for t in it.get("expected_tools", []) if t in TOOL_CATALOG]
motivating = [
h for h in it.get("motivating_hypothesis_ids", [])
if h in valid_hyp_ids
]
priority = max(1, min(10, int(it.get("priority", 5))))
await self.graph.add_investigation_area(
area=area,
description=it.get("description", ""),
suggested_agent=agent,
expected_keywords=[
str(kw).lower() for kw in it.get("expected_keywords", [])
],
expected_tools=tools,
priority=priority,
motivating_hypothesis_ids=motivating,
created_by="llm_derive",
)
if not self.graph.investigation_areas:
await self._derive_areas_fallback(active)
async def _derive_areas_fallback(self, active: list) -> None:
"""One area per active hypothesis as a minimal safety net."""
for h in active:
slug = re.sub(r"[^a-z0-9_]+", "_", h.title.lower())[:40].strip("_")
if not slug:
slug = h.id.replace("-", "_")
await self.graph.add_investigation_area(
area=slug,
description=h.title,
suggested_agent="filesystem",
expected_keywords=h.title.lower().split()[:6],
expected_tools=[],
priority=5,
motivating_hypothesis_ids=[h.id],
created_by="fallback",
)
# ---- LLM JSON helper -----------------------------------------------------
async def _call_llm_for_json(
@@ -234,13 +362,12 @@ class Orchestrator:
"""
error_hint = ""
last_err: Exception | None = None
pattern = r'\[.*?\]' if schema == "array" else r'\{.*?\}'
open_c, close_c = ('[', ']') if schema == "array" else ('{', '}')
for attempt in range(max_retries + 1):
messages = [{"role": "user", "content": user_prompt + error_hint}]
response = await self.llm.chat(messages=messages, system=system)
match = re.search(pattern, response, re.DOTALL)
candidate = match.group() if match else response
candidate = _extract_first_balanced(response, open_c, close_c) or response
try:
return _safe_json_loads(candidate)
except (json.JSONDecodeError, ValueError) as e:
@@ -278,10 +405,19 @@ class Orchestrator:
existing = "\n".join(
f" - {r['node']} [{r['edge_type']}]" for r in related
) or " (none yet)"
related_areas = [
a.area for a in self.graph.investigation_areas.values()
if hyp.id in a.motivating_hypothesis_ids
]
expected_line = (
f" Expected areas to investigate: {', '.join(related_areas)}\n"
if related_areas else ""
)
hyp_blocks.append(
f"Hypothesis [{hyp.id}]: {hyp.title}\n"
f" Description: {hyp.description}\n"
f" Current confidence: {hyp.confidence:.2f}\n"
f"{expected_line}"
f" Existing evidence:\n{existing}"
)
@@ -328,12 +464,21 @@ class Orchestrator:
existing_evidence = "\n".join(
f" - {r['node']} [{r['edge_type']}]" for r in related
) or " (none yet)"
related_areas = [
a.area for a in self.graph.investigation_areas.values()
if hyp.id in a.motivating_hypothesis_ids
]
expected_line = (
f"Expected areas to investigate: {', '.join(related_areas)}\n\n"
if related_areas else ""
)
prompt = (
f"Hypothesis: {hyp.title}\n"
f"Description: {hyp.description}\n"
f"Current confidence: {hyp.confidence:.2f}\n\n"
f"Existing evidence linked to this hypothesis:\n{existing_evidence}\n\n"
f"{expected_line}"
f"What additional evidence should we look for to CONFIRM or DENY this hypothesis?\n"
f"List 1-3 specific, actionable investigation tasks.\n"
f"For each, specify which agent type should handle it: "
@@ -464,74 +609,57 @@ class Orchestrator:
# ---- Gap analysis (coverage check) ---------------------------------------
_AREA_KEYWORDS: dict[str, list[str]] = {
"system_info": ["install date", "registered owner", "product name", "windows xp", "system information"],
"user_accounts": ["user account", "enumerate", "sam hive", "administrator", "mr. evil"],
"shutdown_time": ["shutdown"],
"network_config": ["network interface", "network adapter", "ip address", "dhcp", "mac address", "network config"],
"installed_software": ["installed software", "program files", "installed program"],
"email_config": ["smtp", "pop3", "nntp", "email account", "email config"],
"chat_logs": ["irc", "mirc", "chat log", "channel"],
"network_activity": ["packet capture", "pcap", "interception", "http request", "user-agent"],
"deleted_files": ["deleted file", "recycle", "recycler"],
"execution_evidence": ["prefetch", "execution", "run count", "last execution"],
}
def _check_coverage(self) -> set[str]:
"""Return slugs of investigation_areas already covered by phenomena.
# Deterministic coverage: if the canonical tool was called, the area is covered.
_AREA_TOOLS: dict[str, list[str]] = {
"system_info": ["get_system_info"],
"user_accounts": ["enumerate_users"],
"shutdown_time": ["get_shutdown_time"],
"network_config": ["get_network_interfaces"],
"installed_software": ["list_installed_software"],
"email_config": ["get_email_config"],
"network_activity": ["parse_pcap_strings"],
"deleted_files": ["count_deleted_files"],
"execution_evidence": ["parse_prefetch"],
}
Layer A: any expected_keyword found in evidence text (category +
title + description, lowercased).
Layer B: any expected_tool present in the source_tool set of recorded
phenomena (deterministic — the canonical tool was actually called).
"""
evidence_text = " ".join(
f"{ph.category} {ph.title} {ph.description}".lower()
for ph in self.graph.phenomena.values()
)
used_tools: set[str] = {
ph.source_tool for ph in self.graph.phenomena.values() if ph.source_tool
}
def _check_coverage(self, areas: list[dict]) -> set[str]:
# Layer 1: keyword matching on category + title + description
evidence_text = ""
for ph in self.graph.phenomena.values():
evidence_text += f" {ph.category} {ph.title} {ph.description} ".lower()
# Layer 2: collect all source_tools that produced phenomena
used_tools: set[str] = {ph.source_tool for ph in self.graph.phenomena.values() if ph.source_tool}
covered = set()
for area in areas:
area_name = area["area"]
# Check keywords
keywords = self._AREA_KEYWORDS.get(area_name, [])
if any(kw in evidence_text for kw in keywords):
covered.add(area_name)
covered: set[str] = set()
for a in self.graph.investigation_areas.values():
if any(kw.lower() in evidence_text for kw in a.expected_keywords):
covered.add(a.area)
continue
# Check source_tool
area_tools = self._AREA_TOOLS.get(area_name, [])
if any(tool in used_tools for tool in area_tools):
covered.add(area_name)
if any(t in used_tools for t in a.expected_tools):
covered.add(a.area)
return covered
async def _run_gap_analysis(self) -> None:
areas = self.config.get("investigation_areas", [])
areas = list(self.graph.investigation_areas.values())
if not areas:
return
covered = self._check_coverage(areas)
uncovered = [a for a in areas if a["area"] not in covered]
covered = self._check_coverage()
uncovered = [a for a in areas if a.area not in covered]
if not uncovered:
_log(f"All {len(areas)} investigation areas covered", event="progress")
return
uncovered_names = ", ".join(a["area"] for a in uncovered)
_log(f"{len(uncovered)}/{len(areas)} areas uncovered: {uncovered_names}", event="dispatch")
for area in uncovered:
uncovered_names = ", ".join(a.area for a in uncovered)
_log(
f"{len(uncovered)}/{len(areas)} areas uncovered: {uncovered_names}",
event="dispatch",
)
for a in uncovered:
await self.graph.add_lead(
target_agent=area["agent"],
description=area["task"],
priority=3,
target_agent=a.suggested_agent,
description=a.description,
priority=a.priority,
hypothesis_id=(
a.motivating_hypothesis_ids[0]
if a.motivating_hypothesis_ids else None
),
)
for round_num in range(3):
@@ -578,6 +706,15 @@ class Orchestrator:
json.dumps(leads_data, ensure_ascii=False, indent=2)
)
# Investigation areas export
areas_data = {
aid: a.to_dict()
for aid, a in self.graph.investigation_areas.items()
}
(self.run_dir / "investigation_areas.json").write_text(
json.dumps(areas_data, ensure_ascii=False, indent=2)
)
# Run metadata
end_time = datetime.now()
metadata = {
@@ -637,6 +774,11 @@ class Orchestrator:
if resume_phase <= 2:
_log("Phase 2: Hypothesis Generation", event="phase")
t0 = time.monotonic()
# Seed manual investigation areas (if any) BEFORE LLM derive,
# so manual entries win the dedupe and LLM only augments.
await self._seed_manual_investigation_areas()
manual_hypotheses = self.config.get("hypotheses", [])
if manual_hypotheses:
await self._generate_hypotheses_manual(manual_hypotheses)
@@ -647,10 +789,17 @@ class Orchestrator:
if self.graph.phenomena and self.graph.hypotheses:
await self._judge_new_phenomena()
# Derive investigation areas from active hypotheses.
# No-op if manual seed already populated or resume restored areas.
await self._derive_investigation_areas()
for h in self.graph.hypotheses.values():
_log(f" {h.summary()}", event="hypothesis")
for a in self.graph.investigation_areas.values():
_log(f" {a.summary()}", event="area")
_log(
f"+{len(self.graph.hypotheses)} hypotheses generated",
f"+{len(self.graph.hypotheses)} hypotheses, "
f"{len(self.graph.investigation_areas)} areas",
event="progress", elapsed=time.monotonic() - t0,
)

View File

@@ -14,7 +14,6 @@ from evidence_graph import (
from llm_client import (
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
TOOL_RESULT_TAG, TOOL_RESULT_END,
)
from tool_registry import (
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
@@ -598,7 +597,8 @@ class TestParallelToolExecution:
elapsed = time.monotonic() - start
assert len(results) == 3
assert all("ok" in r for r in results)
# results are (raw, formatted) tuples; both contain "ok"
assert all("ok" in raw and "ok" in formatted for raw, formatted in results)
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
@@ -627,8 +627,10 @@ class TestParallelToolExecution:
results = await client._execute_tool_batch_parallel(calls, tool_executor)
assert len(results) == 2
assert "success" in results[0]
assert "Error" in results[1]
raw0, formatted0 = results[0]
raw1, formatted1 = results[1]
assert "success" in raw0 and "success" in formatted0
assert "Error" in raw1 and "Error" in formatted1
# ---------------------------------------------------------------------------
@@ -843,21 +845,27 @@ class TestToolResultCache:
class TestProgressiveDecay:
def _make_messages(self, n_rounds: int) -> list[dict]:
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
"""Build a synthetic message list shaped like native tool calling:
user → (assistant w/ tool_calls → tool result)+
"""
messages = [{"role": "user", "content": "Start task"}]
for i in range(n_rounds):
tc_id = f"call_{i}"
messages.append({
"role": "assistant",
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
"content": None,
"tool_calls": [
{
"id": tc_id,
"type": "function",
"function": {"name": f"tool_{i}", "arguments": "{}"},
},
],
})
# Tool result message with substantial content
messages.append({
"role": "user",
"content": (
f"{TOOL_RESULT_TAG}\n"
f"[tool_{i}] {'x' * 2500}\n"
f"{TOOL_RESULT_END}"
),
"role": "tool",
"tool_call_id": tc_id,
"content": f"[tool_{i}] {'x' * 2500}",
})
return messages
@@ -865,21 +873,16 @@ class TestProgressiveDecay:
msgs = self._make_messages(3)
result = _apply_progressive_decay(msgs)
assert len(result) == len(msgs)
# Content should be identical for short conversations
for orig, decayed in zip(msgs, result):
assert orig["content"] == decayed["content"]
assert orig.get("content") == decayed.get("content")
def test_old_messages_truncated(self):
msgs = self._make_messages(20)
result = _apply_progressive_decay(msgs)
# Recent tool results should be full length
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
assert len(last_tool_result["content"]) > 2000
# Oldest tool results should be truncated
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
assert len(first_tool_result["content"]) < 500
tool_msgs = [m for m in result if m.get("role") == "tool"]
assert len(tool_msgs[-1]["content"]) > 2000
assert len(tool_msgs[0]["content"]) < 500
def test_message_count_preserved(self):
msgs = self._make_messages(20)
@@ -915,7 +918,7 @@ class TestMessageFolding:
messages.append({"role": "assistant", "content": f"thinking step {i}"})
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
result = await client._fold_old_messages(messages, "system prompt")
result = await client._fold_old_messages(messages)
# Should be significantly shorter
assert len(result) < len(messages)
@@ -937,7 +940,7 @@ class TestMessageFolding:
{"role": "assistant", "content": "hi"},
]
result = await client._fold_old_messages(messages, "system")
result = await client._fold_old_messages(messages)
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
assert result == messages
client.chat.assert_not_called()
@@ -951,7 +954,555 @@ class TestMessageFolding:
client.chat = AsyncMock(side_effect=Exception("API error"))
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
result = await client._fold_old_messages(messages, "system")
result = await client._fold_old_messages(messages)
# On failure, should return original messages unchanged
assert len(result) == 40
@pytest.mark.asyncio
async def test_fold_boundary_never_orphans_tool_message(self):
"""If the natural fold boundary would leave `role: "tool"` at the
head of `recent_messages`, fold must walk the boundary forward
until the head is non-tool. The API rejects orphan tool messages
with HTTP 400."""
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock(return_value="summary")
# Build a long conversation of (assistant{tool_calls}, tool) pairs.
# Place the assistant at the exact n_to_fold boundary so its paired
# tool would otherwise be orphaned at the head of recent_messages.
msgs: list[dict] = [{"role": "user", "content": "task"}]
for i in range(30):
tc_id = f"call_{i}"
msgs.append({
"role": "assistant", "content": None,
"tool_calls": [{
"id": tc_id, "type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}],
})
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": "ok"})
result = await client._fold_old_messages(msgs)
# No `role: "tool"` may appear without an `assistant{tool_calls}`
# immediately preceding it.
for i, m in enumerate(result):
if m.get("role") == "tool":
assert i > 0, "tool message cannot be first"
prev = result[i - 1]
assert prev.get("role") == "assistant" and prev.get("tool_calls"), (
f"tool at index {i} preceded by {prev.get('role')} "
f"(tool_calls={bool(prev.get('tool_calls'))})"
)
# ---------------------------------------------------------------------------
# Investigation areas: dataclass + derivation + coverage + dispatch
# ---------------------------------------------------------------------------
class TestInvestigationAreaDerivation:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_add_investigation_area_dedupes_and_merges_lists(self, graph):
aid1, existed1 = await graph.add_investigation_area(
area="password_hashes", description="SAM dump",
suggested_agent="filesystem",
expected_keywords=["sam", "pwdump"],
expected_tools=["search_strings"],
motivating_hypothesis_ids=["hyp-a"],
created_by="llm_derive",
)
aid2, existed2 = await graph.add_investigation_area(
area="password_hashes", description="Other description",
suggested_agent="registry",
expected_keywords=["sam", "hashdump"], # one new
expected_tools=["search_strings", "parse_registry_key"], # one new
motivating_hypothesis_ids=["hyp-b"], # new
created_by="manual",
)
assert not existed1
assert existed2
assert aid1 == aid2
a = graph.investigation_areas[aid1]
# Description and suggested_agent NOT overwritten (first-write wins)
assert a.description == "SAM dump"
assert a.suggested_agent == "filesystem"
# Three list fields are unioned
assert set(a.expected_keywords) == {"sam", "pwdump", "hashdump"}
assert set(a.expected_tools) == {"search_strings", "parse_registry_key"}
assert set(a.motivating_hypothesis_ids) == {"hyp-a", "hyp-b"}
@pytest.mark.asyncio
async def test_check_coverage_keyword_layer(self, graph):
await graph.add_phenomenon(
"fs", "filesystem", "Cain SAM dump artifact",
"Found sam.lst in the Cain folder",
source_tool="list_directory",
)
await graph.add_investigation_area(
area="password_hashes", description="SAM dump",
suggested_agent="filesystem",
expected_keywords=["sam.lst", "pwdump"],
expected_tools=["nonexistent_tool"],
)
from orchestrator import Orchestrator
from agent_factory import AgentFactory
from unittest.mock import AsyncMock
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
covered = orch._check_coverage()
assert "password_hashes" in covered
@pytest.mark.asyncio
async def test_check_coverage_tool_layer(self, graph):
await graph.add_phenomenon(
"reg", "registry", "User accounts",
"Found accounts",
source_tool="enumerate_users",
)
await graph.add_investigation_area(
area="user_accounts", description="Enum users",
suggested_agent="registry",
expected_keywords=["irrelevant"],
expected_tools=["enumerate_users"],
)
from orchestrator import Orchestrator
from agent_factory import AgentFactory
from unittest.mock import AsyncMock
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
covered = orch._check_coverage()
assert "user_accounts" in covered
@pytest.mark.asyncio
async def test_load_state_round_trip_areas(self, graph, tmp_path):
await graph.add_investigation_area(
area="x", description="d", suggested_agent="filesystem",
expected_keywords=["k1"], expected_tools=["t1"],
priority=2, motivating_hypothesis_ids=["hyp-a"],
created_by="manual",
)
path = tmp_path / "state.json"
graph.save_state(path)
g2 = EvidenceGraph.load_state(path)
assert len(g2.investigation_areas) == 1
a = list(g2.investigation_areas.values())[0]
assert a.area == "x"
assert a.expected_keywords == ["k1"]
assert a.priority == 2
assert a.motivating_hypothesis_ids == ["hyp-a"]
assert a.created_by == "manual"
@pytest.mark.asyncio
async def test_derive_no_op_when_areas_already_populated(self, graph):
"""Resume safety: if areas are already in the graph (manual seed or
restored from disk), _derive_investigation_areas does nothing."""
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
await graph.add_hypothesis("test", "desc", created_by="t")
await graph.add_investigation_area(
area="pre_existing", description="d", suggested_agent="filesystem",
created_by="manual",
)
llm = AsyncMock()
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
await orch._derive_investigation_areas()
# LLM should not have been called
assert llm.chat.call_count == 0
# Area count unchanged
assert len(graph.investigation_areas) == 1
@pytest.mark.asyncio
async def test_fallback_when_llm_returns_empty_list(self, graph):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
await graph.add_hypothesis("Some compromise", "desc", created_by="t")
llm = AsyncMock()
llm.chat.return_value = "[]"
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
await orch._derive_investigation_areas()
# Fallback creates one area per hypothesis
assert len(graph.investigation_areas) == 1
a = list(graph.investigation_areas.values())[0]
assert a.created_by == "fallback"
@pytest.mark.asyncio
async def test_unknown_tool_filtered_kept_keywords(self, graph):
"""LLM emits a tool name not in TOOL_CATALOG; tool is filtered,
but the area itself with its keywords is kept."""
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
h = await graph.add_hypothesis("h", "desc", created_by="t")
llm = AsyncMock()
llm.chat.return_value = (
'[{"area":"foo","description":"desc","suggested_agent":"filesystem",'
'"expected_keywords":["kw1","kw2"],"expected_tools":["nonexistent_tool"],'
f'"priority":2,"motivating_hypothesis_ids":["{h}"]}}]'
)
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
await orch._derive_investigation_areas()
assert "area-foo" in graph.investigation_areas
a = graph.investigation_areas["area-foo"]
assert a.expected_keywords == ["kw1", "kw2"]
assert a.expected_tools == [] # filtered out
@pytest.mark.asyncio
async def test_unknown_agent_resolved_via_AGENT_ALIASES(self, graph):
"""LLM emits 'chat' (which is in AGENT_ALIASES → 'communication').
The area should land with the resolved agent name."""
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
h = await graph.add_hypothesis("h", "desc", created_by="t")
llm = AsyncMock()
llm.chat.return_value = (
'[{"area":"chat_stuff","description":"d","suggested_agent":"chat",'
'"expected_keywords":["irc"],"expected_tools":[],'
f'"priority":3,"motivating_hypothesis_ids":["{h}"]}}]'
)
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
await orch._derive_investigation_areas()
a = graph.investigation_areas["area-chat_stuff"]
assert a.suggested_agent == "communication"
@staticmethod
def _agent_with_executor(graph, llm, tool_name: str, real_executor):
"""Build a BaseAgent that registers tool_name via the real register_tool
path so the mandatory-record wrapper is engaged."""
from base_agent import BaseAgent
agent = BaseAgent(llm, graph)
agent.name = "test_agent"
# Bypass _register_graph_tools side-effects in run() — we register
# only what the test needs.
agent._register_graph_tools = lambda: None
agent.register_tool(
name=tool_name, description="", input_schema={}, executor=real_executor,
)
return agent
@pytest.mark.asyncio
async def test_forced_record_retry_fires_when_zero_phenomena(self):
"""BaseAgent.run should automatically retry one more LLM round if
the agent finished without calling any mandatory recording tool."""
from unittest.mock import AsyncMock
graph = EvidenceGraph()
llm = AsyncMock()
async def real_add(**kw):
await graph.add_phenomenon(
source_agent="test", category="filesystem",
title="Forced retry record",
description="Recorded after STOP prompt",
source_tool="forced_retry",
)
agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add)
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
already_retrying = any(
"STOP." in (m.get("content", "") if isinstance(m, dict) else "")
for m in messages
)
if not already_retrying:
return "Final answer without recording.", list(messages) + [
{"role": "assistant", "content": "Final answer without recording."}
]
await tool_executor["add_phenomenon"]() # goes through wrapper
return "Recorded.", []
llm.tool_call_loop = fake_tool_call_loop
await agent.run("test task")
assert len(graph.phenomena) == 1
assert agent._record_call_counts["add_phenomenon"] == 1
@pytest.mark.asyncio
async def test_no_retry_when_mandatory_tool_was_called(self):
"""Retry should NOT fire if a mandatory recording tool was invoked."""
from unittest.mock import AsyncMock
graph = EvidenceGraph()
llm = AsyncMock()
call_count = {"n": 0}
async def real_add(**kw):
await graph.add_phenomenon(
source_agent="test", category="filesystem", title="x",
description="y", source_tool="t",
)
agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add)
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
call_count["n"] += 1
await tool_executor["add_phenomenon"]() # wrapper increments count
return "done.", list(messages)
llm.tool_call_loop = fake_tool_call_loop
await agent.run("test task")
assert call_count["n"] == 1 # no retry
@pytest.mark.asyncio
async def test_no_retry_when_mandatory_tools_empty(self):
"""ReportAgent declares mandatory_record_tools=() — retry should
not fire even with zero graph mutations (final text IS the output)."""
from unittest.mock import AsyncMock
from base_agent import BaseAgent
graph = EvidenceGraph()
llm = AsyncMock()
call_count = {"n": 0}
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
call_count["n"] += 1
return "report body here", list(messages)
llm.tool_call_loop = fake_tool_call_loop
class ReportLike(BaseAgent):
mandatory_record_tools = ()
agent = ReportLike(llm, graph)
agent.name = "report_like"
agent._register_graph_tools = lambda: None
await agent.run("test task")
assert call_count["n"] == 1
@pytest.mark.asyncio
async def test_forced_retry_fires_for_timeline_agent(self):
"""TimelineAgent.mandatory_record_tools=('add_temporal_edge',) — retry
should fire when timeline finishes without creating any temporal edges,
even though the agent does not have add_phenomenon."""
from unittest.mock import AsyncMock
from base_agent import BaseAgent
graph = EvidenceGraph()
llm = AsyncMock()
call_count = {"n": 0}
edge_added = {"n": 0}
async def real_add_edge(**kw):
edge_added["n"] += 1
class TimelineLike(BaseAgent):
mandatory_record_tools = ("add_temporal_edge",)
agent = TimelineLike(llm, graph)
agent.name = "timeline_like"
agent._register_graph_tools = lambda: None
agent.register_tool("add_temporal_edge", "", {}, real_add_edge)
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
call_count["n"] += 1
already_retrying = any(
"STOP." in (m.get("content", "") if isinstance(m, dict) else "")
for m in messages
)
if not already_retrying:
return "answer", list(messages)
await tool_executor["add_temporal_edge"]()
return "recorded.", []
llm.tool_call_loop = fake_tool_call_loop
await agent.run("build timeline")
assert call_count["n"] == 2 # first + retry
assert edge_added["n"] == 1
assert agent._record_call_counts["add_temporal_edge"] == 1
# ---- terminal_tools: real LLMClient.tool_call_loop short-circuit -----
@pytest.mark.asyncio
async def test_terminal_tool_exits_loop_immediately(self):
"""When a terminal tool is called, tool_call_loop must return
with that tool's result text as final_text — no further LLM calls."""
from unittest.mock import AsyncMock
from llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
client.max_tokens = 4096
client.reasoning_effort = None
client.thinking_enabled = False
client.model = "test"
client._client = None
call_count = {"n": 0}
async def fake_chat_with_tools(messages, openai_tools):
call_count["n"] += 1
if call_count["n"] == 1:
# First turn: model calls a read tool then the terminal tool.
return "thinking aloud", None, [
{"id": "tc1", "name": "read_tool", "arguments": "{}"},
{"id": "tc2", "name": "save_report",
"arguments": '{"content":"FINAL REPORT BODY","output_path":"r.md"}'},
]
raise AssertionError("loop should have exited after terminal tool")
client._chat_with_tools = fake_chat_with_tools
async def read_tool():
return "some data"
async def save_report(content, output_path):
return content # terminal tool returns content as final_text
tools = [
{"name": "read_tool", "description": "", "input_schema": {"type": "object", "properties": {}}},
{"name": "save_report", "description": "",
"input_schema": {"type": "object", "properties": {
"content": {"type": "string"}, "output_path": {"type": "string"}}}},
]
executors = {"read_tool": read_tool, "save_report": save_report}
final_text, _ = await client.tool_call_loop(
messages=[{"role": "user", "content": "go"}],
tools=tools, tool_executor=executors,
system="sys", terminal_tools=("save_report",),
)
assert final_text == "FINAL REPORT BODY"
assert call_count["n"] == 1 # never called a 2nd round
@pytest.mark.asyncio
async def test_no_terminal_short_circuit_when_not_declared(self):
"""When terminal_tools is empty, the same call sequence should
run the read tool, run save_report-like tool, AND continue the loop
(i.e. another LLM round) until the model stops calling tools."""
from unittest.mock import AsyncMock
from llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
client.max_tokens = 4096
client.reasoning_effort = None
client.thinking_enabled = False
client.model = "test"
client._client = None
call_count = {"n": 0}
async def fake_chat_with_tools(messages, openai_tools):
call_count["n"] += 1
if call_count["n"] == 1:
return "", None, [
{"id": "tc1", "name": "add_phenomenon",
"arguments": '{"title":"x","description":"y"}'},
]
return "all done", None, [] # model stops calling tools
client._chat_with_tools = fake_chat_with_tools
async def add_phenomenon(title, description):
return f"recorded {title}"
tools = [
{"name": "add_phenomenon", "description": "",
"input_schema": {"type": "object", "properties": {
"title": {"type": "string"}, "description": {"type": "string"}}}},
]
executors = {"add_phenomenon": add_phenomenon}
final_text, _ = await client.tool_call_loop(
messages=[{"role": "user", "content": "go"}],
tools=tools, tool_executor=executors,
system="sys", terminal_tools=(), # NOT terminal
)
assert final_text == "all done"
assert call_count["n"] == 2 # 2 rounds — terminal_tools empty, loop continues
@pytest.mark.asyncio
async def test_report_agent_terminal_tool_declared(self):
"""ReportAgent should declare save_report as both mandatory and terminal."""
from agents.report import ReportAgent
assert ReportAgent.terminal_tools == ("save_report",)
assert ReportAgent.mandatory_record_tools == ("save_report",)
@pytest.mark.asyncio
async def test_terminal_tool_result_not_truncated(self):
"""Terminal tool's raw return is used as final_text and must NOT
be truncated to 3000 chars (the truncation cap applies only to
LLM-context tool result messages). A 20K-char markdown report
passed through save_report should reach the caller intact."""
from llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
client.max_tokens = 4096
client.reasoning_effort = None
client.thinking_enabled = False
client.model = "test"
client._client = None
long_report = "# Report\n" + ("- finding " + "x" * 100 + "\n") * 200
assert len(long_report) > 10000
call_count = {"n": 0}
async def fake_chat_with_tools(messages, openai_tools):
call_count["n"] += 1
if call_count["n"] == 1:
return "", None, [
{"id": "tc1", "name": "save_report",
"arguments": '{"content":"placeholder","output_path":"r.md"}'},
]
raise AssertionError("loop should have exited")
client._chat_with_tools = fake_chat_with_tools
async def save_report(content, output_path):
return long_report # ignore content arg; return long content
tools = [{"name": "save_report", "description": "",
"input_schema": {"type": "object", "properties": {
"content": {"type": "string"}, "output_path": {"type": "string"}}}}]
executors = {"save_report": save_report}
final_text, _ = await client.tool_call_loop(
messages=[{"role": "user", "content": "go"}],
tools=tools, tool_executor=executors,
system="sys", terminal_tools=("save_report",),
)
assert final_text == long_report
assert len(final_text) > 10000 # not truncated to 3000
@pytest.mark.asyncio
async def test_dispatch_uses_hypothesis_id_when_motivating_ids_present(self, graph):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
h = await graph.add_hypothesis("h", "desc", created_by="t")
await graph.add_investigation_area(
area="uncovered_area", description="d", suggested_agent="registry",
expected_keywords=["xyz_no_match"], expected_tools=[],
priority=2, motivating_hypothesis_ids=[h],
created_by="llm_derive",
)
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
# Don't actually dispatch (would call agents) — just hit the lead-add path
# by manually replicating what _run_gap_analysis does.
covered = orch._check_coverage()
assert "uncovered_area" not in covered
# Simulate dispatch
for a in graph.investigation_areas.values():
if a.area not in covered:
await graph.add_lead(
target_agent=a.suggested_agent,
description=a.description,
priority=a.priority,
hypothesis_id=(a.motivating_hypothesis_ids[0]
if a.motivating_hypothesis_ids else None),
)
assert len(graph.leads) == 1
assert graph.leads[0].hypothesis_id == h
assert graph.leads[0].target_agent == "registry"
assert graph.leads[0].priority == 2