Initial commit
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
619
llm_client.py
Normal file
619
llm_client.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""Custom LLM client using httpx for Claude Messages API via third-party proxy.
|
||||
|
||||
The proxy does not support Claude's native tool_use format (it strips the `tools`
|
||||
field from requests). So we embed tool definitions in the system prompt and parse
|
||||
structured JSON tool calls from the model's text output (ReAct-style).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMAPIError(Exception):
|
||||
"""Raised when the LLM API is unreachable after all retries."""
|
||||
|
||||
def __init__(self, message: str, attempts: int) -> None:
|
||||
super().__init__(message)
|
||||
self.attempts = attempts
|
||||
|
||||
|
||||
# Markers the model uses to signal tool calls and final answers
|
||||
TOOL_CALL_TAG = "<tool_call>"
|
||||
TOOL_CALL_END = "</tool_call>"
|
||||
TOOL_RESULT_TAG = "<tool_result>"
|
||||
TOOL_RESULT_END = "</tool_result>"
|
||||
ANSWER_TAG = "<answer>"
|
||||
ANSWER_END = "</answer>"
|
||||
|
||||
|
||||
def _build_tools_prompt(tools: list[dict]) -> str:
|
||||
"""Format tool definitions for inclusion in the system prompt."""
|
||||
lines = ["You have access to the following tools:\n"]
|
||||
for t in tools:
|
||||
schema = t.get("input_schema", {})
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
params = []
|
||||
for pname, pdef in props.items():
|
||||
req = " (required)" if pname in required else ""
|
||||
desc = pdef.get("description", "")
|
||||
ptype = pdef.get("type", "string")
|
||||
enum_vals = pdef.get("enum")
|
||||
if enum_vals:
|
||||
allowed = ", ".join(f'"{v}"' for v in enum_vals)
|
||||
params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]")
|
||||
else:
|
||||
params.append(f" - {pname}: {ptype}{req} — {desc}")
|
||||
|
||||
param_block = "\n".join(params) if params else " (no parameters)"
|
||||
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
|
||||
|
||||
lines.append(
|
||||
"## How to use tools\n"
|
||||
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
|
||||
f"{TOOL_CALL_TAG}\n"
|
||||
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
|
||||
f"{TOOL_CALL_END}\n\n"
|
||||
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
|
||||
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
|
||||
"When you have finished your analysis and have a final answer, wrap it in:\n"
|
||||
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
|
||||
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
|
||||
"You MUST call at least one tool before giving your final answer."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_tool_calls(text: str) -> list[dict]:
|
||||
"""Extract tool call JSON blocks from model output."""
|
||||
pattern = re.compile(
|
||||
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
|
||||
re.DOTALL,
|
||||
)
|
||||
calls = []
|
||||
for match in pattern.finditer(text):
|
||||
raw = match.group(1).strip()
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
calls.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
|
||||
return calls
|
||||
|
||||
|
||||
def _extract_answer(text: str) -> str | None:
|
||||
"""Extract the final answer from model output."""
|
||||
pattern = re.compile(
|
||||
re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END),
|
||||
re.DOTALL,
|
||||
)
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
|
||||
def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str:
|
||||
"""Truncate a tool result if it exceeds max_chars."""
|
||||
if len(result_text) <= max_chars:
|
||||
return result_text
|
||||
return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]"
|
||||
|
||||
|
||||
# Tools that only read and never mutate state — safe to run concurrently.
|
||||
READ_ONLY_TOOLS: set[str] = {
|
||||
# Graph queries
|
||||
"list_phenomena", "get_phenomenon", "search_graph", "get_related",
|
||||
"get_hypothesis_status", "list_assets", "find_extracted_file",
|
||||
# Sleuth Kit reads
|
||||
"partition_info", "filesystem_info", "list_directory", "find_file",
|
||||
"search_strings", "count_deleted_files", "build_filesystem_timeline",
|
||||
# Registry reads (without auto-record wrappers)
|
||||
"parse_registry_key", "search_registry", "get_user_activity",
|
||||
# Parser reads
|
||||
"read_text_file", "read_binary_preview", "search_text_file",
|
||||
"read_text_file_section", "list_extracted_dir", "parse_pcap_strings",
|
||||
}
|
||||
|
||||
|
||||
def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict:
|
||||
"""Try to fix misnamed tool arguments from LLM hallucination.
|
||||
|
||||
The LLM sometimes confuses parameter names across tools (e.g. passing
|
||||
`key_path` to search_registry which expects `pattern`). This function
|
||||
maps unknown kwargs to missing expected params by position/best-effort.
|
||||
"""
|
||||
# Find the schema for this tool
|
||||
schema = None
|
||||
for t in tools:
|
||||
if t.get("name") == tool_name:
|
||||
schema = t.get("input_schema", {})
|
||||
break
|
||||
if schema is None:
|
||||
return tool_args
|
||||
|
||||
props = schema.get("properties", {})
|
||||
required = set(schema.get("required", []))
|
||||
|
||||
unknown = [k for k in tool_args if k not in props]
|
||||
if not unknown:
|
||||
return tool_args # all args are valid, nothing to fix
|
||||
|
||||
# Build the fixed args: start with valid args
|
||||
fixed = {k: v for k, v in tool_args.items() if k in props}
|
||||
|
||||
# Find which expected params are still missing
|
||||
missing = [p for p in (required or props.keys()) if p not in fixed]
|
||||
|
||||
# Try to map unknown args to missing params, in order
|
||||
unknown_values = [(k, tool_args[k]) for k in unknown]
|
||||
|
||||
for wrong_name, value in unknown_values:
|
||||
if not missing:
|
||||
break
|
||||
# Pick the best match from missing params
|
||||
best = missing.pop(0)
|
||||
logger.warning(
|
||||
"Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)",
|
||||
tool_name, wrong_name, tool_name, best,
|
||||
)
|
||||
fixed[best] = value
|
||||
|
||||
return fixed
|
||||
|
||||
|
||||
def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None:
|
||||
"""Emit a folded tool-call summary line for the terminal formatter.
|
||||
|
||||
Instead of logging each tool call individually, we group by name:
|
||||
"list_directory x27, extract_file x3, read_text_file x3"
|
||||
"""
|
||||
counts = Counter(tc.get("name", "?") for tc in tool_calls)
|
||||
parts = []
|
||||
for name, count in counts.most_common():
|
||||
if count > 1:
|
||||
parts.append(f"{name} x{count}")
|
||||
else:
|
||||
parts.append(name)
|
||||
summary = ", ".join(parts)
|
||||
logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed})
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ToolBatch:
|
||||
"""A batch of tool calls with the same read/write classification."""
|
||||
is_read_only: bool
|
||||
calls: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
def _partition_tool_calls(
|
||||
tool_calls: list[dict],
|
||||
read_only: set[str] | None = None,
|
||||
) -> list[_ToolBatch]:
|
||||
"""Partition tool calls into batches: consecutive read-only tools are
|
||||
grouped together (will run in parallel), write tools are isolated."""
|
||||
if read_only is None:
|
||||
read_only = READ_ONLY_TOOLS
|
||||
batches: list[_ToolBatch] = []
|
||||
for tc in tool_calls:
|
||||
is_ro = tc.get("name", "") in read_only
|
||||
if batches and batches[-1].is_read_only and is_ro:
|
||||
batches[-1].calls.append(tc)
|
||||
else:
|
||||
batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc]))
|
||||
return batches
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context compression — keeps the message list from growing unboundedly.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Stage A: Progressive tool result decay thresholds.
|
||||
# Messages are counted in (assistant, user) pairs from the END of the list.
|
||||
# "Round" = one pair of (assistant tool-calling msg, user tool-result msg).
|
||||
_DECAY_TIERS: list[tuple[int, int]] = [
|
||||
# (rounds_ago_threshold, max_chars_for_tool_results)
|
||||
(5, 3000), # recent 5 rounds: keep full (3000 chars per tool result)
|
||||
(15, 500), # 5-15 rounds ago: aggressive truncation
|
||||
(999, 100), # older than 15 rounds: minimal stub
|
||||
]
|
||||
|
||||
|
||||
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
|
||||
"""Truncate tool results in older messages to save context space.
|
||||
|
||||
Operates in-place-style on a copy. Only touches user messages that
|
||||
contain <tool_result> blocks (these are the tool-result messages
|
||||
generated by tool_call_loop).
|
||||
"""
|
||||
# Count rounds from the end. A "round" is a (assistant, user) pair.
|
||||
# messages alternate: [user, assistant, user, assistant, user, ...]
|
||||
# The initial user message is index 0, then pairs start at index 1.
|
||||
total = len(messages)
|
||||
if total <= 10: # not enough messages to bother
|
||||
return messages
|
||||
|
||||
result = []
|
||||
# Count tool-result user messages from the end
|
||||
tool_result_indices = [
|
||||
i for i, m in enumerate(messages)
|
||||
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
|
||||
]
|
||||
|
||||
# Build a set of indices that need decay, mapped to their max_chars
|
||||
decay_map: dict[int, int] = {}
|
||||
n_tool_msgs = len(tool_result_indices)
|
||||
for rank, idx in enumerate(reversed(tool_result_indices)):
|
||||
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
|
||||
for threshold, max_chars in _DECAY_TIERS:
|
||||
if rounds_ago < threshold:
|
||||
decay_map[idx] = max_chars
|
||||
break
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
if i in decay_map:
|
||||
max_chars = decay_map[i]
|
||||
content = msg["content"]
|
||||
if len(content) > max_chars + 200:
|
||||
# Truncate but preserve the tool_result tags structure
|
||||
truncated = content[:max_chars]
|
||||
# Count how many tool results are in this message
|
||||
n_results = content.count(TOOL_RESULT_TAG)
|
||||
truncated += (
|
||||
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
|
||||
f"{n_results} tool result(s)]"
|
||||
)
|
||||
result.append({"role": msg["role"], "content": truncated})
|
||||
else:
|
||||
result.append(msg)
|
||||
else:
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
# Stage B: LLM-powered message folding.
|
||||
# When messages exceed this count, fold the oldest ones into a summary.
|
||||
_FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count
|
||||
_FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact
|
||||
_FOLD_SUMMARY_SYSTEM = (
|
||||
"You are a concise summarizer for an ongoing forensic investigation conversation. "
|
||||
"Summarize the following early conversation between a forensic analysis agent and its "
|
||||
"tool execution results. Preserve:\n"
|
||||
"- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n"
|
||||
"- Tools that were called and their important results\n"
|
||||
"- Decisions made and current investigation direction\n"
|
||||
"Keep the summary under 800 words. Use bullet points."
|
||||
)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Calls Claude Messages API through a third-party proxy using raw httpx.
|
||||
|
||||
Uses prompt-based tool calling (ReAct pattern) since the proxy does not
|
||||
support Claude's native tool_use format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str = "claude-sonnet-4-6",
|
||||
max_tokens: int = 4096,
|
||||
proxy: str | None = "auto",
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
# proxy="auto": read from env; proxy=None/""/"none": no proxy; proxy="http://...": use it
|
||||
if proxy == "auto":
|
||||
proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY")
|
||||
elif proxy and proxy.lower() != "none":
|
||||
proxy_url = proxy
|
||||
else:
|
||||
proxy_url = None
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
timeout=300.0,
|
||||
proxy=proxy_url,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_retries: int = 5,
|
||||
) -> str:
|
||||
"""Send a streaming chat request and return the assembled text response.
|
||||
|
||||
Uses SSE streaming to keep the connection alive and avoid gateway
|
||||
timeouts (504/524) on long-running completions.
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
for attempt in range(max_retries):
|
||||
logger.debug("LLM request (stream): %d messages (attempt %d)", len(messages), attempt + 1)
|
||||
text_parts: list[str] = []
|
||||
try:
|
||||
async with self._client.stream(
|
||||
"POST", "/v1/messages", json=payload,
|
||||
) as resp:
|
||||
# Check for HTTP errors before consuming stream
|
||||
if resp.status_code >= 400:
|
||||
body = await resp.aread()
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Server error '{resp.status_code}' for url '{resp.url}'",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
|
||||
# Parse SSE events
|
||||
async for line in resp.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:] # strip "data: " prefix
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
event = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
event_type = event.get("type", "")
|
||||
if event_type == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text_parts.append(delta["text"])
|
||||
elif event_type == "message_stop":
|
||||
break
|
||||
elif event_type == "error":
|
||||
err_msg = event.get("error", {}).get("message", "Unknown streaming error")
|
||||
raise httpx.HTTPStatusError(
|
||||
err_msg, request=resp.request, response=resp,
|
||||
)
|
||||
|
||||
text = "".join(text_parts)
|
||||
logger.debug("LLM response (stream): %d chars", len(text))
|
||||
return text
|
||||
|
||||
except (httpx.HTTPStatusError, httpx.ConnectError, httpx.ReadTimeout, httpx.RemoteProtocolError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait = 2 ** attempt * 10
|
||||
logger.warning("Request failed (%s), retrying in %ds...", e, wait)
|
||||
await _asyncio.sleep(wait)
|
||||
else:
|
||||
raise LLMAPIError(
|
||||
f"LLM API unreachable after {max_retries} attempts: {e}",
|
||||
attempts=max_retries,
|
||||
) from e
|
||||
|
||||
# Should not reach here, but just in case
|
||||
return ""
|
||||
|
||||
async def tool_call_loop(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict],
|
||||
tool_executor: dict[str, Any],
|
||||
system: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Run a ReAct-style tool-calling loop.
|
||||
|
||||
The model outputs <tool_call> blocks which we parse and execute,
|
||||
feeding results back as <tool_result> blocks until the model
|
||||
outputs an <answer> block.
|
||||
|
||||
Returns:
|
||||
(final_text, all_messages)
|
||||
"""
|
||||
# Build system prompt with tool definitions
|
||||
tools_prompt = _build_tools_prompt(tools)
|
||||
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
|
||||
|
||||
messages = list(messages) # don't mutate caller's list
|
||||
_folded = False # Track whether we've already folded once this loop
|
||||
|
||||
for i in range(max_iterations):
|
||||
# ── Context compression before each API call ──────────────
|
||||
# Stage A: progressively decay old tool results
|
||||
messages = _apply_progressive_decay(messages)
|
||||
|
||||
# Stage B: fold oldest messages into LLM summary if too long
|
||||
if not _folded and len(messages) > _FOLD_THRESHOLD:
|
||||
messages = await self._fold_old_messages(messages, full_system)
|
||||
_folded = True
|
||||
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
||||
# Allow a second fold if messages grew back significantly
|
||||
messages = await self._fold_old_messages(messages, full_system)
|
||||
|
||||
text = await self.chat(messages, system=full_system)
|
||||
|
||||
# Check for final answer
|
||||
answer = _extract_answer(text)
|
||||
if answer is not None:
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
return answer, messages
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = _extract_tool_calls(text)
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls and no answer tag — treat entire text as answer
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
return text, messages
|
||||
|
||||
# Execute tool calls — read-only tools run in parallel
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
|
||||
result_parts = []
|
||||
batches = _partition_tool_calls(tool_calls)
|
||||
t_batch_start = time.monotonic()
|
||||
|
||||
for batch in batches:
|
||||
if batch.is_read_only and len(batch.calls) > 1:
|
||||
batch_results = await self._execute_tool_batch_parallel(
|
||||
batch.calls, tool_executor, tools,
|
||||
)
|
||||
result_parts.extend(batch_results)
|
||||
else:
|
||||
for tc in batch.calls:
|
||||
result_parts.append(
|
||||
await self._execute_single_tool(tc, tool_executor, tools)
|
||||
)
|
||||
|
||||
# Emit folded tool-call summary for the terminal
|
||||
t_batch_elapsed = time.monotonic() - t_batch_start
|
||||
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
|
||||
|
||||
# Feed results back as a user message
|
||||
result_message = "\n\n".join(result_parts)
|
||||
messages.append({"role": "user", "content": result_message})
|
||||
|
||||
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
|
||||
return "[Max tool call iterations reached]", messages
|
||||
|
||||
async def _execute_single_tool(
|
||||
self, tc: dict, tool_executor: dict[str, Any],
|
||||
tools: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Execute a single tool call and return the formatted result."""
|
||||
tool_name = tc.get("name", "")
|
||||
tool_args = tc.get("arguments", {})
|
||||
|
||||
if tools:
|
||||
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
||||
|
||||
logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
||||
|
||||
executor = tool_executor.get(tool_name)
|
||||
if executor is None:
|
||||
result_text = f"Error: unknown tool '{tool_name}'"
|
||||
else:
|
||||
try:
|
||||
result_text = await executor(**tool_args)
|
||||
except Exception as e:
|
||||
logger.error("Tool %s failed: %s", tool_name, e)
|
||||
result_text = f"Error executing {tool_name}: {e}"
|
||||
|
||||
return (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
)
|
||||
|
||||
async def _execute_tool_batch_parallel(
|
||||
self, calls: list[dict], tool_executor: dict[str, Any],
|
||||
tools: list[dict] | None = None,
|
||||
) -> list[str]:
|
||||
"""Execute multiple read-only tool calls concurrently."""
|
||||
logger.info("Executing %d read-only tools in parallel", len(calls))
|
||||
|
||||
async def _run_one(tc: dict) -> str:
|
||||
tool_name = tc.get("name", "")
|
||||
tool_args = tc.get("arguments", {})
|
||||
if tools:
|
||||
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
||||
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
||||
executor = tool_executor.get(tool_name)
|
||||
if executor is None:
|
||||
result_text = f"Error: unknown tool '{tool_name}'"
|
||||
else:
|
||||
try:
|
||||
result_text = await executor(**tool_args)
|
||||
except Exception as e:
|
||||
logger.error("Tool %s failed: %s", tool_name, e)
|
||||
result_text = f"Error executing {tool_name}: {e}"
|
||||
return (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
|
||||
return list(results)
|
||||
|
||||
async def _fold_old_messages(
|
||||
self, messages: list[dict], system: str,
|
||||
) -> list[dict]:
|
||||
"""Fold old messages into an LLM-generated summary (Stage B).
|
||||
|
||||
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
|
||||
replaces earlier ones with a single summary message.
|
||||
"""
|
||||
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
|
||||
if n_to_fold <= 2:
|
||||
return messages
|
||||
|
||||
old_messages = messages[:n_to_fold]
|
||||
recent_messages = messages[n_to_fold:]
|
||||
|
||||
# Build a text dump of old messages for summarization
|
||||
old_text_parts = []
|
||||
for msg in old_messages:
|
||||
role = msg["role"]
|
||||
content = msg.get("content", "")
|
||||
# Truncate each message for the summary prompt to avoid overload
|
||||
if len(content) > 1000:
|
||||
content = content[:1000] + "..."
|
||||
old_text_parts.append(f"[{role}]: {content}")
|
||||
old_text = "\n\n".join(old_text_parts)
|
||||
|
||||
# Cap total size sent to summarizer
|
||||
if len(old_text) > 15000:
|
||||
old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]"
|
||||
|
||||
logger.info(
|
||||
"Context folding: summarizing %d old messages (%d chars) into summary",
|
||||
n_to_fold, len(old_text),
|
||||
)
|
||||
|
||||
try:
|
||||
summary = await self.chat(
|
||||
messages=[{"role": "user", "content": old_text}],
|
||||
system=_FOLD_SUMMARY_SYSTEM,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Context folding failed: %s — keeping original messages", e)
|
||||
return messages
|
||||
|
||||
# Replace old messages with a single summary
|
||||
summary_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"[Context summary — the following summarizes {n_to_fold} earlier "
|
||||
f"messages in this conversation]\n\n{summary}"
|
||||
),
|
||||
}
|
||||
return [summary_message] + recent_messages
|
||||
Reference in New Issue
Block a user