The previous LLMClient used raw httpx + Claude Messages API (/v1/messages, x-api-key, Anthropic SSE event types). Incompatible with DeepSeek. Rewrite LLMClient.__init__/chat/close to use openai.AsyncOpenAI: - /v1/chat/completions endpoint, OpenAI message format - Bearer auth, native SDK error types - Stream chunks via async for + chunk.choices[0].delta.content Tool calling protocol (ReAct text-based tags) and all surrounding helpers (_apply_progressive_decay, _fold_old_messages, _partition_tool_calls, tool_call_loop, etc.) are unchanged — endpoint-agnostic by design. New optional config params surfaced to config.yaml.agent: - reasoning_effort: "high" | "medium" | "low" — DeepSeek/o1-style depth - thinking_enabled: bool — DeepSeek extra_body.thinking switch main.py and regenerate_report.py pass these through to LLMClient. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
603 lines
22 KiB
Python
603 lines
22 KiB
Python
"""LLM client via the OpenAI SDK (works with DeepSeek's OpenAI-compatible API).
|
|
|
|
Tool calling is text-based (ReAct pattern): tool schemas are embedded in
|
|
the system prompt and tool calls are parsed as <tool_call> JSON blocks
|
|
from model output. This keeps the protocol independent of whether the
|
|
underlying API supports native function/tool calling.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from collections import Counter
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from openai import APIConnectionError, APIError, APITimeoutError, AsyncOpenAI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMAPIError(Exception):
|
|
"""Raised when the LLM API is unreachable after all retries."""
|
|
|
|
def __init__(self, message: str, attempts: int) -> None:
|
|
super().__init__(message)
|
|
self.attempts = attempts
|
|
|
|
|
|
# Markers the model uses to signal tool calls and final answers
|
|
TOOL_CALL_TAG = "<tool_call>"
|
|
TOOL_CALL_END = "</tool_call>"
|
|
TOOL_RESULT_TAG = "<tool_result>"
|
|
TOOL_RESULT_END = "</tool_result>"
|
|
ANSWER_TAG = "<answer>"
|
|
ANSWER_END = "</answer>"
|
|
|
|
|
|
def _build_tools_prompt(tools: list[dict]) -> str:
|
|
"""Format tool definitions for inclusion in the system prompt."""
|
|
lines = ["You have access to the following tools:\n"]
|
|
for t in tools:
|
|
schema = t.get("input_schema", {})
|
|
props = schema.get("properties", {})
|
|
required = schema.get("required", [])
|
|
|
|
params = []
|
|
for pname, pdef in props.items():
|
|
req = " (required)" if pname in required else ""
|
|
desc = pdef.get("description", "")
|
|
ptype = pdef.get("type", "string")
|
|
enum_vals = pdef.get("enum")
|
|
if enum_vals:
|
|
allowed = ", ".join(f'"{v}"' for v in enum_vals)
|
|
params.append(f" - {pname}: {ptype}{req} — {desc} Allowed values: [{allowed}]")
|
|
else:
|
|
params.append(f" - {pname}: {ptype}{req} — {desc}")
|
|
|
|
param_block = "\n".join(params) if params else " (no parameters)"
|
|
lines.append(f"## {t['name']}\n{t['description']}\nParameters:\n{param_block}\n")
|
|
|
|
lines.append(
|
|
"## How to use tools\n"
|
|
"To call a tool, output a JSON block wrapped in XML tags like this:\n"
|
|
f"{TOOL_CALL_TAG}\n"
|
|
'{"name": "tool_name", "arguments": {"param1": "value1"}}\n'
|
|
f"{TOOL_CALL_END}\n\n"
|
|
"You can call multiple tools in sequence. After each tool call, you will receive the result in:\n"
|
|
f"{TOOL_RESULT_TAG}\n...result...\n{TOOL_RESULT_END}\n\n"
|
|
"When you have finished your analysis and have a final answer, wrap it in:\n"
|
|
f"{ANSWER_TAG}\nyour final answer here\n{ANSWER_END}\n\n"
|
|
"Think step by step. Call tools to gather evidence before drawing conclusions.\n"
|
|
"You MUST call at least one tool before giving your final answer."
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _extract_tool_calls(text: str) -> list[dict]:
|
|
"""Extract tool call JSON blocks from model output."""
|
|
pattern = re.compile(
|
|
re.escape(TOOL_CALL_TAG) + r"\s*(.*?)\s*" + re.escape(TOOL_CALL_END),
|
|
re.DOTALL,
|
|
)
|
|
calls = []
|
|
for match in pattern.finditer(text):
|
|
raw = match.group(1).strip()
|
|
try:
|
|
parsed = json.loads(raw)
|
|
calls.append(parsed)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Failed to parse tool call JSON: %s", raw[:200])
|
|
return calls
|
|
|
|
|
|
def _extract_answer(text: str) -> str | None:
|
|
"""Extract the final answer from model output."""
|
|
pattern = re.compile(
|
|
re.escape(ANSWER_TAG) + r"\s*(.*?)\s*" + re.escape(ANSWER_END),
|
|
re.DOTALL,
|
|
)
|
|
match = pattern.search(text)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return None
|
|
|
|
|
|
def _truncate_tool_result(result_text: str, max_chars: int = 3000) -> str:
|
|
"""Truncate a tool result if it exceeds max_chars."""
|
|
if len(result_text) <= max_chars:
|
|
return result_text
|
|
return result_text[: max_chars - 200] + f"\n... [truncated, {len(result_text)} total chars]"
|
|
|
|
|
|
# Tools that only read and never mutate state — safe to run concurrently.
|
|
READ_ONLY_TOOLS: set[str] = {
|
|
# Graph queries
|
|
"list_phenomena", "get_phenomenon", "search_graph", "get_related",
|
|
"get_hypothesis_status", "list_assets", "find_extracted_file",
|
|
# Sleuth Kit reads
|
|
"partition_info", "filesystem_info", "list_directory", "find_file",
|
|
"search_strings", "count_deleted_files", "build_filesystem_timeline",
|
|
# Registry reads (without auto-record wrappers)
|
|
"parse_registry_key", "search_registry", "get_user_activity",
|
|
# Parser reads
|
|
"read_text_file", "read_binary_preview", "search_text_file",
|
|
"read_text_file_section", "list_extracted_dir", "parse_pcap_strings",
|
|
}
|
|
|
|
|
|
def _fix_tool_args(tool_name: str, tool_args: dict, tools: list[dict]) -> dict:
|
|
"""Try to fix misnamed tool arguments from LLM hallucination.
|
|
|
|
The LLM sometimes confuses parameter names across tools (e.g. passing
|
|
`key_path` to search_registry which expects `pattern`). This function
|
|
maps unknown kwargs to missing expected params by position/best-effort.
|
|
"""
|
|
# Find the schema for this tool
|
|
schema = None
|
|
for t in tools:
|
|
if t.get("name") == tool_name:
|
|
schema = t.get("input_schema", {})
|
|
break
|
|
if schema is None:
|
|
return tool_args
|
|
|
|
props = schema.get("properties", {})
|
|
required = set(schema.get("required", []))
|
|
|
|
unknown = [k for k in tool_args if k not in props]
|
|
if not unknown:
|
|
return tool_args # all args are valid, nothing to fix
|
|
|
|
# Build the fixed args: start with valid args
|
|
fixed = {k: v for k, v in tool_args.items() if k in props}
|
|
|
|
# Find which expected params are still missing
|
|
missing = [p for p in (required or props.keys()) if p not in fixed]
|
|
|
|
# Try to map unknown args to missing params, in order
|
|
unknown_values = [(k, tool_args[k]) for k in unknown]
|
|
|
|
for wrong_name, value in unknown_values:
|
|
if not missing:
|
|
break
|
|
# Pick the best match from missing params
|
|
best = missing.pop(0)
|
|
logger.warning(
|
|
"Auto-fixing tool arg: %s(%s=...) -> %s(%s=...)",
|
|
tool_name, wrong_name, tool_name, best,
|
|
)
|
|
fixed[best] = value
|
|
|
|
return fixed
|
|
|
|
|
|
def _emit_tool_call_summary(tool_calls: list[dict], elapsed: float) -> None:
|
|
"""Emit a folded tool-call summary line for the terminal formatter.
|
|
|
|
Instead of logging each tool call individually, we group by name:
|
|
"list_directory x27, extract_file x3, read_text_file x3"
|
|
"""
|
|
counts = Counter(tc.get("name", "?") for tc in tool_calls)
|
|
parts = []
|
|
for name, count in counts.most_common():
|
|
if count > 1:
|
|
parts.append(f"{name} x{count}")
|
|
else:
|
|
parts.append(name)
|
|
summary = ", ".join(parts)
|
|
logger.info(summary, extra={"event": "tool_calls", "elapsed": elapsed})
|
|
|
|
|
|
@dataclass
|
|
class _ToolBatch:
|
|
"""A batch of tool calls with the same read/write classification."""
|
|
is_read_only: bool
|
|
calls: list[dict] = field(default_factory=list)
|
|
|
|
|
|
def _partition_tool_calls(
|
|
tool_calls: list[dict],
|
|
read_only: set[str] | None = None,
|
|
) -> list[_ToolBatch]:
|
|
"""Partition tool calls into batches: consecutive read-only tools are
|
|
grouped together (will run in parallel), write tools are isolated."""
|
|
if read_only is None:
|
|
read_only = READ_ONLY_TOOLS
|
|
batches: list[_ToolBatch] = []
|
|
for tc in tool_calls:
|
|
is_ro = tc.get("name", "") in read_only
|
|
if batches and batches[-1].is_read_only and is_ro:
|
|
batches[-1].calls.append(tc)
|
|
else:
|
|
batches.append(_ToolBatch(is_read_only=is_ro, calls=[tc]))
|
|
return batches
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Context compression — keeps the message list from growing unboundedly.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Stage A: Progressive tool result decay thresholds.
|
|
# Messages are counted in (assistant, user) pairs from the END of the list.
|
|
# "Round" = one pair of (assistant tool-calling msg, user tool-result msg).
|
|
_DECAY_TIERS: list[tuple[int, int]] = [
|
|
# (rounds_ago_threshold, max_chars_for_tool_results)
|
|
(5, 3000), # recent 5 rounds: keep full (3000 chars per tool result)
|
|
(15, 500), # 5-15 rounds ago: aggressive truncation
|
|
(999, 100), # older than 15 rounds: minimal stub
|
|
]
|
|
|
|
|
|
def _apply_progressive_decay(messages: list[dict]) -> list[dict]:
|
|
"""Truncate tool results in older messages to save context space.
|
|
|
|
Operates in-place-style on a copy. Only touches user messages that
|
|
contain <tool_result> blocks (these are the tool-result messages
|
|
generated by tool_call_loop).
|
|
"""
|
|
# Count rounds from the end. A "round" is a (assistant, user) pair.
|
|
# messages alternate: [user, assistant, user, assistant, user, ...]
|
|
# The initial user message is index 0, then pairs start at index 1.
|
|
total = len(messages)
|
|
if total <= 10: # not enough messages to bother
|
|
return messages
|
|
|
|
result = []
|
|
# Count tool-result user messages from the end
|
|
tool_result_indices = [
|
|
i for i, m in enumerate(messages)
|
|
if m["role"] == "user" and TOOL_RESULT_TAG in m.get("content", "")
|
|
]
|
|
|
|
# Build a set of indices that need decay, mapped to their max_chars
|
|
decay_map: dict[int, int] = {}
|
|
n_tool_msgs = len(tool_result_indices)
|
|
for rank, idx in enumerate(reversed(tool_result_indices)):
|
|
rounds_ago = rank # 0 = most recent, 1 = second most recent, ...
|
|
for threshold, max_chars in _DECAY_TIERS:
|
|
if rounds_ago < threshold:
|
|
decay_map[idx] = max_chars
|
|
break
|
|
|
|
for i, msg in enumerate(messages):
|
|
if i in decay_map:
|
|
max_chars = decay_map[i]
|
|
content = msg["content"]
|
|
if len(content) > max_chars + 200:
|
|
# Truncate but preserve the tool_result tags structure
|
|
truncated = content[:max_chars]
|
|
# Count how many tool results are in this message
|
|
n_results = content.count(TOOL_RESULT_TAG)
|
|
truncated += (
|
|
f"\n... [context compressed: {len(content)} -> {max_chars} chars, "
|
|
f"{n_results} tool result(s)]"
|
|
)
|
|
result.append({"role": msg["role"], "content": truncated})
|
|
else:
|
|
result.append(msg)
|
|
else:
|
|
result.append(msg)
|
|
return result
|
|
|
|
|
|
# Stage B: LLM-powered message folding.
|
|
# When messages exceed this count, fold the oldest ones into a summary.
|
|
_FOLD_THRESHOLD = 24 # trigger folding when messages exceed this count
|
|
_FOLD_KEEP_RECENT = 10 # always keep the most recent N messages intact
|
|
_FOLD_SUMMARY_SYSTEM = (
|
|
"You are a concise summarizer for an ongoing forensic investigation conversation. "
|
|
"Summarize the following early conversation between a forensic analysis agent and its "
|
|
"tool execution results. Preserve:\n"
|
|
"- Key findings and evidence discovered (file paths, inode numbers, timestamps)\n"
|
|
"- Tools that were called and their important results\n"
|
|
"- Decisions made and current investigation direction\n"
|
|
"Keep the summary under 800 words. Use bullet points."
|
|
)
|
|
|
|
|
|
class LLMClient:
|
|
"""Async LLM client via the OpenAI SDK.
|
|
|
|
Works with any OpenAI-compatible endpoint (OpenAI, DeepSeek, ...).
|
|
Tool calling is text-based (ReAct) — see module docstring.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
api_key: str,
|
|
model: str = "deepseek-v4-pro",
|
|
max_tokens: int = 4096,
|
|
proxy: str | None = "auto",
|
|
reasoning_effort: str | None = None,
|
|
thinking_enabled: bool = False,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
self.reasoning_effort = reasoning_effort
|
|
self.thinking_enabled = thinking_enabled
|
|
|
|
# proxy="auto": read from env; proxy=None/""/"none": no proxy
|
|
if proxy == "auto":
|
|
proxy_url = os.environ.get("https_proxy") or os.environ.get("HTTPS_PROXY")
|
|
elif proxy and proxy.lower() != "none":
|
|
proxy_url = proxy
|
|
else:
|
|
proxy_url = None
|
|
|
|
http_client = (
|
|
httpx.AsyncClient(proxy=proxy_url, timeout=300.0)
|
|
if proxy_url else None
|
|
)
|
|
|
|
self._client = AsyncOpenAI(
|
|
api_key=self.api_key,
|
|
base_url=self.base_url,
|
|
timeout=300.0,
|
|
http_client=http_client,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
await self._client.close()
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
system: str | None = None,
|
|
max_retries: int = 5,
|
|
) -> str:
|
|
"""Send a streaming chat completion and return the assembled text."""
|
|
full_messages: list[dict] = []
|
|
if system:
|
|
full_messages.append({"role": "system", "content": system})
|
|
full_messages.extend(messages)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": self.model,
|
|
"messages": full_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"stream": True,
|
|
}
|
|
if self.reasoning_effort:
|
|
kwargs["reasoning_effort"] = self.reasoning_effort
|
|
if self.thinking_enabled:
|
|
kwargs["extra_body"] = {"thinking": {"type": "enabled"}}
|
|
|
|
for attempt in range(max_retries):
|
|
logger.debug(
|
|
"LLM request (stream): %d messages (attempt %d)",
|
|
len(messages), attempt + 1,
|
|
)
|
|
text_parts: list[str] = []
|
|
try:
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
|
async for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
delta = chunk.choices[0].delta
|
|
if delta.content:
|
|
text_parts.append(delta.content)
|
|
|
|
text = "".join(text_parts)
|
|
logger.debug("LLM response (stream): %d chars", len(text))
|
|
return text
|
|
|
|
except (APIConnectionError, APITimeoutError, APIError) as e:
|
|
if attempt < max_retries - 1:
|
|
wait = 2 ** attempt * 10
|
|
logger.warning("Request failed (%s), retrying in %ds...", e, wait)
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise LLMAPIError(
|
|
f"LLM API unreachable after {max_retries} attempts: {e}",
|
|
attempts=max_retries,
|
|
) from e
|
|
|
|
return ""
|
|
|
|
async def tool_call_loop(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict],
|
|
tool_executor: dict[str, Any],
|
|
system: str | None = None,
|
|
max_iterations: int = 40,
|
|
) -> tuple[str, list[dict]]:
|
|
"""Run a ReAct-style tool-calling loop.
|
|
|
|
The model outputs <tool_call> blocks which we parse and execute,
|
|
feeding results back as <tool_result> blocks until the model
|
|
outputs an <answer> block.
|
|
|
|
Returns:
|
|
(final_text, all_messages)
|
|
"""
|
|
# Build system prompt with tool definitions
|
|
tools_prompt = _build_tools_prompt(tools)
|
|
full_system = f"{system}\n\n{tools_prompt}" if system else tools_prompt
|
|
|
|
messages = list(messages) # don't mutate caller's list
|
|
_folded = False # Track whether we've already folded once this loop
|
|
|
|
for i in range(max_iterations):
|
|
# ── Context compression before each API call ──────────────
|
|
# Stage A: progressively decay old tool results
|
|
messages = _apply_progressive_decay(messages)
|
|
|
|
# Stage B: fold oldest messages into LLM summary if too long
|
|
if not _folded and len(messages) > _FOLD_THRESHOLD:
|
|
messages = await self._fold_old_messages(messages, full_system)
|
|
_folded = True
|
|
elif _folded and len(messages) > _FOLD_THRESHOLD + _FOLD_KEEP_RECENT:
|
|
# Allow a second fold if messages grew back significantly
|
|
messages = await self._fold_old_messages(messages, full_system)
|
|
|
|
text = await self.chat(messages, system=full_system)
|
|
|
|
# Check for final answer
|
|
answer = _extract_answer(text)
|
|
if answer is not None:
|
|
messages.append({"role": "assistant", "content": text})
|
|
return answer, messages
|
|
|
|
# Check for tool calls
|
|
tool_calls = _extract_tool_calls(text)
|
|
|
|
if not tool_calls:
|
|
# No tool calls and no answer tag — treat entire text as answer
|
|
messages.append({"role": "assistant", "content": text})
|
|
return text, messages
|
|
|
|
# Execute tool calls — read-only tools run in parallel
|
|
messages.append({"role": "assistant", "content": text})
|
|
|
|
result_parts = []
|
|
batches = _partition_tool_calls(tool_calls)
|
|
t_batch_start = time.monotonic()
|
|
|
|
for batch in batches:
|
|
if batch.is_read_only and len(batch.calls) > 1:
|
|
batch_results = await self._execute_tool_batch_parallel(
|
|
batch.calls, tool_executor, tools,
|
|
)
|
|
result_parts.extend(batch_results)
|
|
else:
|
|
for tc in batch.calls:
|
|
result_parts.append(
|
|
await self._execute_single_tool(tc, tool_executor, tools)
|
|
)
|
|
|
|
# Emit folded tool-call summary for the terminal
|
|
t_batch_elapsed = time.monotonic() - t_batch_start
|
|
_emit_tool_call_summary(tool_calls, t_batch_elapsed)
|
|
|
|
# Feed results back as a user message
|
|
result_message = "\n\n".join(result_parts)
|
|
messages.append({"role": "user", "content": result_message})
|
|
|
|
logger.warning("Tool call loop hit max iterations (%d)", max_iterations)
|
|
return "[Max tool call iterations reached]", messages
|
|
|
|
async def _execute_single_tool(
|
|
self, tc: dict, tool_executor: dict[str, Any],
|
|
tools: list[dict] | None = None,
|
|
) -> str:
|
|
"""Execute a single tool call and return the formatted result."""
|
|
tool_name = tc.get("name", "")
|
|
tool_args = tc.get("arguments", {})
|
|
|
|
if tools:
|
|
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
|
|
|
logger.info("Calling tool: %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
|
|
|
executor = tool_executor.get(tool_name)
|
|
if executor is None:
|
|
result_text = f"Error: unknown tool '{tool_name}'"
|
|
else:
|
|
try:
|
|
result_text = await executor(**tool_args)
|
|
except Exception as e:
|
|
logger.error("Tool %s failed: %s", tool_name, e)
|
|
result_text = f"Error executing {tool_name}: {e}"
|
|
|
|
return (
|
|
f"{TOOL_RESULT_TAG}\n"
|
|
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
|
f"{TOOL_RESULT_END}"
|
|
)
|
|
|
|
async def _execute_tool_batch_parallel(
|
|
self, calls: list[dict], tool_executor: dict[str, Any],
|
|
tools: list[dict] | None = None,
|
|
) -> list[str]:
|
|
"""Execute multiple read-only tool calls concurrently."""
|
|
logger.info("Executing %d read-only tools in parallel", len(calls))
|
|
|
|
async def _run_one(tc: dict) -> str:
|
|
tool_name = tc.get("name", "")
|
|
tool_args = tc.get("arguments", {})
|
|
if tools:
|
|
tool_args = _fix_tool_args(tool_name, tool_args, tools)
|
|
logger.info("Calling tool (parallel): %s(%s)", tool_name, json.dumps(tool_args, ensure_ascii=False))
|
|
executor = tool_executor.get(tool_name)
|
|
if executor is None:
|
|
result_text = f"Error: unknown tool '{tool_name}'"
|
|
else:
|
|
try:
|
|
result_text = await executor(**tool_args)
|
|
except Exception as e:
|
|
logger.error("Tool %s failed: %s", tool_name, e)
|
|
result_text = f"Error executing {tool_name}: {e}"
|
|
return (
|
|
f"{TOOL_RESULT_TAG}\n"
|
|
f"[{tool_name}] {_truncate_tool_result(result_text)}\n"
|
|
f"{TOOL_RESULT_END}"
|
|
)
|
|
|
|
results = await asyncio.gather(*[_run_one(tc) for tc in calls])
|
|
return list(results)
|
|
|
|
async def _fold_old_messages(
|
|
self, messages: list[dict], system: str,
|
|
) -> list[dict]:
|
|
"""Fold old messages into an LLM-generated summary (Stage B).
|
|
|
|
Keeps the most recent _FOLD_KEEP_RECENT messages intact and
|
|
replaces earlier ones with a single summary message.
|
|
"""
|
|
n_to_fold = len(messages) - _FOLD_KEEP_RECENT
|
|
if n_to_fold <= 2:
|
|
return messages
|
|
|
|
old_messages = messages[:n_to_fold]
|
|
recent_messages = messages[n_to_fold:]
|
|
|
|
# Build a text dump of old messages for summarization
|
|
old_text_parts = []
|
|
for msg in old_messages:
|
|
role = msg["role"]
|
|
content = msg.get("content", "")
|
|
# Truncate each message for the summary prompt to avoid overload
|
|
if len(content) > 1000:
|
|
content = content[:1000] + "..."
|
|
old_text_parts.append(f"[{role}]: {content}")
|
|
old_text = "\n\n".join(old_text_parts)
|
|
|
|
# Cap total size sent to summarizer
|
|
if len(old_text) > 15000:
|
|
old_text = old_text[:15000] + "\n\n... [further messages omitted for brevity]"
|
|
|
|
logger.info(
|
|
"Context folding: summarizing %d old messages (%d chars) into summary",
|
|
n_to_fold, len(old_text),
|
|
)
|
|
|
|
try:
|
|
summary = await self.chat(
|
|
messages=[{"role": "user", "content": old_text}],
|
|
system=_FOLD_SUMMARY_SYSTEM,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Context folding failed: %s — keeping original messages", e)
|
|
return messages
|
|
|
|
# Replace old messages with a single summary
|
|
summary_message = {
|
|
"role": "user",
|
|
"content": (
|
|
f"[Context summary — the following summarizes {n_to_fold} earlier "
|
|
f"messages in this conversation]\n\n{summary}"
|
|
),
|
|
}
|
|
return [summary_message] + recent_messages
|