Don't cache agent tools during a run (#803)

### Summary:
Towards #767. We were caching the list of tools for an agent, so if you
did `agent.tools.append(...)` from a tool call, the next call to the
model wouldn't include the new tool. THis is a bug.

### Test Plan:
Unit tests. Note that now MCP tools are listed each time the agent runs
(users can still cache the `list_tools` however).
This commit is contained in:
Rohan Mehta 2025-06-02 14:49:16 -04:00 committed by GitHub
parent 775d3e237e
commit d4c7a23e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 143 additions and 48 deletions

View file

@ -15,6 +15,7 @@ from .util._pretty_print import pretty_print_run_error_details
@dataclass @dataclass
class RunErrorDetails: class RunErrorDetails:
"""Data collected from an agent run when an exception occurs.""" """Data collected from an agent run when an exception occurs."""
input: str | list[TResponseInputItem] input: str | list[TResponseInputItem]
new_items: list[RunItem] new_items: list[RunItem]
raw_responses: list[ModelResponse] raw_responses: list[ModelResponse]
@ -29,6 +30,7 @@ class RunErrorDetails:
class AgentsException(Exception): class AgentsException(Exception):
"""Base class for all exceptions in the Agents SDK.""" """Base class for all exceptions in the Agents SDK."""
run_data: RunErrorDetails | None run_data: RunErrorDetails | None
def __init__(self, *args: object) -> None: def __init__(self, *args: object) -> None:

View file

@ -110,12 +110,14 @@ class LitellmModel(Model):
input_tokens_details=InputTokensDetails( input_tokens_details=InputTokensDetails(
cached_tokens=getattr( cached_tokens=getattr(
response_usage.prompt_tokens_details, "cached_tokens", 0 response_usage.prompt_tokens_details, "cached_tokens", 0
) or 0 )
or 0
), ),
output_tokens_details=OutputTokensDetails( output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr( reasoning_tokens=getattr(
response_usage.completion_tokens_details, "reasoning_tokens", 0 response_usage.completion_tokens_details, "reasoning_tokens", 0
) or 0 )
or 0
), ),
) )
if response.usage if response.usage

View file

@ -88,7 +88,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
tuple[ tuple[
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage], MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None GetSessionIdCallback | None,
] ]
]: ]:
"""Create the streams for the server.""" """Create the streams for the server."""
@ -243,7 +243,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
tuple[ tuple[
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage], MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None GetSessionIdCallback | None,
] ]
]: ]:
"""Create the streams for the server.""" """Create the streams for the server."""
@ -314,7 +314,7 @@ class MCPServerSse(_MCPServerWithClientSession):
tuple[ tuple[
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage], MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None GetSessionIdCallback | None,
] ]
]: ]:
"""Create the streams for the server.""" """Create the streams for the server."""
@ -394,7 +394,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
tuple[ tuple[
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage], MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback | None GetSessionIdCallback | None,
] ]
]: ]:
"""Create the streams for the server.""" """Create the streams for the server."""
@ -403,7 +403,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
headers=self.params.get("headers", None), headers=self.params.get("headers", None),
timeout=self.params.get("timeout", timedelta(seconds=30)), timeout=self.params.get("timeout", timedelta(seconds=30)),
sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)), sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)),
terminate_on_close=self.params.get("terminate_on_close", True) terminate_on_close=self.params.get("terminate_on_close", True),
) )
@property @property

View file

@ -274,4 +274,3 @@ class RunResultStreaming(RunResultBase):
def __str__(self) -> str: def __str__(self) -> str:
return pretty_print_run_result_streaming(self) return pretty_print_run_result_streaming(self)

View file

