diff --git a/base_agent.py b/base_agent.py index 307887d..82c4a54 100644 --- a/base_agent.py +++ b/base_agent.py @@ -228,12 +228,27 @@ class BaseAgent: f"what you already found. Then end." ), }) + # Narrow the retry tool surface so the agent can't wander off + # to investigate again — only RECORD and read-only graph + # query tools survive. Each grounding-rejected call burns one + # iteration, so the cap is 30 (not the original 10): a + # Timeline agent writing ~10 temporal edges with one rejection + # apiece needs ~20 turns under the rewritten gateway. + retry_tool_names = set(registered_mandatory) | { + "list_phenomena", "list_assets", "search_graph", + "add_temporal_edge", "link_to_entity", "add_lead", + "add_hypothesis", "save_report", + } + retry_tools = [ + td for td in self.get_tool_definitions() + if td["name"] in retry_tool_names + ] final_text, _ = await self.llm.tool_call_loop( messages=conversation, - tools=self.get_tool_definitions(), + tools=retry_tools, tool_executor=self._executors, system=system, - max_iterations=10, + max_iterations=30, terminal_tools=self.terminal_tools, ) diff --git a/tests/test_optimizations.py b/tests/test_optimizations.py index 2504285..33b6d2b 100644 --- a/tests/test_optimizations.py +++ b/tests/test_optimizations.py @@ -1332,6 +1332,69 @@ class TestInvestigationAreaDerivation: assert edge_added["n"] == 1 assert agent._record_call_counts["add_temporal_edge"] == 1 + @pytest.mark.asyncio + async def test_forced_retry_uses_higher_cap_and_narrowed_tools(self): + """The forced RECORD retry must (a) get a generous iter cap so that + grounding-rejected retries don't blow the budget, and (b) hand the + LLM a tool surface restricted to RECORD + read-only graph tools so + it can't wander back into investigation. + """ + from unittest.mock import AsyncMock + from base_agent import BaseAgent + + graph = EvidenceGraph() + llm = AsyncMock() + + # Capture per-call kwargs of tool_call_loop so we can assert what + # the retry round received. + call_kwargs: list[dict] = [] + + async def real_add_edge(**kw): + return None + + class TimelineLike(BaseAgent): + mandatory_record_tools = ("add_temporal_edge",) + + agent = TimelineLike(llm, graph) + agent.name = "timeline_like" + agent._register_graph_tools = lambda: None + agent.register_tool("add_temporal_edge", "", {}, real_add_edge) + # An investigation-style tool the retry must NOT expose. + async def real_inv(**kw): return "" + agent.register_tool("list_directory", "", {}, real_inv) + # A read-only graph query — should remain available in retry. + async def real_ro(**kw): return "" + agent.register_tool("list_phenomena", "", {}, real_ro) + + async def fake_tool_call_loop(messages, tools, tool_executor, system, max_iterations=40, terminal_tools=()): + call_kwargs.append({ + "tools": [t["name"] for t in tools], + "max_iterations": max_iterations, + }) + already_retrying = any( + "STOP." in (m.get("content", "") if isinstance(m, dict) else "") + for m in messages + ) + if not already_retrying: + return "no record", list(messages) + await tool_executor["add_temporal_edge"]() + return "recorded.", [] + + llm.tool_call_loop = fake_tool_call_loop + await agent.run("build timeline") + + assert len(call_kwargs) == 2 + first_call, retry_call = call_kwargs + # First call: full tool surface, default iter cap. + assert "list_directory" in first_call["tools"] + # Retry call: investigation tool dropped, mandatory + read-only kept. + assert "list_directory" not in retry_call["tools"] + assert "add_temporal_edge" in retry_call["tools"] + assert "list_phenomena" in retry_call["tools"] + # Iter cap on the retry is now generous — 10 was empirically too tight + # because grounding-rejected calls burn iterations. + assert retry_call["max_iterations"] >= 30 + # ---- terminal_tools: real LLMClient.tool_call_loop short-circuit ----- @pytest.mark.asyncio