refactor: native tool calling + generic forced-retry + terminal exit

- llm_client: switch tool_call_loop from text-based <tool_call> regex
  to OpenAI-native tools=[...] / structured tool_calls field; accumulate
  delta.reasoning_content for DeepSeek thinking-mode echo-back; fold
  preserves system msg and aligns boundary to never orphan role:tool
- base_agent: generic forced-retry via mandatory_record_tools class attr
  (filesystem -> add_phenomenon, timeline -> add_temporal_edge,
  hypothesis -> add_hypothesis, report -> save_report); count via
  executor wrapper
- terminal_tools class attr + loop short-circuit: when a terminal tool
  is called, loop exits with its raw return as final_text. ReportAgent
  declares save_report as terminal - replaces the <answer>-tag stop
  signal that native tool calling broke
- _execute_*: return (raw, formatted) - terminal exit uses untruncated
  raw, conversation history uses 3000-char-capped formatted
- evidence_graph + orchestrator: LLM-derived InvestigationArea support
  (hypothesis-driven coverage check, replaces hardcoded _AREA_KEYWORDS /
  _AREA_TOOLS); manual yaml block kept as optional seed
- strip <answer> references from agent prompts (no longer load-bearing)

Verified on CFReDS image across 4 smoke runs: 0 JSON parse failures
(was 3); 22 temporal edges from Phase 4 (was 0); ReportAgent exits via
save_report (was max_iterations regression). 78/78 unit tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
BattleTag
2026-05-13 13:51:19 +08:00
parent 0a2b344c84
commit 444d58726a
9 changed files with 1356 additions and 298 deletions

View File

