Don't cache agent tools during a run (#803)

### Summary:
Towards #767. We were caching the list of tools for an agent, so if you
did `agent.tools.append(...)` from a tool call, the next call to the
model wouldn't include the new tool. THis is a bug.

### Test Plan:
Unit tests. Note that now MCP tools are listed each time the agent runs
(users can still cache the `list_tools` however).
This commit is contained in:
Rohan Mehta 2025-06-02 14:49:16 -04:00 committed by GitHub
parent 775d3e237e
commit d4c7a23e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 143 additions and 48 deletions

View file

@ -15,6 +15,7 @@ from .util._pretty_print import pretty_print_run_error_details
@dataclass
class RunErrorDetails:
"""Data collected from an agent run when an exception occurs."""
input: str | list[TResponseInputItem]
new_items: list[RunItem]
raw_responses: list[ModelResponse]
@ -29,6 +30,7 @@ class RunErrorDetails:
class AgentsException(Exception):
"""Base class for all exceptions in the Agents SDK."""
run_data: RunErrorDetails | None
def __init__(self, *args: object) -> None:

View file

@ -110,12 +110,14 @@ class LitellmModel(Model):
input_tokens_details=InputTokensDetails(
cached_tokens=getattr(
response_usage.prompt_tokens_details, "cached_tokens", 0
) or 0
)
or 0
),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr(
response_usage.completion_tokens_details, "reasoning_tokens", 0
) or 0
)
or 0
),
)
if response.usage

View file

@ -88,7 +88,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None
GetSessionIdCallback | None,
]
]:
"""Create the streams for the server."""
@ -243,7 +243,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None
GetSessionIdCallback | None,
]
]:
"""Create the streams for the server."""
@ -314,7 +314,7 @@ class MCPServerSse(_MCPServerWithClientSession):
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None
GetSessionIdCallback | None,
]
]:
"""Create the streams for the server."""
@ -394,7 +394,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None
GetSessionIdCallback | None,
]
]:
"""Create the streams for the server."""
@ -403,7 +403,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", timedelta(seconds=30)),
sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)),
terminate_on_close=self.params.get("terminate_on_close", True)
terminate_on_close=self.params.get("terminate_on_close", True),
)
@property

View file

@ -274,4 +274,3 @@ class RunResultStreaming(RunResultBase):
def __str__(self) -> str:
return pretty_print_run_result_streaming(self)

View file

@ -1,4 +1,3 @@
from __future__ import annotations
import asyncio
@ -182,6 +181,8 @@ class Runner:
try:
while True:
all_tools = await cls._get_all_tools(current_agent)
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
if current_span is None:
@ -197,8 +198,6 @@ class Runner:
output_type=output_type_name,
)
current_span.start(mark_as_current=True)
all_tools = await cls._get_all_tools(current_agent)
current_span.span_data.tools = [t.name for t in all_tools]
current_turn += 1
@ -210,9 +209,7 @@ class Runner:
data={"max_turns": max_turns},
),
)
raise MaxTurnsExceeded(
f"Max turns ({max_turns}) exceeded"
)
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
logger.debug(
f"Running agent {current_agent.name} (turn {current_turn})",
@ -295,7 +292,7 @@ class Runner:
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=[]
output_guardrail_results=[],
)
raise
finally:
@ -528,6 +525,8 @@ class Runner:
if streamed_result.is_complete:
break
all_tools = await cls._get_all_tools(current_agent)
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
if current_span is None:
@ -543,8 +542,6 @@ class Runner:
output_type=output_type_name,
)
current_span.start(mark_as_current=True)
all_tools = await cls._get_all_tools(current_agent)
tool_names = [t.name for t in all_tools]
current_span.span_data.tools = tool_names
current_turn += 1

View file

