"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API). Tool calling is text-based (ReAct pattern): tool schemas are embedded in the system prompt and tool calls are parsed as JSON blocks from model output. This keeps the protocol independent of whether the underlying API supports native function/tool calling. """ from __future__ import annotations import asyncio import json import logging import os import re import time from collections import Counter from dataclasses import dataclass, field from typing import Any import httpx from openai import APIConnectionError, APIError, APITimeoutError, AsyncOpenAI logger = logging.getLogger(__name__) class LLMAPIError(Exception): """Raised when the LLM API is unreachable after all retries.""" def __init__(self, message: str, attempts: int) -> None: super().__init__(message) self.attempts = attempts # Markers the model uses to signal tool calls and final answers TOOL_CALL_TAG = "" TOOL_CALL_END = "" TOOL_RESULT_TAG = "" TOOL_RESULT_END = "" ANSWER_TAG = "" ANSWER_END = "" def _build_tools_prompt(tools: list[dict]) -> str: """Format tool definitions for inclusion in the system prompt.""" lines = ["You have access to the following tools:\n"] for t in tools: schema = t.get("input_schema", {}) props = schema.get("properties", {}) required = schema.get("required", []) params = [] for pname, pdef in props.items(): req = " (required)" if pname in required else "" desc = pdef.get("description", "") ptype = pdef.get("type", "string") enum_vals = pdef.get("enum") if enum_vals: allowed = ", ".join(f'"{v}"' for v in enum_vals) params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]") else: params.append(f" - {pname}: {ptype}{req} — {desc}") param_block = "\n".join(params) if params else " (no parameters)" lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n") lines.append( "## How to use tools\n" "To call a tool, output a JSON block wrapped in XML tags like this:\n" f"{TOOL_CALL_TAG}\n" '{"name": "tool_name", "arguments": {"param1": "value1"}}\n' f"{TOOL_CALL_END}\n\n" "You can call multiple tools in sequence. After each tool call, you will receive the result in:\n" f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n" "When you have finished your analysis and have a final answer, wrap it in:\n" f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n" "Think step by step. Call tools to gather evidence before drawing conclusions.\n" "You MUST call at least one tool before giving your final answer." ) return "\n".join(lines) def _extract_tool_calls(text: str) -> list[dict]: """Extract tool call JSON blocks from model output.""" pattern = re.compile( re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END), re.DOTALL, ) calls = [] for match in pattern.finditer(text): raw = match.group(1).strip() try: parsed = json.loads(raw) calls.append(parsed) except json.JSONDecodeError: logger.warning("Failed to parse tool call JSON: %s", raw[:200]) return calls def _extract_answer(text: str) -> str | None: """Extract the final answer from model output.""" pattern = re.compile( re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END), re.DOTALL, ) match = pattern.search(text) if match: return match.group(1).strip() return None def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str: """Truncate a tool result if it exceeds max_chars.""" if len(result_text) <= max_chars: return result_text return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]" # Tools that only read and never mutate state — safe to run concurrently. READ_ONLY_TOOLS: set[str] = { # Graph queries "list_phenomena", "get_phenomenon", "search_graph", "get_related", "get_hypothesis_status", "list_assets", "find_extracted_file", # Sleuth Kit reads "partition_info", "filesystem_info", "list_directory", "find_file", "search_strings", "count_deleted_files", "build_filesystem_timeline", # Registry reads (without auto-record wrappers) "parse_registry_key", "search_registry", "get_user_activity", # Parser reads "read_text_file", "read_binary_preview", "search_text_file", "read_text_file_section", "list_extracted_dir", "parse_pcap_strings", } def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict: """Try to fix misnamed tool arguments from LLM hallucination. The LLM sometimes confuses parameter names across tools (e.g. passing `key_path` to search_registry which expects `pattern`). This function maps unknown kwargs to missing expected params by position/best-effort. """ # Find the schema for this tool schema = None for t in tools: if t.get("name") == tool_name: schema = t.get("input_schema", {}) break if schema is None: return tool_args props = schema.get("properties", {}) required = set(schema.get("required", [])) unknown = [k for k in tool_args if k not in props] if not unknown: return tool_args # all args are valid, nothing to fix # Build the fixed args: start with valid args fixed = {k: v for k, v in tool_args.items() if k in props} # Find which expected params are still missing missing = [p for p in (required or props.keys()) if p not in fixed] # Try to map unknown args to missing params, in order unknown_values = [(k, tool_args[k]) for k in unknown] for wrong_name, value in unknown_values: if not missing: break # Pick the best match from missing params best = missing.pop(0) logger.warning( "Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)", tool_name, wrong_name, tool_name, best, ) fixed[best] = value return fixed def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None: """Emit a folded tool-call summary line for the terminal formatter. Instead of logging each tool call individually, we group by name: "list_directory x27, extract_file x3, read_text_file x3" """ counts = Counter(tc.get("name", "?") for tc in tool_calls) parts = [] for name, count in counts.most_common(): if count > 1: parts.append(f"{name} x{count}") else: parts.append(name) summary = ", ".join(parts) logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed}) @dataclass class _ToolBatch: """A batch of tool calls with the same read/write classification.""" is_read_only: bool calls: list[dict] = field(default_factory=list) def _partition_tool_calls( tool_calls: list[dict], read_only: set[str] | None = None, ) -> list[_ToolBatch]: """Partition tool calls into batches: consecutive read-only tools are grouped together (will run in parallel), write tools are isolated.""" if read_only is None: read_only = READ_ONLY_TOOLS batches: list[_ToolBatch] = [] for tc in tool_calls: is_ro = tc.get("name", "") in read_only if batches and batches[-1].is_read_only and is_ro: batches[-1].calls.append(tc) else: batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc])) return batches # --------------------------------------------------------------------------- # Context compression — keeps the message list from growing unboundedly. # --------------------------------------------------------------------------- # Stage A: Progressive tool result decay thresholds. # Messages are counted in (assistant, user) pairs from the END of the list. # "Round" = one pair of (assistant tool-calling msg, user tool-result msg). _DECAY_TIERS: list[tuple[int, int]] = [ # (rounds_ago_threshold, max_chars_for_tool_results) (5, 3000), # recent 5 rounds: keep full (3000 chars per tool result) (15, 500), # 5-15 rounds ago: aggressive truncation (999, 100), # older than 15 rounds: minimal stub ] def _apply_progressive_decay(messages: list[dict]) -> list[dict]: """Truncate tool results in older messages to save context space. Operates in-place-style on a copy. Only touches user messages that contain blocks (these are the tool-result messages generated by tool_call_loop). """ # Count rounds from the end. A "round" is a (assistant, user) pair. # messages alternate: [user, assistant, user, assistant, user, ...] # The initial user message is index 0, then pairs start at index 1. total = len(messages) if total <= 10: # not enough messages to bother return messages result = [] # Count tool-result user messages from the end tool_result_indices = [ i for i, m in enumerate(messages) if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "") ] # Build a set of indices that need decay, mapped to their max_chars decay_map: dict[int, int] = {} n_tool_msgs = len(tool_result_indices) for rank, idx in enumerate(reversed(tool_result_indices)): rounds_ago = rank # 0 = most recent, 1 = second most recent, ... for threshold, max_chars in _DECAY_TIERS: if rounds_ago < threshold: decay_map[idx] = max_chars break for i, msg in enumerate(messages): if i in decay_map: max_chars = decay_map[i] content = msg["content"] if len(content) > max_chars + 200: # Truncate but preserve the tool_result tags structure truncated = content[:max_chars] # Count how many tool results are in this message n_results = content.count(TOOL_RESULT_TAG) truncated += ( f"\n... [context compressed: {len(content)} -> {max_chars} chars, " f"{n_results} tool result(s)]" ) result.append({"role": msg["role"], "content": truncated}) else: result.append(msg) else: result.append(msg) return result # Stage B: LLM-powered message folding. # When messages exceed this count, fold the oldest ones into a summary. _FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count _FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact _FOLD_SUMMARY_SYSTEM = ( "You are a concise summarizer for an ongoing forensic investigation conversation. " "Summarize the following early conversation between a forensic analysis agent and its " "tool execution results. Preserve:\n" "- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n" "- Tools that were called and their important results\n" "- Decisions made and current investigation direction\n" "Keep the summary under 800 words. Use bullet points." ) class LLMClient: """Async LLM client via the OpenAI SDK. Works with any OpenAI-compatible endpoint (OpenAI, DeepSeek, ...). Tool calling is text-based (ReAct) — see module docstring. """ def __init__( self, base_url: str, api_key: str, model: str = "deepseek-v4-pro", max_tokens: int = 4096, proxy: str | None = "auto", reasoning_effort: str | None = None, thinking_enabled: bool = False, ) -> None: self.base_url = base_url.rstrip("/") self.api_key = api_key self.model = model self.max_tokens = max_tokens self.reasoning_effort = reasoning_effort self.thinking_enabled = thinking_enabled # proxy="auto": read from env; proxy=None/""/"none": no proxy if proxy == "auto": proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY") elif proxy and proxy.lower() != "none": proxy_url = proxy else: proxy_url = None http_client = ( httpx.AsyncClient(proxy=proxy_url, timeout=300.0) if proxy_url else None ) self._client = AsyncOpenAI( api_key=self.api_key, base_url=self.base_url, timeout=300.0, http_client=http_client, ) async def close(self) -> None: await self._client.close() async def chat( self, messages: list[dict], system: str | None = None, max_retries: int = 5, ) -> str: """Send a streaming chat completion and return the assembled text.""" full_messages: list[dict] = [] if system: full_messages.append({"role": "system", "content": system}) full_messages.extend(messages) kwargs: dict[str, Any] = { "model": self.model, "messages": full_messages, "max_tokens": self.max_tokens, "stream": True, } if self.reasoning_effort: kwargs["reasoning_effort"] = self.reasoning_effort if self.thinking_enabled: kwargs["extra_body"] = {"thinking": {"type": "enabled"}} for attempt in range(max_retries): logger.debug( "LLM request (stream): %d messages (attempt %d)", len(messages), attempt + 1, ) text_parts: list[str] = [] try: stream = await self._client.chat.completions.create(**kwargs) async for chunk in stream: if not chunk.choices: continue delta = chunk.choices[0].delta if delta.content: text_parts.append(delta.content) text = "".join(text_parts) logger.debug("LLM response (stream): %d chars", len(text)) return text except (APIConnectionError, APITimeoutError, APIError) as e: if attempt < max_retries - 1: wait = 2 ** attempt * 10 logger.warning("Request failed (%s), retrying in %ds...", e, wait) await asyncio.sleep(wait) else: raise LLMAPIError( f"LLM API unreachable after {max_retries} attempts: {e}", attempts=max_retries, ) from e return "" async def tool_call_loop( self, messages: list[dict], tools: list[dict], tool_executor: dict[str, Any], system: str | None = None, max_iterations: int = 40, ) -> tuple[str, list[dict]]: """Run a ReAct-style tool-calling loop. The model outputs blocks which we parse and execute, feeding results back as blocks until the model outputs an block. Returns: (final_text, all_messages) """ # Build system prompt with tool definitions tools_prompt = _build_tools_prompt(tools) full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt messages = list(messages) # don't mutate caller's list _folded = False # Track whether we've already folded once this loop for i in range(max_iterations): # ── Context compression before each API call ────────────── # Stage A: progressively decay old tool results messages = _apply_progressive_decay(messages) # Stage B: fold oldest messages into LLM summary if too long if not _folded and len(messages) > _FOLD_THRESHOLD: messages = await self._fold_old_messages(messages, full_system) _folded = True elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT: # Allow a second fold if messages grew back significantly messages = await self._fold_old_messages(messages, full_system) text = await self.chat(messages, system=full_system) # Check for final answer answer = _extract_answer(text) if answer is not None: messages.append({"role": "assistant", "content": text}) return answer, messages # Check for tool calls tool_calls = _extract_tool_calls(text) if not tool_calls: # No tool calls and no answer tag — treat entire text as answer messages.append({"role": "assistant", "content": text}) return text, messages # Execute tool calls — read-only tools run in parallel messages.append({"role": "assistant", "content": text}) result_parts = [] batches = _partition_tool_calls(tool_calls) t_batch_start = time.monotonic() for batch in batches: if batch.is_read_only and len(batch.calls) > 1: batch_results = await self._execute_tool_batch_parallel( batch.calls, tool_executor, tools, ) result_parts.extend(batch_results) else: for tc in batch.calls: result_parts.append( await self._execute_single_tool(tc, tool_executor, tools) ) # Emit folded tool-call summary for the terminal t_batch_elapsed = time.monotonic() - t_batch_start _emit_tool_call_summary(tool_calls, t_batch_elapsed) # Feed results back as a user message result_message = "\n\n".join(result_parts) messages.append({"role": "user", "content": result_message}) logger.warning("Tool call loop hit max iterations (%d)", max_iterations) return "[Max tool call iterations reached]", messages async def _execute_single_tool( self, tc: dict, tool_executor: dict[str, Any], tools: list[dict] | None = None, ) -> str: """Execute a single tool call and return the formatted result.""" tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False)) executor = tool_executor.get(tool_name) if executor is None: result_text = f"Error: unknown tool '{tool_name}'" else: try: result_text = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) result_text = f"Error executing {tool_name}: {e}" return ( f"{TOOL_RESULT_TAG}\n" f"[{tool_name}] {_truncate_tool_result(result_text)}\n" f"{TOOL_RESULT_END}" ) async def _execute_tool_batch_parallel( self, calls: list[dict], tool_executor: dict[str, Any], tools: list[dict] | None = None, ) -> list[str]: """Execute multiple read-only tool calls concurrently.""" logger.info("Executing %d read-only tools in parallel", len(calls)) async def _run_one(tc: dict) -> str: tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False)) executor = tool_executor.get(tool_name) if executor is None: result_text = f"Error: unknown tool '{tool_name}'" else: try: result_text = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) result_text = f"Error executing {tool_name}: {e}" return ( f"{TOOL_RESULT_TAG}\n" f"[{tool_name}] {_truncate_tool_result(result_text)}\n" f"{TOOL_RESULT_END}" ) results = await asyncio.gather(*[_run_one(tc) for tc in calls]) return list(results) async def _fold_old_messages( self, messages: list[dict], system: str, ) -> list[dict]: """Fold old messages into an LLM-generated summary (Stage B). Keeps the most recent _FOLD_KEEP_RECENT messages intact and replaces earlier ones with a single summary message. """ n_to_fold = len(messages) - _FOLD_KEEP_RECENT if n_to_fold <= 2: return messages old_messages = messages[:n_to_fold] recent_messages = messages[n_to_fold:] # Build a text dump of old messages for summarization old_text_parts = [] for msg in old_messages: role = msg["role"] content = msg.get("content", "") # Truncate each message for the summary prompt to avoid overload if len(content) > 1000: content = content[:1000] + "..." old_text_parts.append(f"[{role}]: {content}") old_text = "\n\n".join(old_text_parts) # Cap total size sent to summarizer if len(old_text) > 15000: old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]" logger.info( "Context folding: summarizing %d old messages (%d chars) into summary", n_to_fold, len(old_text), ) try: summary = await self.chat( messages=[{"role": "user", "content": old_text}], system=_FOLD_SUMMARY_SYSTEM, ) except Exception as e: logger.warning("Context folding failed: %s — keeping original messages", e) return messages # Replace old messages with a single summary summary_message = { "role": "user", "content": ( f"[Context summary — the following summarizes {n_to_fold} earlier " f"messages in this conversation]\n\n{summary}" ), } return [summary_message] + recent_messages