@ -1,4 +1,3 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -182,6 +181,8 @@ class Runner:
try: try:
while True: while True:
all_tools = await cls._get_all_tools(current_agent)
# Start an agent span if we don't have one. This span is ended if the current # Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends. # agent changes, or if the agent loop ends.
if current_span is None: if current_span is None:
@ -197,8 +198,6 @@ class Runner:
output_type=output_type_name, output_type=output_type_name,
) )
current_span.start(mark_as_current=True) current_span.start(mark_as_current=True)
all_tools = await cls._get_all_tools(current_agent)
current_span.span_data.tools = [t.name for t in all_tools] current_span.span_data.tools = [t.name for t in all_tools]
current_turn += 1 current_turn += 1
@ -210,9 +209,7 @@ class Runner:
data={"max_turns": max_turns}, data={"max_turns": max_turns},
), ),
) )
raise MaxTurnsExceeded( raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
f"Max turns ({max_turns}) exceeded"
)
logger.debug( logger.debug(
f"Running agent {current_agent.name} (turn {current_turn})", f"Running agent {current_agent.name} (turn {current_turn})",
@ -295,7 +292,7 @@ class Runner:
last_agent=current_agent, last_agent=current_agent,
context_wrapper=context_wrapper, context_wrapper=context_wrapper,
input_guardrail_results=input_guardrail_results, input_guardrail_results=input_guardrail_results,
output_guardrail_results=[] output_guardrail_results=[],
) )
raise raise
finally: finally:
@ -528,6 +525,8 @@ class Runner:
if streamed_result.is_complete: if streamed_result.is_complete:
break break
all_tools = await cls._get_all_tools(current_agent)
# Start an agent span if we don't have one. This span is ended if the current # Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends. # agent changes, or if the agent loop ends.
if current_span is None: if current_span is None:
@ -543,8 +542,6 @@ class Runner:
output_type=output_type_name, output_type=output_type_name,
) )
current_span.start(mark_as_current=True) current_span.start(mark_as_current=True)
all_tools = await cls._get_all_tools(current_agent)
tool_names = [t.name for t in all_tools] tool_names = [t.name for t in all_tools]
current_span.span_data.tools = tool_names current_span.span_data.tools = tool_names
current_turn += 1 current_turn += 1

View file

@ -17,9 +17,11 @@ DEFAULT_TTS_BUFFER_SIZE = 120
TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
"""Exportable type for the TTSModelSettings voice enum""" """Exportable type for the TTSModelSettings voice enum"""
@dataclass @dataclass
class TTSModelSettings: class TTSModelSettings:
"""Settings for a TTS model.""" """Settings for a TTS model."""
voice: TTSVoice | None = None voice: TTSVoice | None = None
""" """
The voice to use for the TTS model. If not provided, the default voice for the respective model The voice to use for the TTS model. If not provided, the default voice for the respective model

View file

@ -44,6 +44,10 @@ async def test_mcp_tracing():
{ {
"workflow_name": "Agent workflow", "workflow_name": "Agent workflow",
"children": [ "children": [
{
"type": "mcp_tools",
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
},
{ {
"type": "agent", "type": "agent",
"data": { "data": {
@ -53,10 +57,6 @@ async def test_mcp_tracing():
"output_type": "str", "output_type": "str",
}, },
"children": [ "children": [
{
"type": "mcp_tools",
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
},
{ {
"type": "function", "type": "function",
"data": { "data": {
@ -66,8 +66,12 @@ async def test_mcp_tracing():
"mcp_data": {"server": "fake_mcp_server"}, "mcp_data": {"server": "fake_mcp_server"},
}, },
}, },
{
"type": "mcp_tools",
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
},
], ],
} },
], ],
} }
] ]
@ -100,6 +104,13 @@ async def test_mcp_tracing():
{ {
"workflow_name": "Agent workflow", "workflow_name": "Agent workflow",
"children": [ "children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2"],
},
},
{ {
"type": "agent", "type": "agent",
"data": { "data": {
@ -109,13 +120,6 @@ async def test_mcp_tracing():
"output_type": "str", "output_type": "str",
}, },
"children": [ "children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2"],
},
},
{ {
"type": "function", "type": "function",
"data": { "data": {
@ -133,8 +137,15 @@ async def test_mcp_tracing():
"mcp_data": {"server": "fake_mcp_server"}, "mcp_data": {"server": "fake_mcp_server"},
}, },
}, },
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2"],
},
},
], ],
} },
], ],
} }
] ]
@ -165,6 +176,13 @@ async def test_mcp_tracing():
{ {
"workflow_name": "Agent workflow", "workflow_name": "Agent workflow",
"children": [ "children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
},
},
{ {
"type": "agent", "type": "agent",
"data": { "data": {
@ -174,13 +192,6 @@ async def test_mcp_tracing():
"output_type": "str", "output_type": "str",
}, },
"children": [ "children": [
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
},
},
{ {
"type": "function", "type": "function",
"data": { "data": {
@ -190,8 +201,15 @@ async def test_mcp_tracing():
"mcp_data": {"server": "fake_mcp_server"}, "mcp_data": {"server": "fake_mcp_server"},
}, },
}, },
{
"type": "mcp_tools",
"data": {
"server": "fake_mcp_server",
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
},
},
], ],
} },
], ],
} }
] ]

