refactor: native tool calling + generic forced-retry + terminal exit
- llm_client: switch tool_call_loop from text-based <tool_call> regex to OpenAI-native tools=[...] / structured tool_calls field; accumulate delta.reasoning_content for DeepSeek thinking-mode echo-back; fold preserves system msg and aligns boundary to never orphan role:tool - base_agent: generic forced-retry via mandatory_record_tools class attr (filesystem -> add_phenomenon, timeline -> add_temporal_edge, hypothesis -> add_hypothesis, report -> save_report); count via executor wrapper - terminal_tools class attr + loop short-circuit: when a terminal tool is called, loop exits with its raw return as final_text. ReportAgent declares save_report as terminal - replaces the <answer>-tag stop signal that native tool calling broke - _execute_*: return (raw, formatted) - terminal exit uses untruncated raw, conversation history uses 3000-char-capped formatted - evidence_graph + orchestrator: LLM-derived InvestigationArea support (hypothesis-driven coverage check, replaces hardcoded _AREA_KEYWORDS / _AREA_TOOLS); manual yaml block kept as optional seed - strip <answer> references from agent prompts (no longer load-bearing) Verified on CFReDS image across 4 smoke runs: 0 JSON parse failures (was 3); 22 temporal edges from Phase 4 (was 0); ReportAgent exits via save_report (was max_iterations regression). 78/78 unit tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
508
llm_client.py
508
llm_client.py
@@ -1,9 +1,10 @@
|
||||
"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
|
||||
|
||||
Tool calling is text-based (ReAct pattern): tool schemas are embedded in
|
||||
the system prompt and tool calls are parsed as <tool_call> JSON blocks
|
||||
from model output. This keeps the protocol independent of whether the
|
||||
underlying API supports native function/tool calling.
|
||||
Tool calling uses the OpenAI-native `tools=[...]` parameter. The model
|
||||
returns structured tool_calls via the streaming protocol; we accumulate
|
||||
them, dispatch to our executors, and feed results back as `role: "tool"`
|
||||
messages. This eliminates the fragile "model writes JSON inside free
|
||||
text" problem of the previous ReAct text mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -32,52 +33,51 @@ class LLMAPIError(Exception):
|
||||
self.attempts = attempts
|
||||
|
||||
|
||||
# Markers the model uses to signal tool calls and final answers
|
||||
TOOL_CALL_TAG = "<tool_call>"
|
||||
TOOL_CALL_END = "</tool_call>"
|
||||
TOOL_RESULT_TAG = "<tool_result>"
|
||||
TOOL_RESULT_END = "</tool_result>"
|
||||
# Optional answer tags — kept for backward compat with prompts that wrap
|
||||
# their final response in <answer>...</answer>. Native tool calling does
|
||||
# not need these (no tool_calls = final), but if the model continues to
|
||||
# emit them, we strip the tags so callers see clean text.
|
||||
ANSWER_TAG = "<answer>"
|
||||
ANSWER_END = "</answer>"
|
||||
|
||||
|
||||
def _build_tools_prompt(tools: list[dict]) -> str:
|
||||
"""Format tool definitions for inclusion in the system prompt."""
|
||||
lines = ["You have access to the following tools:\n"]
|
||||
for t in tools:
|
||||
schema = t.get("input_schema", {})
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
def _to_openai_tools(tools: list[dict]) -> list[dict]:
|
||||
"""Convert internal tool definitions to OpenAI native function-tools format."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
|
||||
},
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
|
||||
params = []
|
||||
for pname, pdef in props.items():
|
||||
req = " (required)" if pname in required else ""
|
||||
desc = pdef.get("description", "")
|
||||
ptype = pdef.get("type", "string")
|
||||
enum_vals = pdef.get("enum")
|
||||
if enum_vals:
|
||||
allowed = ", ".join(f'"{v}"' for v in enum_vals)
|
||||
params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]")
|
||||
else:
|
||||
params.append(f" - {pname}: {ptype}{req} — {desc}")
|
||||
|
||||
param_block = "\n".join(params) if params else " (no parameters)"
|
||||
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
|
||||
def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None:
|
||||
"""Return the first balanced [...] or {...} substring, or None if no balanced pair.
|
||||
|
||||
lines.append(
|
||||
"## How to use tools\n"
|
||||
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
|
||||
f"{TOOL_CALL_TAG}\n"
|
||||
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
|
||||
f"{TOOL_CALL_END}\n\n"
|
||||
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
|
||||
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
|
||||
"When you have finished your analysis and have a final answer, wrap it in:\n"
|
||||
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
|
||||
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
|
||||
"You MUST call at least one tool before giving your final answer."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
Stack-based — handles nested brackets correctly (regex with .*? would
|
||||
truncate at the first inner closing bracket, regex with .* would over-eat
|
||||
trailing text). Brackets inside JSON string literals are ignored by
|
||||
callers because the caller passes the result through json.loads which
|
||||
re-parses with proper string handling.
|
||||
"""
|
||||
start = text.find(open_char)
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if c == open_char:
|
||||
depth += 1
|
||||
elif c == close_char:
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _safe_json_loads(text: str):
|
||||
@@ -87,8 +87,8 @@ def _safe_json_loads(text: str):
|
||||
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
|
||||
input (first 600 chars) so we can diagnose what the model emitted.
|
||||
|
||||
Used both by orchestrator JSON callsites and by _extract_tool_calls
|
||||
when parsing <tool_call> blocks from model output.
|
||||
Used by orchestrator JSON callsites (_call_llm_for_json) and by
|
||||
tool_call_loop when parsing tool_call arguments returned by the API.
|
||||
"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
@@ -110,23 +110,6 @@ def _safe_json_loads(text: str):
|
||||
raise
|
||||
|
||||
|
||||
def _extract_tool_calls(text: str) -> list[dict]:
|
||||
"""Extract tool call JSON blocks from model output."""
|
||||
pattern = re.compile(
|
||||
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
|
||||
re.DOTALL,
|
||||
)
|
||||
calls = []
|
||||
for match in pattern.finditer(text):
|
||||
raw = match.group(1).strip()
|
||||
try:
|
||||
parsed = _safe_json_loads(raw)
|
||||
calls.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
|
||||
return calls
|
||||
|
||||
|
||||
def _extract_answer(text: str) -> str | None:
|
||||
"""Extract the final answer from model output."""
|
||||
pattern = re.compile(
|
||||
@@ -266,50 +249,41 @@ _DECAY_TIERS: list[tuple[int, int]] = [
|
||||
|
||||
|
||||
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
|
||||
"""Truncate tool results in older messages to save context space.
|
||||
"""Truncate the `content` of older `role: "tool"` messages to save context.
|
||||
|
||||
Operates in-place-style on a copy. Only touches user messages that
|
||||
contain <tool_result> blocks (these are the tool-result messages
|
||||
generated by tool_call_loop).
|
||||
Each `role: "tool"` message in the conversation corresponds to one tool
|
||||
call's result. We rank these messages by recency and progressively
|
||||
truncate older ones according to `_DECAY_TIERS`.
|
||||
"""
|
||||
# Count rounds from the end. A "round" is a (assistant, user) pair.
|
||||
# messages alternate: [user, assistant, user, assistant, user, ...]
|
||||
# The initial user message is index 0, then pairs start at index 1.
|
||||
total = len(messages)
|
||||
if total <= 10: # not enough messages to bother
|
||||
if total <= 10:
|
||||
return messages
|
||||
|
||||
result = []
|
||||
# Count tool-result user messages from the end
|
||||
tool_result_indices = [
|
||||
i for i, m in enumerate(messages)
|
||||
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
|
||||
tool_msg_indices = [
|
||||
i for i, m in enumerate(messages) if m.get("role") == "tool"
|
||||
]
|
||||
|
||||
# Build a set of indices that need decay, mapped to their max_chars
|
||||
decay_map: dict[int, int] = {}
|
||||
n_tool_msgs = len(tool_result_indices)
|
||||
for rank, idx in enumerate(reversed(tool_result_indices)):
|
||||
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
|
||||
for rank, idx in enumerate(reversed(tool_msg_indices)):
|
||||
rounds_ago = rank
|
||||
for threshold, max_chars in _DECAY_TIERS:
|
||||
if rounds_ago < threshold:
|
||||
decay_map[idx] = max_chars
|
||||
break
|
||||
|
||||
result = []
|
||||
for i, msg in enumerate(messages):
|
||||
if i in decay_map:
|
||||
max_chars = decay_map[i]
|
||||
content = msg["content"]
|
||||
content = msg.get("content", "") or ""
|
||||
if len(content) > max_chars + 200:
|
||||
# Truncate but preserve the tool_result tags structure
|
||||
truncated = content[:max_chars]
|
||||
# Count how many tool results are in this message
|
||||
n_results = content.count(TOOL_RESULT_TAG)
|
||||
truncated += (
|
||||
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
|
||||
f"{n_results} tool result(s)]"
|
||||
truncated = (
|
||||
content[:max_chars]
|
||||
+ f"\n... [context compressed: {len(content)} -> {max_chars} chars]"
|
||||
)
|
||||
result.append({"role": msg["role"], "content": truncated})
|
||||
new_msg = dict(msg)
|
||||
new_msg["content"] = truncated
|
||||
result.append(new_msg)
|
||||
else:
|
||||
result.append(msg)
|
||||
else:
|
||||
@@ -434,6 +408,95 @@ class LLMClient:
|
||||
|
||||
return ""
|
||||
|
||||
async def _chat_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
openai_tools: list[dict],
|
||||
max_retries: int = 5,
|
||||
) -> tuple[str, str | None, list[dict]]:
|
||||
"""Stream a chat completion with native tool calling enabled.
|
||||
|
||||
Returns:
|
||||
(text_content, reasoning_content, raw_tool_calls).
|
||||
- reasoning_content is non-None when DeepSeek thinking mode is
|
||||
active; the caller MUST echo it back in the assistant message
|
||||
on subsequent requests, or the API returns HTTP 400.
|
||||
- raw_tool_calls is a list of {"id","name","arguments"} dicts;
|
||||
arguments is the raw JSON string returned by the API.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": True,
|
||||
"tools": openai_tools,
|
||||
}
|
||||
if self.reasoning_effort:
|
||||
kwargs["reasoning_effort"] = self.reasoning_effort
|
||||
if self.thinking_enabled:
|
||||
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
logger.debug(
|
||||
"LLM request (stream+tools): %d messages, %d tools (attempt %d)",
|
||||
len(messages), len(openai_tools), attempt + 1,
|
||||
)
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments}
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
text_parts.append(delta.content)
|
||||
# DeepSeek thinking-mode: reasoning_content is returned
|
||||
# alongside content and MUST be echoed back on subsequent
|
||||
# requests, otherwise the API rejects with HTTP 400.
|
||||
rc = getattr(delta, "reasoning_content", None)
|
||||
if rc:
|
||||
reasoning_parts.append(rc)
|
||||
if delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index
|
||||
entry = tool_calls_acc.setdefault(
|
||||
idx, {"id": None, "name": None, "arguments": ""},
|
||||
)
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
fn = tc_delta.function
|
||||
if fn:
|
||||
if fn.name:
|
||||
entry["name"] = fn.name
|
||||
if fn.arguments:
|
||||
entry["arguments"] += fn.arguments
|
||||
|
||||
text = "".join(text_parts)
|
||||
reasoning = "".join(reasoning_parts) or None
|
||||
ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
|
||||
logger.debug(
|
||||
"LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls",
|
||||
len(text), len(reasoning or ""), len(ordered),
|
||||
)
|
||||
return text, reasoning, ordered
|
||||
|
||||
except (APIConnectionError, APITimeoutError, APIError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait = 2 ** attempt * 10
|
||||
logger.warning(
|
||||
"Tool-call request failed (%s), retrying in %ds...", e, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
raise LLMAPIError(
|
||||
f"LLM API unreachable after {max_retries} attempts: {e}",
|
||||
attempts=max_retries,
|
||||
) from e
|
||||
|
||||
return "", None, []
|
||||
|
||||
async def tool_call_loop(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -441,87 +504,159 @@ class LLMClient:
|
||||
tool_executor: dict[str, Any],
|
||||
system: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
terminal_tools: tuple[str, ...] = (),
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Run a ReAct-style tool-calling loop.
|
||||
"""Run a tool-calling loop using OpenAI-native tool calls.
|
||||
|
||||
The model outputs <tool_call> blocks which we parse and execute,
|
||||
feeding results back as <tool_result> blocks until the model
|
||||
outputs an <answer> block.
|
||||
The model returns structured `tool_calls` in its message; we
|
||||
dispatch them through our executor dict and feed each result back
|
||||
as a `role: "tool"` message with the matching `tool_call_id`. The
|
||||
loop ends when:
|
||||
- the model returns a message with no tool_calls (normal exit), or
|
||||
- any tool in `terminal_tools` is called — in that case, the loop
|
||||
short-circuits with that tool's result text as final_text. This
|
||||
gives agents (notably ReportAgent) an explicit completion signal
|
||||
that the old `<answer>` text tag used to provide.
|
||||
|
||||
Returns:
|
||||
(final_text, all_messages)
|
||||
(final_text, full_message_history)
|
||||
"""
|
||||
# Build system prompt with tool definitions
|
||||
tools_prompt = _build_tools_prompt(tools)
|
||||
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
|
||||
terminal_set = set(terminal_tools)
|
||||
openai_tools = _to_openai_tools(tools)
|
||||
|
||||
messages = list(messages) # don't mutate caller's list
|
||||
_folded = False # Track whether we've already folded once this loop
|
||||
# The caller may pass `messages` either as raw conversation (no system)
|
||||
# together with `system=...`, OR as a complete history that already
|
||||
# starts with the system message (retry path). Accept both shapes.
|
||||
if messages and messages[0].get("role") == "system":
|
||||
full_messages: list[dict] = list(messages)
|
||||
else:
|
||||
full_messages = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
_folded = False
|
||||
|
||||
for i in range(max_iterations):
|
||||
for _i in range(max_iterations):
|
||||
# ── Context compression before each API call ──────────────
|
||||
# Stage A: progressively decay old tool results
|
||||
messages = _apply_progressive_decay(messages)
|
||||
|
||||
# Stage B: fold oldest messages into LLM summary if too long
|
||||
if not _folded and len(messages) > _FOLD_THRESHOLD:
|
||||
messages = await self._fold_old_messages(messages, full_system)
|
||||
full_messages = _apply_progressive_decay(full_messages)
|
||||
if not _folded and len(full_messages) > _FOLD_THRESHOLD:
|
||||
full_messages = await self._fold_old_messages(full_messages)
|
||||
_folded = True
|
||||
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
||||
# Allow a second fold if messages grew back significantly
|
||||
messages = await self._fold_old_messages(messages, full_system)
|
||||
elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
||||
full_messages = await self._fold_old_messages(full_messages)
|
||||
|
||||
text = await self.chat(messages, system=full_system)
|
||||
text, reasoning, raw_tool_calls = await self._chat_with_tools(
|
||||
full_messages, openai_tools,
|
||||
)
|
||||
|
||||
# Check for final answer
|
||||
answer = _extract_answer(text)
|
||||
if answer is not None:
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
return answer, messages
|
||||
if not raw_tool_calls:
|
||||
# Model produced a final response. Strip optional <answer>
|
||||
# tags for backward compatibility with old prompts.
|
||||
final_msg: dict[str, Any] = {"role": "assistant", "content": text}
|
||||
if reasoning:
|
||||
final_msg["reasoning_content"] = reasoning
|
||||
full_messages.append(final_msg)
|
||||
answer = _extract_answer(text)
|
||||
return (answer if answer is not None else text), full_messages
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = _extract_tool_calls(text)
|
||||
# Parse arguments + build internal call dicts
|
||||
parsed_calls: list[dict] = []
|
||||
for rc in raw_tool_calls:
|
||||
args_str = rc.get("arguments", "") or ""
|
||||
try:
|
||||
args = _safe_json_loads(args_str) if args_str.strip() else {}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"Failed to parse arguments for tool %s: %s",
|
||||
rc.get("name"), e,
|
||||
)
|
||||
args = {}
|
||||
parsed_calls.append({
|
||||
"id": rc.get("id"),
|
||||
"name": rc.get("name", ""),
|
||||
"arguments": args,
|
||||
})
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls and no answer tag — treat entire text as answer
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
return text, messages
|
||||
# Append the assistant turn with the raw tool_calls (and the
|
||||
# DeepSeek-mandated reasoning_content echo-back), then execute.
|
||||
asst_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": text or None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": rc.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": rc.get("name", ""),
|
||||
"arguments": rc.get("arguments", "") or "",
|
||||
},
|
||||
}
|
||||
for rc in raw_tool_calls
|
||||
],
|
||||
}
|
||||
if reasoning:
|
||||
asst_msg["reasoning_content"] = reasoning
|
||||
full_messages.append(asst_msg)
|
||||
|
||||
# Execute tool calls — read-only tools run in parallel
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
|
||||
result_parts = []
|
||||
batches = _partition_tool_calls(tool_calls)
|
||||
batches = _partition_tool_calls(parsed_calls)
|
||||
t_batch_start = time.monotonic()
|
||||
|
||||
# Each entry: (tool_call_dict, raw_result, formatted_for_llm)
|
||||
executed: list[tuple[dict, str, str]] = []
|
||||
for batch in batches:
|
||||
if batch.is_read_only and len(batch.calls) > 1:
|
||||
batch_results = await self._execute_tool_batch_parallel(
|
||||
results = await self._execute_tool_batch_parallel(
|
||||
batch.calls, tool_executor, tools,
|
||||
)
|
||||
result_parts.extend(batch_results)
|
||||
for tc, (raw, formatted) in zip(batch.calls, results):
|
||||
executed.append((tc, raw, formatted))
|
||||
else:
|
||||
for tc in batch.calls:
|
||||
result_parts.append(
|
||||
await self._execute_single_tool(tc, tool_executor, tools)
|
||||
raw, formatted = await self._execute_single_tool(
|
||||
tc, tool_executor, tools,
|
||||
)
|
||||
executed.append((tc, raw, formatted))
|
||||
|
||||
# Emit folded tool-call summary for the terminal
|
||||
t_batch_elapsed = time.monotonic() - t_batch_start
|
||||
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
|
||||
_emit_tool_call_summary(parsed_calls, t_batch_elapsed)
|
||||
|
||||
# Feed results back as a user message
|
||||
result_message = "\n\n".join(result_parts)
|
||||
messages.append({"role": "user", "content": result_message})
|
||||
# Append formatted tool results to the conversation (this is
|
||||
# what the LLM sees on subsequent rounds — truncated for context
|
||||
# economy).
|
||||
for tc, _raw, formatted in executed:
|
||||
full_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"content": formatted,
|
||||
})
|
||||
|
||||
# Terminal-tool short-circuit: if the model called any tool in
|
||||
# `terminal_tools`, end the loop immediately. The terminal tool's
|
||||
# RAW result (untruncated) becomes final_text — the LLM may have
|
||||
# produced a 20K-char report via save_report and we must not
|
||||
# truncate it just because the LLM-facing copy is truncated.
|
||||
if terminal_set:
|
||||
for tc, raw, _formatted in executed:
|
||||
name = tc.get("name", "")
|
||||
if name in terminal_set:
|
||||
logger.info(
|
||||
"Terminal tool %s called — exiting tool_call_loop", name,
|
||||
)
|
||||
return raw, full_messages
|
||||
|
||||
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
|
||||
return "[Max tool call iterations reached]", messages
|
||||
return "[Max tool call iterations reached]", full_messages
|
||||
|
||||
async def _execute_single_tool(
|
||||
self, tc: dict, tool_executor: dict[str, Any],
|
||||
tools: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Execute a single tool call and return the formatted result."""
|
||||
) -> tuple[str, str]:
|
||||
"""Execute a single tool call.
|
||||
|
||||
Returns (raw_result, formatted_for_llm). `raw_result` is the
|
||||
unmodified executor return (used by terminal-tool short-circuit as
|
||||
final_text). `formatted_for_llm` is `[tool_name] {truncated}` and
|
||||
is what gets fed back to the model as the tool message content.
|
||||
"""
|
||||
tool_name = tc.get("name", "")
|
||||
tool_args = tc.get("arguments", {})
|
||||
|
||||
@@ -532,72 +667,106 @@ class LLMClient:
|
||||
|
||||
executor = tool_executor.get(tool_name)
|
||||
if executor is None:
|
||||
result_text = f"Error: unknown tool '{tool_name}'"
|
||||
raw = f"Error: unknown tool '{tool_name}'"
|
||||
else:
|
||||
try:
|
||||
result_text = await executor(**tool_args)
|
||||
raw = await executor(**tool_args)
|
||||
except Exception as e:
|
||||
logger.error("Tool %s failed: %s", tool_name, e)
|
||||
result_text = f"Error executing {tool_name}: {e}"
|
||||
raw = f"Error executing {tool_name}: {e}"
|
||||
|
||||
return (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
)
|
||||
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
|
||||
return raw, formatted
|
||||
|
||||
async def _execute_tool_batch_parallel(
|
||||
self, calls: list[dict], tool_executor: dict[str, Any],
|
||||
tools: list[dict] | None = None,
|
||||
) -> list[str]:
|
||||
"""Execute multiple read-only tool calls concurrently."""
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Execute multiple read-only tool calls concurrently.
|
||||
|
||||
Returns a list of (raw_result, formatted_for_llm) tuples in the
|
||||
same order as `calls`.
|
||||
"""
|
||||
logger.info("Executing %d read-only tools in parallel", len(calls))
|
||||
|
||||
async def _run_one(tc: dict) -> str:
|
||||
async def _run_one(tc: dict) -> tuple[str, str]:
|
||||
tool_name = tc.get("name", "")
|
||||
tool_args = tc.get("arguments", {})
|
||||
if tools:
|
||||
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
||||
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
||||
logger.info(
|
||||
"Calling tool (parallel): %s(%s)",
|
||||
tool_name, json.dumps(tool_args, ensure_ascii=False),
|
||||
)
|
||||
executor = tool_executor.get(tool_name)
|
||||
if executor is None:
|
||||
result_text = f"Error: unknown tool '{tool_name}'"
|
||||
raw = f"Error: unknown tool '{tool_name}'"
|
||||
else:
|
||||
try:
|
||||
result_text = await executor(**tool_args)
|
||||
raw = await executor(**tool_args)
|
||||
except Exception as e:
|
||||
logger.error("Tool %s failed: %s", tool_name, e)
|
||||
result_text = f"Error executing {tool_name}: {e}"
|
||||
return (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
)
|
||||
raw = f"Error executing {tool_name}: {e}"
|
||||
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
|
||||
return raw, formatted
|
||||
|
||||
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
|
||||
return list(results)
|
||||
|
||||
async def _fold_old_messages(
|
||||
self, messages: list[dict], system: str,
|
||||
self, messages: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Fold old messages into an LLM-generated summary (Stage B).
|
||||
|
||||
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
|
||||
replaces earlier ones with a single summary message.
|
||||
Preserves the leading system message (if any), keeps the most
|
||||
recent _FOLD_KEEP_RECENT messages intact, and replaces the older
|
||||
middle slice with a single summary user message.
|
||||
"""
|
||||
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
|
||||
# Pin the system message — it must NEVER be summarized away.
|
||||
system_msgs: list[dict] = []
|
||||
body = messages
|
||||
if messages and messages[0].get("role") == "system":
|
||||
system_msgs = [messages[0]]
|
||||
body = messages[1:]
|
||||
|
||||
n_to_fold = len(body) - _FOLD_KEEP_RECENT
|
||||
if n_to_fold <= 2:
|
||||
return messages
|
||||
|
||||
old_messages = messages[:n_to_fold]
|
||||
recent_messages = messages[n_to_fold:]
|
||||
# Pull the fold boundary forward so we never split an assistant turn
|
||||
# from its matching tool results. The API rejects (HTTP 400) any
|
||||
# `role: "tool"` message that does not immediately follow an
|
||||
# `assistant` message with `tool_calls`. We walk the boundary into
|
||||
# `recent_messages` while its head is a `role: "tool"` message, or
|
||||
# while the prior `recent` message is `assistant{tool_calls}` whose
|
||||
# paired tools span the boundary.
|
||||
while n_to_fold < len(body):
|
||||
head = body[n_to_fold]
|
||||
if head.get("role") == "tool":
|
||||
n_to_fold += 1
|
||||
continue
|
||||
break
|
||||
|
||||
if n_to_fold >= len(body):
|
||||
# Everything got folded — nothing recent to keep.
|
||||
return system_msgs + [body[0]] if system_msgs else messages
|
||||
|
||||
old_messages = body[:n_to_fold]
|
||||
recent_messages = body[n_to_fold:]
|
||||
|
||||
# Build a text dump of old messages for summarization
|
||||
old_text_parts = []
|
||||
for msg in old_messages:
|
||||
role = msg["role"]
|
||||
content = msg.get("content", "")
|
||||
# Truncate each message for the summary prompt to avoid overload
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content") or ""
|
||||
# Render tool_calls (assistant turn) compactly.
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
tc_names = [
|
||||
tc.get("function", {}).get("name", "?")
|
||||
for tc in msg["tool_calls"]
|
||||
]
|
||||
content = (content + " " if content else "") + (
|
||||
"called: " + ", ".join(tc_names)
|
||||
)
|
||||
if len(content) > 1000:
|
||||
content = content[:1000] + "..."
|
||||
old_text_parts.append(f"[{role}]: {content}")
|
||||
@@ -621,7 +790,6 @@ class LLMClient:
|
||||
logger.warning("Context folding failed: %s — keeping original messages", e)
|
||||
return messages
|
||||
|
||||
# Replace old messages with a single summary
|
||||
summary_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
@@ -629,4 +797,4 @@ class LLMClient:
|
||||
f"messages in this conversation]\n\n{summary}"
|
||||
),
|
||||
}
|
||||
return [summary_message] + recent_messages
|
||||
return system_msgs + [summary_message] + recent_messages
|
||||
|
||||
Reference in New Issue
Block a user