refactor: native tool calling + generic forced-retry + terminal exit
- llm_client: switch tool_call_loop from text-based <tool_call> regex to OpenAI-native tools=[...] / structured tool_calls field; accumulate delta.reasoning_content for DeepSeek thinking-mode echo-back; fold preserves system msg and aligns boundary to never orphan role:tool - base_agent: generic forced-retry via mandatory_record_tools class attr (filesystem -> add_phenomenon, timeline -> add_temporal_edge, hypothesis -> add_hypothesis, report -> save_report); count via executor wrapper - terminal_tools class attr + loop short-circuit: when a terminal tool is called, loop exits with its raw return as final_text. ReportAgent declares save_report as terminal - replaces the <answer>-tag stop signal that native tool calling broke - _execute_*: return (raw, formatted) - terminal exit uses untruncated raw, conversation history uses 3000-char-capped formatted - evidence_graph + orchestrator: LLM-derived InvestigationArea support (hypothesis-driven coverage check, replaces hardcoded _AREA_KEYWORDS / _AREA_TOOLS); manual yaml block kept as optional seed - strip <answer> references from agent prompts (no longer load-bearing) Verified on CFReDS image across 4 smoke runs: 0 JSON parse failures (was 3); 22 temporal edges from Phase 4 (was 0); ReportAgent exits via save_report (was max_iterations regression). 78/78 unit tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,6 @@ from evidence_graph import (
|
||||
from llm_client import (
|
||||
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
|
||||
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
|
||||
TOOL_RESULT_TAG, TOOL_RESULT_END,
|
||||
)
|
||||
from tool_registry import (
|
||||
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
|
||||
@@ -598,7 +597,8 @@ class TestParallelToolExecution:
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert len(results) == 3
|
||||
assert all("ok" in r for r in results)
|
||||
# results are (raw, formatted) tuples; both contain "ok"
|
||||
assert all("ok" in raw and "ok" in formatted for raw, formatted in results)
|
||||
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
|
||||
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
|
||||
|
||||
@@ -627,8 +627,10 @@ class TestParallelToolExecution:
|
||||
|
||||
results = await client._execute_tool_batch_parallel(calls, tool_executor)
|
||||
assert len(results) == 2
|
||||
assert "success" in results[0]
|
||||
assert "Error" in results[1]
|
||||
raw0, formatted0 = results[0]
|
||||
raw1, formatted1 = results[1]
|
||||
assert "success" in raw0 and "success" in formatted0
|
||||
assert "Error" in raw1 and "Error" in formatted1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -843,21 +845,27 @@ class TestToolResultCache:
|
||||
|
||||
class TestProgressiveDecay:
|
||||
def _make_messages(self, n_rounds: int) -> list[dict]:
|
||||
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
|
||||
"""Build a synthetic message list shaped like native tool calling:
|
||||
user → (assistant w/ tool_calls → tool result)+
|
||||
"""
|
||||
messages = [{"role": "user", "content": "Start task"}]
|
||||
for i in range(n_rounds):
|
||||
tc_id = f"call_{i}"
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc_id,
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
})
|
||||
# Tool result message with substantial content
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[tool_{i}] {'x' * 2500}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
),
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": f"[tool_{i}] {'x' * 2500}",
|
||||
})
|
||||
return messages
|
||||
|
||||
@@ -865,21 +873,16 @@ class TestProgressiveDecay:
|
||||
msgs = self._make_messages(3)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert len(result) == len(msgs)
|
||||
# Content should be identical for short conversations
|
||||
for orig, decayed in zip(msgs, result):
|
||||
assert orig["content"] == decayed["content"]
|
||||
assert orig.get("content") == decayed.get("content")
|
||||
|
||||
def test_old_messages_truncated(self):
|
||||
msgs = self._make_messages(20)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
|
||||
# Recent tool results should be full length
|
||||
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
|
||||
assert len(last_tool_result["content"]) > 2000
|
||||
|
||||
# Oldest tool results should be truncated
|
||||
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
|
||||
assert len(first_tool_result["content"]) < 500
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
assert len(tool_msgs[-1]["content"]) > 2000
|
||||
assert len(tool_msgs[0]["content"]) < 500
|
||||
|
||||
def test_message_count_preserved(self):
|
||||
msgs = self._make_messages(20)
|
||||
@@ -915,7 +918,7 @@ class TestMessageFolding:
|
||||
messages.append({"role": "assistant", "content": f"thinking step {i}"})
|
||||
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
|
||||
|
||||
result = await client._fold_old_messages(messages, "system prompt")
|
||||
result = await client._fold_old_messages(messages)
|
||||
|
||||
# Should be significantly shorter
|
||||
assert len(result) < len(messages)
|
||||
@@ -937,7 +940,7 @@ class TestMessageFolding:
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
result = await client._fold_old_messages(messages)
|
||||
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
|
||||
assert result == messages
|
||||
client.chat.assert_not_called()
|
||||
@@ -951,7 +954,555 @@ class TestMessageFolding:
|
||||
client.chat = AsyncMock(side_effect=Exception("API error"))
|
||||
|
||||
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
result = await client._fold_old_messages(messages)
|
||||
|
||||
# On failure, should return original messages unchanged
|
||||
assert len(result) == 40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_boundary_never_orphans_tool_message(self):
|
||||
"""If the natural fold boundary would leave `role: "tool"` at the
|
||||
head of `recent_messages`, fold must walk the boundary forward
|
||||
until the head is non-tool. The API rejects orphan tool messages
|
||||
with HTTP 400."""
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock(return_value="summary")
|
||||
|
||||
# Build a long conversation of (assistant{tool_calls}, tool) pairs.
|
||||
# Place the assistant at the exact n_to_fold boundary so its paired
|
||||
# tool would otherwise be orphaned at the head of recent_messages.
|
||||
msgs: list[dict] = [{"role": "user", "content": "task"}]
|
||||
for i in range(30):
|
||||
tc_id = f"call_{i}"
|
||||
msgs.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{
|
||||
"id": tc_id, "type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}],
|
||||
})
|
||||
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": "ok"})
|
||||
|
||||
result = await client._fold_old_messages(msgs)
|
||||
# No `role: "tool"` may appear without an `assistant{tool_calls}`
|
||||
# immediately preceding it.
|
||||
for i, m in enumerate(result):
|
||||
if m.get("role") == "tool":
|
||||
assert i > 0, "tool message cannot be first"
|
||||
prev = result[i - 1]
|
||||
assert prev.get("role") == "assistant" and prev.get("tool_calls"), (
|
||||
f"tool at index {i} preceded by {prev.get('role')} "
|
||||
f"(tool_calls={bool(prev.get('tool_calls'))})"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Investigation areas: dataclass + derivation + coverage + dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInvestigationAreaDerivation:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_investigation_area_dedupes_and_merges_lists(self, graph):
|
||||
aid1, existed1 = await graph.add_investigation_area(
|
||||
area="password_hashes", description="SAM dump",
|
||||
suggested_agent="filesystem",
|
||||
expected_keywords=["sam", "pwdump"],
|
||||
expected_tools=["search_strings"],
|
||||
motivating_hypothesis_ids=["hyp-a"],
|
||||
created_by="llm_derive",
|
||||
)
|
||||
aid2, existed2 = await graph.add_investigation_area(
|
||||
area="password_hashes", description="Other description",
|
||||
suggested_agent="registry",
|
||||
expected_keywords=["sam", "hashdump"], # one new
|
||||
expected_tools=["search_strings", "parse_registry_key"], # one new
|
||||
motivating_hypothesis_ids=["hyp-b"], # new
|
||||
created_by="manual",
|
||||
)
|
||||
assert not existed1
|
||||
assert existed2
|
||||
assert aid1 == aid2
|
||||
a = graph.investigation_areas[aid1]
|
||||
# Description and suggested_agent NOT overwritten (first-write wins)
|
||||
assert a.description == "SAM dump"
|
||||
assert a.suggested_agent == "filesystem"
|
||||
# Three list fields are unioned
|
||||
assert set(a.expected_keywords) == {"sam", "pwdump", "hashdump"}
|
||||
assert set(a.expected_tools) == {"search_strings", "parse_registry_key"}
|
||||
assert set(a.motivating_hypothesis_ids) == {"hyp-a", "hyp-b"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_coverage_keyword_layer(self, graph):
|
||||
await graph.add_phenomenon(
|
||||
"fs", "filesystem", "Cain SAM dump artifact",
|
||||
"Found sam.lst in the Cain folder",
|
||||
source_tool="list_directory",
|
||||
)
|
||||
await graph.add_investigation_area(
|
||||
area="password_hashes", description="SAM dump",
|
||||
suggested_agent="filesystem",
|
||||
expected_keywords=["sam.lst", "pwdump"],
|
||||
expected_tools=["nonexistent_tool"],
|
||||
)
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
from unittest.mock import AsyncMock
|
||||
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
|
||||
covered = orch._check_coverage()
|
||||
assert "password_hashes" in covered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_coverage_tool_layer(self, graph):
|
||||
await graph.add_phenomenon(
|
||||
"reg", "registry", "User accounts",
|
||||
"Found accounts",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
await graph.add_investigation_area(
|
||||
area="user_accounts", description="Enum users",
|
||||
suggested_agent="registry",
|
||||
expected_keywords=["irrelevant"],
|
||||
expected_tools=["enumerate_users"],
|
||||
)
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
from unittest.mock import AsyncMock
|
||||
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
|
||||
covered = orch._check_coverage()
|
||||
assert "user_accounts" in covered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state_round_trip_areas(self, graph, tmp_path):
|
||||
await graph.add_investigation_area(
|
||||
area="x", description="d", suggested_agent="filesystem",
|
||||
expected_keywords=["k1"], expected_tools=["t1"],
|
||||
priority=2, motivating_hypothesis_ids=["hyp-a"],
|
||||
created_by="manual",
|
||||
)
|
||||
path = tmp_path / "state.json"
|
||||
graph.save_state(path)
|
||||
g2 = EvidenceGraph.load_state(path)
|
||||
assert len(g2.investigation_areas) == 1
|
||||
a = list(g2.investigation_areas.values())[0]
|
||||
assert a.area == "x"
|
||||
assert a.expected_keywords == ["k1"]
|
||||
assert a.priority == 2
|
||||
assert a.motivating_hypothesis_ids == ["hyp-a"]
|
||||
assert a.created_by == "manual"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derive_no_op_when_areas_already_populated(self, graph):
|
||||
"""Resume safety: if areas are already in the graph (manual seed or
|
||||
restored from disk), _derive_investigation_areas does nothing."""
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
await graph.add_hypothesis("test", "desc", created_by="t")
|
||||
await graph.add_investigation_area(
|
||||
area="pre_existing", description="d", suggested_agent="filesystem",
|
||||
created_by="manual",
|
||||
)
|
||||
|
||||
llm = AsyncMock()
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
# LLM should not have been called
|
||||
assert llm.chat.call_count == 0
|
||||
# Area count unchanged
|
||||
assert len(graph.investigation_areas) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_when_llm_returns_empty_list(self, graph):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
await graph.add_hypothesis("Some compromise", "desc", created_by="t")
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = "[]"
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
# Fallback creates one area per hypothesis
|
||||
assert len(graph.investigation_areas) == 1
|
||||
a = list(graph.investigation_areas.values())[0]
|
||||
assert a.created_by == "fallback"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_filtered_kept_keywords(self, graph):
|
||||
"""LLM emits a tool name not in TOOL_CATALOG; tool is filtered,
|
||||
but the area itself with its keywords is kept."""
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
h = await graph.add_hypothesis("h", "desc", created_by="t")
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = (
|
||||
'[{"area":"foo","description":"desc","suggested_agent":"filesystem",'
|
||||
'"expected_keywords":["kw1","kw2"],"expected_tools":["nonexistent_tool"],'
|
||||
f'"priority":2,"motivating_hypothesis_ids":["{h}"]}}]'
|
||||
)
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
assert "area-foo" in graph.investigation_areas
|
||||
a = graph.investigation_areas["area-foo"]
|
||||
assert a.expected_keywords == ["kw1", "kw2"]
|
||||
assert a.expected_tools == [] # filtered out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_agent_resolved_via_AGENT_ALIASES(self, graph):
|
||||
"""LLM emits 'chat' (which is in AGENT_ALIASES → 'communication').
|
||||
The area should land with the resolved agent name."""
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
h = await graph.add_hypothesis("h", "desc", created_by="t")
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = (
|
||||
'[{"area":"chat_stuff","description":"d","suggested_agent":"chat",'
|
||||
'"expected_keywords":["irc"],"expected_tools":[],'
|
||||
f'"priority":3,"motivating_hypothesis_ids":["{h}"]}}]'
|
||||
)
|
||||
orch = Orchestrator(llm, graph, AgentFactory(llm, graph))
|
||||
await orch._derive_investigation_areas()
|
||||
a = graph.investigation_areas["area-chat_stuff"]
|
||||
assert a.suggested_agent == "communication"
|
||||
|
||||
@staticmethod
|
||||
def _agent_with_executor(graph, llm, tool_name: str, real_executor):
|
||||
"""Build a BaseAgent that registers tool_name via the real register_tool
|
||||
path so the mandatory-record wrapper is engaged."""
|
||||
from base_agent import BaseAgent
|
||||
agent = BaseAgent(llm, graph)
|
||||
agent.name = "test_agent"
|
||||
# Bypass _register_graph_tools side-effects in run() — we register
|
||||
# only what the test needs.
|
||||
agent._register_graph_tools = lambda: None
|
||||
agent.register_tool(
|
||||
name=tool_name, description="", input_schema={}, executor=real_executor,
|
||||
)
|
||||
return agent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forced_record_retry_fires_when_zero_phenomena(self):
|
||||
"""BaseAgent.run should automatically retry one more LLM round if
|
||||
the agent finished without calling any mandatory recording tool."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
|
||||
async def real_add(**kw):
|
||||
await graph.add_phenomenon(
|
||||
source_agent="test", category="filesystem",
|
||||
title="Forced retry record",
|
||||
description="Recorded after STOP prompt",
|
||||
source_tool="forced_retry",
|
||||
)
|
||||
|
||||
agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add)
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
already_retrying = any(
|
||||
"STOP." in (m.get("content", "") if isinstance(m, dict) else "")
|
||||
for m in messages
|
||||
)
|
||||
if not already_retrying:
|
||||
return "Final answer without recording.", list(messages) + [
|
||||
{"role": "assistant", "content": "Final answer without recording."}
|
||||
]
|
||||
await tool_executor["add_phenomenon"]() # goes through wrapper
|
||||
return "Recorded.", []
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
await agent.run("test task")
|
||||
assert len(graph.phenomena) == 1
|
||||
assert agent._record_call_counts["add_phenomenon"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_retry_when_mandatory_tool_was_called(self):
|
||||
"""Retry should NOT fire if a mandatory recording tool was invoked."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def real_add(**kw):
|
||||
await graph.add_phenomenon(
|
||||
source_agent="test", category="filesystem", title="x",
|
||||
description="y", source_tool="t",
|
||||
)
|
||||
|
||||
agent = self._agent_with_executor(graph, llm, "add_phenomenon", real_add)
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
call_count["n"] += 1
|
||||
await tool_executor["add_phenomenon"]() # wrapper increments count
|
||||
return "done.", list(messages)
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
await agent.run("test task")
|
||||
assert call_count["n"] == 1 # no retry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_retry_when_mandatory_tools_empty(self):
|
||||
"""ReportAgent declares mandatory_record_tools=() — retry should
|
||||
not fire even with zero graph mutations (final text IS the output)."""
|
||||
from unittest.mock import AsyncMock
|
||||
from base_agent import BaseAgent
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
call_count["n"] += 1
|
||||
return "report body here", list(messages)
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
|
||||
class ReportLike(BaseAgent):
|
||||
mandatory_record_tools = ()
|
||||
|
||||
agent = ReportLike(llm, graph)
|
||||
agent.name = "report_like"
|
||||
agent._register_graph_tools = lambda: None
|
||||
await agent.run("test task")
|
||||
assert call_count["n"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forced_retry_fires_for_timeline_agent(self):
|
||||
"""TimelineAgent.mandatory_record_tools=('add_temporal_edge',) — retry
|
||||
should fire when timeline finishes without creating any temporal edges,
|
||||
even though the agent does not have add_phenomenon."""
|
||||
from unittest.mock import AsyncMock
|
||||
from base_agent import BaseAgent
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
edge_added = {"n": 0}
|
||||
async def real_add_edge(**kw):
|
||||
edge_added["n"] += 1
|
||||
|
||||
class TimelineLike(BaseAgent):
|
||||
mandatory_record_tools = ("add_temporal_edge",)
|
||||
|
||||
agent = TimelineLike(llm, graph)
|
||||
agent.name = "timeline_like"
|
||||
agent._register_graph_tools = lambda: None
|
||||
agent.register_tool("add_temporal_edge", "", {}, real_add_edge)
|
||||
|
||||
async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()):
|
||||
call_count["n"] += 1
|
||||
already_retrying = any(
|
||||
"STOP." in (m.get("content", "") if isinstance(m, dict) else "")
|
||||
for m in messages
|
||||
)
|
||||
if not already_retrying:
|
||||
return "answer", list(messages)
|
||||
await tool_executor["add_temporal_edge"]()
|
||||
return "recorded.", []
|
||||
|
||||
llm.tool_call_loop = fake_tool_call_loop
|
||||
await agent.run("build timeline")
|
||||
assert call_count["n"] == 2 # first + retry
|
||||
assert edge_added["n"] == 1
|
||||
assert agent._record_call_counts["add_temporal_edge"] == 1
|
||||
|
||||
# ---- terminal_tools: real LLMClient.tool_call_loop short-circuit -----
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminal_tool_exits_loop_immediately(self):
|
||||
"""When a terminal tool is called, tool_call_loop must return
|
||||
with that tool's result text as final_text — no further LLM calls."""
|
||||
from unittest.mock import AsyncMock
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.max_tokens = 4096
|
||||
client.reasoning_effort = None
|
||||
client.thinking_enabled = False
|
||||
client.model = "test"
|
||||
client._client = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_chat_with_tools(messages, openai_tools):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
# First turn: model calls a read tool then the terminal tool.
|
||||
return "thinking aloud", None, [
|
||||
{"id": "tc1", "name": "read_tool", "arguments": "{}"},
|
||||
{"id": "tc2", "name": "save_report",
|
||||
"arguments": '{"content":"FINAL REPORT BODY","output_path":"r.md"}'},
|
||||
]
|
||||
raise AssertionError("loop should have exited after terminal tool")
|
||||
|
||||
client._chat_with_tools = fake_chat_with_tools
|
||||
|
||||
async def read_tool():
|
||||
return "some data"
|
||||
|
||||
async def save_report(content, output_path):
|
||||
return content # terminal tool returns content as final_text
|
||||
|
||||
tools = [
|
||||
{"name": "read_tool", "description": "", "input_schema": {"type": "object", "properties": {}}},
|
||||
{"name": "save_report", "description": "",
|
||||
"input_schema": {"type": "object", "properties": {
|
||||
"content": {"type": "string"}, "output_path": {"type": "string"}}}},
|
||||
]
|
||||
executors = {"read_tool": read_tool, "save_report": save_report}
|
||||
|
||||
final_text, _ = await client.tool_call_loop(
|
||||
messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools, tool_executor=executors,
|
||||
system="sys", terminal_tools=("save_report",),
|
||||
)
|
||||
assert final_text == "FINAL REPORT BODY"
|
||||
assert call_count["n"] == 1 # never called a 2nd round
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_terminal_short_circuit_when_not_declared(self):
|
||||
"""When terminal_tools is empty, the same call sequence should
|
||||
run the read tool, run save_report-like tool, AND continue the loop
|
||||
(i.e. another LLM round) until the model stops calling tools."""
|
||||
from unittest.mock import AsyncMock
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.max_tokens = 4096
|
||||
client.reasoning_effort = None
|
||||
client.thinking_enabled = False
|
||||
client.model = "test"
|
||||
client._client = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_chat_with_tools(messages, openai_tools):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return "", None, [
|
||||
{"id": "tc1", "name": "add_phenomenon",
|
||||
"arguments": '{"title":"x","description":"y"}'},
|
||||
]
|
||||
return "all done", None, [] # model stops calling tools
|
||||
|
||||
client._chat_with_tools = fake_chat_with_tools
|
||||
|
||||
async def add_phenomenon(title, description):
|
||||
return f"recorded {title}"
|
||||
|
||||
tools = [
|
||||
{"name": "add_phenomenon", "description": "",
|
||||
"input_schema": {"type": "object", "properties": {
|
||||
"title": {"type": "string"}, "description": {"type": "string"}}}},
|
||||
]
|
||||
executors = {"add_phenomenon": add_phenomenon}
|
||||
|
||||
final_text, _ = await client.tool_call_loop(
|
||||
messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools, tool_executor=executors,
|
||||
system="sys", terminal_tools=(), # NOT terminal
|
||||
)
|
||||
assert final_text == "all done"
|
||||
assert call_count["n"] == 2 # 2 rounds — terminal_tools empty, loop continues
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_agent_terminal_tool_declared(self):
|
||||
"""ReportAgent should declare save_report as both mandatory and terminal."""
|
||||
from agents.report import ReportAgent
|
||||
assert ReportAgent.terminal_tools == ("save_report",)
|
||||
assert ReportAgent.mandatory_record_tools == ("save_report",)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminal_tool_result_not_truncated(self):
|
||||
"""Terminal tool's raw return is used as final_text and must NOT
|
||||
be truncated to 3000 chars (the truncation cap applies only to
|
||||
LLM-context tool result messages). A 20K-char markdown report
|
||||
passed through save_report should reach the caller intact."""
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.max_tokens = 4096
|
||||
client.reasoning_effort = None
|
||||
client.thinking_enabled = False
|
||||
client.model = "test"
|
||||
client._client = None
|
||||
|
||||
long_report = "# Report\n" + ("- finding " + "x" * 100 + "\n") * 200
|
||||
assert len(long_report) > 10000
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def fake_chat_with_tools(messages, openai_tools):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return "", None, [
|
||||
{"id": "tc1", "name": "save_report",
|
||||
"arguments": '{"content":"placeholder","output_path":"r.md"}'},
|
||||
]
|
||||
raise AssertionError("loop should have exited")
|
||||
|
||||
client._chat_with_tools = fake_chat_with_tools
|
||||
|
||||
async def save_report(content, output_path):
|
||||
return long_report # ignore content arg; return long content
|
||||
|
||||
tools = [{"name": "save_report", "description": "",
|
||||
"input_schema": {"type": "object", "properties": {
|
||||
"content": {"type": "string"}, "output_path": {"type": "string"}}}}]
|
||||
executors = {"save_report": save_report}
|
||||
|
||||
final_text, _ = await client.tool_call_loop(
|
||||
messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools, tool_executor=executors,
|
||||
system="sys", terminal_tools=("save_report",),
|
||||
)
|
||||
assert final_text == long_report
|
||||
assert len(final_text) > 10000 # not truncated to 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_uses_hypothesis_id_when_motivating_ids_present(self, graph):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
h = await graph.add_hypothesis("h", "desc", created_by="t")
|
||||
await graph.add_investigation_area(
|
||||
area="uncovered_area", description="d", suggested_agent="registry",
|
||||
expected_keywords=["xyz_no_match"], expected_tools=[],
|
||||
priority=2, motivating_hypothesis_ids=[h],
|
||||
created_by="llm_derive",
|
||||
)
|
||||
orch = Orchestrator(AsyncMock(), graph, AgentFactory(AsyncMock(), graph))
|
||||
# Don't actually dispatch (would call agents) — just hit the lead-add path
|
||||
# by manually replicating what _run_gap_analysis does.
|
||||
covered = orch._check_coverage()
|
||||
assert "uncovered_area" not in covered
|
||||
# Simulate dispatch
|
||||
for a in graph.investigation_areas.values():
|
||||
if a.area not in covered:
|
||||
await graph.add_lead(
|
||||
target_agent=a.suggested_agent,
|
||||
description=a.description,
|
||||
priority=a.priority,
|
||||
hypothesis_id=(a.motivating_hypothesis_ids[0]
|
||||
if a.motivating_hypothesis_ids else None),
|
||||
)
|
||||
assert len(graph.leads) == 1
|
||||
assert graph.leads[0].hypothesis_id == h
|
||||
assert graph.leads[0].target_agent == "registry"
|
||||
assert graph.leads[0].priority == 2
|
||||
|
||||
Reference in New Issue
Block a user