View file

@ -26,8 +26,7 @@ async def test_extra_body_is_forwarded(monkeypatch):
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
settings = ModelSettings( settings = ModelSettings(
temperature=0.1, temperature=0.1, extra_body={"cached_content": "some_cache", "foo": 123}
extra_body={"cached_content": "some_cache", "foo": 123}
) )
model = LitellmModel(model="test-model") model = LitellmModel(model="test-model")

View file

@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn():
pass pass
assert model.last_turn_args.get("previous_response_id") == "resp-stream-test" assert model.last_turn_args.get("previous_response_id") == "resp-stream-test"
@pytest.mark.asyncio
async def test_dynamic_tool_addition_run() -> None:
"""Test that tools can be added to an agent during a run."""
model = FakeModel()
executed: dict[str, bool] = {"called": False}
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
@function_tool(name_override="tool2")
def tool2() -> str:
executed["called"] = True
return "result2"
@function_tool(name_override="add_tool")
async def add_tool() -> str:
agent.tools.append(tool2)
return "added"
agent.tools.append(add_tool)
model.add_multiple_turn_outputs(
[
[get_function_tool_call("add_tool", json.dumps({}))],
[get_function_tool_call("tool2", json.dumps({}))],
[get_text_message("done")],
]
)
result = await Runner.run(agent, input="start")
assert executed["called"] is True
assert result.final_output == "done"

View file

@ -18,6 +18,7 @@ from agents import (
RunContextWrapper, RunContextWrapper,
Runner, Runner,
UserError, UserError,
function_tool,
handoff, handoff,
) )
from agents.items import RunItem from agents.items import RunItem
@ -684,3 +685,39 @@ async def test_streaming_events():
assert len(agent_data) == 2, "should have 2 agent updated events" assert len(agent_data) == 2, "should have 2 agent updated events"
assert agent_data[0].new_agent == agent_2, "should have started with agent_2" assert agent_data[0].new_agent == agent_2, "should have started with agent_2"
assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1" assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1"
@pytest.mark.asyncio
async def test_dynamic_tool_addition_run_streamed() -> None:
model = FakeModel()
executed: dict[str, bool] = {"called": False}
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
@function_tool(name_override="tool2")
def tool2() -> str:
executed["called"] = True
return "result2"
@function_tool(name_override="add_tool")
async def add_tool() -> str:
agent.tools.append(tool2)
return "added"
agent.tools.append(add_tool)
model.add_multiple_turn_outputs(
[
[get_function_tool_call("add_tool", json.dumps({}))],
[get_function_tool_call("tool2", json.dumps({}))],
[get_text_message("done")],
]
)
result = Runner.run_streamed(agent, input="start")
async for _ in result.stream_events():
pass
assert executed["called"] is True
assert result.final_output == "done"

View file

@ -12,10 +12,12 @@ from .test_responses import get_function_tool, get_function_tool_call, get_text_
async def test_run_error_includes_data(): async def test_run_error_includes_data():
model = FakeModel() model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([ model.add_multiple_turn_outputs(
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], [
[get_text_message("done")], [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
]) [get_text_message("done")],
]
)
with pytest.raises(MaxTurnsExceeded) as exc: with pytest.raises(MaxTurnsExceeded) as exc:
await Runner.run(agent, input="hello", max_turns=1) await Runner.run(agent, input="hello", max_turns=1)
data = exc.value.run_data data = exc.value.run_data
@ -29,10 +31,12 @@ async def test_run_error_includes_data():
async def test_streamed_run_error_includes_data(): async def test_streamed_run_error_includes_data():
model = FakeModel() model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([ model.add_multiple_turn_outputs(
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], [
[get_text_message("done")], [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
]) [get_text_message("done")],
]
)
result = Runner.run_streamed(agent, input="hello", max_turns=1) result = Runner.run_streamed(agent, input="hello", max_turns=1)
with pytest.raises(MaxTurnsExceeded) as exc: with pytest.raises(MaxTurnsExceeded) as exc:
async for _ in result.stream_events(): async for _ in result.stream_events():