test: track tests/ directory in version control

tests/test_optimizations.py — 60 pytest cases covering:
- EvidenceGraph: quality scoring, Jaccard merge, async safety,
  hypothesis confidence updates, asset library
- llm_client: tool-result truncation, parallel batch execution,
  progressive context decay, message folding
- orchestrator: parallel dispatch, batched lead generation,
  batched judging
- tool_registry: result cache key derivation

FakeAgent.run signatures updated to BaseAgent.run(task, lead_id=None).

Previously listed in .gitignore (which is itself untracked, so the
ignore rule lives only locally).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
BattleTag
2026-05-12 14:10:31 +08:00
parent 74e6bde13a
commit 31812a72ee

957
tests/test_optimizations.py Normal file
View File

@@ -0,0 +1,957 @@
"""Tests for dedup/quality scoring, context management, parallel dispatch, and evidence graph."""
from __future__ import annotations
import asyncio
import time
import pytest
from evidence_graph import (
EvidenceGraph, Phenomenon, Lead,
_compute_quality_score, _jaccard_similarity,
)
from llm_client import (
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
TOOL_RESULT_TAG, TOOL_RESULT_END,
)
from tool_registry import (
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
)
# ---------------------------------------------------------------------------
# Quality scoring & similarity
# ---------------------------------------------------------------------------
class TestQualityScore:
def test_full_score(self):
score = _compute_quality_score(
source_tool="list_dir",
timestamp="2024-01-01",
raw_data={"key": "val"},
description="A" * 50,
related_ids=["ph-0"],
)
assert score == pytest.approx(1.0)
def test_minimal_score(self):
score = _compute_quality_score(
source_tool="",
timestamp=None,
raw_data={},
description="short",
related_ids=[],
)
assert score == pytest.approx(0.0)
def test_partial_score(self):
score = _compute_quality_score(
source_tool="parse_registry_key",
timestamp=None,
raw_data={},
description="A" * 50,
related_ids=[],
)
assert score == pytest.approx(0.40)
class TestJaccardSimilarity:
def test_identical(self):
assert _jaccard_similarity("hello world", "hello world") == pytest.approx(1.0)
def test_no_overlap(self):
assert _jaccard_similarity("hello world", "foo bar") == pytest.approx(0.0)
def test_partial_overlap(self):
sim = _jaccard_similarity("User account Mr. Evil", "User account Mr. Evil enumeration")
assert 0.5 < sim < 1.0
def test_empty_string(self):
assert _jaccard_similarity("", "hello") == pytest.approx(0.0)
# ---------------------------------------------------------------------------
# Evidence graph: phenomenon dedup
# ---------------------------------------------------------------------------
class TestPhenomenonDedup:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_new_phenomenon_gets_quality_score(self, graph):
pid, merged = await graph.add_phenomenon(
source_agent="fs", category="filesystem",
title="Directory listing",
description="Found important files in the root directory of the disk image",
source_tool="list_directory",
)
assert not merged
ph = graph.phenomena[pid]
assert ph.confidence == pytest.approx(0.40)
@pytest.mark.asyncio
async def test_identical_phenomenon_merges(self, graph):
pid1, m1 = await graph.add_phenomenon(
source_agent="fs", category="registry",
title="User account Mr. Evil",
description="Found user account Mr. Evil with RID 1003 in the SAM hive",
source_tool="enumerate_users",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="User account Mr. Evil",
description="User account Mr. Evil discovered with RID 1003 in SAM registry hive",
source_tool="enumerate_users",
)
assert not m1
assert m2
assert pid1 == pid2
ph = graph.phenomena[pid1]
assert "reg" in ph.corroborating_agents
assert ph.confidence > 0.40
@pytest.mark.asyncio
async def test_different_category_no_merge(self, graph):
pid1, _ = await graph.add_phenomenon(
source_agent="fs", category="filesystem",
title="User account Mr. Evil",
description="Found user account Mr. Evil with RID 1003",
source_tool="list_directory",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="User account Mr. Evil",
description="Found user account Mr. Evil with RID 1003",
source_tool="enumerate_users",
)
assert not m2
assert pid1 != pid2
@pytest.mark.asyncio
async def test_dissimilar_titles_no_merge(self, graph):
await graph.add_phenomenon(
source_agent="fs", category="registry",
title="Installed software list",
description="Found 25 installed programs",
source_tool="list_installed_software",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="Last shutdown time",
description="System last shut down at 2004-08-27",
source_tool="get_shutdown_time",
)
assert not m2
@pytest.mark.asyncio
async def test_merge_adds_raw_data(self, graph):
pid1, _ = await graph.add_phenomenon(
source_agent="fs", category="registry",
title="Network config",
description="Network interface configuration with IP address and MAC address details",
raw_data={"ip": "192.168.1.1"},
source_tool="get_network_interfaces",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="Network config",
description="Network interface configuration including DHCP and MAC address info",
raw_data={"mac": "00:11:22:33:44:55"},
source_tool="get_network_interfaces",
)
assert m2
ph = graph.phenomena[pid1]
assert "ip" in ph.raw_data
assert "mac" in ph.raw_data
# ---------------------------------------------------------------------------
# Hypothesis confidence updates
# ---------------------------------------------------------------------------
class TestHypothesisConfidence:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_supports_increases_confidence(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
old = graph.hypotheses[hid].confidence
new = await graph.update_hypothesis_confidence(hid, pid, "supports", "reason")
assert new > old
@pytest.mark.asyncio
async def test_contradicts_decreases_confidence(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
old = graph.hypotheses[hid].confidence
new = await graph.update_hypothesis_confidence(hid, pid, "contradicts", "reason")
assert new < old
@pytest.mark.asyncio
async def test_direct_evidence_has_largest_effect(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
conf = await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "reason")
# direct_evidence weight is +0.25 * (1-0.5) = +0.125
assert conf == pytest.approx(0.625)
@pytest.mark.asyncio
async def test_confidence_log_tracked(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
await graph.update_hypothesis_confidence(hid, pid, "supports", "because reasons")
log = graph.hypotheses[hid].confidence_log
assert len(log) == 1
assert log[0]["edge_type"] == "supports"
assert log[0]["reason"] == "because reasons"
@pytest.mark.asyncio
async def test_supported_status_at_threshold(self, graph):
hid = await graph.add_hypothesis("test", "desc", created_by="test")
# Push confidence up with multiple direct_evidence links
for i in range(5):
pid, _ = await graph.add_phenomenon("fs", "filesystem", f"test {i}", f"desc {i}", source_tool="t")
await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "proof")
assert graph.hypotheses[hid].status == "supported"
# ---------------------------------------------------------------------------
# Context window management
# ---------------------------------------------------------------------------
class TestTruncateToolResult:
def test_short_text_unchanged(self):
text = "short result"
assert _truncate_tool_result(text) == text
def test_long_text_truncated(self):
text = "x" * 5000
result = _truncate_tool_result(text, max_chars=3000)
assert len(result) < 3100
assert "truncated" in result
assert "5000" in result
def test_exact_boundary(self):
text = "x" * 3000
assert _truncate_tool_result(text, max_chars=3000) == text
# ---------------------------------------------------------------------------
# Parallel dispatch & async safety
# ---------------------------------------------------------------------------
class TestGraphAsync:
@pytest.mark.asyncio
async def test_concurrent_add_phenomenon(self):
graph = EvidenceGraph()
async def add_batch(agent_name: str, count: int):
for i in range(count):
await graph.add_phenomenon(
source_agent=agent_name,
category=f"cat-{agent_name}",
title=f"Phenomenon {agent_name}-{i}",
description=f"Description from {agent_name} number {i} with enough detail to be meaningful",
source_tool="test_tool",
)
await asyncio.gather(
add_batch("agent_a", 10),
add_batch("agent_b", 10),
add_batch("agent_c", 10),
)
assert len(graph.phenomena) == 30
@pytest.mark.asyncio
async def test_concurrent_add_lead(self):
graph = EvidenceGraph()
async def add_leads(count: int):
for i in range(count):
await graph.add_lead(target_agent="network", description=f"Lead {i}")
await asyncio.gather(add_leads(10), add_leads(10))
pending = await graph.get_pending_leads()
assert len(pending) == 20
@pytest.mark.asyncio
async def test_mark_lead_completed(self):
graph = EvidenceGraph()
lid = await graph.add_lead(target_agent="fs", description="test lead")
pending = await graph.get_pending_leads()
assert len(pending) == 1
await graph.mark_lead_completed(lid)
pending = await graph.get_pending_leads()
assert len(pending) == 0
class TestParallelDispatch:
@pytest.mark.asyncio
async def test_different_agent_types_run_concurrently(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
execution_log: list[tuple[str, float, float]] = []
class FakeAgent:
def __init__(self, name):
self.name = name
async def run(self, task, lead_id=None):
start = time.monotonic()
await asyncio.sleep(0.05)
end = time.monotonic()
execution_log.append((self.name, start, end))
factory._cache = {
"filesystem": FakeAgent("filesystem"),
"registry": FakeAgent("registry"),
"network": FakeAgent("network"),
}
leads = [
Lead(id="l1", target_agent="filesystem", description="fs task"),
Lead(id="l2", target_agent="registry", description="reg task"),
Lead(id="l3", target_agent="network", description="net task"),
]
graph.leads = leads
start_time = time.monotonic()
await orch._dispatch_leads_parallel(leads)
total_time = time.monotonic() - start_time
assert total_time < 0.12, f"Expected parallel execution but took {total_time:.3f}s"
assert len(execution_log) == 3
@pytest.mark.asyncio
async def test_same_agent_type_runs_serially(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
execution_order: list[str] = []
class FakeAgent:
def __init__(self, name):
self.name = name
async def run(self, task, lead_id=None):
execution_order.append(task)
await asyncio.sleep(0.01)
factory._cache = {"filesystem": FakeAgent("filesystem")}
leads = [
Lead(id="l1", target_agent="filesystem", description="task 1"),
Lead(id="l2", target_agent="filesystem", description="task 2"),
]
graph.leads = leads
await orch._dispatch_leads_parallel(leads)
assert len(execution_order) == 2
assert "task 1" in execution_order[0]
assert "task 2" in execution_order[1]
# ---------------------------------------------------------------------------
# Asset library
# ---------------------------------------------------------------------------
class TestAssetLibrary:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_register_asset(self, graph):
aid, existed = await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="filesystem",
)
assert not existed
assert aid.startswith("asset-")
assert len(graph.asset_library) == 1
@pytest.mark.asyncio
async def test_dedup_by_inode(self, graph):
aid1, _ = await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="filesystem",
)
aid2, existed = await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="registry",
)
assert existed
assert aid1 == aid2
assert len(graph.asset_library) == 1
@pytest.mark.asyncio
async def test_query_by_category(self, graph):
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=100, extracted_by="fs",
)
await graph.register_asset(
inode="10090-128-1", original_path="mIRC/logs/hackers.log",
local_path="extracted/hackers.log", category="chat_log",
filename="hackers.log", size_bytes=200, extracted_by="comm",
)
results = graph.query_assets(category="registry_hive")
assert len(results) == 1
assert results[0].filename == "system"
@pytest.mark.asyncio
async def test_lookup_by_inode(self, graph):
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=100, extracted_by="fs",
)
asset = graph.lookup_asset_by_inode("334-128-4")
assert asset is not None
assert asset.local_path == "extracted/system"
assert graph.lookup_asset_by_inode("999-128-4") is None
@pytest.mark.asyncio
async def test_persistence_round_trip(self, graph, tmp_path):
await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="filesystem",
)
path = tmp_path / "test_state.json"
graph.save_state(path)
loaded = EvidenceGraph.load_state(path)
assert len(loaded.asset_library) == 1
assert loaded.lookup_asset_by_inode("334-128-4") is not None
assert loaded.lookup_asset_by_inode("334-128-4").category == "registry_hive"
@pytest.mark.asyncio
async def test_list_assets(self, graph):
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=8192, extracted_by="fs",
)
summaries = graph.list_assets()
assert len(summaries) == 1
assert "system" in summaries[0]
assert "registry_hive" in summaries[0]
@pytest.mark.asyncio
async def test_query_by_filename(self, graph):
await graph.register_asset(
inode="10080-128-3", original_path="mIRC/mirc.ini",
local_path="extracted/mirc.ini", category="config_file",
filename="mirc.ini", size_bytes=500, extracted_by="comm",
)
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=100, extracted_by="fs",
)
results = graph.query_assets(filename_pattern="mirc")
assert len(results) == 1
assert results[0].filename == "mirc.ini"
# ---------------------------------------------------------------------------
# Tool call partitioning (P0: parallel tool execution)
# ---------------------------------------------------------------------------
class TestPartitionToolCalls:
def test_consecutive_read_only_grouped(self):
calls = [
{"name": "list_directory", "arguments": {"inode": "33"}},
{"name": "list_directory", "arguments": {"inode": "46"}},
{"name": "list_directory", "arguments": {"inode": "50"}},
]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert batches[0].is_read_only
assert len(batches[0].calls) == 3
def test_write_tools_isolated(self):
calls = [
{"name": "extract_file", "arguments": {"inode": "40"}},
{"name": "extract_file", "arguments": {"inode": "41"}},
]
batches = _partition_tool_calls(calls)
assert len(batches) == 2
assert not batches[0].is_read_only
assert not batches[1].is_read_only
def test_mixed_read_write_partitioned(self):
calls = [
{"name": "list_directory", "arguments": {}},
{"name": "search_graph", "arguments": {"keyword": "test"}},
{"name": "add_phenomenon", "arguments": {"title": "x"}},
{"name": "list_phenomena", "arguments": {}},
]
batches = _partition_tool_calls(calls)
assert len(batches) == 3
assert batches[0].is_read_only and len(batches[0].calls) == 2
assert not batches[1].is_read_only and len(batches[1].calls) == 1
assert batches[2].is_read_only and len(batches[2].calls) == 1
def test_empty_list(self):
assert _partition_tool_calls([]) == []
def test_single_read_only(self):
calls = [{"name": "partition_info", "arguments": {}}]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert batches[0].is_read_only
def test_single_write(self):
calls = [{"name": "add_lead", "arguments": {}}]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert not batches[0].is_read_only
def test_unknown_tool_treated_as_write(self):
calls = [{"name": "unknown_tool", "arguments": {}}]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert not batches[0].is_read_only
def test_custom_read_only_set(self):
calls = [
{"name": "foo", "arguments": {}},
{"name": "bar", "arguments": {}},
]
batches = _partition_tool_calls(calls, read_only={"foo", "bar"})
assert len(batches) == 1
assert batches[0].is_read_only
class TestParallelToolExecution:
@pytest.mark.asyncio
async def test_parallel_faster_than_serial(self):
"""Read-only tools should execute concurrently."""
from llm_client import LLMClient
async def slow_tool(**kwargs):
await asyncio.sleep(0.05)
return "ok"
tool_executor = {
"list_directory": slow_tool,
"search_graph": slow_tool,
"list_phenomena": slow_tool,
}
client = LLMClient.__new__(LLMClient) # skip __init__
calls = [
{"name": "list_directory", "arguments": {}},
{"name": "search_graph", "arguments": {"keyword": "x"}},
{"name": "list_phenomena", "arguments": {}},
]
start = time.monotonic()
results = await client._execute_tool_batch_parallel(calls, tool_executor)
elapsed = time.monotonic() - start
assert len(results) == 3
assert all("ok" in r for r in results)
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
@pytest.mark.asyncio
async def test_error_in_parallel_batch_does_not_crash(self):
"""A failing tool in a parallel batch should produce an error result, not crash."""
from llm_client import LLMClient
async def ok_tool(**kwargs):
return "success"
async def bad_tool(**kwargs):
raise ValueError("broken")
tool_executor = {
"list_directory": ok_tool,
"search_graph": bad_tool,
}
client = LLMClient.__new__(LLMClient)
calls = [
{"name": "list_directory", "arguments": {}},
{"name": "search_graph", "arguments": {"keyword": "x"}},
]
results = await client._execute_tool_batch_parallel(calls, tool_executor)
assert len(results) == 2
assert "success" in results[0]
assert "Error" in results[1]
# ---------------------------------------------------------------------------
# Batched orchestrator LLM calls (P0: batch lead gen + judging)
# ---------------------------------------------------------------------------
class TestBatchLeadGeneration:
@pytest.mark.asyncio
async def test_batched_leads_created_with_hypothesis_id(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
# Set up 2 hypotheses
hid1 = await graph.add_hypothesis("Wardriving", "Wireless scanning", created_by="test")
hid2 = await graph.add_hypothesis("Data theft", "Stolen files", created_by="test")
# Mock LLM to return batched JSON
llm.chat.return_value = (
f'[{{"hypothesis_id": "{hid1}", "agent": "network", "task": "analyze pcap", "priority": 3}},'
f' {{"hypothesis_id": "{hid2}", "agent": "filesystem", "task": "check deleted files", "priority": 5}}]'
)
await orch._generate_hypothesis_leads()
# Should have made exactly 1 LLM call
assert llm.chat.call_count == 1
# Should have created 2 leads
pending = await graph.get_pending_leads()
assert len(pending) == 2
agents = {l.target_agent for l in pending}
assert "network" in agents
assert "filesystem" in agents
@pytest.mark.asyncio
async def test_fallback_on_parse_error(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
hid = await graph.add_hypothesis("Test", "Test desc", created_by="test")
# First call (batched) returns invalid JSON; fallback calls succeed
llm.chat.side_effect = [
"not valid json at all", # batched call fails
'[{"agent": "registry", "task": "check SAM", "priority": 5}]', # fallback
]
await orch._generate_hypothesis_leads()
# Should have called LLM twice: 1 batched (failed) + 1 fallback
assert llm.chat.call_count == 2
pending = await graph.get_pending_leads()
assert len(pending) == 1
class TestBatchJudging:
@pytest.mark.asyncio
async def test_batched_judgments_applied(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
hid1 = await graph.add_hypothesis("Hyp A", "Desc A", created_by="test")
hid2 = await graph.add_hypothesis("Hyp B", "Desc B", created_by="test")
pid1, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 1", "details 1", source_tool="t")
pid2, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 2", "details 2", source_tool="t")
old_conf1 = graph.hypotheses[hid1].confidence
old_conf2 = graph.hypotheses[hid2].confidence
llm.chat.return_value = (
f'[{{"hypothesis_id": "{hid1}", "phenomenon_id": "{pid1}", "edge_type": "supports", "reason": "evidence for A"}},'
f' {{"hypothesis_id": "{hid2}", "phenomenon_id": "{pid2}", "edge_type": "contradicts", "reason": "disproves B"}}]'
)
await orch._judge_new_phenomena()
assert llm.chat.call_count == 1
assert graph.hypotheses[hid1].confidence > old_conf1
assert graph.hypotheses[hid2].confidence < old_conf2
@pytest.mark.asyncio
async def test_invalid_ids_skipped(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
hid = await graph.add_hypothesis("Hyp", "Desc", created_by="test")
pid, _ = await graph.add_phenomenon("fs", "filesystem", "Finding", "details", source_tool="t")
old_conf = graph.hypotheses[hid].confidence
llm.chat.return_value = (
f'[{{"hypothesis_id": "hyp-nonexistent", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "bad ref"}},'
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "ph-nonexistent", "edge_type": "supports", "reason": "bad ref"}},'
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "good"}}]'
)
await orch._judge_new_phenomena()
# Only the valid judgment should have been applied
assert graph.hypotheses[hid].confidence > old_conf
# Only 1 edge created (the valid one)
hyp_edges = [e for e in graph.edges if e.target_id == hid and e.source_id == pid]
assert len(hyp_edges) == 1
# ---------------------------------------------------------------------------
# Tool result caching
# ---------------------------------------------------------------------------
class TestToolResultCache:
def setup_method(self):
_tool_result_cache.clear()
def test_cache_key_deterministic(self):
k1 = _cache_key("list_directory", {"inode": "33"})
k2 = _cache_key("list_directory", {"inode": "33"})
assert k1 == k2
def test_cache_key_differs_by_args(self):
k1 = _cache_key("list_directory", {"inode": "33"})
k2 = _cache_key("list_directory", {"inode": "46"})
assert k1 != k2
def test_cache_key_differs_by_tool(self):
k1 = _cache_key("list_directory", {"inode": "33"})
k2 = _cache_key("find_file", {"inode": "33"})
assert k1 != k2
@pytest.mark.asyncio
async def test_cached_executor_returns_cached_result(self):
call_count = 0
async def real_executor(**kwargs):
nonlocal call_count
call_count += 1
return f"result for {kwargs}"
cached = _make_cached("test_tool", real_executor)
r1 = await cached(inode="33")
r2 = await cached(inode="33")
assert r1 == r2
assert call_count == 1 # only called once
@pytest.mark.asyncio
async def test_different_args_not_cached(self):
call_count = 0
async def real_executor(**kwargs):
nonlocal call_count
call_count += 1
return f"result-{call_count}"
cached = _make_cached("test_tool", real_executor)
r1 = await cached(inode="33")
r2 = await cached(inode="46")
assert r1 != r2
assert call_count == 2
@pytest.mark.asyncio
async def test_error_results_not_cached(self):
call_count = 0
async def flaky(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return "Error: something went wrong"
return "success"
cached = _make_cached("test_tool", flaky)
r1 = await cached(x=1)
assert "Error" in r1
r2 = await cached(x=1)
assert r2 == "success"
assert call_count == 2 # called twice because error wasn't cached
def test_cacheable_tools_are_read_only(self):
"""Every cacheable tool should also be in READ_ONLY_TOOLS."""
for t in CACHEABLE_TOOLS:
assert t in READ_ONLY_TOOLS, f"{t} is cacheable but not in READ_ONLY_TOOLS"
# ---------------------------------------------------------------------------
# Progressive context decay (Stage A)
# ---------------------------------------------------------------------------
class TestProgressiveDecay:
def _make_messages(self, n_rounds: int) -> list[dict]:
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
messages = [{"role": "user", "content": "Start task"}]
for i in range(n_rounds):
messages.append({
"role": "assistant",
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
})
# Tool result message with substantial content
messages.append({
"role": "user",
"content": (
f"{TOOL_RESULT_TAG}\n"
f"[tool_{i}] {'x' * 2500}\n"
f"{TOOL_RESULT_END}"
),
})
return messages
def test_short_conversation_unchanged(self):
msgs = self._make_messages(3)
result = _apply_progressive_decay(msgs)
assert len(result) == len(msgs)
# Content should be identical for short conversations
for orig, decayed in zip(msgs, result):
assert orig["content"] == decayed["content"]
def test_old_messages_truncated(self):
msgs = self._make_messages(20)
result = _apply_progressive_decay(msgs)
# Recent tool results should be full length
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
assert len(last_tool_result["content"]) > 2000
# Oldest tool results should be truncated
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
assert len(first_tool_result["content"]) < 500
def test_message_count_preserved(self):
msgs = self._make_messages(20)
result = _apply_progressive_decay(msgs)
assert len(result) == len(msgs)
def test_non_tool_messages_unchanged(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
result = _apply_progressive_decay(msgs)
assert result[0]["content"] == "Hello"
assert result[1]["content"] == "Hi there"
# ---------------------------------------------------------------------------
# LLM message folding (Stage B)
# ---------------------------------------------------------------------------
class TestMessageFolding:
@pytest.mark.asyncio
async def test_fold_replaces_old_messages_with_summary(self):
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock(return_value="Summary: found important files")
# Build messages exceeding fold threshold
messages = [{"role": "user", "content": "Start task"}]
for i in range(30):
messages.append({"role": "assistant", "content": f"thinking step {i}"})
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
result = await client._fold_old_messages(messages, "system prompt")
# Should be significantly shorter
assert len(result) < len(messages)
# First message should be the summary
assert "Context summary" in result[0]["content"]
# Recent messages preserved
assert len(result) == _FOLD_KEEP_RECENT + 1 # +1 for summary
@pytest.mark.asyncio
async def test_fold_not_triggered_below_threshold(self):
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock()
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
result = await client._fold_old_messages(messages, "system")
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
assert result == messages
client.chat.assert_not_called()
@pytest.mark.asyncio
async def test_fold_graceful_on_llm_failure(self):
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock(side_effect=Exception("API error"))
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
result = await client._fold_old_messages(messages, "system")
# On failure, should return original messages unchanged
assert len(result) == 40