From 444d58726a375dd4ffe272fdbe07d54b118e186c Mon Sep 17 00:00:00 2001 From: BattleTag Date: Wed, 13 May 2026 13:51:19 +0800 Subject: [PATCH] refactor: native tool calling + generic forced-retry + terminal exit - llm_client: switch tool_call_loop from text-based regex to OpenAI-native tools=[...] / structured tool_calls field; accumulate delta.reasoning_content for DeepSeek thinking-mode echo-back; fold preserves system msg and aligns boundary to never orphan role:tool - base_agent: generic forced-retry via mandatory_record_tools class attr (filesystem -> add_phenomenon, timeline -> add_temporal_edge, hypothesis -> add_hypothesis, report -> save_report); count via executor wrapper - terminal_tools class attr + loop short-circuit: when a terminal tool is called, loop exits with its raw return as final_text. ReportAgent declares save_report as terminal - replaces the -tag stop signal that native tool calling broke - _execute_*: return (raw, formatted) - terminal exit uses untruncated raw, conversation history uses 3000-char-capped formatted - evidence_graph + orchestrator: LLM-derived InvestigationArea support (hypothesis-driven coverage check, replaces hardcoded _AREA_KEYWORDS / _AREA_TOOLS); manual yaml block kept as optional seed - strip references from agent prompts (no longer load-bearing) Verified on CFReDS image across 4 smoke runs: 0 JSON parse failures (was 3); 22 temporal edges from Phase 4 (was 0); ReportAgent exits via save_report (was max_iterations regression). 78/78 unit tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 25 +- agents/hypothesis.py | 3 +- agents/report.py | 57 ++-- agents/timeline.py | 3 +- base_agent.py | 94 +++++- evidence_graph.py | 98 ++++++ llm_client.py | 508 ++++++++++++++++++++---------- orchestrator.py | 265 ++++++++++++---- tests/test_optimizations.py | 601 ++++++++++++++++++++++++++++++++++-- 9 files changed, 1356 insertions(+), 298 deletions(-) diff --git a/README.md b/README.md index 30c8538..c45b48c 100644 --- a/README.md +++ b/README.md @@ -67,14 +67,18 @@ Phenomenon → Hypothesis 的边类型与权重写死在 `HYPOTHESIS_EDGE_WEIGHT | **Phase 4** | TimelineAgent 用 `build_filesystem_timeline` 生成 MAC 时间线,与 Phenomenon 时间戳关联 | | **Phase 5** | ReportAgent 综合假设、证据、实体,生成 Markdown 报告 | -### Gap Analysis(Phase 3 末) +### Investigation Areas(hypothesis-derived) -`config.yaml:investigation_areas` 列出必须覆盖的调查领域(系统信息、用户账户、网络配置、邮件配置、IRC 日志、PCAP、删除文件、Prefetch 等)。Orchestrator 两层判定覆盖情况: +Phase 2 末尾 orchestrator 调一次 LLM 从所有 active hypothesis 派生 5-12 个 **InvestigationArea**(snake_case slug、description、suggested_agent、expected_keywords、expected_tools、priority、motivating_hypothesis_ids)。Areas 存进 `graph.investigation_areas`,序列化到 `runs//investigation_areas.json`。两个用途: -1. **关键词匹配**(`_AREA_KEYWORDS`)— 扫现有 Phenomenon 标题/描述 -2. **工具命中**(`_AREA_TOOLS`)— 检查是否调用过该领域的关键工具(如 `enumerate_users`、`parse_pcap_strings`) +1. **Phase 3 主循环提示** — 每个 hypothesis 块附 `Expected areas: a, b, c`,LLM 仍自由选 lead 但有软引导 +2. **Phase 3 末尾 Gap Analysis** — 两层判定覆盖情况: + - **关键词匹配**:扫 Phenomenon 标题/描述对照 area.expected_keywords + - **工具命中**:检查 area.expected_tools 是否实际调用过 -未覆盖的领域自动派发 lead,最多 3 轮补漏。 +未覆盖的 area 自动派 lead(`suggested_agent` + `priority` + `motivating_hypothesis_ids[0]` 透传给 `Lead.hypothesis_id` 保留 provenance),最多 3 轮补漏。 + +**手动 override**:`config.yaml:investigation_areas` 默认注释掉,纯 LLM 派生。取消注释可添加强制必查的领域,会先于 LLM 写入并通过 slug-based dedupe 保护不被覆盖(LLM 只会 augment keyword/tool 列表)。这是跨案件/跨平台适配的关键 —— 不再 hardcode Windows-specific 领域。 ## Agent 体系 @@ -183,11 +187,12 @@ max_investigation_rounds: 5 # Phase 3 最大迭代轮数 # - title: "嫌疑人主动实施网络嗅探" # description: "..." -investigation_areas: # Gap Analysis 必须覆盖的领域 - - area: system_info - agent: registry - task: "..." - # ... +# investigation_areas: # 可选:手动 override(默认全 LLM 派生) +# - area: shutdown_time # LLM 通过 slug dedupe 只 augment +# agent: registry # keyword/tool 列表,不覆盖 manual +# priority: 3 +# keywords: [shutdown] +# tools: [get_shutdown_time] ``` 未配置 `hypotheses` 时由 HypothesisAgent 自动生成。 diff --git a/agents/hypothesis.py b/agents/hypothesis.py index 6284d8f..6c4bb56 100644 --- a/agents/hypothesis.py +++ b/agents/hypothesis.py @@ -24,6 +24,7 @@ class HypothesisAgent(BaseAgent): "and formulate investigative hypotheses about what happened on this system. " "Your ultimate goal: build the most complete picture of events that occurred." ) + mandatory_record_tools = ("add_hypothesis",) def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None: super().__init__(llm, graph) @@ -68,7 +69,7 @@ class HypothesisAgent(BaseAgent): f"WORKFLOW:\n" f"1. Call list_phenomena and search_graph to review existing findings.\n" f"2. For each hypothesis you want to record, call add_hypothesis (title + description).\n" - f"3. Wrap a short summary in when you have generated 3-7 hypotheses.\n\n" + f"3. STOP after you have generated 3-7 hypotheses. Do not call any more tools.\n\n" f"STRICT BOUNDARIES:\n" f"- Your only mutation tool is add_hypothesis. Do NOT attempt list_directory, " f"parse_registry_key, extract_file, or any disk-image investigation tools — " diff --git a/agents/report.py b/agents/report.py index 0ce0b15..2df76ea 100644 --- a/agents/report.py +++ b/agents/report.py @@ -2,9 +2,6 @@ from __future__ import annotations -import json -import os - from base_agent import BaseAgent from evidence_graph import EvidenceGraph from llm_client import LLMClient @@ -15,11 +12,16 @@ class ReportAgent(BaseAgent): role = ( "Forensic report writer. You synthesize all findings from the investigation " "into a structured, professional forensic analysis report organized by hypotheses.\n\n" - "IMPORTANT: Only include findings that have a source_tool attribution (marked VERIFIED). " + "Only include findings that have a source_tool attribution (marked VERIFIED). " "If evidence lacks source attribution, mark it as UNVERIFIED. " - "Do NOT invent or fabricate any data, timestamps, or findings not present in the evidence.\n\n" - "CRITICAL: You MUST call save_report to write the final report." + "Do NOT invent or fabricate any data, timestamps, or findings not present in the evidence." ) + # Calling save_report is BOTH the recording action and the completion + # signal. tool_call_loop returns the moment save_report executes; the + # tool's return value becomes the agent's final_text. The forced-retry + # mechanism fires if save_report is never called. + mandatory_record_tools = ("save_report",) + terminal_tools = ("save_report",) def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None: super().__init__(llm, graph) @@ -30,23 +32,26 @@ class ReportAgent(BaseAgent): self._register_graph_read_tools() def _build_system_prompt(self, task: str) -> str: - """Report agent gets a clean prompt — no Phase A/B/C/D workflow.""" return ( f"You are a forensic report writer.\n" f"Role: {self.role}\n\n" f"Investigation state:\n{self.graph.stats_summary()}\n\n" f"Your task: {task}\n\n" f"WORKFLOW:\n" - f"1. Call get_hypotheses_with_evidence to get all hypotheses and their linked evidence\n" - f"2. Call get_all_phenomena to get detailed findings by category\n" - f"3. Call get_entities to get people, programs, and hosts\n" - f"4. Call get_case_info for case metadata\n" - f"5. Write the complete report directly in your block\n\n" + f"1. Call get_hypotheses_with_evidence, get_all_phenomena, get_entities, get_case_info " + f" to gather all the data needed for the report. Make these calls in parallel.\n" + f"2. Assemble the complete markdown forensic report.\n" + f"3. Call save_report(content=, output_path=\"report.md\").\n" + f" This single call is the completion signal — the run ENDS the moment it executes.\n" + f" Do NOT call any read tools after this point; they will not run.\n" + f" Do NOT write the report as free text outside of save_report; only the\n" + f" `content` argument of save_report is persisted.\n\n" f"RULES:\n" - f"- Write the report DIRECTLY in — do NOT use save_report tool\n" - f"- Only include findings present in the evidence graph\n" - f"- Do NOT invent timestamps, file paths, or data not in the phenomena\n" - f"- The report must be complete — do not cut off mid-section\n" + f"- The report must be the complete markdown — do not cut off mid-section.\n" + f"- Only include findings present in the evidence graph.\n" + f"- Do NOT invent timestamps, file paths, or data not in the phenomena.\n" + f"- The `content` argument can be 10K+ chars. JSON-escape inner quotes (\\\") and\n" + f" backslashes (\\\\) and newlines (\\n) correctly.\n" ) def _register_tools(self) -> None: @@ -186,10 +191,16 @@ class ReportAgent(BaseAgent): return "\n".join(lines) async def _save_report(self, content: str, output_path: str) -> str: - try: - os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) - with open(output_path, "w") as f: - f.write(content) - return f"Report saved to {output_path} ({len(content)} chars)" - except Exception as e: - return f"Error saving report: {e}" + """Save the report and return the content itself. + + The content is returned (rather than a "saved to ..." status string) + so that when tool_call_loop short-circuits on this terminal tool, + `final_text` is the full markdown — orchestrator writes it to the + canonical report.md path under runs//. + + The output_path argument is kept for backward compat but the model's + chosen path is ignored — the orchestrator owns the persistence path. + """ + if not content: + return "" + return content diff --git a/agents/timeline.py b/agents/timeline.py index 752fb33..8efb955 100644 --- a/agents/timeline.py +++ b/agents/timeline.py @@ -24,6 +24,7 @@ class TimelineAgent(BaseAgent): "MAC timestamps and correlate events across all phenomena categories in the " "evidence graph to reconstruct the sequence of activities on the system." ) + mandatory_record_tools = ("add_temporal_edge",) def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None: super().__init__(llm, graph) @@ -95,7 +96,7 @@ class TimelineAgent(BaseAgent): f" - 'Tool installation' (before) 'Tool execution'\n" f"4. Aim for 15-40 temporal edges that connect the major events into a " f"forensic story.\n" - f"5. Wrap a short summary in when done.\n\n" + f"5. STOP after recording all meaningful temporal edges. Do not call any more tools.\n\n" f"STRICT BOUNDARIES:\n" f"- Your job is to CONNECT existing phenomena, NOT to discover new ones. " f"You CANNOT call add_phenomenon — the tool isn't yours.\n" diff --git a/base_agent.py b/base_agent.py index 41691bb..5357fe7 100644 --- a/base_agent.py +++ b/base_agent.py @@ -31,11 +31,26 @@ class BaseAgent: name: str = "base" role: str = "A forensic analysis agent." + # Tools the agent MUST invoke at least once for the run to count as productive. + # If none of these were called when tool_call_loop returns, run() fires a + # forced retry with an explicit "you forgot to record" instruction. + # Subclasses override to declare their own recording responsibility + # (timeline → add_temporal_edge, hypothesis → add_hypothesis, report → save_report). + mandatory_record_tools: tuple[str, ...] = ("add_phenomenon",) + + # Tools whose invocation ends the run immediately. After any terminal tool + # is called, tool_call_loop returns with that tool's result text as + # final_text. Used by agents whose "completion" is a single explicit + # action rather than "model decides to stop calling tools". For multi-call + # agents (filesystem records many phenomena) leave empty. + terminal_tools: tuple[str, ...] = () + def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None: self.llm = llm self.graph = graph self._tools: dict[str, dict] = {} # name -> schema self._executors: dict[str, Any] = {} # name -> async callable + self._record_call_counts: dict[str, int] = {} self._work_log: list[str] = [] self._current_lead_id: str | None = None @@ -52,7 +67,18 @@ class BaseAgent: "description": description, "input_schema": input_schema, } - self._executors[name] = executor + if name in self.mandatory_record_tools: + self._executors[name] = self._wrap_record_executor(name, executor) + else: + self._executors[name] = executor + + def _wrap_record_executor(self, name: str, executor: Any) -> Any: + """Wrap a mandatory-record executor to count successful invocations.""" + async def wrapped(*args, **kwargs): + result = await executor(*args, **kwargs) + self._record_call_counts[name] = self._record_call_counts.get(name, 0) + 1 + return result + return wrapped def get_tool_definitions(self) -> list[dict]: """Get tool definitions in Claude API format.""" @@ -91,20 +117,19 @@ class BaseAgent: f" FIRST call list_phenomena to get the current IDs — do NOT rely on memory.\n" f" Then call link_to_entity for each relevant phenomenon.\n" f" NEVER guess or fabricate a phenomenon ID. If an ID is not in list_phenomena output, it does not exist.\n\n" - f"Phase D — ANSWER:\n" - f" Only give your AFTER completing Phases B and C.\n\n" + f"Phase D — STOP:\n" + f" Once all phenomena are recorded and entities linked, you are DONE.\n" + f" Do not call any more tools. The orchestrator picks up automatically.\n\n" f"CRITICAL — RECORDING REQUIREMENT:\n" - f"- Your block is DISCARDED by the orchestrator. Only graph mutations propagate.\n" - f"- Other agents and the final report read ONLY the evidence graph " - f"(phenomena, entities, edges).\n" - f"- You MUST call add_phenomenon for EVERY significant finding BEFORE you end.\n" + f"- Only graph mutations propagate to other agents and the final report.\n" + f"- You MUST call add_phenomenon for EVERY significant finding BEFORE you stop.\n" f"- NEGATIVE findings count too. If you searched X (a directory, a pattern, " f"a registry key) and found NOTHING, that absence IS evidence — call " f"add_phenomenon with a 'No matches for X' title and the search scope in " f"raw_data. Negative findings constrain the hypothesis space and prevent " f"the next agent from wasting time re-searching.\n" - f"- If you produce without having called add_phenomenon at least once, " - f"the task is FAILED regardless of what you wrote in .\n" + f"- If you stop without having called add_phenomenon at least once, the task " + f"is FAILED and a forced retry will fire.\n" f"- Include exact file paths, inode numbers, timestamps, and the source_tool " f"that produced each finding.\n\n" f"ANTI-HALLUCINATION RULES — STRICTLY ENFORCED:\n" @@ -124,6 +149,7 @@ class BaseAgent: self._current_lead_id = lead_id self._register_graph_tools() + self._record_call_counts.clear() system = self._build_system_prompt(task) messages = [{"role": "user", "content": task}] @@ -132,12 +158,60 @@ class BaseAgent: ph_before = len(self.graph.phenomena) try: - final_text, _ = await self.llm.tool_call_loop( + final_text, conversation = await self.llm.tool_call_loop( messages=messages, tools=self.get_tool_definitions(), tool_executor=self._executors, system=system, + terminal_tools=self.terminal_tools, ) + + # Forced-record retry: if the agent has any mandatory recording + # tools but never invoked any of them, force one more round with + # an explicit "you forgot to record" instruction. The mandatory + # set is declared on the class — Timeline → add_temporal_edge, + # Hypothesis → add_hypothesis, ReportAgent → (). For agents with + # empty mandatory_record_tools this branch is a no-op. + registered_mandatory = [ + t for t in self.mandatory_record_tools if t in self._executors + ] + recorded_any = any( + self._record_call_counts.get(t, 0) > 0 + for t in registered_mandatory + ) + if registered_mandatory and not recorded_any: + missing = "/".join(registered_mandatory) + logger.warning( + "[%s] finished without calling any of [%s] — forcing RECORD retry", + self.name, missing, + ) + conversation.append({ + "role": "user", + "content": ( + f"STOP. You produced an answer without ever calling " + f"{missing}. Your answer is DISCARDED — only graph " + f"mutations propagate to other agents and the final " + f"report.\n\n" + f"You MUST now call {missing} for every significant " + f"finding from your prior investigation, including " + f"exact identifiers, timestamps, and the source_tool " + f"that produced each finding. If you genuinely found " + f"NOTHING noteworthy, call the recording tool ONCE " + f"with a 'No significant findings' style entry " + f"summarizing what you searched.\n\n" + f"Do not run more investigation tools. Just record " + f"what you already found. Then end." + ), + }) + final_text, _ = await self.llm.tool_call_loop( + messages=conversation, + tools=self.get_tool_definitions(), + tool_executor=self._executors, + system=system, + max_iterations=10, + terminal_tools=self.terminal_tools, + ) + self._work_log.append(f"[Task: {task[:80]}] -> {final_text[:150]}") except Exception: self.graph.agent_status[self.name] = "failed" diff --git a/evidence_graph.py b/evidence_graph.py index 99ac1cd..c1f9999 100644 --- a/evidence_graph.py +++ b/evidence_graph.py @@ -197,6 +197,41 @@ class Lead: return cls(**d) +@dataclass +class InvestigationArea: + """An area to investigate to confirm/refute one or more hypotheses. + + Derived by the orchestrator from active hypotheses after Phase 2; also + seeded from config.yaml:investigation_areas as an optional manual + override. Each area carries its own keywords + expected tools so the + gap-analysis coverage check is generic, not tied to hard-coded constants. + """ + + id: str # "area-{slug}" + area: str # snake_case slug (dedupe key) + description: str + suggested_agent: str # filesystem / registry / communication / network / timeline + expected_keywords: list[str] = field(default_factory=list) + expected_tools: list[str] = field(default_factory=list) + priority: int = 5 # 1 (highest) - 10 (lowest) + motivating_hypothesis_ids: list[str] = field(default_factory=list) + created_by: str = "" # "manual" | "llm_derive" | "fallback" + created_at: str = "" + + def to_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_dict(cls, d: dict) -> InvestigationArea: + return cls(**d) + + def summary(self) -> str: + return ( + f"[{self.area}] P{self.priority} agent={self.suggested_agent} " + f"(motivating: {len(self.motivating_hypothesis_ids)})" + ) + + @dataclass class ExtractedAsset: """A file extracted from the disk image and tracked in the asset library.""" @@ -270,6 +305,11 @@ class EvidenceGraph: self.asset_library: dict[str, ExtractedAsset] = {} self._inode_index: dict[str, str] = {} # inode → asset_id + # Investigation areas — derived from hypotheses (LLM) and/or seeded + # from config.yaml:investigation_areas (manual override). Drives the + # gap-analysis coverage check. + self.investigation_areas: dict[str, InvestigationArea] = {} + # Set by BaseAgent.run() before each agent execution self._current_agent: str = "" @@ -295,6 +335,9 @@ class EvidenceGraph: "leads": [l.to_dict() for l in self.leads], "agent_status": dict(self.agent_status), "asset_library": {aid: a.to_dict() for aid, a in self.asset_library.items()}, + "investigation_areas": { + aid: a.to_dict() for aid, a in self.investigation_areas.items() + }, "saved_at": datetime.now().isoformat(), } tmp = self._persist_path.with_suffix(".tmp") @@ -345,6 +388,10 @@ class EvidenceGraph: asset = ExtractedAsset.from_dict(a_data) graph.asset_library[aid] = asset graph._inode_index[asset.inode] = aid + graph.investigation_areas = { + aid: InvestigationArea.from_dict(a) + for aid, a in data.get("investigation_areas", {}).items() + } graph._rebuild_adjacency() logger.info( "EvidenceGraph restored: %d phenomena, %d hypotheses, %d entities, " @@ -656,6 +703,57 @@ class EvidenceGraph: break self._auto_save() + # ---- Investigation areas ------------------------------------------------- + + async def add_investigation_area( + self, + area: str, + description: str, + suggested_agent: str, + expected_keywords: list[str] | None = None, + expected_tools: list[str] | None = None, + priority: int = 5, + motivating_hypothesis_ids: list[str] | None = None, + created_by: str = "", + ) -> tuple[str, bool]: + """Add or merge an investigation area. Dedupe key is the `area` slug. + + On collision, union the three list fields (keywords / tools / + motivating_hypothesis_ids); description / suggested_agent / priority + are preserved from the first writer (manual seed wins over LLM derive). + Returns (id, was_existing). + """ + async with self._lock: + for existing in self.investigation_areas.values(): + if existing.area == area: + for kw in (expected_keywords or []): + if kw not in existing.expected_keywords: + existing.expected_keywords.append(kw) + for t in (expected_tools or []): + if t not in existing.expected_tools: + existing.expected_tools.append(t) + for hid in (motivating_hypothesis_ids or []): + if hid not in existing.motivating_hypothesis_ids: + existing.motivating_hypothesis_ids.append(hid) + self._auto_save() + return existing.id, True + + aid = f"area-{area}" + self.investigation_areas[aid] = InvestigationArea( + id=aid, + area=area, + description=description, + suggested_agent=suggested_agent, + expected_keywords=list(expected_keywords or []), + expected_tools=list(expected_tools or []), + priority=priority, + motivating_hypothesis_ids=list(motivating_hypothesis_ids or []), + created_by=created_by, + created_at=datetime.now().isoformat(), + ) + self._auto_save() + return aid, False + # ---- Asset library ------------------------------------------------------- async def register_asset( diff --git a/llm_client.py b/llm_client.py index f6abd09..6b6e7d3 100644 --- a/llm_client.py +++ b/llm_client.py @@ -1,9 +1,10 @@ """LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API). -Tool calling is text-based (ReAct pattern): tool schemas are embedded in -the system prompt and tool calls are parsed as JSON blocks -from model output. This keeps the protocol independent of whether the -underlying API supports native function/tool calling. +Tool calling uses the OpenAI-native `tools=[...]` parameter. The model +returns structured tool_calls via the streaming protocol; we accumulate +them, dispatch to our executors, and feed results back as `role: "tool"` +messages. This eliminates the fragile "model writes JSON inside free +text" problem of the previous ReAct text mode. """ from __future__ import annotations @@ -32,52 +33,51 @@ class LLMAPIError(Exception): self.attempts = attempts -# Markers the model uses to signal tool calls and final answers -TOOL_CALL_TAG = "" -TOOL_CALL_END = "" -TOOL_RESULT_TAG = "" -TOOL_RESULT_END = "" +# Optional answer tags — kept for backward compat with prompts that wrap +# their final response in .... Native tool calling does +# not need these (no tool_calls = final), but if the model continues to +# emit them, we strip the tags so callers see clean text. ANSWER_TAG = "" ANSWER_END = "" -def _build_tools_prompt(tools: list[dict]) -> str: - """Format tool definitions for inclusion in the system prompt.""" - lines = ["You have access to the following tools:\n"] - for t in tools: - schema = t.get("input_schema", {}) - props = schema.get("properties", {}) - required = schema.get("required", []) +def _to_openai_tools(tools: list[dict]) -> list[dict]: + """Convert internal tool definitions to OpenAI native function-tools format.""" + return [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t["description"], + "parameters": t.get("input_schema", {"type": "object", "properties": {}}), + }, + } + for t in tools + ] - params = [] - for pname, pdef in props.items(): - req = " (required)" if pname in required else "" - desc = pdef.get("description", "") - ptype = pdef.get("type", "string") - enum_vals = pdef.get("enum") - if enum_vals: - allowed = ", ".join(f'"{v}"' for v in enum_vals) - params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]") - else: - params.append(f" - {pname}: {ptype}{req} — {desc}") - param_block = "\n".join(params) if params else " (no parameters)" - lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n") +def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None: + """Return the first balanced [...] or {...} substring, or None if no balanced pair. - lines.append( - "## How to use tools\n" - "To call a tool, output a JSON block wrapped in XML tags like this:\n" - f"{TOOL_CALL_TAG}\n" - '{"name": "tool_name", "arguments": {"param1": "value1"}}\n' - f"{TOOL_CALL_END}\n\n" - "You can call multiple tools in sequence. After each tool call, you will receive the result in:\n" - f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n" - "When you have finished your analysis and have a final answer, wrap it in:\n" - f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n" - "Think step by step. Call tools to gather evidence before drawing conclusions.\n" - "You MUST call at least one tool before giving your final answer." - ) - return "\n".join(lines) + Stack-based — handles nested brackets correctly (regex with .*? would + truncate at the first inner closing bracket, regex with .* would over-eat + trailing text). Brackets inside JSON string literals are ignored by + callers because the caller passes the result through json.loads which + re-parses with proper string handling. + """ + start = text.find(open_char) + if start < 0: + return None + depth = 0 + for i in range(start, len(text)): + c = text[i] + if c == open_char: + depth += 1 + elif c == close_char: + depth -= 1 + if depth == 0: + return text[start:i + 1] + return None def _safe_json_loads(text: str): @@ -87,8 +87,8 @@ def _safe_json_loads(text: str): (\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw input (first 600 chars) so we can diagnose what the model emitted. - Used both by orchestrator JSON callsites and by _extract_tool_calls - when parsing blocks from model output. + Used by orchestrator JSON callsites (_call_llm_for_json) and by + tool_call_loop when parsing tool_call arguments returned by the API. """ try: return json.loads(text) @@ -110,23 +110,6 @@ def _safe_json_loads(text: str): raise -def _extract_tool_calls(text: str) -> list[dict]: - """Extract tool call JSON blocks from model output.""" - pattern = re.compile( - re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END), - re.DOTALL, - ) - calls = [] - for match in pattern.finditer(text): - raw = match.group(1).strip() - try: - parsed = _safe_json_loads(raw) - calls.append(parsed) - except json.JSONDecodeError: - logger.warning("Failed to parse tool call JSON: %s", raw[:200]) - return calls - - def _extract_answer(text: str) -> str | None: """Extract the final answer from model output.""" pattern = re.compile( @@ -266,50 +249,41 @@ _DECAY_TIERS: list[tuple[int, int]] = [ def _apply_progressive_decay(messages: list[dict]) -> list[dict]: - """Truncate tool results in older messages to save context space. + """Truncate the `content` of older `role: "tool"` messages to save context. - Operates in-place-style on a copy. Only touches user messages that - contain blocks (these are the tool-result messages - generated by tool_call_loop). + Each `role: "tool"` message in the conversation corresponds to one tool + call's result. We rank these messages by recency and progressively + truncate older ones according to `_DECAY_TIERS`. """ - # Count rounds from the end. A "round" is a (assistant, user) pair. - # messages alternate: [user, assistant, user, assistant, user, ...] - # The initial user message is index 0, then pairs start at index 1. total = len(messages) - if total <= 10: # not enough messages to bother + if total <= 10: return messages - result = [] - # Count tool-result user messages from the end - tool_result_indices = [ - i for i, m in enumerate(messages) - if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "") + tool_msg_indices = [ + i for i, m in enumerate(messages) if m.get("role") == "tool" ] - # Build a set of indices that need decay, mapped to their max_chars decay_map: dict[int, int] = {} - n_tool_msgs = len(tool_result_indices) - for rank, idx in enumerate(reversed(tool_result_indices)): - rounds_ago = rank # 0 = most recent, 1 = second most recent, ... + for rank, idx in enumerate(reversed(tool_msg_indices)): + rounds_ago = rank for threshold, max_chars in _DECAY_TIERS: if rounds_ago < threshold: decay_map[idx] = max_chars break + result = [] for i, msg in enumerate(messages): if i in decay_map: max_chars = decay_map[i] - content = msg["content"] + content = msg.get("content", "") or "" if len(content) > max_chars + 200: - # Truncate but preserve the tool_result tags structure - truncated = content[:max_chars] - # Count how many tool results are in this message - n_results = content.count(TOOL_RESULT_TAG) - truncated += ( - f"\n... [context compressed: {len(content)} -> {max_chars} chars, " - f"{n_results} tool result(s)]" + truncated = ( + content[:max_chars] + + f"\n... [context compressed: {len(content)} -> {max_chars} chars]" ) - result.append({"role": msg["role"], "content": truncated}) + new_msg = dict(msg) + new_msg["content"] = truncated + result.append(new_msg) else: result.append(msg) else: @@ -434,6 +408,95 @@ class LLMClient: return "" + async def _chat_with_tools( + self, + messages: list[dict], + openai_tools: list[dict], + max_retries: int = 5, + ) -> tuple[str, str | None, list[dict]]: + """Stream a chat completion with native tool calling enabled. + + Returns: + (text_content, reasoning_content, raw_tool_calls). + - reasoning_content is non-None when DeepSeek thinking mode is + active; the caller MUST echo it back in the assistant message + on subsequent requests, or the API returns HTTP 400. + - raw_tool_calls is a list of {"id","name","arguments"} dicts; + arguments is the raw JSON string returned by the API. + """ + kwargs: dict[str, Any] = { + "model": self.model, + "messages": messages, + "max_tokens": self.max_tokens, + "stream": True, + "tools": openai_tools, + } + if self.reasoning_effort: + kwargs["reasoning_effort"] = self.reasoning_effort + if self.thinking_enabled: + kwargs["extra_body"] = {"thinking": {"type": "enabled"}} + + for attempt in range(max_retries): + logger.debug( + "LLM request (stream+tools): %d messages, %d tools (attempt %d)", + len(messages), len(openai_tools), attempt + 1, + ) + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments} + try: + stream = await self._client.chat.completions.create(**kwargs) + async for chunk in stream: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + if delta.content: + text_parts.append(delta.content) + # DeepSeek thinking-mode: reasoning_content is returned + # alongside content and MUST be echoed back on subsequent + # requests, otherwise the API rejects with HTTP 400. + rc = getattr(delta, "reasoning_content", None) + if rc: + reasoning_parts.append(rc) + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + entry = tool_calls_acc.setdefault( + idx, {"id": None, "name": None, "arguments": ""}, + ) + if tc_delta.id: + entry["id"] = tc_delta.id + fn = tc_delta.function + if fn: + if fn.name: + entry["name"] = fn.name + if fn.arguments: + entry["arguments"] += fn.arguments + + text = "".join(text_parts) + reasoning = "".join(reasoning_parts) or None + ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)] + logger.debug( + "LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls", + len(text), len(reasoning or ""), len(ordered), + ) + return text, reasoning, ordered + + except (APIConnectionError, APITimeoutError, APIError) as e: + if attempt < max_retries - 1: + wait = 2 ** attempt * 10 + logger.warning( + "Tool-call request failed (%s), retrying in %ds...", e, wait, + ) + await asyncio.sleep(wait) + else: + raise LLMAPIError( + f"LLM API unreachable after {max_retries} attempts: {e}", + attempts=max_retries, + ) from e + + return "", None, [] + async def tool_call_loop( self, messages: list[dict], @@ -441,87 +504,159 @@ class LLMClient: tool_executor: dict[str, Any], system: str | None = None, max_iterations: int = 40, + terminal_tools: tuple[str, ...] = (), ) -> tuple[str, list[dict]]: - """Run a ReAct-style tool-calling loop. + """Run a tool-calling loop using OpenAI-native tool calls. - The model outputs blocks which we parse and execute, - feeding results back as blocks until the model - outputs an block. + The model returns structured `tool_calls` in its message; we + dispatch them through our executor dict and feed each result back + as a `role: "tool"` message with the matching `tool_call_id`. The + loop ends when: + - the model returns a message with no tool_calls (normal exit), or + - any tool in `terminal_tools` is called — in that case, the loop + short-circuits with that tool's result text as final_text. This + gives agents (notably ReportAgent) an explicit completion signal + that the old `` text tag used to provide. Returns: - (final_text, all_messages) + (final_text, full_message_history) """ - # Build system prompt with tool definitions - tools_prompt = _build_tools_prompt(tools) - full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt + terminal_set = set(terminal_tools) + openai_tools = _to_openai_tools(tools) - messages = list(messages) # don't mutate caller's list - _folded = False # Track whether we've already folded once this loop + # The caller may pass `messages` either as raw conversation (no system) + # together with `system=...`, OR as a complete history that already + # starts with the system message (retry path). Accept both shapes. + if messages and messages[0].get("role") == "system": + full_messages: list[dict] = list(messages) + else: + full_messages = [] + if system: + full_messages.append({"role": "system", "content": system}) + full_messages.extend(messages) + _folded = False - for i in range(max_iterations): + for _i in range(max_iterations): # ── Context compression before each API call ────────────── - # Stage A: progressively decay old tool results - messages = _apply_progressive_decay(messages) - - # Stage B: fold oldest messages into LLM summary if too long - if not _folded and len(messages) > _FOLD_THRESHOLD: - messages = await self._fold_old_messages(messages, full_system) + full_messages = _apply_progressive_decay(full_messages) + if not _folded and len(full_messages) > _FOLD_THRESHOLD: + full_messages = await self._fold_old_messages(full_messages) _folded = True - elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT: - # Allow a second fold if messages grew back significantly - messages = await self._fold_old_messages(messages, full_system) + elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT: + full_messages = await self._fold_old_messages(full_messages) - text = await self.chat(messages, system=full_system) + text, reasoning, raw_tool_calls = await self._chat_with_tools( + full_messages, openai_tools, + ) - # Check for final answer - answer = _extract_answer(text) - if answer is not None: - messages.append({"role": "assistant", "content": text}) - return answer, messages + if not raw_tool_calls: + # Model produced a final response. Strip optional + # tags for backward compatibility with old prompts. + final_msg: dict[str, Any] = {"role": "assistant", "content": text} + if reasoning: + final_msg["reasoning_content"] = reasoning + full_messages.append(final_msg) + answer = _extract_answer(text) + return (answer if answer is not None else text), full_messages - # Check for tool calls - tool_calls = _extract_tool_calls(text) + # Parse arguments + build internal call dicts + parsed_calls: list[dict] = [] + for rc in raw_tool_calls: + args_str = rc.get("arguments", "") or "" + try: + args = _safe_json_loads(args_str) if args_str.strip() else {} + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to parse arguments for tool %s: %s", + rc.get("name"), e, + ) + args = {} + parsed_calls.append({ + "id": rc.get("id"), + "name": rc.get("name", ""), + "arguments": args, + }) - if not tool_calls: - # No tool calls and no answer tag — treat entire text as answer - messages.append({"role": "assistant", "content": text}) - return text, messages + # Append the assistant turn with the raw tool_calls (and the + # DeepSeek-mandated reasoning_content echo-back), then execute. + asst_msg: dict[str, Any] = { + "role": "assistant", + "content": text or None, + "tool_calls": [ + { + "id": rc.get("id"), + "type": "function", + "function": { + "name": rc.get("name", ""), + "arguments": rc.get("arguments", "") or "", + }, + } + for rc in raw_tool_calls + ], + } + if reasoning: + asst_msg["reasoning_content"] = reasoning + full_messages.append(asst_msg) - # Execute tool calls — read-only tools run in parallel - messages.append({"role": "assistant", "content": text}) - - result_parts = [] - batches = _partition_tool_calls(tool_calls) + batches = _partition_tool_calls(parsed_calls) t_batch_start = time.monotonic() - + # Each entry: (tool_call_dict, raw_result, formatted_for_llm) + executed: list[tuple[dict, str, str]] = [] for batch in batches: if batch.is_read_only and len(batch.calls) > 1: - batch_results = await self._execute_tool_batch_parallel( + results = await self._execute_tool_batch_parallel( batch.calls, tool_executor, tools, ) - result_parts.extend(batch_results) + for tc, (raw, formatted) in zip(batch.calls, results): + executed.append((tc, raw, formatted)) else: for tc in batch.calls: - result_parts.append( - await self._execute_single_tool(tc, tool_executor, tools) + raw, formatted = await self._execute_single_tool( + tc, tool_executor, tools, ) + executed.append((tc, raw, formatted)) - # Emit folded tool-call summary for the terminal t_batch_elapsed = time.monotonic() - t_batch_start - _emit_tool_call_summary(tool_calls, t_batch_elapsed) + _emit_tool_call_summary(parsed_calls, t_batch_elapsed) - # Feed results back as a user message - result_message = "\n\n".join(result_parts) - messages.append({"role": "user", "content": result_message}) + # Append formatted tool results to the conversation (this is + # what the LLM sees on subsequent rounds — truncated for context + # economy). + for tc, _raw, formatted in executed: + full_messages.append({ + "role": "tool", + "tool_call_id": tc["id"], + "content": formatted, + }) + + # Terminal-tool short-circuit: if the model called any tool in + # `terminal_tools`, end the loop immediately. The terminal tool's + # RAW result (untruncated) becomes final_text — the LLM may have + # produced a 20K-char report via save_report and we must not + # truncate it just because the LLM-facing copy is truncated. + if terminal_set: + for tc, raw, _formatted in executed: + name = tc.get("name", "") + if name in terminal_set: + logger.info( + "Terminal tool %s called — exiting tool_call_loop", name, + ) + return raw, full_messages logger.warning("Tool call loop hit max iterations (%d)", max_iterations) - return "[Max tool call iterations reached]", messages + return "[Max tool call iterations reached]", full_messages async def _execute_single_tool( self, tc: dict, tool_executor: dict[str, Any], tools: list[dict] | None = None, - ) -> str: - """Execute a single tool call and return the formatted result.""" + ) -> tuple[str, str]: + """Execute a single tool call. + + Returns (raw_result, formatted_for_llm). `raw_result` is the + unmodified executor return (used by terminal-tool short-circuit as + final_text). `formatted_for_llm` is `[tool_name] {truncated}` and + is what gets fed back to the model as the tool message content. + """ tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) @@ -532,72 +667,106 @@ class LLMClient: executor = tool_executor.get(tool_name) if executor is None: - result_text = f"Error: unknown tool '{tool_name}'" + raw = f"Error: unknown tool '{tool_name}'" else: try: - result_text = await executor(**tool_args) + raw = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) - result_text = f"Error executing {tool_name}: {e}" + raw = f"Error executing {tool_name}: {e}" - return ( - f"{TOOL_RESULT_TAG}\n" - f"[{tool_name}] {_truncate_tool_result(result_text)}\n" - f"{TOOL_RESULT_END}" - ) + formatted = f"[{tool_name}] {_truncate_tool_result(raw)}" + return raw, formatted async def _execute_tool_batch_parallel( self, calls: list[dict], tool_executor: dict[str, Any], tools: list[dict] | None = None, - ) -> list[str]: - """Execute multiple read-only tool calls concurrently.""" + ) -> list[tuple[str, str]]: + """Execute multiple read-only tool calls concurrently. + + Returns a list of (raw_result, formatted_for_llm) tuples in the + same order as `calls`. + """ logger.info("Executing %d read-only tools in parallel", len(calls)) - async def _run_one(tc: dict) -> str: + async def _run_one(tc: dict) -> tuple[str, str]: tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) - logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False)) + logger.info( + "Calling tool (parallel): %s(%s)", + tool_name, json.dumps(tool_args, ensure_ascii=False), + ) executor = tool_executor.get(tool_name) if executor is None: - result_text = f"Error: unknown tool '{tool_name}'" + raw = f"Error: unknown tool '{tool_name}'" else: try: - result_text = await executor(**tool_args) + raw = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) - result_text = f"Error executing {tool_name}: {e}" - return ( - f"{TOOL_RESULT_TAG}\n" - f"[{tool_name}] {_truncate_tool_result(result_text)}\n" - f"{TOOL_RESULT_END}" - ) + raw = f"Error executing {tool_name}: {e}" + formatted = f"[{tool_name}] {_truncate_tool_result(raw)}" + return raw, formatted results = await asyncio.gather(*[_run_one(tc) for tc in calls]) return list(results) async def _fold_old_messages( - self, messages: list[dict], system: str, + self, messages: list[dict], ) -> list[dict]: """Fold old messages into an LLM-generated summary (Stage B). - Keeps the most recent _FOLD_KEEP_RECENT messages intact and - replaces earlier ones with a single summary message. + Preserves the leading system message (if any), keeps the most + recent _FOLD_KEEP_RECENT messages intact, and replaces the older + middle slice with a single summary user message. """ - n_to_fold = len(messages) - _FOLD_KEEP_RECENT + # Pin the system message — it must NEVER be summarized away. + system_msgs: list[dict] = [] + body = messages + if messages and messages[0].get("role") == "system": + system_msgs = [messages[0]] + body = messages[1:] + + n_to_fold = len(body) - _FOLD_KEEP_RECENT if n_to_fold <= 2: return messages - old_messages = messages[:n_to_fold] - recent_messages = messages[n_to_fold:] + # Pull the fold boundary forward so we never split an assistant turn + # from its matching tool results. The API rejects (HTTP 400) any + # `role: "tool"` message that does not immediately follow an + # `assistant` message with `tool_calls`. We walk the boundary into + # `recent_messages` while its head is a `role: "tool"` message, or + # while the prior `recent` message is `assistant{tool_calls}` whose + # paired tools span the boundary. + while n_to_fold < len(body): + head = body[n_to_fold] + if head.get("role") == "tool": + n_to_fold += 1 + continue + break + + if n_to_fold >= len(body): + # Everything got folded — nothing recent to keep. + return system_msgs + [body[0]] if system_msgs else messages + + old_messages = body[:n_to_fold] + recent_messages = body[n_to_fold:] - # Build a text dump of old messages for summarization old_text_parts = [] for msg in old_messages: - role = msg["role"] - content = msg.get("content", "") - # Truncate each message for the summary prompt to avoid overload + role = msg.get("role", "?") + content = msg.get("content") or "" + # Render tool_calls (assistant turn) compactly. + if role == "assistant" and msg.get("tool_calls"): + tc_names = [ + tc.get("function", {}).get("name", "?") + for tc in msg["tool_calls"] + ] + content = (content + " " if content else "") + ( + "called: " + ", ".join(tc_names) + ) if len(content) > 1000: content = content[:1000] + "..." old_text_parts.append(f"[{role}]: {content}") @@ -621,7 +790,6 @@ class LLMClient: logger.warning("Context folding failed: %s — keeping original messages", e) return messages - # Replace old messages with a single summary summary_message = { "role": "user", "content": ( @@ -629,4 +797,4 @@ class LLMClient: f"messages in this conversation]\n\n{summary}" ), } - return [summary_message] + recent_messages + return system_msgs + [summary_message] + recent_messages diff --git a/orchestrator.py b/orchestrator.py index 8f654b7..a71ade8 100644 --- a/orchestrator.py +++ b/orchestrator.py @@ -12,7 +12,8 @@ from pathlib import Path from agent_factory import AgentFactory from evidence_graph import EvidenceGraph -from llm_client import LLMClient, _safe_json_loads +from llm_client import LLMClient, _extract_first_balanced, _safe_json_loads +from tool_registry import TOOL_CATALOG logger = logging.getLogger(__name__) @@ -93,6 +94,14 @@ class Orchestrator: "Omit phenomena that are unrelated. Be conservative — only link genuinely relevant evidence." ) + _AREA_DERIVE_SYSTEM = ( + "You are a forensic investigation strategist. Given a set of hypotheses, " + "decompose them into a minimal aggregate set of investigation areas. An " + "area is a focused, concrete question with the keywords and tool names an " + "answering phenomenon would mention. Aggregate aggressively — when two " + "hypotheses share an area, emit it once and list both hypothesis_ids." + ) + def __init__( self, llm: LLMClient, @@ -216,6 +225,125 @@ class Orchestrator: "The ultimate goal is to reconstruct a detailed timeline of what happened on this host." ) + # ---- Investigation areas (manual seed + LLM derive) ---------------------- + + _VALID_AGENT_TYPES = {"filesystem", "registry", "communication", "network", "timeline"} + + async def _seed_manual_investigation_areas(self) -> None: + """Import config.yaml:investigation_areas entries (manual override). + + Run early in Phase 2 so manual entries are in the graph before LLM + derivation; LLM derive then augments via slug-based dedupe. + """ + for entry in self.config.get("investigation_areas", []): + area = entry.get("area") + if not area: + continue + await self.graph.add_investigation_area( + area=area, + description=entry.get("description", entry.get("task", "")), + suggested_agent=entry.get("agent", "filesystem"), + expected_keywords=entry.get("keywords", []), + expected_tools=entry.get("tools", []), + priority=entry.get("priority", 3), + created_by="manual", + ) + + async def _derive_investigation_areas(self) -> None: + """Ask LLM to derive investigation areas from active hypotheses. + + Manual-seeded or already-populated graph (resume) → no-op. On LLM + failure or empty output, falls back to one-area-per-hypothesis. + """ + if self.graph.investigation_areas: + return + active = [h for h in self.graph.hypotheses.values() if h.status == "active"] + if not active: + return + + available_tools = sorted(TOOL_CATALOG.keys()) + hyp_lines = "\n".join( + f" [{h.id}] {h.title}: {h.description}" for h in active + ) + prompt = ( + f"Active hypotheses:\n{hyp_lines}\n\n" + f"Available agents: {sorted(self._VALID_AGENT_TYPES)}\n" + f"Available tool names (pick 1-3 per area for expected_tools): {available_tools}\n\n" + f"Emit 5-12 distinct investigation areas covering the FULL hypothesis set.\n" + f"Each area must include:\n" + f" - area: snake_case slug (dedupe key)\n" + f" - description: one sentence on what to find\n" + f" - suggested_agent: one of the agents above\n" + f" - expected_keywords: 3-8 lowercase tokens that an answering phenomenon would mention\n" + f" - expected_tools: 1-3 tool names from the list above\n" + f" - priority: 1 (highest) to 10\n" + f" - motivating_hypothesis_ids: at least one [hyp-xxx] from above\n\n" + f"Aggregate aggressively — when two hypotheses share an area, emit it ONCE " + f"and list both ids in motivating_hypothesis_ids.\n\n" + f"Respond ONLY with JSON:\n" + f'[{{"area":"...","description":"...","suggested_agent":"...",' + f'"expected_keywords":[...],"expected_tools":[...],"priority":1-10,' + f'"motivating_hypothesis_ids":["hyp-xxx"]}}]' + ) + + try: + items = await self._call_llm_for_json( + system=self._AREA_DERIVE_SYSTEM, + user_prompt=prompt, + schema="array", + ) + except Exception as e: + logger.warning("Area derivation LLM failed: %s — falling back", e) + await self._derive_areas_fallback(active) + return + + valid_hyp_ids = set(self.graph.hypotheses.keys()) + for it in items: + area = it.get("area", "").strip() + if not area: + continue + agent = it.get("suggested_agent", "filesystem") + if agent not in self._VALID_AGENT_TYPES: + agent = AGENT_ALIASES.get(agent, "filesystem") + tools = [t for t in it.get("expected_tools", []) if t in TOOL_CATALOG] + motivating = [ + h for h in it.get("motivating_hypothesis_ids", []) + if h in valid_hyp_ids + ] + priority = max(1, min(10, int(it.get("priority", 5)))) + await self.graph.add_investigation_area( + area=area, + description=it.get("description", ""), + suggested_agent=agent, + expected_keywords=[ + str(kw).lower() for kw in it.get("expected_keywords", []) + ], + expected_tools=tools, + priority=priority, + motivating_hypothesis_ids=motivating, + created_by="llm_derive", + ) + + if not self.graph.investigation_areas: + await self._derive_areas_fallback(active) + + async def _derive_areas_fallback(self, active: list) -> None: + """One area per active hypothesis as a minimal safety net.""" + for h in active: + slug = re.sub(r"[^a-z0-9_]+", "_", h.title.lower())[:40].strip("_") + if not slug: + slug = h.id.replace("-", "_") + await self.graph.add_investigation_area( + area=slug, + description=h.title, + suggested_agent="filesystem", + expected_keywords=h.title.lower().split()[:6], + expected_tools=[], + priority=5, + motivating_hypothesis_ids=[h.id], + created_by="fallback", + ) + # ---- LLM JSON helper ----------------------------------------------------- async def _call_llm_for_json( @@ -234,13 +362,12 @@ class Orchestrator: """ error_hint = "" last_err: Exception | None = None - pattern = r'\[.*?\]' if schema == "array" else r'\{.*?\}' + open_c, close_c = ('[', ']') if schema == "array" else ('{', '}') for attempt in range(max_retries + 1): messages = [{"role": "user", "content": user_prompt + error_hint}] response = await self.llm.chat(messages=messages, system=system) - match = re.search(pattern, response, re.DOTALL) - candidate = match.group() if match else response + candidate = _extract_first_balanced(response, open_c, close_c) or response try: return _safe_json_loads(candidate) except (json.JSONDecodeError, ValueError) as e: @@ -278,10 +405,19 @@ class Orchestrator: existing = "\n".join( f" - {r['node']} [{r['edge_type']}]" for r in related ) or " (none yet)" + related_areas = [ + a.area for a in self.graph.investigation_areas.values() + if hyp.id in a.motivating_hypothesis_ids + ] + expected_line = ( + f" Expected areas to investigate: {', '.join(related_areas)}\n" + if related_areas else "" + ) hyp_blocks.append( f"Hypothesis [{hyp.id}]: {hyp.title}\n" f" Description: {hyp.description}\n" f" Current confidence: {hyp.confidence:.2f}\n" + f"{expected_line}" f" Existing evidence:\n{existing}" ) @@ -328,12 +464,21 @@ class Orchestrator: existing_evidence = "\n".join( f" - {r['node']} [{r['edge_type']}]" for r in related ) or " (none yet)" + related_areas = [ + a.area for a in self.graph.investigation_areas.values() + if hyp.id in a.motivating_hypothesis_ids + ] + expected_line = ( + f"Expected areas to investigate: {', '.join(related_areas)}\n\n" + if related_areas else "" + ) prompt = ( f"Hypothesis: {hyp.title}\n" f"Description: {hyp.description}\n" f"Current confidence: {hyp.confidence:.2f}\n\n" f"Existing evidence linked to this hypothesis:\n{existing_evidence}\n\n" + f"{expected_line}" f"What additional evidence should we look for to CONFIRM or DENY this hypothesis?\n" f"List 1-3 specific, actionable investigation tasks.\n" f"For each, specify which agent type should handle it: " @@ -464,74 +609,57 @@ class Orchestrator: # ---- Gap analysis (coverage check) --------------------------------------- - _AREA_KEYWORDS: dict[str, list[str]] = { - "system_info": ["install date", "registered owner", "product name", "windows xp", "system information"], - "user_accounts": ["user account", "enumerate", "sam hive", "administrator", "mr. evil"], - "shutdown_time": ["shutdown"], - "network_config": ["network interface", "network adapter", "ip address", "dhcp", "mac address", "network config"], - "installed_software": ["installed software", "program files", "installed program"], - "email_config": ["smtp", "pop3", "nntp", "email account", "email config"], - "chat_logs": ["irc", "mirc", "chat log", "channel"], - "network_activity": ["packet capture", "pcap", "interception", "http request", "user-agent"], - "deleted_files": ["deleted file", "recycle", "recycler"], - "execution_evidence": ["prefetch", "execution", "run count", "last execution"], - } + def _check_coverage(self) -> set[str]: + """Return slugs of investigation_areas already covered by phenomena. - # Deterministic coverage: if the canonical tool was called, the area is covered. - _AREA_TOOLS: dict[str, list[str]] = { - "system_info": ["get_system_info"], - "user_accounts": ["enumerate_users"], - "shutdown_time": ["get_shutdown_time"], - "network_config": ["get_network_interfaces"], - "installed_software": ["list_installed_software"], - "email_config": ["get_email_config"], - "network_activity": ["parse_pcap_strings"], - "deleted_files": ["count_deleted_files"], - "execution_evidence": ["parse_prefetch"], - } + Layer A: any expected_keyword found in evidence text (category + + title + description, lowercased). + Layer B: any expected_tool present in the source_tool set of recorded + phenomena (deterministic — the canonical tool was actually called). + """ + evidence_text = " ".join( + f"{ph.category} {ph.title} {ph.description}".lower() + for ph in self.graph.phenomena.values() + ) + used_tools: set[str] = { + ph.source_tool for ph in self.graph.phenomena.values() if ph.source_tool + } - def _check_coverage(self, areas: list[dict]) -> set[str]: - # Layer 1: keyword matching on category + title + description - evidence_text = "" - for ph in self.graph.phenomena.values(): - evidence_text += f" {ph.category} {ph.title} {ph.description} ".lower() - - # Layer 2: collect all source_tools that produced phenomena - used_tools: set[str] = {ph.source_tool for ph in self.graph.phenomena.values() if ph.source_tool} - - covered = set() - for area in areas: - area_name = area["area"] - # Check keywords - keywords = self._AREA_KEYWORDS.get(area_name, []) - if any(kw in evidence_text for kw in keywords): - covered.add(area_name) + covered: set[str] = set() + for a in self.graph.investigation_areas.values(): + if any(kw.lower() in evidence_text for kw in a.expected_keywords): + covered.add(a.area) continue - # Check source_tool - area_tools = self._AREA_TOOLS.get(area_name, []) - if any(tool in used_tools for tool in area_tools): - covered.add(area_name) + if any(t in used_tools for t in a.expected_tools): + covered.add(a.area) return covered async def _run_gap_analysis(self) -> None: - areas = self.config.get("investigation_areas", []) + areas = list(self.graph.investigation_areas.values()) if not areas: return - covered = self._check_coverage(areas) - uncovered = [a for a in areas if a["area"] not in covered] + covered = self._check_coverage() + uncovered = [a for a in areas if a.area not in covered] if not uncovered: _log(f"All {len(areas)} investigation areas covered", event="progress") return - uncovered_names = ", ".join(a["area"] for a in uncovered) - _log(f"{len(uncovered)}/{len(areas)} areas uncovered: {uncovered_names}", event="dispatch") - for area in uncovered: + uncovered_names = ", ".join(a.area for a in uncovered) + _log( + f"{len(uncovered)}/{len(areas)} areas uncovered: {uncovered_names}", + event="dispatch", + ) + for a in uncovered: await self.graph.add_lead( - target_agent=area["agent"], - description=area["task"], - priority=3, + target_agent=a.suggested_agent, + description=a.description, + priority=a.priority, + hypothesis_id=( + a.motivating_hypothesis_ids[0] + if a.motivating_hypothesis_ids else None + ), ) for round_num in range(3): @@ -578,6 +706,15 @@ class Orchestrator: json.dumps(leads_data, ensure_ascii=False, indent=2) ) + # Investigation areas export + areas_data = { + aid: a.to_dict() + for aid, a in self.graph.investigation_areas.items() + } + (self.run_dir / "investigation_areas.json").write_text( + json.dumps(areas_data, ensure_ascii=False, indent=2) + ) + # Run metadata end_time = datetime.now() metadata = { @@ -637,6 +774,11 @@ class Orchestrator: if resume_phase <= 2: _log("Phase 2: Hypothesis Generation", event="phase") t0 = time.monotonic() + + # Seed manual investigation areas (if any) BEFORE LLM derive, + # so manual entries win the dedupe and LLM only augments. + await self._seed_manual_investigation_areas() + manual_hypotheses = self.config.get("hypotheses", []) if manual_hypotheses: await self._generate_hypotheses_manual(manual_hypotheses) @@ -647,10 +789,17 @@ class Orchestrator: if self.graph.phenomena and self.graph.hypotheses: await self._judge_new_phenomena() + # Derive investigation areas from active hypotheses. + # No-op if manual seed already populated or resume restored areas. + await self._derive_investigation_areas() + for h in self.graph.hypotheses.values(): _log(f" {h.summary()}", event="hypothesis") + for a in self.graph.investigation_areas.values(): + _log(f" {a.summary()}", event="area") _log( - f"+{len(self.graph.hypotheses)} hypotheses generated", + f"+{len(self.graph.hypotheses)} hypotheses, " + f"{len(self.graph.investigation_areas)} areas", event="progress", elapsed=time.monotonic() - t0, ) diff --git a/tests/test_optimizations.py b/tests/test_optimizations.py index e9a7530..2b703ee 100644 --- a/tests/test_optimizations.py +++ b/tests/test_optimizations.py @@ -14,7 +14,6 @@ from evidence_graph import ( from llm_client import ( _truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS, _apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT, - TOOL_RESULT_TAG, TOOL_RESULT_END, ) from tool_registry import ( _tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS, @@ -598,7 +597,8 @@ class TestParallelToolExecution: elapsed = time.monotonic() - start assert len(results) == 3 - assert all("ok" in r for r in results) + # results are (raw, formatted) tuples; both contain "ok" + assert all("ok" in raw and "ok" in formatted for raw, formatted in results) # 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s" @@ -627,8 +627,10 @@ class TestParallelToolExecution: results = await client._execute_tool_batch_parallel(calls, tool_executor) assert len(results) == 2 - assert "success" in results[0] - assert "Error" in results[1] + raw0, formatted0 = results[0] + raw1, formatted1 = results[1] + assert "success" in raw0 and "success" in formatted0 + assert "Error" in raw1 and "Error" in formatted1 # --------------------------------------------------------------------------- @@ -843,21 +845,27 @@ class TestToolResultCache: class TestProgressiveDecay: def _make_messages(self, n_rounds: int) -> list[dict]: - """Build a synthetic message list with n_rounds of (assistant, user) pairs.""" + """Build a synthetic message list shaped like native tool calling: + user → (assistant w/ tool_calls → tool result)+ + """ messages = [{"role": "user", "content": "Start task"}] for i in range(n_rounds): + tc_id = f"call_{i}" messages.append({ "role": "assistant", - "content": f"{{'name': 'tool_{i}'}}", + "content": None, + "tool_calls": [ + { + "id": tc_id, + "type": "function", + "function": {"name": f"tool_{i}", "arguments": "{}"}, + }, + ], }) - # Tool result message with substantial content messages.append({ - "role": "user", - "content": ( - f"{TOOL_RESULT_TAG}\n" - f"[tool_{i}] {'x' * 2500}\n" - f"{TOOL_RESULT_END}" - ), + "role": "tool", + "tool_call_id": tc_id, + "content": f"[tool_{i}] {'x' * 2500}", }) return messages @@ -865,21 +873,16 @@ class TestProgressiveDecay: msgs = self._make_messages(3) result = _apply_progressive_decay(msgs) assert len(result) == len(msgs) - # Content should be identical for short conversations for orig, decayed in zip(msgs, result): - assert orig["content"] == decayed["content"] + assert orig.get("content") == decayed.get("content") def test_old_messages_truncated(self): msgs = self._make_messages(20) result = _apply_progressive_decay(msgs) - # Recent tool results should be full length - last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1] - assert len(last_tool_result["content"]) > 2000 - - # Oldest tool results should be truncated - first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0] - assert len(first_tool_result["content"]) < 500 + tool_msgs = [m for m in result if m.get("role") == "tool"] + assert len(tool_msgs[-1]["content"]) > 2000 + assert len(tool_msgs[0]["content"]) < 500 def test_message_count_preserved(self): msgs = self._make_messages(20) @@ -915,7 +918,7 @@ class TestMessageFolding: messages.append({"role": "assistant", "content": f"thinking step {i}"}) messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"}) - result = await client._fold_old_messages(messages, "system prompt") + result = await client._fold_old_messages(messages) # Should be significantly shorter assert len(result) < len(messages) @@ -937,7 +940,7 @@ class TestMessageFolding: {"role": "assistant", "content": "hi"}, ] - result = await client._fold_old_messages(messages, "system") + result = await client._fold_old_messages(messages) # Should return original (n_to_fold = 2 - 10 = negative, so no folding) assert result == messages client.chat.assert_not_called() @@ -951,7 +954,555 @@ class TestMessageFolding: client.chat = AsyncMock(side_effect=Exception("API error")) messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)] - result = await client._fold_old_messages(messages, "system") + result = await client._fold_old_messages(messages) # On failure, should return original messages unchanged assert len(result) == 40 + + @pytest.mark.asyncio + async def test_fold_boundary_never_orphans_tool_message(self): + """If the natural fold boundary would leave `role: "tool"` at the + head of `recent_messages`, fold must walk the boundary forward + until the head is non-tool. The API rejects orphan tool messages + with HTTP 400.""" + from llm_client import LLMClient + from unittest.mock import AsyncMock + + client = LLMClient.__new__(LLMClient) + client.chat = AsyncMock(return_value="summary") + + # Build a long conversation of (assistant{tool_calls}, tool) pairs. + # Place the assistant at the exact n_to_fold boundary so its paired + # tool would otherwise be orphaned at the head of recent_messages. + msgs: list[dict] = [{"role": "user", "content": "task"}] + for i in range(30): + tc_id = f"call_{i}" + msgs.append({ + "role": "assistant", "content": None, + "tool_calls": [{ + "id": tc_id, "type": "function", + "function": {"name": f"t_{i}", "arguments": "{}"}, + }], + }) + msgs.append({"role": "tool", "tool_call_id": tc_id, "content": "ok"}) + + result = await client._fold_old_messages(msgs) + # No `role: "tool"` may appear without an `assistant{tool_calls}` + # immediately preceding it. + for i, m in enumerate(result): + if m.get("role") == "tool": + assert i > 0, "tool message cannot be first" + prev = result[i - 1] + assert prev.get("role") == "assistant" and prev.get("tool_calls"), ( + f"tool at index {i} preceded by {prev.get('role')} " + f"(tool_calls={bool(prev.get('tool_calls'))})" + ) + + +# --------------------------------------------------------------------------- +# Investigation areas: dataclass + derivation + coverage + dispatch +# --------------------------------------------------------------------------- + +class TestInvestigationAreaDerivation: + @pytest.fixture + def graph(self): + return EvidenceGraph() + + @pytest.mark.asyncio + async def test_add_investigation_area_dedupes_and_merges_lists(self, graph): + aid1, existed1 = await graph.add_investigation_area( + area="password_hashes", description="SAM dump", + suggested_agent="filesystem", + expected_keywords=["sam", "pwdump"], + expected_tools=["search_strings"], + motivating_hypothesis_ids=["hyp-a"], + created_by="llm_derive", + ) + aid2, existed2 = await graph.add_investigation_area( + area="password_hashes", description="Other description", + suggested_agent="registry", + expected_keywords=["sam", "hashdump"], # one new + expected_tools=["search_strings", "parse_registry_key"], # one new + motivating_hypothesis_ids=["hyp-b"], # new + created_by="manual", + ) + assert not existed1 + assert existed2 + assert aid1 == aid2 + a = graph.investigation_areas[aid1] + # Description and suggested_agent NOT overwritten (first-write wins) + assert a.description == "SAM dump" + assert a.suggested_agent == "filesystem" + # Three list fields are unioned + assert set(a.expected_keywords) == {"sam", "pwdump", "hashdump"} + assert set(a.expected_tools) == {"search_strings", "parse_registry_key"} + assert set(a.motivating_hypothesis_ids) == {"hyp-a", "hyp-b"} + + @pytest.mark.asyncio + async def test_check_coverage_keyword_layer(self, graph): + await graph.add_phenomenon( + "fs", "filesystem", "Cain SAM dump artifact", + "Found sam.lst in the Cain folder", + source_tool="list_directory", + ) + await graph.add_investigation_area( + area="password_hashes", description="SAM dump", + suggested_agent="filesystem", + expected_keywords=["sam.lst", "pwdump"], + expected_tools=["nonexistent_tool"], + ) + from orchestrator import Orchestrator + from agent_factory import AgentFactory + from unittest.mock import AsyncMock + orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph)) + covered = orch._check_coverage() + assert "password_hashes" in covered + + @pytest.mark.asyncio + async def test_check_coverage_tool_layer(self, graph): + await graph.add_phenomenon( + "reg", "registry", "User accounts", + "Found accounts", + source_tool="enumerate_users", + ) + await graph.add_investigation_area( + area="user_accounts", description="Enum users", + suggested_agent="registry", + expected_keywords=["irrelevant"], + expected_tools=["enumerate_users"], + ) + from orchestrator import Orchestrator + from agent_factory import AgentFactory + from unittest.mock import AsyncMock + orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph)) + covered = orch._check_coverage() + assert "user_accounts" in covered + + @pytest.mark.asyncio + async def test_load_state_round_trip_areas(self, graph, tmp_path): + await graph.add_investigation_area( + area="x", description="d", suggested_agent="filesystem", + expected_keywords=["k1"], expected_tools=["t1"], + priority=2, motivating_hypothesis_ids=["hyp-a"], + created_by="manual", + ) + path = tmp_path / "state.json" + graph.save_state(path) + g2 = EvidenceGraph.load_state(path) + assert len(g2.investigation_areas) == 1 + a = list(g2.investigation_areas.values())[0] + assert a.area == "x" + assert a.expected_keywords == ["k1"] + assert a.priority == 2 + assert a.motivating_hypothesis_ids == ["hyp-a"] + assert a.created_by == "manual" + + @pytest.mark.asyncio + async def test_derive_no_op_when_areas_already_populated(self, graph): + """Resume safety: if areas are already in the graph (manual seed or + restored from disk), _derive_investigation_areas does nothing.""" + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + await graph.add_hypothesis("test", "desc", created_by="t") + await graph.add_investigation_area( + area="pre_existing", description="d", suggested_agent="filesystem", + created_by="manual", + ) + + llm = AsyncMock() + orch = Orchestrator(llm, graph, AgentFactory(llm, graph)) + await orch._derive_investigation_areas() + # LLM should not have been called + assert llm.chat.call_count == 0 + # Area count unchanged + assert len(graph.investigation_areas) == 1 + + @pytest.mark.asyncio + async def test_fallback_when_llm_returns_empty_list(self, graph): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + await graph.add_hypothesis("Some compromise", "desc", created_by="t") + llm = AsyncMock() + llm.chat.return_value = "[]" + orch = Orchestrator(llm, graph, AgentFactory(llm, graph)) + await orch._derive_investigation_areas() + # Fallback creates one area per hypothesis + assert len(graph.investigation_areas) == 1 + a = list(graph.investigation_areas.values())[0] + assert a.created_by == "fallback" + + @pytest.mark.asyncio + async def test_unknown_tool_filtered_kept_keywords(self, graph): + """LLM emits a tool name not in TOOL_CATALOG; tool is filtered, + but the area itself with its keywords is kept.""" + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + h = await graph.add_hypothesis("h", "desc", created_by="t") + llm = AsyncMock() + llm.chat.return_value = ( + '[{"area":"foo","description":"desc","suggested_agent":"filesystem",' + '"expected_keywords":["kw1","kw2"],"expected_tools":["nonexistent_tool"],' + f'"priority":2,"motivating_hypothesis_ids":["{h}"]}}]' + ) + orch = Orchestrator(llm, graph, AgentFactory(llm, graph)) + await orch._derive_investigation_areas() + assert "area-foo" in graph.investigation_areas + a = graph.investigation_areas["area-foo"] + assert a.expected_keywords == ["kw1", "kw2"] + assert a.expected_tools == [] # filtered out + + @pytest.mark.asyncio + async def test_unknown_agent_resolved_via_AGENT_ALIASES(self, graph): + """LLM emits 'chat' (which is in AGENT_ALIASES → 'communication'). + The area should land with the resolved agent name.""" + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + h = await graph.add_hypothesis("h", "desc", created_by="t") + llm = AsyncMock() + llm.chat.return_value = ( + '[{"area":"chat_stuff","description":"d","suggested_agent":"chat",' + '"expected_keywords":["irc"],"expected_tools":[],' + f'"priority":3,"motivating_hypothesis_ids":["{h}"]}}]' + ) + orch = Orchestrator(llm, graph, AgentFactory(llm, graph)) + await orch._derive_investigation_areas() + a = graph.investigation_areas["area-chat_stuff"] + assert a.suggested_agent == "communication" + + @staticmethod + def _agent_with_executor(graph, llm, tool_name: str, real_executor): + """Build a BaseAgent that registers tool_name via the real register_tool + path so the mandatory-record wrapper is engaged.""" + from base_agent import BaseAgent + agent = BaseAgent(llm, graph) + agent.name = "test_agent" + # Bypass _register_graph_tools side-effects in run() — we register + # only what the test needs. + agent._register_graph_tools = lambda: None + agent.register_tool( + name=tool_name, description="", input_schema={}, executor=real_executor, + ) + return agent + + @pytest.mark.asyncio + async def test_forced_record_retry_fires_when_zero_phenomena(self): + """BaseAgent.run should automatically retry one more LLM round if + the agent finished without calling any mandatory recording tool.""" + from unittest.mock import AsyncMock + + graph = EvidenceGraph() + llm = AsyncMock() + + async def real_add(**kw): + await graph.add_phenomenon( + source_agent="test", category="filesystem", + title="Forced retry record", + description="Recorded after STOP prompt", + source_tool="forced_retry", + ) + + agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add) + + async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()): + already_retrying = any( + "STOP." in (m.get("content", "") if isinstance(m, dict) else "") + for m in messages + ) + if not already_retrying: + return "Final answer without recording.", list(messages) + [ + {"role": "assistant", "content": "Final answer without recording."} + ] + await tool_executor["add_phenomenon"]() # goes through wrapper + return "Recorded.", [] + + llm.tool_call_loop = fake_tool_call_loop + await agent.run("test task") + assert len(graph.phenomena) == 1 + assert agent._record_call_counts["add_phenomenon"] == 1 + + @pytest.mark.asyncio + async def test_no_retry_when_mandatory_tool_was_called(self): + """Retry should NOT fire if a mandatory recording tool was invoked.""" + from unittest.mock import AsyncMock + + graph = EvidenceGraph() + llm = AsyncMock() + call_count = {"n": 0} + + async def real_add(**kw): + await graph.add_phenomenon( + source_agent="test", category="filesystem", title="x", + description="y", source_tool="t", + ) + + agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add) + + async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()): + call_count["n"] += 1 + await tool_executor["add_phenomenon"]() # wrapper increments count + return "done.", list(messages) + + llm.tool_call_loop = fake_tool_call_loop + await agent.run("test task") + assert call_count["n"] == 1 # no retry + + @pytest.mark.asyncio + async def test_no_retry_when_mandatory_tools_empty(self): + """ReportAgent declares mandatory_record_tools=() — retry should + not fire even with zero graph mutations (final text IS the output).""" + from unittest.mock import AsyncMock + from base_agent import BaseAgent + + graph = EvidenceGraph() + llm = AsyncMock() + call_count = {"n": 0} + + async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()): + call_count["n"] += 1 + return "report body here", list(messages) + + llm.tool_call_loop = fake_tool_call_loop + + class ReportLike(BaseAgent): + mandatory_record_tools = () + + agent = ReportLike(llm, graph) + agent.name = "report_like" + agent._register_graph_tools = lambda: None + await agent.run("test task") + assert call_count["n"] == 1 + + @pytest.mark.asyncio + async def test_forced_retry_fires_for_timeline_agent(self): + """TimelineAgent.mandatory_record_tools=('add_temporal_edge',) — retry + should fire when timeline finishes without creating any temporal edges, + even though the agent does not have add_phenomenon.""" + from unittest.mock import AsyncMock + from base_agent import BaseAgent + + graph = EvidenceGraph() + llm = AsyncMock() + call_count = {"n": 0} + + edge_added = {"n": 0} + async def real_add_edge(**kw): + edge_added["n"] += 1 + + class TimelineLike(BaseAgent): + mandatory_record_tools = ("add_temporal_edge",) + + agent = TimelineLike(llm, graph) + agent.name = "timeline_like" + agent._register_graph_tools = lambda: None + agent.register_tool("add_temporal_edge", "", {}, real_add_edge) + + async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()): + call_count["n"] += 1 + already_retrying = any( + "STOP." in (m.get("content", "") if isinstance(m, dict) else "") + for m in messages + ) + if not already_retrying: + return "answer", list(messages) + await tool_executor["add_temporal_edge"]() + return "recorded.", [] + + llm.tool_call_loop = fake_tool_call_loop + await agent.run("build timeline") + assert call_count["n"] == 2 # first + retry + assert edge_added["n"] == 1 + assert agent._record_call_counts["add_temporal_edge"] == 1 + + # ---- terminal_tools: real LLMClient.tool_call_loop short-circuit ----- + + @pytest.mark.asyncio + async def test_terminal_tool_exits_loop_immediately(self): + """When a terminal tool is called, tool_call_loop must return + with that tool's result text as final_text — no further LLM calls.""" + from unittest.mock import AsyncMock + from llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + client.max_tokens = 4096 + client.reasoning_effort = None + client.thinking_enabled = False + client.model = "test" + client._client = None + + call_count = {"n": 0} + + async def fake_chat_with_tools(messages, openai_tools): + call_count["n"] += 1 + if call_count["n"] == 1: + # First turn: model calls a read tool then the terminal tool. + return "thinking aloud", None, [ + {"id": "tc1", "name": "read_tool", "arguments": "{}"}, + {"id": "tc2", "name": "save_report", + "arguments": '{"content":"FINAL REPORT BODY","output_path":"r.md"}'}, + ] + raise AssertionError("loop should have exited after terminal tool") + + client._chat_with_tools = fake_chat_with_tools + + async def read_tool(): + return "some data" + + async def save_report(content, output_path): + return content # terminal tool returns content as final_text + + tools = [ + {"name": "read_tool", "description": "", "input_schema": {"type": "object", "properties": {}}}, + {"name": "save_report", "description": "", + "input_schema": {"type": "object", "properties": { + "content": {"type": "string"}, "output_path": {"type": "string"}}}}, + ] + executors = {"read_tool": read_tool, "save_report": save_report} + + final_text, _ = await client.tool_call_loop( + messages=[{"role": "user", "content": "go"}], + tools=tools, tool_executor=executors, + system="sys", terminal_tools=("save_report",), + ) + assert final_text == "FINAL REPORT BODY" + assert call_count["n"] == 1 # never called a 2nd round + + @pytest.mark.asyncio + async def test_no_terminal_short_circuit_when_not_declared(self): + """When terminal_tools is empty, the same call sequence should + run the read tool, run save_report-like tool, AND continue the loop + (i.e. another LLM round) until the model stops calling tools.""" + from unittest.mock import AsyncMock + from llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + client.max_tokens = 4096 + client.reasoning_effort = None + client.thinking_enabled = False + client.model = "test" + client._client = None + + call_count = {"n": 0} + + async def fake_chat_with_tools(messages, openai_tools): + call_count["n"] += 1 + if call_count["n"] == 1: + return "", None, [ + {"id": "tc1", "name": "add_phenomenon", + "arguments": '{"title":"x","description":"y"}'}, + ] + return "all done", None, [] # model stops calling tools + + client._chat_with_tools = fake_chat_with_tools + + async def add_phenomenon(title, description): + return f"recorded {title}" + + tools = [ + {"name": "add_phenomenon", "description": "", + "input_schema": {"type": "object", "properties": { + "title": {"type": "string"}, "description": {"type": "string"}}}}, + ] + executors = {"add_phenomenon": add_phenomenon} + + final_text, _ = await client.tool_call_loop( + messages=[{"role": "user", "content": "go"}], + tools=tools, tool_executor=executors, + system="sys", terminal_tools=(), # NOT terminal + ) + assert final_text == "all done" + assert call_count["n"] == 2 # 2 rounds — terminal_tools empty, loop continues + + @pytest.mark.asyncio + async def test_report_agent_terminal_tool_declared(self): + """ReportAgent should declare save_report as both mandatory and terminal.""" + from agents.report import ReportAgent + assert ReportAgent.terminal_tools == ("save_report",) + assert ReportAgent.mandatory_record_tools == ("save_report",) + + @pytest.mark.asyncio + async def test_terminal_tool_result_not_truncated(self): + """Terminal tool's raw return is used as final_text and must NOT + be truncated to 3000 chars (the truncation cap applies only to + LLM-context tool result messages). A 20K-char markdown report + passed through save_report should reach the caller intact.""" + from llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + client.max_tokens = 4096 + client.reasoning_effort = None + client.thinking_enabled = False + client.model = "test" + client._client = None + + long_report = "# Report\n" + ("- finding " + "x" * 100 + "\n") * 200 + assert len(long_report) > 10000 + + call_count = {"n": 0} + + async def fake_chat_with_tools(messages, openai_tools): + call_count["n"] += 1 + if call_count["n"] == 1: + return "", None, [ + {"id": "tc1", "name": "save_report", + "arguments": '{"content":"placeholder","output_path":"r.md"}'}, + ] + raise AssertionError("loop should have exited") + + client._chat_with_tools = fake_chat_with_tools + + async def save_report(content, output_path): + return long_report # ignore content arg; return long content + + tools = [{"name": "save_report", "description": "", + "input_schema": {"type": "object", "properties": { + "content": {"type": "string"}, "output_path": {"type": "string"}}}}] + executors = {"save_report": save_report} + + final_text, _ = await client.tool_call_loop( + messages=[{"role": "user", "content": "go"}], + tools=tools, tool_executor=executors, + system="sys", terminal_tools=("save_report",), + ) + assert final_text == long_report + assert len(final_text) > 10000 # not truncated to 3000 + + @pytest.mark.asyncio + async def test_dispatch_uses_hypothesis_id_when_motivating_ids_present(self, graph): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + h = await graph.add_hypothesis("h", "desc", created_by="t") + await graph.add_investigation_area( + area="uncovered_area", description="d", suggested_agent="registry", + expected_keywords=["xyz_no_match"], expected_tools=[], + priority=2, motivating_hypothesis_ids=[h], + created_by="llm_derive", + ) + orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph)) + # Don't actually dispatch (would call agents) — just hit the lead-add path + # by manually replicating what _run_gap_analysis does. + covered = orch._check_coverage() + assert "uncovered_area" not in covered + # Simulate dispatch + for a in graph.investigation_areas.values(): + if a.area not in covered: + await graph.add_lead( + target_agent=a.suggested_agent, + description=a.description, + priority=a.priority, + hypothesis_id=(a.motivating_hypothesis_ids[0] + if a.motivating_hypothesis_ids else None), + ) + assert len(graph.leads) == 1 + assert graph.leads[0].hypothesis_id == h + assert graph.leads[0].target_agent == "registry" + assert graph.leads[0].priority == 2