"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
Tool calling uses the OpenAI-native `tools=[...]` parameter. The model
returns structured tool_calls via the streaming protocol; we accumulate
them, dispatch to our executors, and feed results back as `role: "tool"`
messages. This eliminates the fragile "model writes JSON inside free
text" problem of the previous ReAct text mode.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from collections import Counter
from dataclasses import dataclass, field
from typing import Any
import httpx
from openai import APIConnectionError, APIError, APITimeoutError, AsyncOpenAI
logger = logging.getLogger(__name__)
class LLMAPIError(Exception):
"""Raised when the LLM API is unreachable after all retries."""
def __init__(self, message: str, attempts: int) -> None:
super().__init__(message)
self.attempts = attempts
# Optional answer tags — kept for backward compat with prompts that wrap
# their final response in .... Native tool calling does
# not need these (no tool_calls = final), but if the model continues to
# emit them, we strip the tags so callers see clean text.
ANSWER_TAG = ""
ANSWER_END = ""
def _to_openai_tools(tools: list[dict]) -> list[dict]:
"""Convert internal tool definitions to OpenAI native function-tools format."""
return [
{
"type": "function",
"function": {
"name": t["name"],
"description": t["description"],
"parameters": t.get("input_schema", {"type": "object", "properties": {}}),
},
}
for t in tools
]
def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None:
"""Return the first balanced [...] or {...} substring, or None if no balanced pair.
Stack-based — handles nested brackets correctly (regex with .*? would
truncate at the first inner closing bracket, regex with .* would over-eat
trailing text). Brackets inside JSON string literals are ignored by
callers because the caller passes the result through json.loads which
re-parses with proper string handling.
"""
start = text.find(open_char)
if start < 0:
return None
depth = 0
for i in range(start, len(text)):
c = text[i]
if c == open_char:
depth += 1
elif c == close_char:
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _safe_json_loads(text: str):
"""Parse JSON with progressive sanitization for LLM-produced output.
Tries (0) as-is, (1) escape stray backslashes outside valid JSON escapes
(\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw
input (first 600 chars) so we can diagnose what the model emitted.
Used by orchestrator JSON callsites (_call_llm_for_json) and by
tool_call_loop when parsing tool_call arguments returned by the API.
"""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
stage1 = re.sub(
r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})',
r'\\\\',
text,
)
try:
return json.loads(stage1)
except json.JSONDecodeError as e:
logger.warning(
"_safe_json_loads failed after sanitize (%s); raw head[:600]=%r",
e, text[:600],
)
raise
def _extract_answer(text: str) -> str | None:
"""Extract the final answer from model output."""
pattern = re.compile(
re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END),
re.DOTALL,
)
match = pattern.search(text)
if match:
return match.group(1).strip()
return None
def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str:
"""Truncate a tool result if it exceeds max_chars."""
if len(result_text) <= max_chars:
return result_text
return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]"
# Tools that only read and never mutate state — safe to run concurrently.
READ_ONLY_TOOLS: set[str] = {
# Graph queries
"list_phenomena", "get_phenomenon", "search_graph", "get_related",
"get_hypothesis_status", "list_assets", "find_extracted_file",
# Sleuth Kit reads
"partition_info", "filesystem_info", "list_directory", "find_file",
"search_strings", "count_deleted_files", "build_filesystem_timeline",
# Registry reads (without auto-record wrappers)
"parse_registry_key", "search_registry", "get_user_activity",
# Parser reads
"read_text_file", "read_binary_preview", "search_text_file",
"read_text_file_section", "list_extracted_dir", "parse_pcap_strings",
"find_files",
# iOS plugin reads (S4)
"parse_plist", "sqlite_tables", "sqlite_query",
"parse_ios_keychain", "read_idevice_info",
# Android + media reads (S6) — set_active_partition is NOT read-only.
"probe_android_partitions", "ocr_image",
}
def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict:
"""Try to fix misnamed tool arguments from LLM hallucination.
The LLM sometimes confuses parameter names across tools (e.g. passing
`key_path` to search_registry which expects `pattern`). This function
maps unknown kwargs to missing expected params by position/best-effort.
"""
# Find the schema for this tool
schema = None
for t in tools:
if t.get("name") == tool_name:
schema = t.get("input_schema", {})
break
if schema is None:
return tool_args
props = schema.get("properties", {})
required = set(schema.get("required", []))
unknown = [k for k in tool_args if k not in props]
if not unknown:
return tool_args # all args are valid, nothing to fix
# Build the fixed args: start with valid args
fixed = {k: v for k, v in tool_args.items() if k in props}
# Find which expected params are still missing
missing = [p for p in (required or props.keys()) if p not in fixed]
# Try to map unknown args to missing params, in order
unknown_values = [(k, tool_args[k]) for k in unknown]
for wrong_name, value in unknown_values:
if not missing:
break
# Pick the best match from missing params
best = missing.pop(0)
logger.warning(
"Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)",
tool_name, wrong_name, tool_name, best,
)
fixed[best] = value
return fixed
def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None:
"""Emit a folded tool-call summary line for the terminal formatter.
Instead of logging each tool call individually, we group by name:
"list_directory x27, extract_file x3, read_text_file x3"
"""
counts = Counter(tc.get("name", "?") for tc in tool_calls)
parts = []
for name, count in counts.most_common():
if count > 1:
parts.append(f"{name} x{count}")
else:
parts.append(name)
summary = ", ".join(parts)
logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed})
@dataclass
class _ToolBatch:
"""A batch of tool calls with the same read/write classification."""
is_read_only: bool
calls: list[dict] = field(default_factory=list)
def _partition_tool_calls(
tool_calls: list[dict],
read_only: set[str] | None = None,
) -> list[_ToolBatch]:
"""Partition tool calls into batches: consecutive read-only tools are
grouped together (will run in parallel), write tools are isolated."""
if read_only is None:
read_only = READ_ONLY_TOOLS
batches: list[_ToolBatch] = []
for tc in tool_calls:
is_ro = tc.get("name", "") in read_only
if batches and batches[-1].is_read_only and is_ro:
batches[-1].calls.append(tc)
else:
batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc]))
return batches
# ---------------------------------------------------------------------------
# Context compression — keeps the message list from growing unboundedly.
# ---------------------------------------------------------------------------
# Stage A: Progressive tool result decay thresholds.
# Messages are counted in (assistant, user) pairs from the END of the list.
# "Round" = one pair of (assistant tool-calling msg, user tool-result msg).
_DECAY_TIERS: list[tuple[int, int]] = [
# (rounds_ago_threshold, max_chars_for_tool_results)
(5, 3000), # recent 5 rounds: keep full (3000 chars per tool result)
(15, 500), # 5-15 rounds ago: aggressive truncation
(999, 100), # older than 15 rounds: minimal stub
]
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
"""Truncate the `content` of older `role: "tool"` messages to save context.
Each `role: "tool"` message in the conversation corresponds to one tool
call's result. We rank these messages by recency and progressively
truncate older ones according to `_DECAY_TIERS`.
"""
total = len(messages)
if total <= 10:
return messages
tool_msg_indices = [
i for i, m in enumerate(messages) if m.get("role") == "tool"
]
decay_map: dict[int, int] = {}
for rank, idx in enumerate(reversed(tool_msg_indices)):
rounds_ago = rank
for threshold, max_chars in _DECAY_TIERS:
if rounds_ago < threshold:
decay_map[idx] = max_chars
break
result = []
for i, msg in enumerate(messages):
if i in decay_map:
max_chars = decay_map[i]
content = msg.get("content", "") or ""
if len(content) > max_chars + 200:
truncated = (
content[:max_chars]
+ f"\n... [context compressed: {len(content)} -> {max_chars} chars]"
)
new_msg = dict(msg)
new_msg["content"] = truncated
result.append(new_msg)
else:
result.append(msg)
else:
result.append(msg)
return result
# Stage B: LLM-powered message folding.
# When messages exceed this count, fold the oldest ones into a summary.
_FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count
_FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact
_FOLD_SUMMARY_SYSTEM = (
"You are a concise summarizer for an ongoing forensic investigation conversation. "
"Summarize the following early conversation between a forensic analysis agent and its "
"tool execution results. Preserve:\n"
"- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n"
"- Tools that were called and their important results\n"
"- Decisions made and current investigation direction\n"
"Keep the summary under 800 words. Use bullet points."
)
class LLMClient:
"""Async LLM client via the OpenAI SDK.
Works with any OpenAI-compatible endpoint (OpenAI, DeepSeek, ...).
Tool calling is text-based (ReAct) — see module docstring.
"""
def __init__(
self,
base_url: str,
api_key: str,
model: str = "deepseek-v4-pro",
max_tokens: int = 4096,
proxy: str | None = "auto",
reasoning_effort: str | None = None,
thinking_enabled: bool = False,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.thinking_enabled = thinking_enabled
# proxy="auto": read from env; proxy=None/""/"none": no proxy
if proxy == "auto":
proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY")
elif proxy and proxy.lower() != "none":
proxy_url = proxy
else:
proxy_url = None
http_client = (
httpx.AsyncClient(proxy=proxy_url, timeout=300.0)
if proxy_url else None
)
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=300.0,
http_client=http_client,
)
async def close(self) -> None:
await self._client.close()
async def chat(
self,
messages: list[dict],
system: str | None = None,
max_retries: int = 5,
) -> str:
"""Send a streaming chat completion and return the assembled text."""
full_messages: list[dict] = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": self.max_tokens,
"stream": True,
}
if self.reasoning_effort:
kwargs["reasoning_effort"] = self.reasoning_effort
if self.thinking_enabled:
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
for attempt in range(max_retries):
logger.debug(
"LLM request (stream): %d messages (attempt %d)",
len(messages), attempt + 1,
)
text_parts: list[str] = []
try:
stream = await self._client.chat.completions.create(**kwargs)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta.content:
text_parts.append(delta.content)
text = "".join(text_parts)
logger.debug("LLM response (stream): %d chars", len(text))
return text
except (APIConnectionError, APITimeoutError, APIError) as e:
if attempt < max_retries - 1:
wait = 2 ** attempt * 10
logger.warning("Request failed (%s), retrying in %ds...", e, wait)
await asyncio.sleep(wait)
else:
raise LLMAPIError(
f"LLM API unreachable after {max_retries} attempts: {e}",
attempts=max_retries,
) from e
return ""
async def _chat_with_tools(
self,
messages: list[dict],
openai_tools: list[dict],
max_retries: int = 5,
) -> tuple[str, str | None, list[dict]]:
"""Stream a chat completion with native tool calling enabled.
Returns:
(text_content, reasoning_content, raw_tool_calls).
- reasoning_content is non-None when DeepSeek thinking mode is
active; the caller MUST echo it back in the assistant message
on subsequent requests, or the API returns HTTP 400.
- raw_tool_calls is a list of {"id","name","arguments"} dicts;
arguments is the raw JSON string returned by the API.
"""
kwargs: dict[str, Any] = {
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"stream": True,
"tools": openai_tools,
}
if self.reasoning_effort:
kwargs["reasoning_effort"] = self.reasoning_effort
if self.thinking_enabled:
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
for attempt in range(max_retries):
logger.debug(
"LLM request (stream+tools): %d messages, %d tools (attempt %d)",
len(messages), len(openai_tools), attempt + 1,
)
text_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments}
try:
stream = await self._client.chat.completions.create(**kwargs)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta.content:
text_parts.append(delta.content)
# DeepSeek thinking-mode: reasoning_content is returned
# alongside content and MUST be echoed back on subsequent
# requests, otherwise the API rejects with HTTP 400.
rc = getattr(delta, "reasoning_content", None)
if rc:
reasoning_parts.append(rc)
if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
entry = tool_calls_acc.setdefault(
idx, {"id": None, "name": None, "arguments": ""},
)
if tc_delta.id:
entry["id"] = tc_delta.id
fn = tc_delta.function
if fn:
if fn.name:
entry["name"] = fn.name
if fn.arguments:
entry["arguments"] += fn.arguments
text = "".join(text_parts)
reasoning = "".join(reasoning_parts) or None
ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
logger.debug(
"LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls",
len(text), len(reasoning or ""), len(ordered),
)
return text, reasoning, ordered
except (APIConnectionError, APITimeoutError, APIError) as e:
if attempt < max_retries - 1:
wait = 2 ** attempt * 10
logger.warning(
"Tool-call request failed (%s), retrying in %ds...", e, wait,
)
await asyncio.sleep(wait)
else:
raise LLMAPIError(
f"LLM API unreachable after {max_retries} attempts: {e}",
attempts=max_retries,
) from e
return "", None, []
async def tool_call_loop(
self,
messages: list[dict],
tools: list[dict],
tool_executor: dict[str, Any],
system: str | None = None,
max_iterations: int = 60,
terminal_tools: tuple[str, ...] = (),
) -> tuple[str, list[dict]]:
"""Run a tool-calling loop using OpenAI-native tool calls.
The model returns structured `tool_calls` in its message; we
dispatch them through our executor dict and feed each result back
as a `role: "tool"` message with the matching `tool_call_id`. The
loop ends when:
- the model returns a message with no tool_calls (normal exit), or
- any tool in `terminal_tools` is called — in that case, the loop
short-circuits with that tool's result text as final_text. This
gives agents (notably ReportAgent) an explicit completion signal
that the old `` text tag used to provide.
Returns:
(final_text, full_message_history)
"""
terminal_set = set(terminal_tools)
openai_tools = _to_openai_tools(tools)
# The caller may pass `messages` either as raw conversation (no system)
# together with `system=...`, OR as a complete history that already
# starts with the system message (retry path). Accept both shapes.
if messages and messages[0].get("role") == "system":
full_messages: list[dict] = list(messages)
else:
full_messages = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
_folded = False
for _i in range(max_iterations):
# ── Context compression before each API call ──────────────
full_messages = _apply_progressive_decay(full_messages)
if not _folded and len(full_messages) > _FOLD_THRESHOLD:
full_messages = await self._fold_old_messages(full_messages)
_folded = True
elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
full_messages = await self._fold_old_messages(full_messages)
text, reasoning, raw_tool_calls = await self._chat_with_tools(
full_messages, openai_tools,
)
if not raw_tool_calls:
# Model produced a final response. Strip optional
# tags for backward compatibility with old prompts.
final_msg: dict[str, Any] = {"role": "assistant", "content": text}
if reasoning:
final_msg["reasoning_content"] = reasoning
full_messages.append(final_msg)
answer = _extract_answer(text)
return (answer if answer is not None else text), full_messages
# Parse arguments + build internal call dicts
parsed_calls: list[dict] = []
for rc in raw_tool_calls:
args_str = rc.get("arguments", "") or ""
try:
args = _safe_json_loads(args_str) if args_str.strip() else {}
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to parse arguments for tool %s: %s",
rc.get("name"), e,
)
args = {}
parsed_calls.append({
"id": rc.get("id"),
"name": rc.get("name", ""),
"arguments": args,
})
# Append the assistant turn with the raw tool_calls (and the
# DeepSeek-mandated reasoning_content echo-back), then execute.
asst_msg: dict[str, Any] = {
"role": "assistant",
"content": text or None,
"tool_calls": [
{
"id": rc.get("id"),
"type": "function",
"function": {
"name": rc.get("name", ""),
"arguments": rc.get("arguments", "") or "",
},
}
for rc in raw_tool_calls
],
}
if reasoning:
asst_msg["reasoning_content"] = reasoning
full_messages.append(asst_msg)
batches = _partition_tool_calls(parsed_calls)
t_batch_start = time.monotonic()
# Each entry: (tool_call_dict, raw_result, formatted_for_llm)
executed: list[tuple[dict, str, str]] = []
for batch in batches:
if batch.is_read_only and len(batch.calls) > 1:
results = await self._execute_tool_batch_parallel(
batch.calls, tool_executor, tools,
)
for tc, (raw, formatted) in zip(batch.calls, results):
executed.append((tc, raw, formatted))
else:
for tc in batch.calls:
raw, formatted = await self._execute_single_tool(
tc, tool_executor, tools,
)
executed.append((tc, raw, formatted))
t_batch_elapsed = time.monotonic() - t_batch_start
_emit_tool_call_summary(parsed_calls, t_batch_elapsed)
# Append formatted tool results to the conversation (this is
# what the LLM sees on subsequent rounds — truncated for context
# economy).
for tc, _raw, formatted in executed:
full_messages.append({
"role": "tool",
"tool_call_id": tc["id"],
"content": formatted,
})
# Terminal-tool short-circuit: if the model called any tool in
# `terminal_tools`, end the loop immediately. The terminal tool's
# RAW result (untruncated) becomes final_text — the LLM may have
# produced a 20K-char report via save_report and we must not
# truncate it just because the LLM-facing copy is truncated.
if terminal_set:
for tc, raw, _formatted in executed:
name = tc.get("name", "")
if name in terminal_set:
logger.info(
"Terminal tool %s called — exiting tool_call_loop", name,
)
return raw, full_messages
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
return "[Max tool call iterations reached]", full_messages
async def _execute_single_tool(
self, tc: dict, tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> tuple[str, str]:
"""Execute a single tool call.
Returns (raw_result, formatted_for_llm). `raw_result` is the
unmodified executor return (used by terminal-tool short-circuit as
final_text). `formatted_for_llm` is `[tool_name] {truncated}` and
is what gets fed back to the model as the tool message content.
"""
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
if tools:
tool_args = _fix_tool_args(tool_name, tool_args, tools)
logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
executor = tool_executor.get(tool_name)
if executor is None:
raw = f"Error: unknown tool '{tool_name}'"
else:
try:
raw = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
raw = f"Error executing {tool_name}: {e}"
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
return raw, formatted
async def _execute_tool_batch_parallel(
self, calls: list[dict], tool_executor: dict[str, Any],
tools: list[dict] | None = None,
) -> list[tuple[str, str]]:
"""Execute multiple read-only tool calls concurrently.
Returns a list of (raw_result, formatted_for_llm) tuples in the
same order as `calls`.
"""
logger.info("Executing %d read-only tools in parallel", len(calls))
async def _run_one(tc: dict) -> tuple[str, str]:
tool_name = tc.get("name", "")
tool_args = tc.get("arguments", {})
if tools:
tool_args = _fix_tool_args(tool_name, tool_args, tools)
logger.info(
"Calling tool (parallel): %s(%s)",
tool_name, json.dumps(tool_args, ensure_ascii=False),
)
executor = tool_executor.get(tool_name)
if executor is None:
raw = f"Error: unknown tool '{tool_name}'"
else:
try:
raw = await executor(**tool_args)
except Exception as e:
logger.error("Tool %s failed: %s", tool_name, e)
raw = f"Error executing {tool_name}: {e}"
formatted = f"[{tool_name}] {_truncate_tool_result(raw)}"
return raw, formatted
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
return list(results)
async def _fold_old_messages(
self, messages: list[dict],
) -> list[dict]:
"""Fold old messages into an LLM-generated summary (Stage B).
Preserves the leading system message (if any), keeps the most
recent _FOLD_KEEP_RECENT messages intact, and replaces the older
middle slice with a single summary user message.
"""
# Pin the system message — it must NEVER be summarized away.
system_msgs: list[dict] = []
body = messages
if messages and messages[0].get("role") == "system":
system_msgs = [messages[0]]
body = messages[1:]
n_to_fold = len(body) - _FOLD_KEEP_RECENT
if n_to_fold <= 2:
return messages
# Pull the fold boundary forward so we never split an assistant turn
# from its matching tool results. The API rejects (HTTP 400) any
# `role: "tool"` message that does not immediately follow an
# `assistant` message with `tool_calls`. We walk the boundary into
# `recent_messages` while its head is a `role: "tool"` message, or
# while the prior `recent` message is `assistant{tool_calls}` whose
# paired tools span the boundary.
while n_to_fold < len(body):
head = body[n_to_fold]
if head.get("role") == "tool":
n_to_fold += 1
continue
break
if n_to_fold >= len(body):
# Everything got folded — nothing recent to keep.
return system_msgs + [body[0]] if system_msgs else messages
old_messages = body[:n_to_fold]
recent_messages = body[n_to_fold:]
old_text_parts = []
for msg in old_messages:
role = msg.get("role", "?")
content = msg.get("content") or ""
# Render tool_calls (assistant turn) compactly.
if role == "assistant" and msg.get("tool_calls"):
tc_names = [
tc.get("function", {}).get("name", "?")
for tc in msg["tool_calls"]
]
content = (content + " " if content else "") + (
"called: " + ", ".join(tc_names)
)
if len(content) > 1000:
content = content[:1000] + "..."
old_text_parts.append(f"[{role}]: {content}")
old_text = "\n\n".join(old_text_parts)
# Cap total size sent to summarizer
if len(old_text) > 15000:
old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]"
logger.info(
"Context folding: summarizing %d old messages (%d chars) into summary",
n_to_fold, len(old_text),
)
try:
summary = await self.chat(
messages=[{"role": "user", "content": old_text}],
system=_FOLD_SUMMARY_SYSTEM,
)
except Exception as e:
logger.warning("Context folding failed: %s — keeping original messages", e)
return messages
summary_message = {
"role": "user",
"content": (
f"[Context summary — the following summarizes {n_to_fold} earlier "
f"messages in this conversation]\n\n{summary}"
),
}
return system_msgs + [summary_message] + recent_messages