@ -17,9 +17,11 @@ DEFAULT_TTS_BUFFER_SIZE = 120
TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
"""Exportable type for the TTSModelSettings voice enum"""
@dataclass
class TTSModelSettings:
"""Settings for a TTS model."""
voice: TTSVoice | None = None
"""
The voice to use for the TTS model. If not provided, the default voice for the respective model

View file

@ -44,6 +44,10 @@ async def test_mcp_tracing():
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "mcp_tools",
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
},
{
"type": "agent",
"data": {
@ -53,10 +57,6 @@ async def test_mcp_tracing():
"output_type": "str",
},
"children": [
{
"type": "mcp_tools",
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
},
{
"type": "function",
"data": {
@ -66,8 +66,12 @@ async def test_mcp_tracing():
"mcp_data": {"server": "fake_mcp_server"},
},
},
{
"type": "mcp_tools",
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
},
],
}
},
],
}
]
@ -100,6 +104,13 @@ async def test_mcp_tracing():
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2"],
},
},
{
"type": "agent",
"data": {
@ -109,13 +120,6 @@ async def test_mcp_tracing():
"output_type": "str",
},
"children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2"],
},
},
{
"type": "function",
"data": {
@ -133,8 +137,15 @@ async def test_mcp_tracing():
"mcp_data": {"server": "fake_mcp_server"},
},
},
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2"],
},
},
],
}
},
],
}
]
@ -165,6 +176,13 @@ async def test_mcp_tracing():
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
},
},
{
"type": "agent",
"data": {
@ -174,13 +192,6 @@ async def test_mcp_tracing():
"output_type": "str",
},
"children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
},
},
{
"type": "function",
"data": {
@ -190,8 +201,15 @@ async def test_mcp_tracing():
"mcp_data": {"server": "fake_mcp_server"},
},
},
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
},
},
],
}
},
],
}
]

View file

@ -26,8 +26,7 @@ async def test_extra_body_is_forwarded(monkeypatch):
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
settings = ModelSettings(
temperature=0.1,
extra_body={"cached_content": "some_cache", "foo": 123}
temperature=0.1, extra_body={"cached_content": "some_cache", "foo": 123}
)
model = LitellmModel(model="test-model")

View file

@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn():
pass
assert model.last_turn_args.get("previous_response_id") == "resp-stream-test"
@pytest.mark.asyncio
async def test_dynamic_tool_addition_run() -> None:
"""Test that tools can be added to an agent during a run."""
model = FakeModel()
executed: dict[str, bool] = {"called": False}
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
@function_tool(name_override="tool2")
def tool2() -> str:
executed["called"] = True
return "result2"
@function_tool(name_override="add_tool")
async def add_tool() -> str:
agent.tools.append(tool2)
return "added"
agent.tools.append(add_tool)
model.add_multiple_turn_outputs(
[
[get_function_tool_call("add_tool", json.dumps({}))],
[get_function_tool_call("tool2", json.dumps({}))],
[get_text_message("done")],
]
)
result = await Runner.run(agent, input="start")
assert executed["called"] is True
assert result.final_output == "done"

View file

@ -18,6 +18,7 @@ from agents import (
RunContextWrapper,
Runner,
UserError,
function_tool,
handoff,
)
from agents.items import RunItem
@ -684,3 +685,39 @@ async def test_streaming_events():
assert len(agent_data) == 2, "should have 2 agent updated events"
assert agent_data[0].new_agent == agent_2, "should have started with agent_2"
assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1"
@pytest.mark.asyncio
async def test_dynamic_tool_addition_run_streamed() -> None:
model = FakeModel()
executed: dict[str, bool] = {"called": False}
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
@function_tool(name_override="tool2")
def tool2() -> str:
executed["called"] = True
return "result2"
@function_tool(name_override="add_tool")
async def add_tool() -> str:
agent.tools.append(tool2)
return "added"
agent.tools.append(add_tool)
model.add_multiple_turn_outputs(
[
[get_function_tool_call("add_tool", json.dumps({}))],
[get_function_tool_call("tool2", json.dumps({}))],
[get_text_message("done")],
]
)
result = Runner.run_streamed(agent, input="start")
async for _ in result.stream_events():
pass
assert executed["called"] is True
assert result.final_output == "done"

View file

@ -12,10 +12,12 @@ from .test_responses import get_function_tool, get_function_tool_call, get_text_
async def test_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
model.add_multiple_turn_outputs(
[
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
]
)
with pytest.raises(MaxTurnsExceeded) as exc:
await Runner.run(agent, input="hello", max_turns=1)
data = exc.value.run_data
@ -29,10 +31,12 @@ async def test_run_error_includes_data():
async def test_streamed_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
model.add_multiple_turn_outputs(
[
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
]
)
result = Runner.run_streamed(agent, input="hello", max_turns=1)
with pytest.raises(MaxTurnsExceeded) as exc:
async for _ in result.stream_events():