@@ -1,9 +1,10 @@
"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
Tool calling is text-based (ReAct pattern): tool schemas are embedded in
the system prompt and tool calls are parsed as <tool_call> JSON blocks
from model output. This keeps the protocol independent of whether the
underlying API supports native function/tool calling.
Tool calling uses the OpenAI-native `tools=[...]` parameter. The model
returns structured tool_calls via the streaming protocol; we accumulate
them, dispatch to our executors, and feed results back as `role: "tool"`
messages. This eliminates the fragile "model writes JSON inside free
text" problem of the previous ReAct text mode.
"""
from __future__ import annotations
@@ -32,52 +33,51 @@ class LLMAPIError(Exception):
self.attempts = attempts
# Markers the model uses to signal tool calls and final answers
TOOL_CALL_TAG = "<tool_call>"
TOOL_CALL_END = "</tool_call>"
TOOL_RESULT_TAG = "<tool_result>"
TOOL_RESULT_END = "</tool_result>"
# Optional answer tags — kept for backward compat with prompts that wrap
# their final response in <answer>...</answer>. Native tool calling does
# not need these (no tool_calls = final), but if the model continues to
# emit them, we strip the tags so callers see clean text.
ANSWER_TAG = "<answer>"
ANSWER_END = "</answer>"
def _build_tools_prompt(tools: list[dict]) -> str:
"""Format tool definitions for inclusion in the system prompt."""
lines = ["You have access to the following tools:\n"]
for t in tools:
schema = t.get("input_schema", {})
props = schema.get("properties", {})
required = schema.get("required", [])
def _to_openai_tools(tools: list[dict]) -> list[dict]:
"""Convert internal tool definitions to OpenAI native function-tools format."""
return [
{
"type": "function",
"function": {
"name": t["name"],
"description": t["description"],
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
},
}
for t in tools
]
params = []
for pname, pdef in props.items():
req = " (required)" if pname in required else ""
desc = pdef.get("description", "")
ptype = pdef.get("type", "string")
enum_vals = pdef.get("enum")
if enum_vals:
allowed = ", ".join(f'"{v}"' for v in enum_vals)
params.append(f" - {pname}: {ptype}{req}{desc} Allowed values: [{allowed}]")
else:
params.append(f" - {pname}: {ptype}{req}{desc}")
param_block = "\n".join(params) if params else " (no parameters)"
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None:
"""Return the first balanced [...] or {...} substring, or None if no balanced pair.
lines.append(
"## How to use tools\n"
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
f"{TOOL_CALL_TAG}\n"
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
f"{TOOL_CALL_END}\n\n"
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
"When you have finished your analysis and have a final answer, wrap it in:\n"
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
"You MUST call at least one tool before giving your final answer."
)
return "\n".join(lines)
Stack-based — handles nested brackets correctly (regex with .*? would
truncate at the first inner closing bracket, regex with .* would over-eat
trailing text). Brackets inside JSON string literals are ignored by
callers because the caller passes the result through json.loads which
re-parses with proper string handling.
"""
start = text.find(open_char)
if start < 0:
return None
depth = 0
for i in range(start, len(text)):
c = text[i]
if c == open_char:
depth += 1
elif c == close_char:
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _safe_json_loads(text: str):
@@ -87,8 +87,8 @@ def _safe_json_loads(text: str):
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
input (first 600 chars) so we can diagnose what the model emitted.
Used both by orchestrator JSON callsites and by _extract_tool_calls
when parsing <tool_call> blocks from model output.
Used by orchestrator JSON callsites (_call_llm_for_json) and by
tool_call_loop when parsing tool_call arguments returned by the API.
"""
try:
return json.loads(text)
@@ -110,23 +110,6 @@ def _safe_json_loads(text: str):
raise
def _extract_tool_calls(text: str) -> list[dict]:
"""Extract tool call JSON blocks from model output."""
pattern = re.compile(
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
re.DOTALL,
)
calls = []
for match in pattern.finditer(text):
raw = match.group(1).strip()
try:
parsed = _safe_json_loads(raw)
calls.append(parsed)
except json.JSONDecodeError:
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
return calls
def _extract_answer(text: str) -> str | None:
"""Extract the final answer from model output."""
pattern = re.compile(
@@ -266,50 +249,41 @@ _DECAY_TIERS: list[tuple[int, int]] = [
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
"""Truncate tool results in older messages to save context space.
"""Truncate the `content` of older `role: "tool"` messages to save context.
Operates in-place-style on a copy. Only touches user messages that
contain <tool_result> blocks (these are the tool-result messages
generated by tool_call_loop).
Each `role: "tool"` message in the conversation corresponds to one tool
call's result. We rank these messages by recency and progressively
truncate older ones according to `_DECAY_TIERS`.
"""
# Count rounds from the end. A "round" is a (assistant, user) pair.
# messages alternate: [user, assistant, user, assistant, user, ...]
# The initial user message is index 0, then pairs start at index 1.
total = len(messages)
if total <= 10: # not enough messages to bother
if total <= 10:
return messages
result = []
# Count tool-result user messages from the end
tool_result_indices = [
i for i, m in enumerate(messages)
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
tool_msg_indices = [
i for i, m in enumerate(messages) if m.get("role") == "tool"
]
# Build a set of indices that need decay, mapped to their max_chars
decay_map: dict[int, int] = {}
n_tool_msgs = len(tool_result_indices)
for rank, idx in enumerate(reversed(tool_result_indices)):
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
for rank, idx in enumerate(reversed(tool_msg_indices)):
rounds_ago = rank
for threshold, max_chars in _DECAY_TIERS:
if rounds_ago < threshold:
decay_map[idx] = max_chars
break
result = []
for i, msg in enumerate(messages):
if i in decay_map:
max_chars = decay_map[i]
content = msg["content"]
content = msg.get("content", "") or ""
if len(content) > max_chars + 200:
# Truncate but preserve the tool_result tags structure
truncated = content[:max_chars]
# Count how many tool results are in this message
n_results = content.count(TOOL_RESULT_TAG)
truncated += (
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
f"{n_results} tool result(s)]"
truncated = (
content[:max_chars]
+ f"\n... [context compressed: {len(content)} -> {max_chars} chars]"
)
result.append({"role": msg["role"], "content": truncated})
new_msg = dict(msg)
new_msg["content"] = truncated
result.append(new_msg)
else:
result.append(msg)
else:
@@ -434,6 +408,95 @@ class LLMClient:
return ""
async def _chat_with_tools(
self,
messages: list[dict],
openai_tools: list[dict],
max_retries: int = 5,
) -> tuple[str, str | None, list[dict]]:
"""Stream a chat completion with native tool calling enabled.
Returns:
(text_content, reasoning_content, raw_tool_calls).
- reasoning_content is non-None when DeepSeek thinking mode is
active; the caller MUST echo it back in the assistant message
on subsequent requests, or the API returns HTTP 400.
- raw_tool_calls is a list of {"id","name","arguments"} dicts;
arguments is the raw JSON string returned by the API.
"""
kwargs: dict[str, Any] = {
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"stream": True,
"tools": openai_tools,
}
if self.reasoning_effort:
kwargs["reasoning_effort"] = self.reasoning_effort
if self.thinking_enabled:
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
for attempt in range(max_retries):
logger.debug(
"LLM request (stream+tools): %d messages, %d tools (attempt %d)",
len(messages), len(openai_tools), attempt + 1,
)
text_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments}
try:
stream = await self._client.chat.completions.create(**kwargs)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta.content:
text_parts.append(delta.content)
# DeepSeek thinking-mode: reasoning_content is returned
# alongside content and MUST be echoed back on subsequent
# requests, otherwise the API rejects with HTTP 400.
rc = getattr(delta, "reasoning_content", None)
if rc:
reasoning_parts.append(rc)
if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
entry = tool_calls_acc.setdefault(
idx, {"id": None, "name": None, "arguments": ""},
)
if tc_delta.id:
entry["id"] = tc_delta.id
fn = tc_delta.function
if fn:
if fn.name:
entry["name"] = fn.name
if fn.arguments:
entry["arguments"] += fn.arguments
text = "".join(text_parts)
reasoning = "".join(reasoning_parts) or None
ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
logger.debug(
"LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls",
len(text), len(reasoning or ""), len(ordered),
)
return text, reasoning, ordered
except (APIConnectionError, APITimeoutError, APIError) as e:
if attempt < max_retries - 1:
wait = 2 ** attempt * 10
logger.warning(
"Tool-call request failed (%s), retrying in %ds...", e, wait,
)
await asyncio.sleep(wait)
else:
raise LLMAPIError(
f"LLM API unreachable after {max_retries} attempts: {e}",
attempts=max_retries,
) from e
return "", None, []
async def tool_call_loop(
self,
messages: list[dict],
@@ -441,87 +504,159 @@ class LLMClient:
tool_executor: dict[str, Any],
system: str | None = None,
max_iterations: int = 40,
terminal_tools: tuple[str, ...] = (),
) -> tuple[str, list[dict]]:
"""Run a ReAct-style tool-calling loop.
"""Run a tool-calling loop using OpenAI-native tool calls.
The model outputs <tool_call> blocks which we parse and execute,
feeding results back as <tool_result> blocks until the model
outputs an <answer> block.
The model returns structured `tool_calls` in its message; we
dispatch them through our executor dict and feed each result back
as a `role: "tool"` message with the matching `tool_call_id`. The
loop ends when:
- the model returns a message with no tool_calls (normal exit), or
- any tool in `terminal_tools` is called — in that case, the loop
short-circuits with that tool's result text as final_text. This
gives agents (notably ReportAgent) an explicit completion signal
that the old `<answer>` text tag used to provide.
Returns:
(final_text, all_messages)
(final_text, full_message_history)
"""
# Build system prompt with tool definitions
tools_prompt = _build_tools_prompt(tools)
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
terminal_set = set(terminal_tools)
openai_tools = _to_openai_tools(tools)
messages = list(messages) # don't mutate caller's list
_folded = False # Track whether we've already folded once this loop
# The caller may pass `messages` either as raw conversation (no system)
# together with `system=...`, OR as a complete history that already
# starts with the system message (retry path). Accept both shapes.
if messages and messages[0].get("role") == "system":
full_messages: list[dict] = list(messages)
else:
full_messages = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
_folded = False
for i in range(max_iterations):
for _i in range(max_iterations):
# ── Context compression before each API call ──────────────
# Stage A: progressively decay old tool results
messages = _apply_progressive_decay(messages)
# Stage B: fold oldest messages into LLM summary if too long
if not _folded and len(messages) > _FOLD_THRESHOLD:
messages = await self._fold_old_messages(messages, full_system)
full_messages = _apply_progressive_decay(full_messages)
if not _folded and len(full_messages) > _FOLD_THRESHOLD:
full_messages = await self._fold_old_messages(full_messages)
_folded = True
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
# Allow a second fold if messages grew back significantly
messages = await self._fold_old_messages(messages, full_system)
elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
full_messages = await self._fold_old_messages(full_messages)
text = await self.chat(messages, system=full_system)
text, reasoning, raw_tool_calls = await self._chat_with_tools(
full_messages, openai_tools,
)
# Check for final answer
answer = _extract_answer(text)
if answer is not None:
messages.append({"role": "assistant", "content": text})
return answer, messages
if not raw_tool_calls:
# Model produced a final response. Strip optional <answer>
# tags for backward compatibility with old prompts.
final_msg: dict[str, Any] = {"role": "assistant", "content": text}
if reasoning:
final_msg["reasoning_content"] = reasoning
full_messages.append(final_msg)
answer = _extract_answer(text)
return (answer if answer is not None else text), full_messages
# Check for tool calls
tool_calls = _extract_tool_calls(text)
# Parse arguments + build internal call dicts
parsed_calls: list[dict] = []
for rc in raw_tool_calls:
args_str = rc.get("arguments", "") or ""
try:
args = _safe_json_loads(args_str) if args_str.strip() else {}
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to parse arguments for tool %s: %s",
rc.get("name"), e,
)
args = {}
parsed_calls.append({
"id": rc.get("id"),
"name": rc.get("name", ""),
"arguments": args,
})
if not tool_calls:
# No tool calls and no answer tag — treat entire text as answer
messages.append({"role": "assistant", "content": text})
return text, messages
# Append the assistant turn with the raw tool_calls (and the
# DeepSeek-mandated reasoning_content echo-back), then execute.
asst_msg: dict[str, Any] = {
"role": "assistant",
"content": text or None,
"tool_calls": [
{
"id": rc.get("id"),
"type": "function",
"function": {
"name": rc.get("name", ""),
"arguments": rc.get("arguments", "") or "",
},
}
for rc in raw_tool_calls
],
}
if reasoning:
asst_msg["reasoning_content"] = reasoning
full_messages.append(asst_msg)
# Execute tool calls — read-only tools run in parallel
messages.append({"role": "assistant", "content": text})
result_parts = []
batches = _partition_tool_calls(tool_calls)
batches = _partition_tool_calls(parsed_calls)
t_batch_start = time.monotonic()
# Each entry: (tool_call_dict, raw_result, formatted_for_llm)
executed: list[tuple[dict, str, str]] = []
for batch in batches:
if batch.is_read_only and len(batch.calls) > 1:
batch_results = await self._execute_tool_batch_parallel(
results = await self._execute_tool_batch_parallel(
batch.calls, tool_executor, tools,
)
result_parts.extend(batch_results)
for tc, (raw, formatted) in zip(batch.calls, results):
executed.append((tc, raw, formatted))
else:
for tc in batch.calls:
result_parts.append(
await self._execute_single_tool(tc, tool_executor, tools)
raw, formatted = await self._execute_single_tool(
tc, tool_executor, tools,
)
executed.append((tc, raw, formatted))
# Emit folded tool-call summary for the terminal
t_batch_elapsed = time.monotonic() - t_batch_start
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
_emit_tool_call_summary(parsed_calls, t_batch_elapsed)
# Feed results back as a user message
result_message = "\n\n".join(result_parts)
messages.append({"role": "user", "content": result_message})
# Append formatted tool results to the conversation (this is
# what the LLM sees on subsequent rounds — truncated for context
# economy).
for tc, _raw, formatted in executed:
full_messages.append({
"role": "tool",
"tool_call_id": tc["id"],
"content": formatted,
})
# Terminal-tool short-circuit: if the model called any tool in
# `terminal_tools`, end the loop immediately. The terminal tool's
# RAW result (untruncated) becomes final_text — the LLM may have
# produced a 20K-char report via save_report and we must not
# truncate it just because the LLM-facing copy is truncated.
if terminal_set:
for tc, raw, _formatted in executed:
name = tc.get("name", "")
if name in terminal_set:
logger.info(
"Terminal tool %s called — exiting tool_call_loop", name,
)
return raw, full_messages
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
return "[Max tool call iterations reached]", messages
return "[Max tool call iterations reached]", full_messages
async def _execute_single_tool(
self, tc: dict, tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> str:
"""Execute a single tool call and return the formatted result."""
) -> tuple[str, str]:
"""Execute a single tool call.
Returns (raw_result, formatted_for_llm). `raw_result` is the
unmodified executor return (used by terminal-tool short-circuit as
final_text). `formatted_for_llm` is `[tool_name] {truncated}` and
is what gets fed back to the model as the tool message content.
"""
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
@@ -532,72 +667,106 @@ class LLMClient:
executor = tool_executor.get(tool_name)
if executor is None:
result_text = f"Error: unknown tool '{tool_name}'"
raw = f"Error: unknown tool '{tool_name}'"
else:
try:
result_text = await executor(**tool_args)
raw = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
result_text = f"Error executing {tool_name}: {e}"
raw = f"Error executing {tool_name}: {e}"
return (
f"{TOOL_RESULT_TAG}\n"
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
f"{TOOL_RESULT_END}"
)
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
return raw, formatted
async def _execute_tool_batch_parallel(
self, calls: list[dict], tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> list[str]:
"""Execute multiple read-only tool calls concurrently."""
) -> list[tuple[str, str]]:
"""Execute multiple read-only tool calls concurrently.
Returns a list of (raw_result, formatted_for_llm) tuples in the
same order as `calls`.
"""
logger.info("Executing %d read-only tools in parallel", len(calls))
async def _run_one(tc: dict) -> str:
async def _run_one(tc: dict) -> tuple[str, str]:
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
if tools:
tool_args = _fix_tool_args(tool_name, tool_args, tools)
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
logger.info(
"Calling tool (parallel): %s(%s)",
tool_name, json.dumps(tool_args, ensure_ascii=False),
)
executor = tool_executor.get(tool_name)
if executor is None:
result_text = f"Error: unknown tool '{tool_name}'"
raw = f"Error: unknown tool '{tool_name}'"
else:
try:
result_text = await executor(**tool_args)
raw = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
result_text = f"Error executing {tool_name}: {e}"
return (
f"{TOOL_RESULT_TAG}\n"
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
f"{TOOL_RESULT_END}"
)
raw = f"Error executing {tool_name}: {e}"
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
return raw, formatted
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
return list(results)
async def _fold_old_messages(
self, messages: list[dict], system: str,
self, messages: list[dict],
) -> list[dict]:
"""Fold old messages into an LLM-generated summary (Stage B).
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
replaces earlier ones with a single summary message.
Preserves the leading system message (if any), keeps the most
recent _FOLD_KEEP_RECENT messages intact, and replaces the older
middle slice with a single summary user message.
"""
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
# Pin the system message — it must NEVER be summarized away.
system_msgs: list[dict] = []
body = messages
if messages and messages[0].get("role") == "system":
system_msgs = [messages[0]]
body = messages[1:]
n_to_fold = len(body) - _FOLD_KEEP_RECENT
if n_to_fold <= 2:
return messages
old_messages = messages[:n_to_fold]
recent_messages = messages[n_to_fold:]
# Pull the fold boundary forward so we never split an assistant turn
# from its matching tool results. The API rejects (HTTP 400) any
# `role: "tool"` message that does not immediately follow an
# `assistant` message with `tool_calls`. We walk the boundary into
# `recent_messages` while its head is a `role: "tool"` message, or
# while the prior `recent` message is `assistant{tool_calls}` whose
# paired tools span the boundary.
while n_to_fold < len(body):
head = body[n_to_fold]
if head.get("role") == "tool":
n_to_fold += 1
continue
break
if n_to_fold >= len(body):
# Everything got folded — nothing recent to keep.
return system_msgs + [body[0]] if system_msgs else messages
old_messages = body[:n_to_fold]
recent_messages = body[n_to_fold:]
# Build a text dump of old messages for summarization
old_text_parts = []
for msg in old_messages:
role = msg["role"]
content = msg.get("content", "")
# Truncate each message for the summary prompt to avoid overload
role = msg.get("role", "?")
content = msg.get("content") or ""
# Render tool_calls (assistant turn) compactly.
if role == "assistant" and msg.get("tool_calls"):
tc_names = [
tc.get("function", {}).get("name", "?")
for tc in msg["tool_calls"]
]
content = (content + " " if content else "") + (
"called: " + ", ".join(tc_names)
)
if len(content) > 1000:
content = content[:1000] + "..."
old_text_parts.append(f"[{role}]: {content}")
@@ -621,7 +790,6 @@ class LLMClient:
logger.warning("Context folding failed: %s — keeping original messages", e)
return messages
# Replace old messages with a single summary
summary_message = {
"role": "user",
"content": (
@@ -629,4 +797,4 @@ class LLMClient:
f"messages in this conversation]\n\n{summary}"
),
}
return [summary_message] + recent_messages
return system_msgs + [summary_message] + recent_messages