test: track tests/ directory in version control
tests/test_optimizations.py — 60 pytest cases covering: - EvidenceGraph: quality scoring, Jaccard merge, async safety, hypothesis confidence updates, asset library - llm_client: tool-result truncation, parallel batch execution, progressive context decay, message folding - orchestrator: parallel dispatch, batched lead generation, batched judging - tool_registry: result cache key derivation FakeAgent.run signatures updated to BaseAgent.run(task, lead_id=None). Previously listed in .gitignore (which is itself untracked, so the ignore rule lives only locally). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
957
tests/test_optimizations.py
Normal file
957
tests/test_optimizations.py
Normal file
@@ -0,0 +1,957 @@
|
||||
"""Tests for dedup/quality scoring, context management, parallel dispatch, and evidence graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from evidence_graph import (
|
||||
EvidenceGraph, Phenomenon, Lead,
|
||||
_compute_quality_score, _jaccard_similarity,
|
||||
)
|
||||
from llm_client import (
|
||||
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
|
||||
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
|
||||
TOOL_RESULT_TAG, TOOL_RESULT_END,
|
||||
)
|
||||
from tool_registry import (
|
||||
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quality scoring & similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQualityScore:
|
||||
def test_full_score(self):
|
||||
score = _compute_quality_score(
|
||||
source_tool="list_dir",
|
||||
timestamp="2024-01-01",
|
||||
raw_data={"key": "val"},
|
||||
description="A" * 50,
|
||||
related_ids=["ph-0"],
|
||||
)
|
||||
assert score == pytest.approx(1.0)
|
||||
|
||||
def test_minimal_score(self):
|
||||
score = _compute_quality_score(
|
||||
source_tool="",
|
||||
timestamp=None,
|
||||
raw_data={},
|
||||
description="short",
|
||||
related_ids=[],
|
||||
)
|
||||
assert score == pytest.approx(0.0)
|
||||
|
||||
def test_partial_score(self):
|
||||
score = _compute_quality_score(
|
||||
source_tool="parse_registry_key",
|
||||
timestamp=None,
|
||||
raw_data={},
|
||||
description="A" * 50,
|
||||
related_ids=[],
|
||||
)
|
||||
assert score == pytest.approx(0.40)
|
||||
|
||||
|
||||
class TestJaccardSimilarity:
|
||||
def test_identical(self):
|
||||
assert _jaccard_similarity("hello world", "hello world") == pytest.approx(1.0)
|
||||
|
||||
def test_no_overlap(self):
|
||||
assert _jaccard_similarity("hello world", "foo bar") == pytest.approx(0.0)
|
||||
|
||||
def test_partial_overlap(self):
|
||||
sim = _jaccard_similarity("User account Mr. Evil", "User account Mr. Evil enumeration")
|
||||
assert 0.5 < sim < 1.0
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _jaccard_similarity("", "hello") == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evidence graph: phenomenon dedup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPhenomenonDedup:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_phenomenon_gets_quality_score(self, graph):
|
||||
pid, merged = await graph.add_phenomenon(
|
||||
source_agent="fs", category="filesystem",
|
||||
title="Directory listing",
|
||||
description="Found important files in the root directory of the disk image",
|
||||
source_tool="list_directory",
|
||||
)
|
||||
assert not merged
|
||||
ph = graph.phenomena[pid]
|
||||
assert ph.confidence == pytest.approx(0.40)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identical_phenomenon_merges(self, graph):
|
||||
pid1, m1 = await graph.add_phenomenon(
|
||||
source_agent="fs", category="registry",
|
||||
title="User account Mr. Evil",
|
||||
description="Found user account Mr. Evil with RID 1003 in the SAM hive",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="User account Mr. Evil",
|
||||
description="User account Mr. Evil discovered with RID 1003 in SAM registry hive",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
assert not m1
|
||||
assert m2
|
||||
assert pid1 == pid2
|
||||
ph = graph.phenomena[pid1]
|
||||
assert "reg" in ph.corroborating_agents
|
||||
assert ph.confidence > 0.40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_category_no_merge(self, graph):
|
||||
pid1, _ = await graph.add_phenomenon(
|
||||
source_agent="fs", category="filesystem",
|
||||
title="User account Mr. Evil",
|
||||
description="Found user account Mr. Evil with RID 1003",
|
||||
source_tool="list_directory",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="User account Mr. Evil",
|
||||
description="Found user account Mr. Evil with RID 1003",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
assert not m2
|
||||
assert pid1 != pid2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dissimilar_titles_no_merge(self, graph):
|
||||
await graph.add_phenomenon(
|
||||
source_agent="fs", category="registry",
|
||||
title="Installed software list",
|
||||
description="Found 25 installed programs",
|
||||
source_tool="list_installed_software",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="Last shutdown time",
|
||||
description="System last shut down at 2004-08-27",
|
||||
source_tool="get_shutdown_time",
|
||||
)
|
||||
assert not m2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_adds_raw_data(self, graph):
|
||||
pid1, _ = await graph.add_phenomenon(
|
||||
source_agent="fs", category="registry",
|
||||
title="Network config",
|
||||
description="Network interface configuration with IP address and MAC address details",
|
||||
raw_data={"ip": "192.168.1.1"},
|
||||
source_tool="get_network_interfaces",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="Network config",
|
||||
description="Network interface configuration including DHCP and MAC address info",
|
||||
raw_data={"mac": "00:11:22:33:44:55"},
|
||||
source_tool="get_network_interfaces",
|
||||
)
|
||||
assert m2
|
||||
ph = graph.phenomena[pid1]
|
||||
assert "ip" in ph.raw_data
|
||||
assert "mac" in ph.raw_data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis confidence updates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHypothesisConfidence:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supports_increases_confidence(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
old = graph.hypotheses[hid].confidence
|
||||
new = await graph.update_hypothesis_confidence(hid, pid, "supports", "reason")
|
||||
assert new > old
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contradicts_decreases_confidence(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
old = graph.hypotheses[hid].confidence
|
||||
new = await graph.update_hypothesis_confidence(hid, pid, "contradicts", "reason")
|
||||
assert new < old
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_evidence_has_largest_effect(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
conf = await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "reason")
|
||||
# direct_evidence weight is +0.25 * (1-0.5) = +0.125
|
||||
assert conf == pytest.approx(0.625)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_log_tracked(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
await graph.update_hypothesis_confidence(hid, pid, "supports", "because reasons")
|
||||
log = graph.hypotheses[hid].confidence_log
|
||||
assert len(log) == 1
|
||||
assert log[0]["edge_type"] == "supports"
|
||||
assert log[0]["reason"] == "because reasons"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_status_at_threshold(self, graph):
|
||||
hid = await graph.add_hypothesis("test", "desc", created_by="test")
|
||||
# Push confidence up with multiple direct_evidence links
|
||||
for i in range(5):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", f"test {i}", f"desc {i}", source_tool="t")
|
||||
await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "proof")
|
||||
assert graph.hypotheses[hid].status == "supported"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context window management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTruncateToolResult:
|
||||
def test_short_text_unchanged(self):
|
||||
text = "short result"
|
||||
assert _truncate_tool_result(text) == text
|
||||
|
||||
def test_long_text_truncated(self):
|
||||
text = "x" * 5000
|
||||
result = _truncate_tool_result(text, max_chars=3000)
|
||||
assert len(result) < 3100
|
||||
assert "truncated" in result
|
||||
assert "5000" in result
|
||||
|
||||
def test_exact_boundary(self):
|
||||
text = "x" * 3000
|
||||
assert _truncate_tool_result(text, max_chars=3000) == text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel dispatch & async safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_add_phenomenon(self):
|
||||
graph = EvidenceGraph()
|
||||
|
||||
async def add_batch(agent_name: str, count: int):
|
||||
for i in range(count):
|
||||
await graph.add_phenomenon(
|
||||
source_agent=agent_name,
|
||||
category=f"cat-{agent_name}",
|
||||
title=f"Phenomenon {agent_name}-{i}",
|
||||
description=f"Description from {agent_name} number {i} with enough detail to be meaningful",
|
||||
source_tool="test_tool",
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
add_batch("agent_a", 10),
|
||||
add_batch("agent_b", 10),
|
||||
add_batch("agent_c", 10),
|
||||
)
|
||||
assert len(graph.phenomena) == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_add_lead(self):
|
||||
graph = EvidenceGraph()
|
||||
|
||||
async def add_leads(count: int):
|
||||
for i in range(count):
|
||||
await graph.add_lead(target_agent="network", description=f"Lead {i}")
|
||||
|
||||
await asyncio.gather(add_leads(10), add_leads(10))
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_lead_completed(self):
|
||||
graph = EvidenceGraph()
|
||||
lid = await graph.add_lead(target_agent="fs", description="test lead")
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 1
|
||||
|
||||
await graph.mark_lead_completed(lid)
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
class TestParallelDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_agent_types_run_concurrently(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
execution_log: list[tuple[str, float, float]] = []
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
async def run(self, task, lead_id=None):
|
||||
start = time.monotonic()
|
||||
await asyncio.sleep(0.05)
|
||||
end = time.monotonic()
|
||||
execution_log.append((self.name, start, end))
|
||||
|
||||
factory._cache = {
|
||||
"filesystem": FakeAgent("filesystem"),
|
||||
"registry": FakeAgent("registry"),
|
||||
"network": FakeAgent("network"),
|
||||
}
|
||||
|
||||
leads = [
|
||||
Lead(id="l1", target_agent="filesystem", description="fs task"),
|
||||
Lead(id="l2", target_agent="registry", description="reg task"),
|
||||
Lead(id="l3", target_agent="network", description="net task"),
|
||||
]
|
||||
graph.leads = leads
|
||||
|
||||
start_time = time.monotonic()
|
||||
await orch._dispatch_leads_parallel(leads)
|
||||
total_time = time.monotonic() - start_time
|
||||
|
||||
assert total_time < 0.12, f"Expected parallel execution but took {total_time:.3f}s"
|
||||
assert len(execution_log) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_agent_type_runs_serially(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
async def run(self, task, lead_id=None):
|
||||
execution_order.append(task)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
factory._cache = {"filesystem": FakeAgent("filesystem")}
|
||||
|
||||
leads = [
|
||||
Lead(id="l1", target_agent="filesystem", description="task 1"),
|
||||
Lead(id="l2", target_agent="filesystem", description="task 2"),
|
||||
]
|
||||
graph.leads = leads
|
||||
|
||||
await orch._dispatch_leads_parallel(leads)
|
||||
|
||||
assert len(execution_order) == 2
|
||||
assert "task 1" in execution_order[0]
|
||||
assert "task 2" in execution_order[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Asset library
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAssetLibrary:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_asset(self, graph):
|
||||
aid, existed = await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="filesystem",
|
||||
)
|
||||
assert not existed
|
||||
assert aid.startswith("asset-")
|
||||
assert len(graph.asset_library) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_by_inode(self, graph):
|
||||
aid1, _ = await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="filesystem",
|
||||
)
|
||||
aid2, existed = await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="registry",
|
||||
)
|
||||
assert existed
|
||||
assert aid1 == aid2
|
||||
assert len(graph.asset_library) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_category(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=100, extracted_by="fs",
|
||||
)
|
||||
await graph.register_asset(
|
||||
inode="10090-128-1", original_path="mIRC/logs/hackers.log",
|
||||
local_path="extracted/hackers.log", category="chat_log",
|
||||
filename="hackers.log", size_bytes=200, extracted_by="comm",
|
||||
)
|
||||
results = graph.query_assets(category="registry_hive")
|
||||
assert len(results) == 1
|
||||
assert results[0].filename == "system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_by_inode(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=100, extracted_by="fs",
|
||||
)
|
||||
asset = graph.lookup_asset_by_inode("334-128-4")
|
||||
assert asset is not None
|
||||
assert asset.local_path == "extracted/system"
|
||||
|
||||
assert graph.lookup_asset_by_inode("999-128-4") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_round_trip(self, graph, tmp_path):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="filesystem",
|
||||
)
|
||||
path = tmp_path / "test_state.json"
|
||||
graph.save_state(path)
|
||||
|
||||
loaded = EvidenceGraph.load_state(path)
|
||||
assert len(loaded.asset_library) == 1
|
||||
assert loaded.lookup_asset_by_inode("334-128-4") is not None
|
||||
assert loaded.lookup_asset_by_inode("334-128-4").category == "registry_hive"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=8192, extracted_by="fs",
|
||||
)
|
||||
summaries = graph.list_assets()
|
||||
assert len(summaries) == 1
|
||||
assert "system" in summaries[0]
|
||||
assert "registry_hive" in summaries[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_filename(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="10080-128-3", original_path="mIRC/mirc.ini",
|
||||
local_path="extracted/mirc.ini", category="config_file",
|
||||
filename="mirc.ini", size_bytes=500, extracted_by="comm",
|
||||
)
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=100, extracted_by="fs",
|
||||
)
|
||||
results = graph.query_assets(filename_pattern="mirc")
|
||||
assert len(results) == 1
|
||||
assert results[0].filename == "mirc.ini"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool call partitioning (P0: parallel tool execution)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPartitionToolCalls:
|
||||
def test_consecutive_read_only_grouped(self):
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {"inode": "33"}},
|
||||
{"name": "list_directory", "arguments": {"inode": "46"}},
|
||||
{"name": "list_directory", "arguments": {"inode": "50"}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert batches[0].is_read_only
|
||||
assert len(batches[0].calls) == 3
|
||||
|
||||
def test_write_tools_isolated(self):
|
||||
calls = [
|
||||
{"name": "extract_file", "arguments": {"inode": "40"}},
|
||||
{"name": "extract_file", "arguments": {"inode": "41"}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 2
|
||||
assert not batches[0].is_read_only
|
||||
assert not batches[1].is_read_only
|
||||
|
||||
def test_mixed_read_write_partitioned(self):
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {}},
|
||||
{"name": "search_graph", "arguments": {"keyword": "test"}},
|
||||
{"name": "add_phenomenon", "arguments": {"title": "x"}},
|
||||
{"name": "list_phenomena", "arguments": {}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 3
|
||||
assert batches[0].is_read_only and len(batches[0].calls) == 2
|
||||
assert not batches[1].is_read_only and len(batches[1].calls) == 1
|
||||
assert batches[2].is_read_only and len(batches[2].calls) == 1
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _partition_tool_calls([]) == []
|
||||
|
||||
def test_single_read_only(self):
|
||||
calls = [{"name": "partition_info", "arguments": {}}]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert batches[0].is_read_only
|
||||
|
||||
def test_single_write(self):
|
||||
calls = [{"name": "add_lead", "arguments": {}}]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert not batches[0].is_read_only
|
||||
|
||||
def test_unknown_tool_treated_as_write(self):
|
||||
calls = [{"name": "unknown_tool", "arguments": {}}]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert not batches[0].is_read_only
|
||||
|
||||
def test_custom_read_only_set(self):
|
||||
calls = [
|
||||
{"name": "foo", "arguments": {}},
|
||||
{"name": "bar", "arguments": {}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls, read_only={"foo", "bar"})
|
||||
assert len(batches) == 1
|
||||
assert batches[0].is_read_only
|
||||
|
||||
|
||||
class TestParallelToolExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_faster_than_serial(self):
|
||||
"""Read-only tools should execute concurrently."""
|
||||
from llm_client import LLMClient
|
||||
|
||||
async def slow_tool(**kwargs):
|
||||
await asyncio.sleep(0.05)
|
||||
return "ok"
|
||||
|
||||
tool_executor = {
|
||||
"list_directory": slow_tool,
|
||||
"search_graph": slow_tool,
|
||||
"list_phenomena": slow_tool,
|
||||
}
|
||||
|
||||
client = LLMClient.__new__(LLMClient) # skip __init__
|
||||
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {}},
|
||||
{"name": "search_graph", "arguments": {"keyword": "x"}},
|
||||
{"name": "list_phenomena", "arguments": {}},
|
||||
]
|
||||
|
||||
start = time.monotonic()
|
||||
results = await client._execute_tool_batch_parallel(calls, tool_executor)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert len(results) == 3
|
||||
assert all("ok" in r for r in results)
|
||||
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
|
||||
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_in_parallel_batch_does_not_crash(self):
|
||||
"""A failing tool in a parallel batch should produce an error result, not crash."""
|
||||
from llm_client import LLMClient
|
||||
|
||||
async def ok_tool(**kwargs):
|
||||
return "success"
|
||||
|
||||
async def bad_tool(**kwargs):
|
||||
raise ValueError("broken")
|
||||
|
||||
tool_executor = {
|
||||
"list_directory": ok_tool,
|
||||
"search_graph": bad_tool,
|
||||
}
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {}},
|
||||
{"name": "search_graph", "arguments": {"keyword": "x"}},
|
||||
]
|
||||
|
||||
results = await client._execute_tool_batch_parallel(calls, tool_executor)
|
||||
assert len(results) == 2
|
||||
assert "success" in results[0]
|
||||
assert "Error" in results[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batched orchestrator LLM calls (P0: batch lead gen + judging)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBatchLeadGeneration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_leads_created_with_hypothesis_id(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
# Set up 2 hypotheses
|
||||
hid1 = await graph.add_hypothesis("Wardriving", "Wireless scanning", created_by="test")
|
||||
hid2 = await graph.add_hypothesis("Data theft", "Stolen files", created_by="test")
|
||||
|
||||
# Mock LLM to return batched JSON
|
||||
llm.chat.return_value = (
|
||||
f'[{{"hypothesis_id": "{hid1}", "agent": "network", "task": "analyze pcap", "priority": 3}},'
|
||||
f' {{"hypothesis_id": "{hid2}", "agent": "filesystem", "task": "check deleted files", "priority": 5}}]'
|
||||
)
|
||||
|
||||
await orch._generate_hypothesis_leads()
|
||||
|
||||
# Should have made exactly 1 LLM call
|
||||
assert llm.chat.call_count == 1
|
||||
# Should have created 2 leads
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 2
|
||||
agents = {l.target_agent for l in pending}
|
||||
assert "network" in agents
|
||||
assert "filesystem" in agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_parse_error(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
hid = await graph.add_hypothesis("Test", "Test desc", created_by="test")
|
||||
|
||||
# First call (batched) returns invalid JSON; fallback calls succeed
|
||||
llm.chat.side_effect = [
|
||||
"not valid json at all", # batched call fails
|
||||
'[{"agent": "registry", "task": "check SAM", "priority": 5}]', # fallback
|
||||
]
|
||||
|
||||
await orch._generate_hypothesis_leads()
|
||||
|
||||
# Should have called LLM twice: 1 batched (failed) + 1 fallback
|
||||
assert llm.chat.call_count == 2
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 1
|
||||
|
||||
|
||||
class TestBatchJudging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_judgments_applied(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
hid1 = await graph.add_hypothesis("Hyp A", "Desc A", created_by="test")
|
||||
hid2 = await graph.add_hypothesis("Hyp B", "Desc B", created_by="test")
|
||||
pid1, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 1", "details 1", source_tool="t")
|
||||
pid2, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 2", "details 2", source_tool="t")
|
||||
|
||||
old_conf1 = graph.hypotheses[hid1].confidence
|
||||
old_conf2 = graph.hypotheses[hid2].confidence
|
||||
|
||||
llm.chat.return_value = (
|
||||
f'[{{"hypothesis_id": "{hid1}", "phenomenon_id": "{pid1}", "edge_type": "supports", "reason": "evidence for A"}},'
|
||||
f' {{"hypothesis_id": "{hid2}", "phenomenon_id": "{pid2}", "edge_type": "contradicts", "reason": "disproves B"}}]'
|
||||
)
|
||||
|
||||
await orch._judge_new_phenomena()
|
||||
|
||||
assert llm.chat.call_count == 1
|
||||
assert graph.hypotheses[hid1].confidence > old_conf1
|
||||
assert graph.hypotheses[hid2].confidence < old_conf2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_ids_skipped(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
hid = await graph.add_hypothesis("Hyp", "Desc", created_by="test")
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "Finding", "details", source_tool="t")
|
||||
|
||||
old_conf = graph.hypotheses[hid].confidence
|
||||
|
||||
llm.chat.return_value = (
|
||||
f'[{{"hypothesis_id": "hyp-nonexistent", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "bad ref"}},'
|
||||
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "ph-nonexistent", "edge_type": "supports", "reason": "bad ref"}},'
|
||||
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "good"}}]'
|
||||
)
|
||||
|
||||
await orch._judge_new_phenomena()
|
||||
|
||||
# Only the valid judgment should have been applied
|
||||
assert graph.hypotheses[hid].confidence > old_conf
|
||||
# Only 1 edge created (the valid one)
|
||||
hyp_edges = [e for e in graph.edges if e.target_id == hid and e.source_id == pid]
|
||||
assert len(hyp_edges) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool result caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolResultCache:
|
||||
def setup_method(self):
|
||||
_tool_result_cache.clear()
|
||||
|
||||
def test_cache_key_deterministic(self):
|
||||
k1 = _cache_key("list_directory", {"inode": "33"})
|
||||
k2 = _cache_key("list_directory", {"inode": "33"})
|
||||
assert k1 == k2
|
||||
|
||||
def test_cache_key_differs_by_args(self):
|
||||
k1 = _cache_key("list_directory", {"inode": "33"})
|
||||
k2 = _cache_key("list_directory", {"inode": "46"})
|
||||
assert k1 != k2
|
||||
|
||||
def test_cache_key_differs_by_tool(self):
|
||||
k1 = _cache_key("list_directory", {"inode": "33"})
|
||||
k2 = _cache_key("find_file", {"inode": "33"})
|
||||
assert k1 != k2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_executor_returns_cached_result(self):
|
||||
call_count = 0
|
||||
|
||||
async def real_executor(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result for {kwargs}"
|
||||
|
||||
cached = _make_cached("test_tool", real_executor)
|
||||
|
||||
r1 = await cached(inode="33")
|
||||
r2 = await cached(inode="33")
|
||||
assert r1 == r2
|
||||
assert call_count == 1 # only called once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_args_not_cached(self):
|
||||
call_count = 0
|
||||
|
||||
async def real_executor(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{call_count}"
|
||||
|
||||
cached = _make_cached("test_tool", real_executor)
|
||||
|
||||
r1 = await cached(inode="33")
|
||||
r2 = await cached(inode="46")
|
||||
assert r1 != r2
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_results_not_cached(self):
|
||||
call_count = 0
|
||||
|
||||
async def flaky(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "Error: something went wrong"
|
||||
return "success"
|
||||
|
||||
cached = _make_cached("test_tool", flaky)
|
||||
|
||||
r1 = await cached(x=1)
|
||||
assert "Error" in r1
|
||||
r2 = await cached(x=1)
|
||||
assert r2 == "success"
|
||||
assert call_count == 2 # called twice because error wasn't cached
|
||||
|
||||
def test_cacheable_tools_are_read_only(self):
|
||||
"""Every cacheable tool should also be in READ_ONLY_TOOLS."""
|
||||
for t in CACHEABLE_TOOLS:
|
||||
assert t in READ_ONLY_TOOLS, f"{t} is cacheable but not in READ_ONLY_TOOLS"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progressive context decay (Stage A)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProgressiveDecay:
|
||||
def _make_messages(self, n_rounds: int) -> list[dict]:
|
||||
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
|
||||
messages = [{"role": "user", "content": "Start task"}]
|
||||
for i in range(n_rounds):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
|
||||
})
|
||||
# Tool result message with substantial content
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[tool_{i}] {'x' * 2500}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
),
|
||||
})
|
||||
return messages
|
||||
|
||||
def test_short_conversation_unchanged(self):
|
||||
msgs = self._make_messages(3)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert len(result) == len(msgs)
|
||||
# Content should be identical for short conversations
|
||||
for orig, decayed in zip(msgs, result):
|
||||
assert orig["content"] == decayed["content"]
|
||||
|
||||
def test_old_messages_truncated(self):
|
||||
msgs = self._make_messages(20)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
|
||||
# Recent tool results should be full length
|
||||
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
|
||||
assert len(last_tool_result["content"]) > 2000
|
||||
|
||||
# Oldest tool results should be truncated
|
||||
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
|
||||
assert len(first_tool_result["content"]) < 500
|
||||
|
||||
def test_message_count_preserved(self):
|
||||
msgs = self._make_messages(20)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert len(result) == len(msgs)
|
||||
|
||||
def test_non_tool_messages_unchanged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert result[0]["content"] == "Hello"
|
||||
assert result[1]["content"] == "Hi there"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM message folding (Stage B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageFolding:
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_replaces_old_messages_with_summary(self):
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock(return_value="Summary: found important files")
|
||||
|
||||
# Build messages exceeding fold threshold
|
||||
messages = [{"role": "user", "content": "Start task"}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "assistant", "content": f"thinking step {i}"})
|
||||
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
|
||||
|
||||
result = await client._fold_old_messages(messages, "system prompt")
|
||||
|
||||
# Should be significantly shorter
|
||||
assert len(result) < len(messages)
|
||||
# First message should be the summary
|
||||
assert "Context summary" in result[0]["content"]
|
||||
# Recent messages preserved
|
||||
assert len(result) == _FOLD_KEEP_RECENT + 1 # +1 for summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_not_triggered_below_threshold(self):
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
|
||||
assert result == messages
|
||||
client.chat.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_graceful_on_llm_failure(self):
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock(side_effect=Exception("API error"))
|
||||
|
||||
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
|
||||
# On failure, should return original messages unchanged
|
||||
assert len(result) == 40
|
||||
Reference in New Issue
Block a user