"""Custom LLM client using httpx for Claude Messages API via third-party proxy. The proxy does not support Claude's native tool_use format (it strips the `tools` field from requests). So we embed tool definitions in the system prompt and parse structured JSON tool calls from the model's text output (ReAct-style). """ from __future__ import annotations import asyncio import json import logging import os import re import time from collections import Counter from dataclasses import dataclass, field from typing import Any import httpx logger = logging.getLogger(__name__) class LLMAPIError(Exception): """Raised when the LLM API is unreachable after all retries.""" def __init__(self, message: str, attempts: int) -> None: super().__init__(message) self.attempts = attempts # Markers the model uses to signal tool calls and final answers TOOL_CALL_TAG = "" TOOL_CALL_END = "" TOOL_RESULT_TAG = "" TOOL_RESULT_END = "" ANSWER_TAG = "" ANSWER_END = "" def _build_tools_prompt(tools: list[dict]) -> str: """Format tool definitions for inclusion in the system prompt.""" lines = ["You have access to the following tools:\n"] for t in tools: schema = t.get("input_schema", {}) props = schema.get("properties", {}) required = schema.get("required", []) params = [] for pname, pdef in props.items(): req = " (required)" if pname in required else "" desc = pdef.get("description", "") ptype = pdef.get("type", "string") enum_vals = pdef.get("enum") if enum_vals: allowed = ", ".join(f'"{v}"' for v in enum_vals) params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]") else: params.append(f" - {pname}: {ptype}{req} — {desc}") param_block = "\n".join(params) if params else " (no parameters)" lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n") lines.append( "## How to use tools\n" "To call a tool, output a JSON block wrapped in XML tags like this:\n" f"{TOOL_CALL_TAG}\n" '{"name": "tool_name", "arguments": {"param1": "value1"}}\n' f"{TOOL_CALL_END}\n\n" "You can call multiple tools in sequence. After each tool call, you will receive the result in:\n" f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n" "When you have finished your analysis and have a final answer, wrap it in:\n" f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n" "Think step by step. Call tools to gather evidence before drawing conclusions.\n" "You MUST call at least one tool before giving your final answer." ) return "\n".join(lines) def _extract_tool_calls(text: str) -> list[dict]: """Extract tool call JSON blocks from model output.""" pattern = re.compile( re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END), re.DOTALL, ) calls = [] for match in pattern.finditer(text): raw = match.group(1).strip() try: parsed = json.loads(raw) calls.append(parsed) except json.JSONDecodeError: logger.warning("Failed to parse tool call JSON: %s", raw[:200]) return calls def _extract_answer(text: str) -> str | None: """Extract the final answer from model output.""" pattern = re.compile( re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END), re.DOTALL, ) match = pattern.search(text) if match: return match.group(1).strip() return None def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str: """Truncate a tool result if it exceeds max_chars.""" if len(result_text) <= max_chars: return result_text return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]" # Tools that only read and never mutate state — safe to run concurrently. READ_ONLY_TOOLS: set[str] = { # Graph queries "list_phenomena", "get_phenomenon", "search_graph", "get_related", "get_hypothesis_status", "list_assets", "find_extracted_file", # Sleuth Kit reads "partition_info", "filesystem_info", "list_directory", "find_file", "search_strings", "count_deleted_files", "build_filesystem_timeline", # Registry reads (without auto-record wrappers) "parse_registry_key", "search_registry", "get_user_activity", # Parser reads "read_text_file", "read_binary_preview", "search_text_file", "read_text_file_section", "list_extracted_dir", "parse_pcap_strings", } def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict: """Try to fix misnamed tool arguments from LLM hallucination. The LLM sometimes confuses parameter names across tools (e.g. passing `key_path` to search_registry which expects `pattern`). This function maps unknown kwargs to missing expected params by position/best-effort. """ # Find the schema for this tool schema = None for t in tools: if t.get("name") == tool_name: schema = t.get("input_schema", {}) break if schema is None: return tool_args props = schema.get("properties", {}) required = set(schema.get("required", [])) unknown = [k for k in tool_args if k not in props] if not unknown: return tool_args # all args are valid, nothing to fix # Build the fixed args: start with valid args fixed = {k: v for k, v in tool_args.items() if k in props} # Find which expected params are still missing missing = [p for p in (required or props.keys()) if p not in fixed] # Try to map unknown args to missing params, in order unknown_values = [(k, tool_args[k]) for k in unknown] for wrong_name, value in unknown_values: if not missing: break # Pick the best match from missing params best = missing.pop(0) logger.warning( "Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)", tool_name, wrong_name, tool_name, best, ) fixed[best] = value return fixed def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None: """Emit a folded tool-call summary line for the terminal formatter. Instead of logging each tool call individually, we group by name: "list_directory x27, extract_file x3, read_text_file x3" """ counts = Counter(tc.get("name", "?") for tc in tool_calls) parts = [] for name, count in counts.most_common(): if count > 1: parts.append(f"{name} x{count}") else: parts.append(name) summary = ", ".join(parts) logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed}) @dataclass class _ToolBatch: """A batch of tool calls with the same read/write classification.""" is_read_only: bool calls: list[dict] = field(default_factory=list) def _partition_tool_calls( tool_calls: list[dict], read_only: set[str] | None = None, ) -> list[_ToolBatch]: """Partition tool calls into batches: consecutive read-only tools are grouped together (will run in parallel), write tools are isolated.""" if read_only is None: read_only = READ_ONLY_TOOLS batches: list[_ToolBatch] = [] for tc in tool_calls: is_ro = tc.get("name", "") in read_only if batches and batches[-1].is_read_only and is_ro: batches[-1].calls.append(tc) else: batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc])) return batches # --------------------------------------------------------------------------- # Context compression — keeps the message list from growing unboundedly. # --------------------------------------------------------------------------- # Stage A: Progressive tool result decay thresholds. # Messages are counted in (assistant, user) pairs from the END of the list. # "Round" = one pair of (assistant tool-calling msg, user tool-result msg). _DECAY_TIERS: list[tuple[int, int]] = [ # (rounds_ago_threshold, max_chars_for_tool_results) (5, 3000), # recent 5 rounds: keep full (3000 chars per tool result) (15, 500), # 5-15 rounds ago: aggressive truncation (999, 100), # older than 15 rounds: minimal stub ] def _apply_progressive_decay(messages: list[dict]) -> list[dict]: """Truncate tool results in older messages to save context space. Operates in-place-style on a copy. Only touches user messages that contain blocks (these are the tool-result messages generated by tool_call_loop). """ # Count rounds from the end. A "round" is a (assistant, user) pair. # messages alternate: [user, assistant, user, assistant, user, ...] # The initial user message is index 0, then pairs start at index 1. total = len(messages) if total <= 10: # not enough messages to bother return messages result = [] # Count tool-result user messages from the end tool_result_indices = [ i for i, m in enumerate(messages) if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "") ] # Build a set of indices that need decay, mapped to their max_chars decay_map: dict[int, int] = {} n_tool_msgs = len(tool_result_indices) for rank, idx in enumerate(reversed(tool_result_indices)): rounds_ago = rank # 0 = most recent, 1 = second most recent, ... for threshold, max_chars in _DECAY_TIERS: if rounds_ago < threshold: decay_map[idx] = max_chars break for i, msg in enumerate(messages): if i in decay_map: max_chars = decay_map[i] content = msg["content"] if len(content) > max_chars + 200: # Truncate but preserve the tool_result tags structure truncated = content[:max_chars] # Count how many tool results are in this message n_results = content.count(TOOL_RESULT_TAG) truncated += ( f"\n... [context compressed: {len(content)} -> {max_chars} chars, " f"{n_results} tool result(s)]" ) result.append({"role": msg["role"], "content": truncated}) else: result.append(msg) else: result.append(msg) return result # Stage B: LLM-powered message folding. # When messages exceed this count, fold the oldest ones into a summary. _FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count _FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact _FOLD_SUMMARY_SYSTEM = ( "You are a concise summarizer for an ongoing forensic investigation conversation. " "Summarize the following early conversation between a forensic analysis agent and its " "tool execution results. Preserve:\n" "- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n" "- Tools that were called and their important results\n" "- Decisions made and current investigation direction\n" "Keep the summary under 800 words. Use bullet points." ) class LLMClient: """Calls Claude Messages API through a third-party proxy using raw httpx. Uses prompt-based tool calling (ReAct pattern) since the proxy does not support Claude's native tool_use format. """ def __init__( self, base_url: str, api_key: str, model: str = "claude-sonnet-4-6", max_tokens: int = 4096, proxy: str | None = "auto", ) -> None: self.base_url = base_url.rstrip("/") self.api_key = api_key self.model = model self.max_tokens = max_tokens # proxy="auto": read from env; proxy=None/""/"none": no proxy; proxy="http://...": use it if proxy == "auto": proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY") elif proxy and proxy.lower() != "none": proxy_url = proxy else: proxy_url = None self._client = httpx.AsyncClient( base_url=self.base_url, headers={ "x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json", }, timeout=300.0, proxy=proxy_url, ) async def close(self) -> None: await self._client.aclose() async def chat( self, messages: list[dict], system: str | None = None, max_retries: int = 5, ) -> str: """Send a streaming chat request and return the assembled text response. Uses SSE streaming to keep the connection alive and avoid gateway timeouts (504/524) on long-running completions. """ import asyncio as _asyncio payload: dict[str, Any] = { "model": self.model, "max_tokens": self.max_tokens, "messages": messages, "stream": True, } if system: payload["system"] = system for attempt in range(max_retries): logger.debug("LLM request (stream): %d messages (attempt %d)", len(messages), attempt + 1) text_parts: list[str] = [] try: async with self._client.stream( "POST", "/v1/messages", json=payload, ) as resp: # Check for HTTP errors before consuming stream if resp.status_code >= 400: body = await resp.aread() raise httpx.HTTPStatusError( f"Server error '{resp.status_code}' for url '{resp.url}'", request=resp.request, response=resp, ) # Parse SSE events async for line in resp.aiter_lines(): if not line.startswith("data: "): continue data_str = line[6:] # strip "data: " prefix if data_str.strip() == "[DONE]": break try: event = json.loads(data_str) except json.JSONDecodeError: continue event_type = event.get("type", "") if event_type == "content_block_delta": delta = event.get("delta", {}) if delta.get("type") == "text_delta": text_parts.append(delta["text"]) elif event_type == "message_stop": break elif event_type == "error": err_msg = event.get("error", {}).get("message", "Unknown streaming error") raise httpx.HTTPStatusError( err_msg, request=resp.request, response=resp, ) text = "".join(text_parts) logger.debug("LLM response (stream): %d chars", len(text)) return text except (httpx.HTTPStatusError, httpx.ConnectError, httpx.ReadTimeout, httpx.RemoteProtocolError) as e: if attempt < max_retries - 1: wait = 2 ** attempt * 10 logger.warning("Request failed (%s), retrying in %ds...", e, wait) await _asyncio.sleep(wait) else: raise LLMAPIError( f"LLM API unreachable after {max_retries} attempts: {e}", attempts=max_retries, ) from e # Should not reach here, but just in case return "" async def tool_call_loop( self, messages: list[dict], tools: list[dict], tool_executor: dict[str, Any], system: str | None = None, max_iterations: int = 40, ) -> tuple[str, list[dict]]: """Run a ReAct-style tool-calling loop. The model outputs blocks which we parse and execute, feeding results back as blocks until the model outputs an block. Returns: (final_text, all_messages) """ # Build system prompt with tool definitions tools_prompt = _build_tools_prompt(tools) full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt messages = list(messages) # don't mutate caller's list _folded = False # Track whether we've already folded once this loop for i in range(max_iterations): # ── Context compression before each API call ────────────── # Stage A: progressively decay old tool results messages = _apply_progressive_decay(messages) # Stage B: fold oldest messages into LLM summary if too long if not _folded and len(messages) > _FOLD_THRESHOLD: messages = await self._fold_old_messages(messages, full_system) _folded = True elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT: # Allow a second fold if messages grew back significantly messages = await self._fold_old_messages(messages, full_system) text = await self.chat(messages, system=full_system) # Check for final answer answer = _extract_answer(text) if answer is not None: messages.append({"role": "assistant", "content": text}) return answer, messages # Check for tool calls tool_calls = _extract_tool_calls(text) if not tool_calls: # No tool calls and no answer tag — treat entire text as answer messages.append({"role": "assistant", "content": text}) return text, messages # Execute tool calls — read-only tools run in parallel messages.append({"role": "assistant", "content": text}) result_parts = [] batches = _partition_tool_calls(tool_calls) t_batch_start = time.monotonic() for batch in batches: if batch.is_read_only and len(batch.calls) > 1: batch_results = await self._execute_tool_batch_parallel( batch.calls, tool_executor, tools, ) result_parts.extend(batch_results) else: for tc in batch.calls: result_parts.append( await self._execute_single_tool(tc, tool_executor, tools) ) # Emit folded tool-call summary for the terminal t_batch_elapsed = time.monotonic() - t_batch_start _emit_tool_call_summary(tool_calls, t_batch_elapsed) # Feed results back as a user message result_message = "\n\n".join(result_parts) messages.append({"role": "user", "content": result_message}) logger.warning("Tool call loop hit max iterations (%d)", max_iterations) return "[Max tool call iterations reached]", messages async def _execute_single_tool( self, tc: dict, tool_executor: dict[str, Any], tools: list[dict] | None = None, ) -> str: """Execute a single tool call and return the formatted result.""" tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False)) executor = tool_executor.get(tool_name) if executor is None: result_text = f"Error: unknown tool '{tool_name}'" else: try: result_text = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) result_text = f"Error executing {tool_name}: {e}" return ( f"{TOOL_RESULT_TAG}\n" f"[{tool_name}] {_truncate_tool_result(result_text)}\n" f"{TOOL_RESULT_END}" ) async def _execute_tool_batch_parallel( self, calls: list[dict], tool_executor: dict[str, Any], tools: list[dict] | None = None, ) -> list[str]: """Execute multiple read-only tool calls concurrently.""" logger.info("Executing %d read-only tools in parallel", len(calls)) async def _run_one(tc: dict) -> str: tool_name = tc.get("name", "") tool_args = tc.get("arguments", {}) if tools: tool_args = _fix_tool_args(tool_name, tool_args, tools) logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False)) executor = tool_executor.get(tool_name) if executor is None: result_text = f"Error: unknown tool '{tool_name}'" else: try: result_text = await executor(**tool_args) except Exception as e: logger.error("Tool %s failed: %s", tool_name, e) result_text = f"Error executing {tool_name}: {e}" return ( f"{TOOL_RESULT_TAG}\n" f"[{tool_name}] {_truncate_tool_result(result_text)}\n" f"{TOOL_RESULT_END}" ) results = await asyncio.gather(*[_run_one(tc) for tc in calls]) return list(results) async def _fold_old_messages( self, messages: list[dict], system: str, ) -> list[dict]: """Fold old messages into an LLM-generated summary (Stage B). Keeps the most recent _FOLD_KEEP_RECENT messages intact and replaces earlier ones with a single summary message. """ n_to_fold = len(messages) - _FOLD_KEEP_RECENT if n_to_fold <= 2: return messages old_messages = messages[:n_to_fold] recent_messages = messages[n_to_fold:] # Build a text dump of old messages for summarization old_text_parts = [] for msg in old_messages: role = msg["role"] content = msg.get("content", "") # Truncate each message for the summary prompt to avoid overload if len(content) > 1000: content = content[:1000] + "..." old_text_parts.append(f"[{role}]: {content}") old_text = "\n\n".join(old_text_parts) # Cap total size sent to summarizer if len(old_text) > 15000: old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]" logger.info( "Context folding: summarizing %d old messages (%d chars) into summary", n_to_fold, len(old_text), ) try: summary = await self.chat( messages=[{"role": "user", "content": old_text}], system=_FOLD_SUMMARY_SYSTEM, ) except Exception as e: logger.warning("Context folding failed: %s — keeping original messages", e) return messages # Replace old messages with a single summary summary_message = { "role": "user", "content": ( f"[Context summary — the following summarizes {n_to_fold} earlier " f"messages in this conversation]\n\n{summary}" ), } return [summary_message] + recent_messages