diff --git a/tests/test_optimizations.py b/tests/test_optimizations.py new file mode 100644 index 0000000..e9a7530 --- /dev/null +++ b/tests/test_optimizations.py @@ -0,0 +1,957 @@ +"""Tests for dedup/quality scoring, context management, parallel dispatch, and evidence graph.""" + +from __future__ import annotations + +import asyncio +import time + +import pytest + +from evidence_graph import ( + EvidenceGraph, Phenomenon, Lead, + _compute_quality_score, _jaccard_similarity, +) +from llm_client import ( + _truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS, + _apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT, + TOOL_RESULT_TAG, TOOL_RESULT_END, +) +from tool_registry import ( + _tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS, +) + + +# --------------------------------------------------------------------------- +# Quality scoring & similarity +# --------------------------------------------------------------------------- + +class TestQualityScore: + def test_full_score(self): + score = _compute_quality_score( + source_tool="list_dir", + timestamp="2024-01-01", + raw_data={"key": "val"}, + description="A" * 50, + related_ids=["ph-0"], + ) + assert score == pytest.approx(1.0) + + def test_minimal_score(self): + score = _compute_quality_score( + source_tool="", + timestamp=None, + raw_data={}, + description="short", + related_ids=[], + ) + assert score == pytest.approx(0.0) + + def test_partial_score(self): + score = _compute_quality_score( + source_tool="parse_registry_key", + timestamp=None, + raw_data={}, + description="A" * 50, + related_ids=[], + ) + assert score == pytest.approx(0.40) + + +class TestJaccardSimilarity: + def test_identical(self): + assert _jaccard_similarity("hello world", "hello world") == pytest.approx(1.0) + + def test_no_overlap(self): + assert _jaccard_similarity("hello world", "foo bar") == pytest.approx(0.0) + + def test_partial_overlap(self): + sim = _jaccard_similarity("User account Mr. Evil", "User account Mr. Evil enumeration") + assert 0.5 < sim < 1.0 + + def test_empty_string(self): + assert _jaccard_similarity("", "hello") == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# Evidence graph: phenomenon dedup +# --------------------------------------------------------------------------- + +class TestPhenomenonDedup: + @pytest.fixture + def graph(self): + return EvidenceGraph() + + @pytest.mark.asyncio + async def test_new_phenomenon_gets_quality_score(self, graph): + pid, merged = await graph.add_phenomenon( + source_agent="fs", category="filesystem", + title="Directory listing", + description="Found important files in the root directory of the disk image", + source_tool="list_directory", + ) + assert not merged + ph = graph.phenomena[pid] + assert ph.confidence == pytest.approx(0.40) + + @pytest.mark.asyncio + async def test_identical_phenomenon_merges(self, graph): + pid1, m1 = await graph.add_phenomenon( + source_agent="fs", category="registry", + title="User account Mr. Evil", + description="Found user account Mr. Evil with RID 1003 in the SAM hive", + source_tool="enumerate_users", + ) + pid2, m2 = await graph.add_phenomenon( + source_agent="reg", category="registry", + title="User account Mr. Evil", + description="User account Mr. Evil discovered with RID 1003 in SAM registry hive", + source_tool="enumerate_users", + ) + assert not m1 + assert m2 + assert pid1 == pid2 + ph = graph.phenomena[pid1] + assert "reg" in ph.corroborating_agents + assert ph.confidence > 0.40 + + @pytest.mark.asyncio + async def test_different_category_no_merge(self, graph): + pid1, _ = await graph.add_phenomenon( + source_agent="fs", category="filesystem", + title="User account Mr. Evil", + description="Found user account Mr. Evil with RID 1003", + source_tool="list_directory", + ) + pid2, m2 = await graph.add_phenomenon( + source_agent="reg", category="registry", + title="User account Mr. Evil", + description="Found user account Mr. Evil with RID 1003", + source_tool="enumerate_users", + ) + assert not m2 + assert pid1 != pid2 + + @pytest.mark.asyncio + async def test_dissimilar_titles_no_merge(self, graph): + await graph.add_phenomenon( + source_agent="fs", category="registry", + title="Installed software list", + description="Found 25 installed programs", + source_tool="list_installed_software", + ) + pid2, m2 = await graph.add_phenomenon( + source_agent="reg", category="registry", + title="Last shutdown time", + description="System last shut down at 2004-08-27", + source_tool="get_shutdown_time", + ) + assert not m2 + + @pytest.mark.asyncio + async def test_merge_adds_raw_data(self, graph): + pid1, _ = await graph.add_phenomenon( + source_agent="fs", category="registry", + title="Network config", + description="Network interface configuration with IP address and MAC address details", + raw_data={"ip": "192.168.1.1"}, + source_tool="get_network_interfaces", + ) + pid2, m2 = await graph.add_phenomenon( + source_agent="reg", category="registry", + title="Network config", + description="Network interface configuration including DHCP and MAC address info", + raw_data={"mac": "00:11:22:33:44:55"}, + source_tool="get_network_interfaces", + ) + assert m2 + ph = graph.phenomena[pid1] + assert "ip" in ph.raw_data + assert "mac" in ph.raw_data + + +# --------------------------------------------------------------------------- +# Hypothesis confidence updates +# --------------------------------------------------------------------------- + +class TestHypothesisConfidence: + @pytest.fixture + def graph(self): + return EvidenceGraph() + + @pytest.mark.asyncio + async def test_supports_increases_confidence(self, graph): + pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t") + hid = await graph.add_hypothesis("test hyp", "desc", created_by="test") + old = graph.hypotheses[hid].confidence + new = await graph.update_hypothesis_confidence(hid, pid, "supports", "reason") + assert new > old + + @pytest.mark.asyncio + async def test_contradicts_decreases_confidence(self, graph): + pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t") + hid = await graph.add_hypothesis("test hyp", "desc", created_by="test") + old = graph.hypotheses[hid].confidence + new = await graph.update_hypothesis_confidence(hid, pid, "contradicts", "reason") + assert new < old + + @pytest.mark.asyncio + async def test_direct_evidence_has_largest_effect(self, graph): + pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t") + hid = await graph.add_hypothesis("test hyp", "desc", created_by="test") + conf = await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "reason") + # direct_evidence weight is +0.25 * (1-0.5) = +0.125 + assert conf == pytest.approx(0.625) + + @pytest.mark.asyncio + async def test_confidence_log_tracked(self, graph): + pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t") + hid = await graph.add_hypothesis("test hyp", "desc", created_by="test") + await graph.update_hypothesis_confidence(hid, pid, "supports", "because reasons") + log = graph.hypotheses[hid].confidence_log + assert len(log) == 1 + assert log[0]["edge_type"] == "supports" + assert log[0]["reason"] == "because reasons" + + @pytest.mark.asyncio + async def test_supported_status_at_threshold(self, graph): + hid = await graph.add_hypothesis("test", "desc", created_by="test") + # Push confidence up with multiple direct_evidence links + for i in range(5): + pid, _ = await graph.add_phenomenon("fs", "filesystem", f"test {i}", f"desc {i}", source_tool="t") + await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "proof") + assert graph.hypotheses[hid].status == "supported" + + +# --------------------------------------------------------------------------- +# Context window management +# --------------------------------------------------------------------------- + +class TestTruncateToolResult: + def test_short_text_unchanged(self): + text = "short result" + assert _truncate_tool_result(text) == text + + def test_long_text_truncated(self): + text = "x" * 5000 + result = _truncate_tool_result(text, max_chars=3000) + assert len(result) < 3100 + assert "truncated" in result + assert "5000" in result + + def test_exact_boundary(self): + text = "x" * 3000 + assert _truncate_tool_result(text, max_chars=3000) == text + + +# --------------------------------------------------------------------------- +# Parallel dispatch & async safety +# --------------------------------------------------------------------------- + +class TestGraphAsync: + @pytest.mark.asyncio + async def test_concurrent_add_phenomenon(self): + graph = EvidenceGraph() + + async def add_batch(agent_name: str, count: int): + for i in range(count): + await graph.add_phenomenon( + source_agent=agent_name, + category=f"cat-{agent_name}", + title=f"Phenomenon {agent_name}-{i}", + description=f"Description from {agent_name} number {i} with enough detail to be meaningful", + source_tool="test_tool", + ) + + await asyncio.gather( + add_batch("agent_a", 10), + add_batch("agent_b", 10), + add_batch("agent_c", 10), + ) + assert len(graph.phenomena) == 30 + + @pytest.mark.asyncio + async def test_concurrent_add_lead(self): + graph = EvidenceGraph() + + async def add_leads(count: int): + for i in range(count): + await graph.add_lead(target_agent="network", description=f"Lead {i}") + + await asyncio.gather(add_leads(10), add_leads(10)) + pending = await graph.get_pending_leads() + assert len(pending) == 20 + + @pytest.mark.asyncio + async def test_mark_lead_completed(self): + graph = EvidenceGraph() + lid = await graph.add_lead(target_agent="fs", description="test lead") + pending = await graph.get_pending_leads() + assert len(pending) == 1 + + await graph.mark_lead_completed(lid) + pending = await graph.get_pending_leads() + assert len(pending) == 0 + + +class TestParallelDispatch: + @pytest.mark.asyncio + async def test_different_agent_types_run_concurrently(self): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + graph = EvidenceGraph() + llm = AsyncMock() + factory = AgentFactory(llm, graph) + + orch = Orchestrator(llm, graph, factory) + + execution_log: list[tuple[str, float, float]] = [] + + class FakeAgent: + def __init__(self, name): + self.name = name + + async def run(self, task, lead_id=None): + start = time.monotonic() + await asyncio.sleep(0.05) + end = time.monotonic() + execution_log.append((self.name, start, end)) + + factory._cache = { + "filesystem": FakeAgent("filesystem"), + "registry": FakeAgent("registry"), + "network": FakeAgent("network"), + } + + leads = [ + Lead(id="l1", target_agent="filesystem", description="fs task"), + Lead(id="l2", target_agent="registry", description="reg task"), + Lead(id="l3", target_agent="network", description="net task"), + ] + graph.leads = leads + + start_time = time.monotonic() + await orch._dispatch_leads_parallel(leads) + total_time = time.monotonic() - start_time + + assert total_time < 0.12, f"Expected parallel execution but took {total_time:.3f}s" + assert len(execution_log) == 3 + + @pytest.mark.asyncio + async def test_same_agent_type_runs_serially(self): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + graph = EvidenceGraph() + llm = AsyncMock() + factory = AgentFactory(llm, graph) + orch = Orchestrator(llm, graph, factory) + + execution_order: list[str] = [] + + class FakeAgent: + def __init__(self, name): + self.name = name + + async def run(self, task, lead_id=None): + execution_order.append(task) + await asyncio.sleep(0.01) + + factory._cache = {"filesystem": FakeAgent("filesystem")} + + leads = [ + Lead(id="l1", target_agent="filesystem", description="task 1"), + Lead(id="l2", target_agent="filesystem", description="task 2"), + ] + graph.leads = leads + + await orch._dispatch_leads_parallel(leads) + + assert len(execution_order) == 2 + assert "task 1" in execution_order[0] + assert "task 2" in execution_order[1] + + +# --------------------------------------------------------------------------- +# Asset library +# --------------------------------------------------------------------------- + +class TestAssetLibrary: + @pytest.fixture + def graph(self): + return EvidenceGraph() + + @pytest.mark.asyncio + async def test_register_asset(self, graph): + aid, existed = await graph.register_asset( + inode="334-128-4", + original_path="WINDOWS/system32/config/SYSTEM", + local_path="extracted/system", + category="registry_hive", + filename="system", + size_bytes=262144, + extracted_by="filesystem", + ) + assert not existed + assert aid.startswith("asset-") + assert len(graph.asset_library) == 1 + + @pytest.mark.asyncio + async def test_dedup_by_inode(self, graph): + aid1, _ = await graph.register_asset( + inode="334-128-4", + original_path="WINDOWS/system32/config/SYSTEM", + local_path="extracted/system", + category="registry_hive", + filename="system", + size_bytes=262144, + extracted_by="filesystem", + ) + aid2, existed = await graph.register_asset( + inode="334-128-4", + original_path="WINDOWS/system32/config/SYSTEM", + local_path="extracted/system", + category="registry_hive", + filename="system", + size_bytes=262144, + extracted_by="registry", + ) + assert existed + assert aid1 == aid2 + assert len(graph.asset_library) == 1 + + @pytest.mark.asyncio + async def test_query_by_category(self, graph): + await graph.register_asset( + inode="334-128-4", original_path="config/SYSTEM", + local_path="extracted/system", category="registry_hive", + filename="system", size_bytes=100, extracted_by="fs", + ) + await graph.register_asset( + inode="10090-128-1", original_path="mIRC/logs/hackers.log", + local_path="extracted/hackers.log", category="chat_log", + filename="hackers.log", size_bytes=200, extracted_by="comm", + ) + results = graph.query_assets(category="registry_hive") + assert len(results) == 1 + assert results[0].filename == "system" + + @pytest.mark.asyncio + async def test_lookup_by_inode(self, graph): + await graph.register_asset( + inode="334-128-4", original_path="config/SYSTEM", + local_path="extracted/system", category="registry_hive", + filename="system", size_bytes=100, extracted_by="fs", + ) + asset = graph.lookup_asset_by_inode("334-128-4") + assert asset is not None + assert asset.local_path == "extracted/system" + + assert graph.lookup_asset_by_inode("999-128-4") is None + + @pytest.mark.asyncio + async def test_persistence_round_trip(self, graph, tmp_path): + await graph.register_asset( + inode="334-128-4", + original_path="WINDOWS/system32/config/SYSTEM", + local_path="extracted/system", + category="registry_hive", + filename="system", + size_bytes=262144, + extracted_by="filesystem", + ) + path = tmp_path / "test_state.json" + graph.save_state(path) + + loaded = EvidenceGraph.load_state(path) + assert len(loaded.asset_library) == 1 + assert loaded.lookup_asset_by_inode("334-128-4") is not None + assert loaded.lookup_asset_by_inode("334-128-4").category == "registry_hive" + + @pytest.mark.asyncio + async def test_list_assets(self, graph): + await graph.register_asset( + inode="334-128-4", original_path="config/SYSTEM", + local_path="extracted/system", category="registry_hive", + filename="system", size_bytes=8192, extracted_by="fs", + ) + summaries = graph.list_assets() + assert len(summaries) == 1 + assert "system" in summaries[0] + assert "registry_hive" in summaries[0] + + @pytest.mark.asyncio + async def test_query_by_filename(self, graph): + await graph.register_asset( + inode="10080-128-3", original_path="mIRC/mirc.ini", + local_path="extracted/mirc.ini", category="config_file", + filename="mirc.ini", size_bytes=500, extracted_by="comm", + ) + await graph.register_asset( + inode="334-128-4", original_path="config/SYSTEM", + local_path="extracted/system", category="registry_hive", + filename="system", size_bytes=100, extracted_by="fs", + ) + results = graph.query_assets(filename_pattern="mirc") + assert len(results) == 1 + assert results[0].filename == "mirc.ini" + + +# --------------------------------------------------------------------------- +# Tool call partitioning (P0: parallel tool execution) +# --------------------------------------------------------------------------- + +class TestPartitionToolCalls: + def test_consecutive_read_only_grouped(self): + calls = [ + {"name": "list_directory", "arguments": {"inode": "33"}}, + {"name": "list_directory", "arguments": {"inode": "46"}}, + {"name": "list_directory", "arguments": {"inode": "50"}}, + ] + batches = _partition_tool_calls(calls) + assert len(batches) == 1 + assert batches[0].is_read_only + assert len(batches[0].calls) == 3 + + def test_write_tools_isolated(self): + calls = [ + {"name": "extract_file", "arguments": {"inode": "40"}}, + {"name": "extract_file", "arguments": {"inode": "41"}}, + ] + batches = _partition_tool_calls(calls) + assert len(batches) == 2 + assert not batches[0].is_read_only + assert not batches[1].is_read_only + + def test_mixed_read_write_partitioned(self): + calls = [ + {"name": "list_directory", "arguments": {}}, + {"name": "search_graph", "arguments": {"keyword": "test"}}, + {"name": "add_phenomenon", "arguments": {"title": "x"}}, + {"name": "list_phenomena", "arguments": {}}, + ] + batches = _partition_tool_calls(calls) + assert len(batches) == 3 + assert batches[0].is_read_only and len(batches[0].calls) == 2 + assert not batches[1].is_read_only and len(batches[1].calls) == 1 + assert batches[2].is_read_only and len(batches[2].calls) == 1 + + def test_empty_list(self): + assert _partition_tool_calls([]) == [] + + def test_single_read_only(self): + calls = [{"name": "partition_info", "arguments": {}}] + batches = _partition_tool_calls(calls) + assert len(batches) == 1 + assert batches[0].is_read_only + + def test_single_write(self): + calls = [{"name": "add_lead", "arguments": {}}] + batches = _partition_tool_calls(calls) + assert len(batches) == 1 + assert not batches[0].is_read_only + + def test_unknown_tool_treated_as_write(self): + calls = [{"name": "unknown_tool", "arguments": {}}] + batches = _partition_tool_calls(calls) + assert len(batches) == 1 + assert not batches[0].is_read_only + + def test_custom_read_only_set(self): + calls = [ + {"name": "foo", "arguments": {}}, + {"name": "bar", "arguments": {}}, + ] + batches = _partition_tool_calls(calls, read_only={"foo", "bar"}) + assert len(batches) == 1 + assert batches[0].is_read_only + + +class TestParallelToolExecution: + @pytest.mark.asyncio + async def test_parallel_faster_than_serial(self): + """Read-only tools should execute concurrently.""" + from llm_client import LLMClient + + async def slow_tool(**kwargs): + await asyncio.sleep(0.05) + return "ok" + + tool_executor = { + "list_directory": slow_tool, + "search_graph": slow_tool, + "list_phenomena": slow_tool, + } + + client = LLMClient.__new__(LLMClient) # skip __init__ + + calls = [ + {"name": "list_directory", "arguments": {}}, + {"name": "search_graph", "arguments": {"keyword": "x"}}, + {"name": "list_phenomena", "arguments": {}}, + ] + + start = time.monotonic() + results = await client._execute_tool_batch_parallel(calls, tool_executor) + elapsed = time.monotonic() - start + + assert len(results) == 3 + assert all("ok" in r for r in results) + # 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial + assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_error_in_parallel_batch_does_not_crash(self): + """A failing tool in a parallel batch should produce an error result, not crash.""" + from llm_client import LLMClient + + async def ok_tool(**kwargs): + return "success" + + async def bad_tool(**kwargs): + raise ValueError("broken") + + tool_executor = { + "list_directory": ok_tool, + "search_graph": bad_tool, + } + + client = LLMClient.__new__(LLMClient) + + calls = [ + {"name": "list_directory", "arguments": {}}, + {"name": "search_graph", "arguments": {"keyword": "x"}}, + ] + + results = await client._execute_tool_batch_parallel(calls, tool_executor) + assert len(results) == 2 + assert "success" in results[0] + assert "Error" in results[1] + + +# --------------------------------------------------------------------------- +# Batched orchestrator LLM calls (P0: batch lead gen + judging) +# --------------------------------------------------------------------------- + +class TestBatchLeadGeneration: + @pytest.mark.asyncio + async def test_batched_leads_created_with_hypothesis_id(self): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + graph = EvidenceGraph() + llm = AsyncMock() + factory = AgentFactory(llm, graph) + orch = Orchestrator(llm, graph, factory) + + # Set up 2 hypotheses + hid1 = await graph.add_hypothesis("Wardriving", "Wireless scanning", created_by="test") + hid2 = await graph.add_hypothesis("Data theft", "Stolen files", created_by="test") + + # Mock LLM to return batched JSON + llm.chat.return_value = ( + f'[{{"hypothesis_id": "{hid1}", "agent": "network", "task": "analyze pcap", "priority": 3}},' + f' {{"hypothesis_id": "{hid2}", "agent": "filesystem", "task": "check deleted files", "priority": 5}}]' + ) + + await orch._generate_hypothesis_leads() + + # Should have made exactly 1 LLM call + assert llm.chat.call_count == 1 + # Should have created 2 leads + pending = await graph.get_pending_leads() + assert len(pending) == 2 + agents = {l.target_agent for l in pending} + assert "network" in agents + assert "filesystem" in agents + + @pytest.mark.asyncio + async def test_fallback_on_parse_error(self): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + graph = EvidenceGraph() + llm = AsyncMock() + factory = AgentFactory(llm, graph) + orch = Orchestrator(llm, graph, factory) + + hid = await graph.add_hypothesis("Test", "Test desc", created_by="test") + + # First call (batched) returns invalid JSON; fallback calls succeed + llm.chat.side_effect = [ + "not valid json at all", # batched call fails + '[{"agent": "registry", "task": "check SAM", "priority": 5}]', # fallback + ] + + await orch._generate_hypothesis_leads() + + # Should have called LLM twice: 1 batched (failed) + 1 fallback + assert llm.chat.call_count == 2 + pending = await graph.get_pending_leads() + assert len(pending) == 1 + + +class TestBatchJudging: + @pytest.mark.asyncio + async def test_batched_judgments_applied(self): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + graph = EvidenceGraph() + llm = AsyncMock() + factory = AgentFactory(llm, graph) + orch = Orchestrator(llm, graph, factory) + + hid1 = await graph.add_hypothesis("Hyp A", "Desc A", created_by="test") + hid2 = await graph.add_hypothesis("Hyp B", "Desc B", created_by="test") + pid1, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 1", "details 1", source_tool="t") + pid2, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 2", "details 2", source_tool="t") + + old_conf1 = graph.hypotheses[hid1].confidence + old_conf2 = graph.hypotheses[hid2].confidence + + llm.chat.return_value = ( + f'[{{"hypothesis_id": "{hid1}", "phenomenon_id": "{pid1}", "edge_type": "supports", "reason": "evidence for A"}},' + f' {{"hypothesis_id": "{hid2}", "phenomenon_id": "{pid2}", "edge_type": "contradicts", "reason": "disproves B"}}]' + ) + + await orch._judge_new_phenomena() + + assert llm.chat.call_count == 1 + assert graph.hypotheses[hid1].confidence > old_conf1 + assert graph.hypotheses[hid2].confidence < old_conf2 + + @pytest.mark.asyncio + async def test_invalid_ids_skipped(self): + from unittest.mock import AsyncMock + from orchestrator import Orchestrator + from agent_factory import AgentFactory + + graph = EvidenceGraph() + llm = AsyncMock() + factory = AgentFactory(llm, graph) + orch = Orchestrator(llm, graph, factory) + + hid = await graph.add_hypothesis("Hyp", "Desc", created_by="test") + pid, _ = await graph.add_phenomenon("fs", "filesystem", "Finding", "details", source_tool="t") + + old_conf = graph.hypotheses[hid].confidence + + llm.chat.return_value = ( + f'[{{"hypothesis_id": "hyp-nonexistent", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "bad ref"}},' + f' {{"hypothesis_id": "{hid}", "phenomenon_id": "ph-nonexistent", "edge_type": "supports", "reason": "bad ref"}},' + f' {{"hypothesis_id": "{hid}", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "good"}}]' + ) + + await orch._judge_new_phenomena() + + # Only the valid judgment should have been applied + assert graph.hypotheses[hid].confidence > old_conf + # Only 1 edge created (the valid one) + hyp_edges = [e for e in graph.edges if e.target_id == hid and e.source_id == pid] + assert len(hyp_edges) == 1 + + +# --------------------------------------------------------------------------- +# Tool result caching +# --------------------------------------------------------------------------- + +class TestToolResultCache: + def setup_method(self): + _tool_result_cache.clear() + + def test_cache_key_deterministic(self): + k1 = _cache_key("list_directory", {"inode": "33"}) + k2 = _cache_key("list_directory", {"inode": "33"}) + assert k1 == k2 + + def test_cache_key_differs_by_args(self): + k1 = _cache_key("list_directory", {"inode": "33"}) + k2 = _cache_key("list_directory", {"inode": "46"}) + assert k1 != k2 + + def test_cache_key_differs_by_tool(self): + k1 = _cache_key("list_directory", {"inode": "33"}) + k2 = _cache_key("find_file", {"inode": "33"}) + assert k1 != k2 + + @pytest.mark.asyncio + async def test_cached_executor_returns_cached_result(self): + call_count = 0 + + async def real_executor(**kwargs): + nonlocal call_count + call_count += 1 + return f"result for {kwargs}" + + cached = _make_cached("test_tool", real_executor) + + r1 = await cached(inode="33") + r2 = await cached(inode="33") + assert r1 == r2 + assert call_count == 1 # only called once + + @pytest.mark.asyncio + async def test_different_args_not_cached(self): + call_count = 0 + + async def real_executor(**kwargs): + nonlocal call_count + call_count += 1 + return f"result-{call_count}" + + cached = _make_cached("test_tool", real_executor) + + r1 = await cached(inode="33") + r2 = await cached(inode="46") + assert r1 != r2 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_error_results_not_cached(self): + call_count = 0 + + async def flaky(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "Error: something went wrong" + return "success" + + cached = _make_cached("test_tool", flaky) + + r1 = await cached(x=1) + assert "Error" in r1 + r2 = await cached(x=1) + assert r2 == "success" + assert call_count == 2 # called twice because error wasn't cached + + def test_cacheable_tools_are_read_only(self): + """Every cacheable tool should also be in READ_ONLY_TOOLS.""" + for t in CACHEABLE_TOOLS: + assert t in READ_ONLY_TOOLS, f"{t} is cacheable but not in READ_ONLY_TOOLS" + + +# --------------------------------------------------------------------------- +# Progressive context decay (Stage A) +# --------------------------------------------------------------------------- + +class TestProgressiveDecay: + def _make_messages(self, n_rounds: int) -> list[dict]: + """Build a synthetic message list with n_rounds of (assistant, user) pairs.""" + messages = [{"role": "user", "content": "Start task"}] + for i in range(n_rounds): + messages.append({ + "role": "assistant", + "content": f"{{'name': 'tool_{i}'}}", + }) + # Tool result message with substantial content + messages.append({ + "role": "user", + "content": ( + f"{TOOL_RESULT_TAG}\n" + f"[tool_{i}] {'x' * 2500}\n" + f"{TOOL_RESULT_END}" + ), + }) + return messages + + def test_short_conversation_unchanged(self): + msgs = self._make_messages(3) + result = _apply_progressive_decay(msgs) + assert len(result) == len(msgs) + # Content should be identical for short conversations + for orig, decayed in zip(msgs, result): + assert orig["content"] == decayed["content"] + + def test_old_messages_truncated(self): + msgs = self._make_messages(20) + result = _apply_progressive_decay(msgs) + + # Recent tool results should be full length + last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1] + assert len(last_tool_result["content"]) > 2000 + + # Oldest tool results should be truncated + first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0] + assert len(first_tool_result["content"]) < 500 + + def test_message_count_preserved(self): + msgs = self._make_messages(20) + result = _apply_progressive_decay(msgs) + assert len(result) == len(msgs) + + def test_non_tool_messages_unchanged(self): + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + result = _apply_progressive_decay(msgs) + assert result[0]["content"] == "Hello" + assert result[1]["content"] == "Hi there" + + +# --------------------------------------------------------------------------- +# LLM message folding (Stage B) +# --------------------------------------------------------------------------- + +class TestMessageFolding: + @pytest.mark.asyncio + async def test_fold_replaces_old_messages_with_summary(self): + from llm_client import LLMClient + from unittest.mock import AsyncMock + + client = LLMClient.__new__(LLMClient) + client.chat = AsyncMock(return_value="Summary: found important files") + + # Build messages exceeding fold threshold + messages = [{"role": "user", "content": "Start task"}] + for i in range(30): + messages.append({"role": "assistant", "content": f"thinking step {i}"}) + messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"}) + + result = await client._fold_old_messages(messages, "system prompt") + + # Should be significantly shorter + assert len(result) < len(messages) + # First message should be the summary + assert "Context summary" in result[0]["content"] + # Recent messages preserved + assert len(result) == _FOLD_KEEP_RECENT + 1 # +1 for summary + + @pytest.mark.asyncio + async def test_fold_not_triggered_below_threshold(self): + from llm_client import LLMClient + from unittest.mock import AsyncMock + + client = LLMClient.__new__(LLMClient) + client.chat = AsyncMock() + + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + + result = await client._fold_old_messages(messages, "system") + # Should return original (n_to_fold = 2 - 10 = negative, so no folding) + assert result == messages + client.chat.assert_not_called() + + @pytest.mark.asyncio + async def test_fold_graceful_on_llm_failure(self): + from llm_client import LLMClient + from unittest.mock import AsyncMock + + client = LLMClient.__new__(LLMClient) + client.chat = AsyncMock(side_effect=Exception("API error")) + + messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)] + result = await client._fold_old_messages(messages, "system") + + # On failure, should return original messages unchanged + assert len(result) == 40