Files
MASForensic/llm_client.py
BattleTag 0a2b344c84 fix: share _safe_json_loads with tool-call parser, not just orchestrator
Move _safe_json_loads from orchestrator.py to llm_client.py and have
_extract_tool_calls use it when parsing <tool_call> JSON blocks from
model output. orchestrator now imports it from llm_client.

Background: in the first full DeepSeek run (runs/2026-05-12T17-25-38),
~10 'Failed to parse tool call JSON' warnings appeared, all from regex
patterns where the LLM wrote \. or \* inside JSON string values:

  Failed to parse tool call JSON: {..., "pattern": "Outlook Express|...|\.dbx"}
  Failed to parse tool call JSON: {..., "pattern": "ethereal.*\.pcap"}
  Failed to parse tool call JSON: {..., "pattern": "lookatlan.*\.txt|..."}

These are exactly the kind of stray-backslash errors stage-1 sanitize
already handles for orchestrator JSON calls — but tool-call extraction
was using bare json.loads. Result: each failed tool call silently dropped
on the floor, the LLM never got a result, and at least one network agent
burned 14m26s spinning before hitting max_iterations=40.

Now the sanitize/log-on-failure path is shared. Verified against the
three failure cases from yesterday's log: all three now parse cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 20:29:21 +08:00

633 lines
23 KiB
Python

