DESIGN_STRATEGIST.md §2. Four read-only view tools the strategist uses
to ground its decision each round.
graph_overview() — hypotheses table (log_odds, conf, edges_in,
distinct_sources, recent_flip), sources table,
pending leads. distinct_sources is the
critical signal: a hypothesis with 23 edges
but only 1 distinct_source has fragile cross-
source independence and is a candidate for
a corroboration-seeking lead.
source_coverage(src) — per-source ✓/✗ against an expected-artefact
catalogue. Catalogue is heuristic hints,
NOT a forced checklist. Footer reminds the
strategist to investigate ✗ items only when
an active hypothesis depends on them — this
is the "应试能力存在但不被绑死" guardrail.
marginal_yield(N) — new phenomena / edges / status flips per
recent round. Two consecutive zero-yield
rounds = strong signal to declare complete.
budget_status() — usage vs caps (tool_calls, rounds, wall
clock). Pacing warnings at 70% / 90%.
tools/strategy.py also exports EXPECTED_ARTEFACTS, a per-source-type
table of (name, detector, value_for) entries. Detectors are
substring patterns on tool name + args; the matcher resolves at
call time against graph.tool_invocations. Catalogue covers iOS /
Android / Windows disk / media-collection / archive source types.
All four tools registered in tool_registry, listed as read-only in
llm_client.READ_ONLY_TOOLS for parallel execution. They go through
the invocation-logging wrapper so the strategist's reads are
themselves auditable (the wrapper does NOT cache them — graph
state changes between calls).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
809 lines
31 KiB
Python
809 lines
31 KiB
Python
"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
|
|
|
|
Tool calling uses the OpenAI-native `tools=[...]` parameter. The model
|
|
returns structured tool_calls via the streaming protocol; we accumulate
|
|
them, dispatch to our executors, and feed results back as `role: "tool"`
|
|
messages. This eliminates the fragile "model writes JSON inside free
|
|
text" problem of the previous ReAct text mode.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from collections import Counter
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from openai import APIConnectionError, APIError, APITimeoutError, AsyncOpenAI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMAPIError(Exception):
|
|
"""Raised when the LLM API is unreachable after all retries."""
|
|
|
|
def __init__(self, message: str, attempts: int) -> None:
|
|
super().__init__(message)
|
|
self.attempts = attempts
|
|
|
|
|
|
# Optional answer tags — kept for backward compat with prompts that wrap
|
|
# their final response in <answer>...</answer>. Native tool calling does
|
|
# not need these (no tool_calls = final), but if the model continues to
|
|
# emit them, we strip the tags so callers see clean text.
|
|
ANSWER_TAG = "<answer>"
|
|
ANSWER_END = "</answer>"
|
|
|
|
|
|
def _to_openai_tools(tools: list[dict]) -> list[dict]:
|
|
"""Convert internal tool definitions to OpenAI native function-tools format."""
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": t["name"],
|
|
"description": t["description"],
|
|
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
|
|
},
|
|
}
|
|
for t in tools
|
|
]
|
|
|
|
|
|
def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None:
|
|
"""Return the first balanced [...] or {...} substring, or None if no balanced pair.
|
|
|
|
Stack-based — handles nested brackets correctly (regex with .*? would
|
|
truncate at the first inner closing bracket, regex with .* would over-eat
|
|
trailing text). Brackets inside JSON string literals are ignored by
|
|
callers because the caller passes the result through json.loads which
|
|
re-parses with proper string handling.
|
|
"""
|
|
start = text.find(open_char)
|
|
if start < 0:
|
|
return None
|
|
depth = 0
|
|
for i in range(start, len(text)):
|
|
c = text[i]
|
|
if c == open_char:
|
|
depth += 1
|
|
elif c == close_char:
|
|
depth -= 1
|
|
if depth == 0:
|
|
return text[start:i + 1]
|
|
return None
|
|
|
|
|
|
def _safe_json_loads(text: str):
|
|
"""Parse JSON with progressive sanitization for LLM-produced output.
|
|
|
|
Tries (0) as-is, (1) escape stray backslashes outside valid JSON escapes
|
|
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
|
|
input (first 600 chars) so we can diagnose what the model emitted.
|
|
|
|
Used by orchestrator JSON callsites (_call_llm_for_json) and by
|
|
tool_call_loop when parsing tool_call arguments returned by the API.
|
|
"""
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
stage1 = re.sub(
|
|
r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})',
|
|
r'\\\\',
|
|
text,
|
|
)
|
|
try:
|
|
return json.loads(stage1)
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(
|
|
"_safe_json_loads failed after sanitize (%s); raw head[:600]=%r",
|
|
e, text[:600],
|
|
)
|
|
raise
|
|
|
|
|
|
def _extract_answer(text: str) -> str | None:
|
|
"""Extract the final answer from model output."""
|
|
pattern = re.compile(
|
|
re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END),
|
|
re.DOTALL,
|
|
)
|
|
match = pattern.search(text)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return None
|
|
|
|
|
|
def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str:
|
|
"""Truncate a tool result if it exceeds max_chars."""
|
|
if len(result_text) <= max_chars:
|
|
return result_text
|
|
return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]"
|
|
|
|
|
|
# Tools that only read and never mutate state — safe to run concurrently.
|
|
READ_ONLY_TOOLS: set[str] = {
|
|
# Graph queries
|
|
"list_phenomena", "get_phenomenon", "search_graph", "get_related",
|
|
"get_hypothesis_status", "list_assets", "find_extracted_file",
|
|
# Sleuth Kit reads
|
|
"partition_info", "filesystem_info", "list_directory", "find_file",
|
|
"search_strings", "count_deleted_files", "build_filesystem_timeline",
|
|
# Registry reads (without auto-record wrappers)
|
|
"parse_registry_key", "search_registry", "get_user_activity",
|
|
# Parser reads
|
|
"read_text_file", "read_binary_preview", "search_text_file",
|
|
"read_text_file_section", "list_extracted_dir", "parse_pcap_strings",
|
|
"find_files",
|
|
# iOS plugin reads (S4)
|
|
"parse_plist", "sqlite_tables", "sqlite_query",
|
|
"parse_ios_keychain", "read_idevice_info",
|
|
# Android + media reads (S6) — set_active_partition is NOT read-only.
|
|
"probe_android_partitions", "ocr_image",
|
|
# Strategist view tools (DESIGN_STRATEGIST.md §2) — pure renders.
|
|
"graph_overview", "source_coverage", "marginal_yield", "budget_status",
|
|
}
|
|
|
|
|
|
def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict:
|
|
"""Try to fix misnamed tool arguments from LLM hallucination.
|
|
|
|
The LLM sometimes confuses parameter names across tools (e.g. passing
|
|
`key_path` to search_registry which expects `pattern`). This function
|
|
maps unknown kwargs to missing expected params by position/best-effort.
|
|
"""
|
|
# Find the schema for this tool
|
|
schema = None
|
|
for t in tools:
|
|
if t.get("name") == tool_name:
|
|
schema = t.get("input_schema", {})
|
|
break
|
|
if schema is None:
|
|
return tool_args
|
|
|
|
props = schema.get("properties", {})
|
|
required = set(schema.get("required", []))
|
|
|
|
unknown = [k for k in tool_args if k not in props]
|
|
if not unknown:
|
|
return tool_args # all args are valid, nothing to fix
|
|
|
|
# Build the fixed args: start with valid args
|
|
fixed = {k: v for k, v in tool_args.items() if k in props}
|
|
|
|
# Find which expected params are still missing
|
|
missing = [p for p in (required or props.keys()) if p not in fixed]
|
|
|
|
# Try to map unknown args to missing params, in order
|
|
unknown_values = [(k, tool_args[k]) for k in unknown]
|
|
|
|
for wrong_name, value in unknown_values:
|
|
if not missing:
|
|
break
|
|
# Pick the best match from missing params
|
|
best = missing.pop(0)
|
|
logger.warning(
|
|
"Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)",
|
|
tool_name, wrong_name, tool_name, best,
|
|
)
|
|
fixed[best] = value
|
|
|
|
return fixed
|
|
|
|
|
|
def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None:
|
|
"""Emit a folded tool-call summary line for the terminal formatter.
|
|
|
|
Instead of logging each tool call individually, we group by name:
|
|
"list_directory x27, extract_file x3, read_text_file x3"
|
|
"""
|
|
counts = Counter(tc.get("name", "?") for tc in tool_calls)
|
|
parts = []
|
|
for name, count in counts.most_common():
|
|
if count > 1:
|
|
parts.append(f"{name} x{count}")
|
|
else:
|
|
parts.append(name)
|
|
summary = ", ".join(parts)
|
|
logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed})
|
|
|
|
|
|
@dataclass
|
|
class _ToolBatch:
|
|
"""A batch of tool calls with the same read/write classification."""
|
|
is_read_only: bool
|
|
calls: list[dict] = field(default_factory=list)
|
|
|
|
|
|
def _partition_tool_calls(
|
|
tool_calls: list[dict],
|
|
read_only: set[str] | None = None,
|
|
) -> list[_ToolBatch]:
|
|
"""Partition tool calls into batches: consecutive read-only tools are
|
|
grouped together (will run in parallel), write tools are isolated."""
|
|
if read_only is None:
|
|
read_only = READ_ONLY_TOOLS
|
|
batches: list[_ToolBatch] = []
|
|
for tc in tool_calls:
|
|
is_ro = tc.get("name", "") in read_only
|
|
if batches and batches[-1].is_read_only and is_ro:
|
|
batches[-1].calls.append(tc)
|
|
else:
|
|
batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc]))
|
|
return batches
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Context compression — keeps the message list from growing unboundedly.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Stage A: Progressive tool result decay thresholds.
|
|
# Messages are counted in (assistant, user) pairs from the END of the list.
|
|
# "Round" = one pair of (assistant tool-calling msg, user tool-result msg).
|
|
_DECAY_TIERS: list[tuple[int, int]] = [
|
|
# (rounds_ago_threshold, max_chars_for_tool_results)
|
|
(5, 3000), # recent 5 rounds: keep full (3000 chars per tool result)
|
|
(15, 500), # 5-15 rounds ago: aggressive truncation
|
|
(999, 100), # older than 15 rounds: minimal stub
|
|
]
|
|
|
|
|
|
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
|
|
"""Truncate the `content` of older `role: "tool"` messages to save context.
|
|
|
|
Each `role: "tool"` message in the conversation corresponds to one tool
|
|
call's result. We rank these messages by recency and progressively
|
|
truncate older ones according to `_DECAY_TIERS`.
|
|
"""
|
|
total = len(messages)
|
|
if total <= 10:
|
|
return messages
|
|
|
|
tool_msg_indices = [
|
|
i for i, m in enumerate(messages) if m.get("role") == "tool"
|
|
]
|
|
|
|
decay_map: dict[int, int] = {}
|
|
for rank, idx in enumerate(reversed(tool_msg_indices)):
|
|
rounds_ago = rank
|
|
for threshold, max_chars in _DECAY_TIERS:
|
|
if rounds_ago < threshold:
|
|
decay_map[idx] = max_chars
|
|
break
|
|
|
|
result = []
|
|
for i, msg in enumerate(messages):
|
|
if i in decay_map:
|
|
max_chars = decay_map[i]
|
|
content = msg.get("content", "") or ""
|
|
if len(content) > max_chars + 200:
|
|
truncated = (
|
|
content[:max_chars]
|
|
+ f"\n... [context compressed: {len(content)} -> {max_chars} chars]"
|
|
)
|
|
new_msg = dict(msg)
|
|
new_msg["content"] = truncated
|
|
result.append(new_msg)
|
|
else:
|
|
result.append(msg)
|
|
else:
|
|
result.append(msg)
|
|
return result
|
|
|
|
|
|
# Stage B: LLM-powered message folding.
|
|
# When messages exceed this count, fold the oldest ones into a summary.
|
|
_FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count
|
|
_FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact
|
|
_FOLD_SUMMARY_SYSTEM = (
|
|
"You are a concise summarizer for an ongoing forensic investigation conversation. "
|
|
"Summarize the following early conversation between a forensic analysis agent and its "
|
|
"tool execution results. Preserve:\n"
|
|
"- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n"
|
|
"- Tools that were called and their important results\n"
|
|
"- Decisions made and current investigation direction\n"
|
|
"Keep the summary under 800 words. Use bullet points."
|
|
)
|
|
|
|
|
|
class LLMClient:
|
|
"""Async LLM client via the OpenAI SDK.
|
|
|
|
Works with any OpenAI-compatible endpoint (OpenAI, DeepSeek, ...).
|
|
Tool calling is text-based (ReAct) — see module docstring.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
api_key: str,
|
|
model: str = "deepseek-v4-pro",
|
|
max_tokens: int = 4096,
|
|
proxy: str | None = "auto",
|
|
reasoning_effort: str | None = None,
|
|
thinking_enabled: bool = False,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
self.reasoning_effort = reasoning_effort
|
|
self.thinking_enabled = thinking_enabled
|
|
|
|
# proxy="auto": read from env; proxy=None/""/"none": no proxy
|
|
if proxy == "auto":
|
|
proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY")
|
|
elif proxy and proxy.lower() != "none":
|
|
proxy_url = proxy
|
|
else:
|
|
proxy_url = None
|
|
|
|
http_client = (
|
|
httpx.AsyncClient(proxy=proxy_url, timeout=300.0)
|
|
if proxy_url else None
|
|
)
|
|
|
|
self._client = AsyncOpenAI(
|
|
api_key=self.api_key,
|
|
base_url=self.base_url,
|
|
timeout=300.0,
|
|
http_client=http_client,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
await self._client.close()
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
system: str | None = None,
|
|
max_retries: int = 5,
|
|
) -> str:
|
|
"""Send a streaming chat completion and return the assembled text."""
|
|
full_messages: list[dict] = []
|
|
if system:
|
|
full_messages.append({"role": "system", "content": system})
|
|
full_messages.extend(messages)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": self.model,
|
|
"messages": full_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"stream": True,
|
|
}
|
|
if self.reasoning_effort:
|
|
kwargs["reasoning_effort"] = self.reasoning_effort
|
|
if self.thinking_enabled:
|
|
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
|
|
|
|
for attempt in range(max_retries):
|
|
logger.debug(
|
|
"LLM request (stream): %d messages (attempt %d)",
|
|
len(messages), attempt + 1,
|
|
)
|
|
text_parts: list[str] = []
|
|
try:
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
|
async for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
delta = chunk.choices[0].delta
|
|
if delta.content:
|
|
text_parts.append(delta.content)
|
|
|
|
text = "".join(text_parts)
|
|
logger.debug("LLM response (stream): %d chars", len(text))
|
|
return text
|
|
|
|
except (APIConnectionError, APITimeoutError, APIError) as e:
|
|
if attempt < max_retries - 1:
|
|
wait = 2 ** attempt * 10
|
|
logger.warning("Request failed (%s), retrying in %ds...", e, wait)
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise LLMAPIError(
|
|
f"LLM API unreachable after {max_retries} attempts: {e}",
|
|
attempts=max_retries,
|
|
) from e
|
|
|
|
return ""
|
|
|
|
async def _chat_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
openai_tools: list[dict],
|
|
max_retries: int = 5,
|
|
) -> tuple[str, str | None, list[dict]]:
|
|
"""Stream a chat completion with native tool calling enabled.
|
|
|
|
Returns:
|
|
(text_content, reasoning_content, raw_tool_calls).
|
|
- reasoning_content is non-None when DeepSeek thinking mode is
|
|
active; the caller MUST echo it back in the assistant message
|
|
on subsequent requests, or the API returns HTTP 400.
|
|
- raw_tool_calls is a list of {"id","name","arguments"} dicts;
|
|
arguments is the raw JSON string returned by the API.
|
|
"""
|
|
kwargs: dict[str, Any] = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": self.max_tokens,
|
|
"stream": True,
|
|
"tools": openai_tools,
|
|
}
|
|
if self.reasoning_effort:
|
|
kwargs["reasoning_effort"] = self.reasoning_effort
|
|
if self.thinking_enabled:
|
|
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
|
|
|
|
for attempt in range(max_retries):
|
|
logger.debug(
|
|
"LLM request (stream+tools): %d messages, %d tools (attempt %d)",
|
|
len(messages), len(openai_tools), attempt + 1,
|
|
)
|
|
text_parts: list[str] = []
|
|
reasoning_parts: list[str] = []
|
|
tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments}
|
|
try:
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
|
async for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
delta = chunk.choices[0].delta
|
|
if delta.content:
|
|
text_parts.append(delta.content)
|
|
# DeepSeek thinking-mode: reasoning_content is returned
|
|
# alongside content and MUST be echoed back on subsequent
|
|
# requests, otherwise the API rejects with HTTP 400.
|
|
rc = getattr(delta, "reasoning_content", None)
|
|
if rc:
|
|
reasoning_parts.append(rc)
|
|
if delta.tool_calls:
|
|
for tc_delta in delta.tool_calls:
|
|
idx = tc_delta.index
|
|
entry = tool_calls_acc.setdefault(
|
|
idx, {"id": None, "name": None, "arguments": ""},
|
|
)
|
|
if tc_delta.id:
|
|
entry["id"] = tc_delta.id
|
|
fn = tc_delta.function
|
|
if fn:
|
|
if fn.name:
|
|
entry["name"] = fn.name
|
|
if fn.arguments:
|
|
entry["arguments"] += fn.arguments
|
|
|
|
text = "".join(text_parts)
|
|
reasoning = "".join(reasoning_parts) or None
|
|
ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
|
|
logger.debug(
|
|
"LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls",
|
|
len(text), len(reasoning or ""), len(ordered),
|
|
)
|
|
return text, reasoning, ordered
|
|
|
|
except (APIConnectionError, APITimeoutError, APIError) as e:
|
|
if attempt < max_retries - 1:
|
|
wait = 2 ** attempt * 10
|
|
logger.warning(
|
|
"Tool-call request failed (%s), retrying in %ds...", e, wait,
|
|
)
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise LLMAPIError(
|
|
f"LLM API unreachable after {max_retries} attempts: {e}",
|
|
attempts=max_retries,
|
|
) from e
|
|
|
|
return "", None, []
|
|
|
|
async def tool_call_loop(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict],
|
|
tool_executor: dict[str, Any],
|
|
system: str | None = None,
|
|
max_iterations: int = 60,
|
|
terminal_tools: tuple[str, ...] = (),
|
|
) -> tuple[str, list[dict]]:
|
|
"""Run a tool-calling loop using OpenAI-native tool calls.
|
|
|
|
The model returns structured `tool_calls` in its message; we
|
|
dispatch them through our executor dict and feed each result back
|
|
as a `role: "tool"` message with the matching `tool_call_id`. The
|
|
loop ends when:
|
|
- the model returns a message with no tool_calls (normal exit), or
|
|
- any tool in `terminal_tools` is called — in that case, the loop
|
|
short-circuits with that tool's result text as final_text. This
|
|
gives agents (notably ReportAgent) an explicit completion signal
|
|
that the old `<answer>` text tag used to provide.
|
|
|
|
Returns:
|
|
(final_text, full_message_history)
|
|
"""
|
|
terminal_set = set(terminal_tools)
|
|
openai_tools = _to_openai_tools(tools)
|
|
|
|
# The caller may pass `messages` either as raw conversation (no system)
|
|
# together with `system=...`, OR as a complete history that already
|
|
# starts with the system message (retry path). Accept both shapes.
|
|
if messages and messages[0].get("role") == "system":
|
|
full_messages: list[dict] = list(messages)
|
|
else:
|
|
full_messages = []
|
|
if system:
|
|
full_messages.append({"role": "system", "content": system})
|
|
full_messages.extend(messages)
|
|
_folded = False
|
|
|
|
for _i in range(max_iterations):
|
|
# ── Context compression before each API call ──────────────
|
|
full_messages = _apply_progressive_decay(full_messages)
|
|
if not _folded and len(full_messages) > _FOLD_THRESHOLD:
|
|
full_messages = await self._fold_old_messages(full_messages)
|
|
_folded = True
|
|
elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
|
full_messages = await self._fold_old_messages(full_messages)
|
|
|
|
text, reasoning, raw_tool_calls = await self._chat_with_tools(
|
|
full_messages, openai_tools,
|
|
)
|
|
|
|
if not raw_tool_calls:
|
|
# Model produced a final response. Strip optional <answer>
|
|
# tags for backward compatibility with old prompts.
|
|
final_msg: dict[str, Any] = {"role": "assistant", "content": text}
|
|
if reasoning:
|
|
final_msg["reasoning_content"] = reasoning
|
|
full_messages.append(final_msg)
|
|
answer = _extract_answer(text)
|
|
return (answer if answer is not None else text), full_messages
|
|
|
|
# Parse arguments + build internal call dicts
|
|
parsed_calls: list[dict] = []
|
|
for rc in raw_tool_calls:
|
|
args_str = rc.get("arguments", "") or ""
|
|
try:
|
|
args = _safe_json_loads(args_str) if args_str.strip() else {}
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
logger.warning(
|
|
"Failed to parse arguments for tool %s: %s",
|
|
rc.get("name"), e,
|
|
)
|
|
args = {}
|
|
parsed_calls.append({
|
|
"id": rc.get("id"),
|
|
"name": rc.get("name", ""),
|
|
"arguments": args,
|
|
})
|
|
|
|
# Append the assistant turn with the raw tool_calls (and the
|
|
# DeepSeek-mandated reasoning_content echo-back), then execute.
|
|
asst_msg: dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": text or None,
|
|
"tool_calls": [
|
|
{
|
|
"id": rc.get("id"),
|
|
"type": "function",
|
|
"function": {
|
|
"name": rc.get("name", ""),
|
|
"arguments": rc.get("arguments", "") or "",
|
|
},
|
|
}
|
|
for rc in raw_tool_calls
|
|
],
|
|
}
|
|
if reasoning:
|
|
asst_msg["reasoning_content"] = reasoning
|
|
full_messages.append(asst_msg)
|
|
|
|
batches = _partition_tool_calls(parsed_calls)
|
|
t_batch_start = time.monotonic()
|
|
# Each entry: (tool_call_dict, raw_result, formatted_for_llm)
|
|
executed: list[tuple[dict, str, str]] = []
|
|
for batch in batches:
|
|
if batch.is_read_only and len(batch.calls) > 1:
|
|
results = await self._execute_tool_batch_parallel(
|
|
batch.calls, tool_executor, tools,
|
|
)
|
|
for tc, (raw, formatted) in zip(batch.calls, results):
|
|
executed.append((tc, raw, formatted))
|
|
else:
|
|
for tc in batch.calls:
|
|
raw, formatted = await self._execute_single_tool(
|
|
tc, tool_executor, tools,
|
|
)
|
|
executed.append((tc, raw, formatted))
|
|
|
|
t_batch_elapsed = time.monotonic() - t_batch_start
|
|
_emit_tool_call_summary(parsed_calls, t_batch_elapsed)
|
|
|
|
# Append formatted tool results to the conversation (this is
|
|
# what the LLM sees on subsequent rounds — truncated for context
|
|
# economy).
|
|
for tc, _raw, formatted in executed:
|
|
full_messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc["id"],
|
|
"content": formatted,
|
|
})
|
|
|
|
# Terminal-tool short-circuit: if the model called any tool in
|
|
# `terminal_tools`, end the loop immediately. The terminal tool's
|
|
# RAW result (untruncated) becomes final_text — the LLM may have
|
|
# produced a 20K-char report via save_report and we must not
|
|
# truncate it just because the LLM-facing copy is truncated.
|
|
if terminal_set:
|
|
for tc, raw, _formatted in executed:
|
|
name = tc.get("name", "")
|
|
if name in terminal_set:
|
|
logger.info(
|
|
"Terminal tool %s called — exiting tool_call_loop", name,
|
|
)
|
|
return raw, full_messages
|
|
|
|
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
|
|
return "[Max tool call iterations reached]", full_messages
|
|
|
|
async def _execute_single_tool(
|
|
self, tc: dict, tool_executor: dict[str, Any],
|
|
tools: list[dict] | None = None,
|
|
) -> tuple[str, str]:
|
|
"""Execute a single tool call.
|
|
|
|
Returns (raw_result, formatted_for_llm). `raw_result` is the
|
|
unmodified executor return (used by terminal-tool short-circuit as
|
|
final_text). `formatted_for_llm` is `[tool_name] {truncated}` and
|
|
is what gets fed back to the model as the tool message content.
|
|
"""
|
|
tool_name = tc.get("name", "")
|
|
tool_args = tc.get("arguments", {})
|
|
|
|
if tools:
|
|
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
|
|
|
logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
|
|
|
executor = tool_executor.get(tool_name)
|
|
if executor is None:
|
|
raw = f"Error: unknown tool '{tool_name}'"
|
|
else:
|
|
try:
|
|
raw = await executor(**tool_args)
|
|
except Exception as e:
|
|
logger.error("Tool %s failed: %s", tool_name, e)
|
|
raw = f"Error executing {tool_name}: {e}"
|
|
|
|
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
|
|
return raw, formatted
|
|
|
|
async def _execute_tool_batch_parallel(
|
|
self, calls: list[dict], tool_executor: dict[str, Any],
|
|
tools: list[dict] | None = None,
|
|
) -> list[tuple[str, str]]:
|
|
"""Execute multiple read-only tool calls concurrently.
|
|
|
|
Returns a list of (raw_result, formatted_for_llm) tuples in the
|
|
same order as `calls`.
|
|
"""
|
|
logger.info("Executing %d read-only tools in parallel", len(calls))
|
|
|
|
async def _run_one(tc: dict) -> tuple[str, str]:
|
|
tool_name = tc.get("name", "")
|
|
tool_args = tc.get("arguments", {})
|
|
if tools:
|
|
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
|
logger.info(
|
|
"Calling tool (parallel): %s(%s)",
|
|
tool_name, json.dumps(tool_args, ensure_ascii=False),
|
|
)
|
|
executor = tool_executor.get(tool_name)
|
|
if executor is None:
|
|
raw = f"Error: unknown tool '{tool_name}'"
|
|
else:
|
|
try:
|
|
raw = await executor(**tool_args)
|
|
except Exception as e:
|
|
logger.error("Tool %s failed: %s", tool_name, e)
|
|
raw = f"Error executing {tool_name}: {e}"
|
|
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
|
|
return raw, formatted
|
|
|
|
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
|
|
return list(results)
|
|
|
|
async def _fold_old_messages(
|
|
self, messages: list[dict],
|
|
) -> list[dict]:
|
|
"""Fold old messages into an LLM-generated summary (Stage B).
|
|
|
|
Preserves the leading system message (if any), keeps the most
|
|
recent _FOLD_KEEP_RECENT messages intact, and replaces the older
|
|
middle slice with a single summary user message.
|
|
"""
|
|
# Pin the system message — it must NEVER be summarized away.
|
|
system_msgs: list[dict] = []
|
|
body = messages
|
|
if messages and messages[0].get("role") == "system":
|
|
system_msgs = [messages[0]]
|
|
body = messages[1:]
|
|
|
|
n_to_fold = len(body) - _FOLD_KEEP_RECENT
|
|
if n_to_fold <= 2:
|
|
return messages
|
|
|
|
# Pull the fold boundary forward so we never split an assistant turn
|
|
# from its matching tool results. The API rejects (HTTP 400) any
|
|
# `role: "tool"` message that does not immediately follow an
|
|
# `assistant` message with `tool_calls`. We walk the boundary into
|
|
# `recent_messages` while its head is a `role: "tool"` message, or
|
|
# while the prior `recent` message is `assistant{tool_calls}` whose
|
|
# paired tools span the boundary.
|
|
while n_to_fold < len(body):
|
|
head = body[n_to_fold]
|
|
if head.get("role") == "tool":
|
|
n_to_fold += 1
|
|
continue
|
|
break
|
|
|
|
if n_to_fold >= len(body):
|
|
# Everything got folded — nothing recent to keep.
|
|
return system_msgs + [body[0]] if system_msgs else messages
|
|
|
|
old_messages = body[:n_to_fold]
|
|
recent_messages = body[n_to_fold:]
|
|
|
|
old_text_parts = []
|
|
for msg in old_messages:
|
|
role = msg.get("role", "?")
|
|
content = msg.get("content") or ""
|
|
# Render tool_calls (assistant turn) compactly.
|
|
if role == "assistant" and msg.get("tool_calls"):
|
|
tc_names = [
|
|
tc.get("function", {}).get("name", "?")
|
|
for tc in msg["tool_calls"]
|
|
]
|
|
content = (content + " " if content else "") + (
|
|
"called: " + ", ".join(tc_names)
|
|
)
|
|
if len(content) > 1000:
|
|
content = content[:1000] + "..."
|
|
old_text_parts.append(f"[{role}]: {content}")
|
|
old_text = "\n\n".join(old_text_parts)
|
|
|
|
# Cap total size sent to summarizer
|
|
if len(old_text) > 15000:
|
|
old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]"
|
|
|
|
logger.info(
|
|
"Context folding: summarizing %d old messages (%d chars) into summary",
|
|
n_to_fold, len(old_text),
|
|
)
|
|
|
|
try:
|
|
summary = await self.chat(
|
|
messages=[{"role": "user", "content": old_text}],
|
|
system=_FOLD_SUMMARY_SYSTEM,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Context folding failed: %s — keeping original messages", e)
|
|
return messages
|
|
|
|
summary_message = {
|
|
"role": "user",
|
|
"content": (
|
|
f"[Context summary — the following summarizes {n_to_fold} earlier "
|
|
f"messages in this conversation]\n\n{summary}"
|
|
),
|
|
}
|
|
return system_msgs + [summary_message] + recent_messages
|