"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API). Tool calling uses the OpenAI-native `tools=[...]` parameter. The model returns structured tool_calls via the streaming protocol; we accumulate them, dispatch to our executors, and feed results back as `role: "tool"` messages. This eliminates the fragile "model writes JSON inside free text" problem of the previous ReAct text mode. """ from __future__ import annotations import asyncio import json import logging import os import re import time from collections import Counter from dataclasses import dataclass, field from typing import Any import httpx from openai import APIConnectionError, APIError, APITimeoutError, AsyncOpenAI logger = logging.getLogger(__name__) class LLMAPIError(Exception): """Raised when the LLM API is unreachable after all retries.""" def __init__(self, message: str, attempts: int) -> None: super().__init__(message) self.attempts = attempts # Optional answer tags — kept for backward compat with prompts that wrap # their final response in .... Native tool calling does # not need these (no tool_calls = final), but if the model continues to # emit them, we strip the tags so callers see clean text. ANSWER_TAG = "" ANSWER_END = "" def _to_openai_tools(tools: list[dict]) -> list[dict]: """Convert internal tool definitions to OpenAI native function-tools format.""" return [ { "type": "function", "function": { "name": t["name"], "description": t["description"], "parameters": t.get("input_schema", {"type": "object", "properties": {}}), }, } for t in tools ] def _extract_first_balanced(text: str, open_char: str, close_char: str) -> str | None: """Return the first balanced [...] or {...} substring, or None if no balanced pair. Stack-based — handles nested brackets correctly (regex with .*? would truncate at the first inner closing bracket, regex with .* would over-eat trailing text). Brackets inside JSON string literals are ignored by callers because the caller passes the result through json.loads which re-parses with proper string handling. """ start = text.find(open_char) if start < 0: return None depth = 0 for i in range(start, len(text)): c = text[i] if c == open_char: depth += 1 elif c == close_char: depth -= 1 if depth == 0: return text[start:i + 1] return None def _safe_json_loads(text: str): """Parse JSON with progressive sanitization for LLM-produced output. Tries (0) as-is, (1) escape stray backslashes outside valid JSON escapes (\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX). On final failure, logs raw input (first 600 chars) so we can diagnose what the model emitted. Used by orchestrator JSON callsites (_call_llm_for_json) and by tool_call_loop when parsing tool_call arguments returned by the API. """ try: return json.loads(text) except json.JSONDecodeError: pass stage1 = re.sub( r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r'\\\\', text, ) try: return json.loads(stage1) except json.JSONDecodeError as e: logger.warning( "_safe_json_loads failed after sanitize (%s); raw head[:600]=%r", e, text[:600], ) raise def _extract_answer(text: str) -> str | None: """Extract the final answer from model output.""" pattern = re.compile( re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END), re.DOTALL, ) match = pattern.search(text) if match: return match.group(1).strip() return None def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str: """Truncate a tool result if it exceeds max_chars.""" if len(result_text) <= max_chars: return result_text return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]" # Tools that only read and never mutate state — safe to run concurrently. READ_ONLY_TOOLS: set[str] = { # Graph queries "list_phenomena", "get_phenomenon", "search_graph", "get_related", "get_hypothesis_status", "list_assets", "find_extracted_file", # Sleuth Kit reads "partition_info", "filesystem_info", "list_directory", "find_file", "search_strings", "count_deleted_files", "build_filesystem_timeline", # Registry reads (without auto-record wrappers) "parse_registry_key", "search_registry", "get_user_activity", # Parser reads "read_text_file", "read_binary_preview", "search_text_file", "read_text_file_section", "list_extracted_dir", "parse_pcap_strings", "find_files", # iOS plugin reads (S4) "parse_plist", "sqlite_tables", "sqlite_query", "parse_ios_keychain", "read_idevice_info", # Android + media reads (S6) — set_active_partition is NOT read-only. "probe_android_partitions", "ocr_image", # Strategist view tools (DESIGN_STRATEGIST.md §2) — pure renders. "graph_overview", "source_coverage", "marginal_yield", "budget_status", } def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict: """Try to fix misnamed tool arguments from LLM hallucination. The LLM sometimes confuses parameter names across tools (e.g. passing `key_path` to search_registry which expects `pattern`). This function maps unknown kwargs to missing expected params by position/best-effort. """ # Find the schema for this tool schema = None for t in tools: if t.get("name") == tool_name: schema = t.get("input_schema", {}) break if schema is None: return tool_args props = schema.get("properties", {}) required = set(schema.get("required", [])) unknown = [k for k in tool_args if k not in props] if not unknown: return tool_args # all args are valid, nothing to fix # Build the fixed args: start with valid args fixed = {k: v for k, v in tool_args.items() if k in props} # Find which expected params are still missing missing = [p for p in (required or props.keys()) if p not in fixed] # Try to map unknown args to missing params, in order unknown_values = [(k, tool_args[k]) for k in unknown] for wrong_name, value in unknown_values: if not missing: break # Pick the best match from missing params best = missing.pop(0) logger.warning( "Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)", tool_name, wrong_name, tool_name, best, ) fixed[best] = value return fixed def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None: """Emit a folded tool-call summary line for the terminal formatter. Instead of logging each tool call individually, we group by name: "list_directory x27, extract_file x3, read_text_file x3" """ counts = Counter(tc.get("name", "?") for tc in tool_calls) parts = [] for name, count in counts.most_common(): if count > 1: parts.append(f"{name} x{count}") else: parts.append(name) summary = ", ".join(parts) logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed}) @dataclass class _ToolBatch: """A batch of tool calls with the same read/write classification.""" is_read_only: bool calls: list[dict] = field(default_factory=list) def _partition_tool_calls( tool_calls: list[dict], read_only: set[str] | None = None, ) -> list[_ToolBatch]: """Partition tool calls into batches: consecutive read-only tools are grouped together (will run in parallel), write tools are isolated.""" if read_only is None: read_only = READ_ONLY_TOOLS batches: list[_ToolBatch] = [] for tc in tool_calls: is_ro = tc.get("name", "") in read_only if batches and batches[-1].is_read_only and is_ro: batches[-1].calls.append(tc) else: batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc])) return batches # --------------------------------------------------------------------------- # Context compression — keeps the message list from growing unboundedly. # --------------------------------------------------------------------------- # Stage A: Progressive tool result decay thresholds. # Messages are counted in (assistant, user) pairs from the END of the list. # "Round" = one pair of (assistant tool-calling msg, user tool-result msg). _DECAY_TIERS: list[tuple[int, int]] = [ # (rounds_ago_threshold, max_chars_for_tool_results) (5, 3000), # recent 5 rounds: keep full (3000 chars per tool result) (15, 500), # 5-15 rounds ago: aggressive truncation (999, 100), # older than 15 rounds: minimal stub ] def _apply_progressive_decay(messages: list[dict]) -> list[dict]: """Truncate the `content` of older `role: "tool"` messages to save context. Each `role: "tool"` message in the conversation corresponds to one tool call's result. We rank these messages by recency and progressively truncate older ones according to `_DECAY_TIERS`. """ total = len(messages) if total <= 10: return messages tool_msg_indices = [ i for i, m in enumerate(messages) if m.get("role") == "tool" ] decay_map: dict[int, int] = {} for rank, idx in enumerate(reversed(tool_msg_indices)): rounds_ago = rank for threshold, max_chars in _DECAY_TIERS: if rounds_ago < threshold: decay_map[idx] = max_chars break result = [] for i, msg in enumerate(messages): if i in decay_map: max_chars = decay_map[i] content = msg.get("content", "") or "" if len(content) > max_chars + 200: truncated = ( content[:max_chars] + f"\n... [context compressed: {len(content)} -> {max_chars} chars]" ) new_msg = dict(msg) new_msg["content"] = truncated result.append(new_msg) else: result.append(msg) else: result.append(msg) return result # Stage B: LLM-powered message folding. # When messages exceed this count, fold the oldest ones into a summary. _FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count _FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact _FOLD_SUMMARY_SYSTEM = ( "You are a concise summarizer for an ongoing forensic investigation conversation. " "Summarize the following early conversation between a forensic analysis agent and its " "tool execution results. Preserve:\n" "- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n" "- Tools that were called and their important results\n" "- Decisions made and current investigation direction\n" "Keep the summary under 800 words. Use bullet points." ) class LLMClient: """Async LLM client via the OpenAI SDK. Works with any OpenAI-compatible endpoint (OpenAI, DeepSeek, ...). Tool calling is text-based (ReAct) — see module docstring. """ def __init__( self, base_url: str, api_key: str, model: str = "deepseek-v4-pro", max_tokens: int = 4096, proxy: str | None = "auto", reasoning_effort: str | None = None, thinking_enabled: bool = False, ) -> None: self.base_url = base_url.rstrip("/") self.api_key = api_key self.model = model self.max_tokens = max_tokens self.reasoning_effort = reasoning_effort self.thinking_enabled = thinking_enabled # proxy="auto": read from env; proxy=None/""/"none": no proxy if proxy == "auto": proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY") elif proxy and proxy.lower() != "none": proxy_url = proxy else: proxy_url = None http_client = ( httpx.AsyncClient(proxy=proxy_url, timeout=300.0) if proxy_url else None ) self._client = AsyncOpenAI( api_key=self.api_key, base_url=self.base_url, timeout=300.0, http_client=http_client, ) async def close(self) -> None: await self._client.close() async def chat( self, messages: list[dict], system: str | None = None, max_retries: int = 5, ) -> str: """Send a streaming chat completion and return the assembled text.""" full_messages: list[dict] = [] if system: full_messages.append({"role": "system", "content": system}) full_messages.extend(messages) kwargs: dict[str, Any] = { "model": self.model, "messages": full_messages, "max_tokens": self.max_tokens, "stream": True, } if self.reasoning_effort: kwargs["reasoning_effort"] = self.reasoning_effort if self.thinking_enabled: kwargs["extra_body"] = {"thinking": {"type": "enabled"}} for attempt in range(max_retries): logger.debug( "LLM request (stream): %d messages (attempt %d)", len(messages), attempt + 1, ) text_parts: list[str] = [] try: stream = await self._client.chat.completions.create(**kwargs) async for chunk in stream: if not chunk.choices: continue delta = chunk.choices[0].delta if delta.content: text_parts.append(delta.content) text = "".join(text_parts) logger.debug("LLM response (stream): %d chars", len(text)) return text except (APIConnectionError, APITimeoutError, APIError) as e: if attempt < max_retries - 1: wait = 2 ** attempt * 10 logger.warning("Request failed (%s), retrying in %ds...", e, wait) await asyncio.sleep(wait) else: raise LLMAPIError( f"LLM API unreachable after {max_retries} attempts: {e}", attempts=max_retries, ) from e return "" async def _chat_with_tools( self, messages: list[dict], openai_tools: list[dict], max_retries: int = 5, ) -> tuple[str, str | None, list[dict]]: """Stream a chat completion with native tool calling enabled. Returns: (text_content, reasoning_content, raw_tool_calls). - reasoning_content is non-None when DeepSeek thinking mode is active; the caller MUST echo it back in the assistant message on subsequent requests, or the API returns HTTP 400. - raw_tool_calls is a list of {"id","name","arguments"} dicts; arguments is the raw JSON string returned by the API. """ kwargs: dict[str, Any] = { "model": self.model, "messages": messages, "max_tokens": self.max_tokens, "stream": True, "tools": openai_tools, } if self.reasoning_effort: kwargs["reasoning_effort"] = self.reasoning_effort if self.thinking_enabled: kwargs["extra_body"] = {"thinking": {"type": "enabled"}} for attempt in range(max_retries): logger.debug( "LLM request (stream+tools): %d messages, %d tools (attempt %d)", len(messages), len(openai_tools), attempt + 1, ) text_parts: list[str] = [] reasoning_parts: list[str] = [] tool_calls_acc: dict[int, dict] = {} # index -> {id, name, arguments} try: stream = await self._client.chat.completions.create(**kwargs) async for chunk in stream: if not chunk.choices: continue delta = chunk.choices[0].delta if delta.content: text_parts.append(delta.content) # DeepSeek thinking-mode: reasoning_content is returned # alongside content and MUST be echoed back on subsequent # requests, otherwise the API rejects with HTTP 400. rc = getattr(delta, "reasoning_content", None) if rc: reasoning_parts.append(rc) if delta.tool_calls: for tc_delta in delta.tool_calls: idx = tc_delta.index entry = tool_calls_acc.setdefault( idx, {"id": None, "name": None, "arguments": ""}, ) if tc_delta.id: entry["id"] = tc_delta.id fn = tc_delta.function if fn: if fn.name: entry["name"] = fn.name if fn.arguments: entry["arguments"] += fn.arguments text = "".join(text_parts) reasoning = "".join(reasoning_parts) or None ordered = [tool_calls_acc[i] for i in sorted(tool_calls_acc)] logger.debug( "LLM response (stream+tools): %d chars, %d reasoning chars, %d tool calls", len(text), len(reasoning or ""), len(ordered), ) return text, reasoning, ordered except (APIConnectionError, APITimeoutError, APIError) as e: if attempt < max_retries - 1: wait = 2 ** attempt * 10 logger.warning( "Tool-call request failed (%s), retrying in %ds...", e, wait, ) await asyncio.sleep(wait) else: raise LLMAPIError( f"LLM API unreachable after {max_retries} attempts: {e}", attempts=max_retries, ) from e return "", None, [] async def tool_call_loop( self, messages: list[dict], tools: list[dict], tool_executor: dict[str, Any], system: str | None = None, max_iterations: int = 60, terminal_tools: tuple[str, ...] = (), ) -> tuple[str, list[dict]]: """Run a tool-calling loop using OpenAI-native tool calls. The model returns structured `tool_calls` in its message; we dispatch them through our executor dict and feed each result back as a `role: "tool"` message with the matching `tool_call_id`. The loop ends when: - the model returns a message with no tool_calls (normal exit), or - any tool in `terminal_tools` is called — in that case, the loop short-circuits with that tool's result text as final_text. This gives agents (notably ReportAgent) an explicit completion signal that the old `` text tag used to provide. Returns: (final_text, full_message_history) """ terminal_set = set(terminal_tools) openai_tools = _to_openai_tools(tools) # The caller may pass `messages` either as raw conversation (no system) # together with `system=...`, OR as a complete history that already # starts with the system message (retry path). Accept both shapes. if messages and messages[0].get("role") == "system": full_messages: list[dict] = list(messages) else: full_messages = [] if system: full_messages.append({"role": "system", "content": system}) full_messages.extend(messages) _folded = False for _i in range(max_iterations): # ── Context compression before each API call ────────────── full_messages = _apply_progressive_decay(full_messages) if not _folded and len(full_messages) > _FOLD_THRESHOLD: full_messages = await self._fold_old_messages(full_messages) _folded = True elif _folded and len(full_messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT: full_messages = await self._fold_old_messages(full_messages) text, reasoning, raw_tool_calls = await self._chat_with_tools( full_messages, openai_tools, ) if not raw_tool_calls: # Model produced a final response. Strip optional # tags for backward compatibility with old prompts. final_msg: dict[str, Any] = {"role": "assistant", "content": text} if reasoning: final_msg["reasoning_content"] = reasoning full_messages.append(final_msg) answer = _extract_answer(text) return (answer if answer is not None else text), full_messages # Parse arguments + build internal call dicts parsed_calls: list[dict] = [] for rc in raw_tool_calls: args_str = rc.get("arguments", "") or "" try: args = _safe_json_loads(args_str) if args_str.strip() else {} except (json.JSONDecodeError, ValueError) as e: logger.warning( "Failed to parse arguments for tool %s: %s", rc.get("name"), e, ) args = {} parsed_calls.append({ "id": rc.get("id"), "name": rc.get("name", ""), "arguments": args, }) # Append the assistant turn with the raw tool_calls (and the # DeepSeek-mandated reasoning_content echo-back), then execute. asst_msg: dict[str, Any] = { "role": "assistant", "content": text or None, "tool_calls": [ { "id": rc.get("id"), "type": "function", "function": { "name": rc.get("name", ""), "arguments": rc.get("arguments", "") or "", }, } for rc in raw_tool_calls ], } if reasoning: asst_msg["reasoning_content"] = reasoning full_messages.append(asst_msg) batches = _partition_tool_calls(parsed_calls) t_batch_start = time.monotonic() # Each entry: (tool_call_dict, raw_result, formatted_for_llm) executed: list[tuple[dict, str, str]] = [] for batch in batches: if batch.is_read_only and len(batch.calls) > 1: results = await self._execute_tool_batch_parallel( batch.calls, tool_executor, tools, ) for tc, (raw, formatted) in zip(batch.calls, results): executed.append((tc, raw, formatted)) else: for tc in batch.calls: raw, formatted = await self._execute_single_tool( tc, tool_executor, tools, ) executed.append((tc, raw, formatted)) t_batch_elapsed = time.monotonic() - t_batch_start _emit_tool_call_summary(parsed_calls, t_batch_elapsed) # Append formatted tool results to the conversation (this is # what the LLM sees on subsequent rounds — truncated for context # economy). for tc, _raw, formatted in executed: full_messages.append({ "role": "tool", "tool_call_id": tc["id"], "content": formatted, }) # Terminal-tool short-circuit: if the model called any tool in # `terminal_tools`, end the loop immediately. The terminal tool's # RAW result (untruncated) becomes final_text — the LLM may have # produced a 20K-char report via save_report and we must not # truncate it just because the LLM-facing copy is truncated. if terminal_set: for tc, raw, _formatted in executed: name = tc.get("name", "") if name in terminal_set: logger.info( "Terminal tool %s called — exiting tool_call_loop", name, ) return raw, full_messages logger.warning("Tool call loop hit max iterations (%d)", max_iterations) return "[Max tool call iterations reached]", full_messages async def _execute_single_tool( self, tc: dict, tool_executor: dict[str, Any], tools: list[dict] | None = None, ) -> tuple[str, str]: """Execute a single tool call. Returns (raw_result, formatted_for_llm). `raw_result` is the unmodified executor return (used by terminal-tool short-circuit as final_text). `formatted_for_llm` is `[tool_name] {truncated}` and is what gets fed back to the model as the tool message content. """ tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False)) executor = tool_executor.get(tool_name) if executor is None: raw = f"Error: unknown tool '{tool_name}'" else: try: raw = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) raw = f"Error executing {tool_name}: {e}" formatted = f"[{tool_name}] {_truncate_tool_result(raw)}" return raw, formatted async def _execute_tool_batch_parallel( self, calls: list[dict], tool_executor: dict[str, Any], tools: list[dict] | None = None, ) -> list[tuple[str, str]]: """Execute multiple read-only tool calls concurrently. Returns a list of (raw_result, formatted_for_llm) tuples in the same order as `calls`. """ logger.info("Executing %d read-only tools in parallel", len(calls)) async def _run_one(tc: dict) -> tuple[str, str]: tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) logger.info( "Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False), ) executor = tool_executor.get(tool_name) if executor is None: raw = f"Error: unknown tool '{tool_name}'" else: try: raw = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) raw = f"Error executing {tool_name}: {e}" formatted = f"[{tool_name}] {_truncate_tool_result(raw)}" return raw, formatted results = await asyncio.gather(*[_run_one(tc) for tc in calls]) return list(results) async def _fold_old_messages( self, messages: list[dict], ) -> list[dict]: """Fold old messages into an LLM-generated summary (Stage B). Preserves the leading system message (if any), keeps the most recent _FOLD_KEEP_RECENT messages intact, and replaces the older middle slice with a single summary user message. """ # Pin the system message — it must NEVER be summarized away. system_msgs: list[dict] = [] body = messages if messages and messages[0].get("role") == "system": system_msgs = [messages[0]] body = messages[1:] n_to_fold = len(body) - _FOLD_KEEP_RECENT if n_to_fold <= 2: return messages # Pull the fold boundary forward so we never split an assistant turn # from its matching tool results. The API rejects (HTTP 400) any # `role: "tool"` message that does not immediately follow an # `assistant` message with `tool_calls`. We walk the boundary into # `recent_messages` while its head is a `role: "tool"` message, or # while the prior `recent` message is `assistant{tool_calls}` whose # paired tools span the boundary. while n_to_fold < len(body): head = body[n_to_fold] if head.get("role") == "tool": n_to_fold += 1 continue break if n_to_fold >= len(body): # Everything got folded — nothing recent to keep. return system_msgs + [body[0]] if system_msgs else messages old_messages = body[:n_to_fold] recent_messages = body[n_to_fold:] old_text_parts = [] for msg in old_messages: role = msg.get("role", "?") content = msg.get("content") or "" # Render tool_calls (assistant turn) compactly. if role == "assistant" and msg.get("tool_calls"): tc_names = [ tc.get("function", {}).get("name", "?") for tc in msg["tool_calls"] ] content = (content + " " if content else "") + ( "called: " + ", ".join(tc_names) ) if len(content) > 1000: content = content[:1000] + "..." old_text_parts.append(f"[{role}]: {content}") old_text = "\n\n".join(old_text_parts) # Cap total size sent to summarizer if len(old_text) > 15000: old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]" logger.info( "Context folding: summarizing %d old messages (%d chars) into summary", n_to_fold, len(old_text), ) try: summary = await self.chat( messages=[{"role": "user", "content": old_text}], system=_FOLD_SUMMARY_SYSTEM, ) except Exception as e: logger.warning("Context folding failed: %s — keeping original messages", e) return messages summary_message = { "role": "user", "content": ( f"[Context summary — the following summarizes {n_to_fold} earlier " f"messages in this conversation]\n\n{summary}" ), } return system_msgs + [summary_message] + recent_messages