"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
Tool calling is text-based (ReAct pattern): tool schemas are embedded in
the system prompt and tool calls are parsed as <tool_call> JSON blocks
from model output. This keeps the protocol independent of whether the
underlying API supports native function/tool calling.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from collections import Counter
from dataclasses import dataclass, field
from typing import Any
import httpx
from openai import APIConnectionError, APIError, APITimeoutError, AsyncOpenAI
logger = logging.getLogger(__name__)
class LLMAPIError(Exception):
"""Raised when the LLM API is unreachable after all retries."""
def __init__(self, message: str, attempts: int) -> None:
super().__init__(message)
self.attempts = attempts
# Markers the model uses to signal tool calls and final answers
TOOL_CALL_TAG = "<tool_call>"
TOOL_CALL_END = "</tool_call>"
TOOL_RESULT_TAG = "<tool_result>"
TOOL_RESULT_END = "</tool_result>"
ANSWER_TAG = "<answer>"
ANSWER_END = "</answer>"
def _build_tools_prompt(tools: list[dict]) -> str:
"""Format tool definitions for inclusion in the system prompt."""
lines = ["You have access to the following tools:\n"]
for t in tools:
schema = t.get("input_schema", {})
props = schema.get("properties", {})
required = schema.get("required", [])
params = []
for pname, pdef in props.items():
req = " (required)" if pname in required else ""
desc = pdef.get("description", "")
ptype = pdef.get("type", "string")
enum_vals = pdef.get("enum")
if enum_vals:
allowed = ", ".join(f'"{v}"' for v in enum_vals)
params.append(f" - {pname}: {ptype}{req}{desc} Allowed values: [{allowed}]")
else:
params.append(f" - {pname}: {ptype}{req}{desc}")
param_block = "\n".join(params) if params else " (no parameters)"
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
lines.append(
"## How to use tools\n"
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
f"{TOOL_CALL_TAG}\n"
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
f"{TOOL_CALL_END}\n\n"
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
"When you have finished your analysis and have a final answer, wrap it in:\n"
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
"You MUST call at least one tool before giving your final answer."
)
return "\n".join(lines)
def _safe_json_loads(text: str):
"""Parse JSON with progressive sanitization for LLM-produced output.
Tries (0) as-is, (1) escape stray backslashes outside valid JSON escapes
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
input (first 600 chars) so we can diagnose what the model emitted.
Used both by orchestrator JSON callsites and by _extract_tool_calls
when parsing <tool_call> blocks from model output.
"""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
stage1 = re.sub(
r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})',
r'\\\\',
text,
)
try:
return json.loads(stage1)
except json.JSONDecodeError as e:
logger.warning(
"_safe_json_loads failed after sanitize (%s); raw head[:600]=%r",
e, text[:600],
)
raise
def _extract_tool_calls(text: str) -> list[dict]:
"""Extract tool call JSON blocks from model output."""
pattern = re.compile(
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
re.DOTALL,
)
calls = []
for match in pattern.finditer(text):
raw = match.group(1).strip()
try:
parsed = _safe_json_loads(raw)
calls.append(parsed)
except json.JSONDecodeError:
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
return calls
def _extract_answer(text: str) -> str | None:
"""Extract the final answer from model output."""
pattern = re.compile(
re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END),
re.DOTALL,
)
match = pattern.search(text)
if match:
return match.group(1).strip()
return None
def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str:
"""Truncate a tool result if it exceeds max_chars."""
if len(result_text) <= max_chars:
return result_text
return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]"
# Tools that only read and never mutate state — safe to run concurrently.
READ_ONLY_TOOLS: set[str] = {
# Graph queries
"list_phenomena", "get_phenomenon", "search_graph", "get_related",
"get_hypothesis_status", "list_assets", "find_extracted_file",
# Sleuth Kit reads
"partition_info", "filesystem_info", "list_directory", "find_file",
"search_strings", "count_deleted_files", "build_filesystem_timeline",
# Registry reads (without auto-record wrappers)
"parse_registry_key", "search_registry", "get_user_activity",
# Parser reads
"read_text_file", "read_binary_preview", "search_text_file",
"read_text_file_section", "list_extracted_dir", "parse_pcap_strings",
}
def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict:
"""Try to fix misnamed tool arguments from LLM hallucination.
The LLM sometimes confuses parameter names across tools (e.g. passing
`key_path` to search_registry which expects `pattern`). This function
maps unknown kwargs to missing expected params by position/best-effort.
"""
# Find the schema for this tool
schema = None
for t in tools:
if t.get("name") == tool_name:
schema = t.get("input_schema", {})
break
if schema is None:
return tool_args
props = schema.get("properties", {})
required = set(schema.get("required", []))
unknown = [k for k in tool_args if k not in props]
if not unknown:
return tool_args # all args are valid, nothing to fix
# Build the fixed args: start with valid args
fixed = {k: v for k, v in tool_args.items() if k in props}
# Find which expected params are still missing
missing = [p for p in (required or props.keys()) if p not in fixed]
# Try to map unknown args to missing params, in order
unknown_values = [(k, tool_args[k]) for k in unknown]
for wrong_name, value in unknown_values:
if not missing:
break
# Pick the best match from missing params
best = missing.pop(0)
logger.warning(
"Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)",
tool_name, wrong_name, tool_name, best,
)
fixed[best] = value
return fixed
def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None:
"""Emit a folded tool-call summary line for the terminal formatter.
Instead of logging each tool call individually, we group by name:
"list_directory x27, extract_file x3, read_text_file x3"
"""
counts = Counter(tc.get("name", "?") for tc in tool_calls)
parts = []
for name, count in counts.most_common():
if count > 1:
parts.append(f"{name} x{count}")
else:
parts.append(name)
summary = ", ".join(parts)
logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed})
@dataclass
class _ToolBatch:
"""A batch of tool calls with the same read/write classification."""
is_read_only: bool
calls: list[dict] = field(default_factory=list)
def _partition_tool_calls(
tool_calls: list[dict],
read_only: set[str] | None = None,
) -> list[_ToolBatch]:
"""Partition tool calls into batches: consecutive read-only tools are
grouped together (will run in parallel), write tools are isolated."""
if read_only is None:
read_only = READ_ONLY_TOOLS
batches: list[_ToolBatch] = []
for tc in tool_calls:
is_ro = tc.get("name", "") in read_only
if batches and batches[-1].is_read_only and is_ro:
batches[-1].calls.append(tc)
else:
batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc]))
return batches
# ---------------------------------------------------------------------------
# Context compression — keeps the message list from growing unboundedly.
# ---------------------------------------------------------------------------
# Stage A: Progressive tool result decay thresholds.
# Messages are counted in (assistant, user) pairs from the END of the list.
# "Round" = one pair of (assistant tool-calling msg, user tool-result msg).
_DECAY_TIERS: list[tuple[int, int]] = [
# (rounds_ago_threshold, max_chars_for_tool_results)
(5, 3000), # recent 5 rounds: keep full (3000 chars per tool result)
(15, 500), # 5-15 rounds ago: aggressive truncation
(999, 100), # older than 15 rounds: minimal stub
]
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
"""Truncate tool results in older messages to save context space.
Operates in-place-style on a copy. Only touches user messages that
contain <tool_result> blocks (these are the tool-result messages
generated by tool_call_loop).
"""
# Count rounds from the end. A "round" is a (assistant, user) pair.
# messages alternate: [user, assistant, user, assistant, user, ...]
# The initial user message is index 0, then pairs start at index 1.
total = len(messages)
if total <= 10: # not enough messages to bother
return messages
result = []
# Count tool-result user messages from the end
tool_result_indices = [
i for i, m in enumerate(messages)
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
]
# Build a set of indices that need decay, mapped to their max_chars
decay_map: dict[int, int] = {}
n_tool_msgs = len(tool_result_indices)
for rank, idx in enumerate(reversed(tool_result_indices)):
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
for threshold, max_chars in _DECAY_TIERS:
if rounds_ago < threshold:
decay_map[idx] = max_chars
break
for i, msg in enumerate(messages):
if i in decay_map:
max_chars = decay_map[i]
content = msg["content"]
if len(content) > max_chars + 200:
# Truncate but preserve the tool_result tags structure
truncated = content[:max_chars]
# Count how many tool results are in this message
n_results = content.count(TOOL_RESULT_TAG)
truncated += (
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
f"{n_results} tool result(s)]"
)
result.append({"role": msg["role"], "content": truncated})
else:
result.append(msg)
else:
result.append(msg)
return result
# Stage B: LLM-powered message folding.
# When messages exceed this count, fold the oldest ones into a summary.
_FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count
_FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact
_FOLD_SUMMARY_SYSTEM = (
"You are a concise summarizer for an ongoing forensic investigation conversation. "
"Summarize the following early conversation between a forensic analysis agent and its "
"tool execution results. Preserve:\n"
"- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n"
"- Tools that were called and their important results\n"
"- Decisions made and current investigation direction\n"
"Keep the summary under 800 words. Use bullet points."
)
class LLMClient:
"""Async LLM client via the OpenAI SDK.
Works with any OpenAI-compatible endpoint (OpenAI, DeepSeek, ...).
Tool calling is text-based (ReAct) — see module docstring.
"""
def __init__(
self,
base_url: str,
api_key: str,
model: str = "deepseek-v4-pro",
max_tokens: int = 4096,
proxy: str | None = "auto",
reasoning_effort: str | None = None,
thinking_enabled: bool = False,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.thinking_enabled = thinking_enabled
# proxy="auto": read from env; proxy=None/""/"none": no proxy
if proxy == "auto":
proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY")
elif proxy and proxy.lower() != "none":
proxy_url = proxy
else:
proxy_url = None
http_client = (
httpx.AsyncClient(proxy=proxy_url, timeout=300.0)
if proxy_url else None
)
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=300.0,
http_client=http_client,
)
async def close(self) -> None:
await self._client.close()
async def chat(
self,
messages: list[dict],
system: str | None = None,
max_retries: int = 5,
) -> str:
"""Send a streaming chat completion and return the assembled text."""
full_messages: list[dict] = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": self.max_tokens,
"stream": True,
}
if self.reasoning_effort:
kwargs["reasoning_effort"] = self.reasoning_effort
if self.thinking_enabled:
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
for attempt in range(max_retries):
logger.debug(
"LLM request (stream): %d messages (attempt %d)",
len(messages), attempt + 1,
)
text_parts: list[str] = []
try:
stream = await self._client.chat.completions.create(**kwargs)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta.content:
text_parts.append(delta.content)
text = "".join(text_parts)
logger.debug("LLM response (stream): %d chars", len(text))
return text
except (APIConnectionError, APITimeoutError, APIError) as e:
if attempt < max_retries - 1:
wait = 2 ** attempt * 10
logger.warning("Request failed (%s), retrying in %ds...", e, wait)
await asyncio.sleep(wait)
else:
raise LLMAPIError(
f"LLM API unreachable after {max_retries} attempts: {e}",
attempts=max_retries,
) from e
return ""
async def tool_call_loop(
self,
messages: list[dict],
tools: list[dict],
tool_executor: dict[str, Any],
system: str | None = None,
max_iterations: int = 40,
) -> tuple[str, list[dict]]:
"""Run a ReAct-style tool-calling loop.
The model outputs <tool_call> blocks which we parse and execute,
feeding results back as <tool_result> blocks until the model
outputs an <answer> block.
Returns:
(final_text, all_messages)
"""
# Build system prompt with tool definitions
tools_prompt = _build_tools_prompt(tools)
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
messages = list(messages) # don't mutate caller's list
_folded = False # Track whether we've already folded once this loop
for i in range(max_iterations):
# ── Context compression before each API call ──────────────
# Stage A: progressively decay old tool results
messages = _apply_progressive_decay(messages)
# Stage B: fold oldest messages into LLM summary if too long
if not _folded and len(messages) > _FOLD_THRESHOLD:
messages = await self._fold_old_messages(messages, full_system)
_folded = True
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
# Allow a second fold if messages grew back significantly
messages = await self._fold_old_messages(messages, full_system)
text = await self.chat(messages, system=full_system)
# Check for final answer
answer = _extract_answer(text)
if answer is not None:
messages.append({"role": "assistant", "content": text})
return answer, messages
# Check for tool calls
tool_calls = _extract_tool_calls(text)
if not tool_calls:
# No tool calls and no answer tag — treat entire text as answer
messages.append({"role": "assistant", "content": text})
return text, messages
# Execute tool calls — read-only tools run in parallel
messages.append({"role": "assistant", "content": text})
result_parts = []
batches = _partition_tool_calls(tool_calls)
t_batch_start = time.monotonic()
for batch in batches:
if batch.is_read_only and len(batch.calls) > 1:
batch_results = await self._execute_tool_batch_parallel(
batch.calls, tool_executor, tools,
)
result_parts.extend(batch_results)
else:
for tc in batch.calls:
result_parts.append(
await self._execute_single_tool(tc, tool_executor, tools)
)
# Emit folded tool-call summary for the terminal
t_batch_elapsed = time.monotonic() - t_batch_start
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
# Feed results back as a user message
result_message = "\n\n".join(result_parts)
messages.append({"role": "user", "content": result_message})
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
return "[Max tool call iterations reached]", messages
async def _execute_single_tool(
self, tc: dict, tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> str:
"""Execute a single tool call and return the formatted result."""
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
if tools:
tool_args = _fix_tool_args(tool_name, tool_args, tools)
logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
executor = tool_executor.get(tool_name)
if executor is None:
result_text = f"Error: unknown tool '{tool_name}'"
else:
try:
result_text = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
result_text = f"Error executing {tool_name}: {e}"
return (
f"{TOOL_RESULT_TAG}\n"
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
f"{TOOL_RESULT_END}"
)
async def _execute_tool_batch_parallel(
self, calls: list[dict], tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> list[str]:
"""Execute multiple read-only tool calls concurrently."""
logger.info("Executing %d read-only tools in parallel", len(calls))
async def _run_one(tc: dict) -> str:
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
if tools:
tool_args = _fix_tool_args(tool_name, tool_args, tools)
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
executor = tool_executor.get(tool_name)
if executor is None:
result_text = f"Error: unknown tool '{tool_name}'"
else:
try:
result_text = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
result_text = f"Error executing {tool_name}: {e}"
return (
f"{TOOL_RESULT_TAG}\n"
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
f"{TOOL_RESULT_END}"
)
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
return list(results)
async def _fold_old_messages(
self, messages: list[dict], system: str,
) -> list[dict]:
"""Fold old messages into an LLM-generated summary (Stage B).
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
replaces earlier ones with a single summary message.
"""
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
if n_to_fold <= 2:
return messages
old_messages = messages[:n_to_fold]
recent_messages = messages[n_to_fold:]
# Build a text dump of old messages for summarization
old_text_parts = []
for msg in old_messages:
role = msg["role"]
content = msg.get("content", "")
# Truncate each message for the summary prompt to avoid overload
if len(content) > 1000:
content = content[:1000] + "..."
old_text_parts.append(f"[{role}]: {content}")
old_text = "\n\n".join(old_text_parts)
# Cap total size sent to summarizer
if len(old_text) > 15000:
old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]"
logger.info(
"Context folding: summarizing %d old messages (%d chars) into summary",
n_to_fold, len(old_text),
)
try:
summary = await self.chat(
messages=[{"role": "user", "content": old_text}],
system=_FOLD_SUMMARY_SYSTEM,
)
except Exception as e:
logger.warning("Context folding failed: %s — keeping original messages", e)
return messages
# Replace old messages with a single summary
summary_message = {
"role": "user",
"content": (
f"[Context summary — the following summarizes {n_to_fold} earlier "
f"messages in this conversation]\n\n{summary}"
),
}
return [summary_message] + recent_messages