Don't cache agent tools during a run (#803)
### Summary: Towards #767. We were caching the list of tools for an agent, so if you did `agent.tools.append(...)` from a tool call, the next call to the model wouldn't include the new tool. THis is a bug. ### Test Plan: Unit tests. Note that now MCP tools are listed each time the agent runs (users can still cache the `list_tools` however).
This commit is contained in:
parent
775d3e237e
commit
d4c7a23e1d
11 changed files with 143 additions and 48 deletions
|
|
@ -15,6 +15,7 @@ from .util._pretty_print import pretty_print_run_error_details
|
|||
@dataclass
|
||||
class RunErrorDetails:
|
||||
"""Data collected from an agent run when an exception occurs."""
|
||||
|
||||
input: str | list[TResponseInputItem]
|
||||
new_items: list[RunItem]
|
||||
raw_responses: list[ModelResponse]
|
||||
|
|
@ -29,6 +30,7 @@ class RunErrorDetails:
|
|||
|
||||
class AgentsException(Exception):
|
||||
"""Base class for all exceptions in the Agents SDK."""
|
||||
|
||||
run_data: RunErrorDetails | None
|
||||
|
||||
def __init__(self, *args: object) -> None:
|
||||
|
|
|
|||
|
|
@ -110,12 +110,14 @@ class LitellmModel(Model):
|
|||
input_tokens_details=InputTokensDetails(
|
||||
cached_tokens=getattr(
|
||||
response_usage.prompt_tokens_details, "cached_tokens", 0
|
||||
) or 0
|
||||
)
|
||||
or 0
|
||||
),
|
||||
output_tokens_details=OutputTokensDetails(
|
||||
reasoning_tokens=getattr(
|
||||
response_usage.completion_tokens_details, "reasoning_tokens", 0
|
||||
) or 0
|
||||
)
|
||||
or 0
|
||||
),
|
||||
)
|
||||
if response.usage
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback | None
|
||||
GetSessionIdCallback | None,
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
|
|
@ -243,7 +243,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback | None
|
||||
GetSessionIdCallback | None,
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
|
|
@ -314,7 +314,7 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback | None
|
||||
GetSessionIdCallback | None,
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
|
|
@ -394,7 +394,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback | None
|
||||
GetSessionIdCallback | None,
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
|
|
@ -403,7 +403,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|||
headers=self.params.get("headers", None),
|
||||
timeout=self.params.get("timeout", timedelta(seconds=30)),
|
||||
sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)),
|
||||
terminate_on_close=self.params.get("terminate_on_close", True)
|
||||
terminate_on_close=self.params.get("terminate_on_close", True),
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -274,4 +274,3 @@ class RunResultStreaming(RunResultBase):
|
|||
|
||||
def __str__(self) -> str:
|
||||
return pretty_print_run_result_streaming(self)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -182,6 +181,8 @@ class Runner:
|
|||
|
||||
try:
|
||||
while True:
|
||||
all_tools = await cls._get_all_tools(current_agent)
|
||||
|
||||
# Start an agent span if we don't have one. This span is ended if the current
|
||||
# agent changes, or if the agent loop ends.
|
||||
if current_span is None:
|
||||
|
|
@ -197,8 +198,6 @@ class Runner:
|
|||
output_type=output_type_name,
|
||||
)
|
||||
current_span.start(mark_as_current=True)
|
||||
|
||||
all_tools = await cls._get_all_tools(current_agent)
|
||||
current_span.span_data.tools = [t.name for t in all_tools]
|
||||
|
||||
current_turn += 1
|
||||
|
|
@ -210,9 +209,7 @@ class Runner:
|
|||
data={"max_turns": max_turns},
|
||||
),
|
||||
)
|
||||
raise MaxTurnsExceeded(
|
||||
f"Max turns ({max_turns}) exceeded"
|
||||
)
|
||||
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
|
||||
|
||||
logger.debug(
|
||||
f"Running agent {current_agent.name} (turn {current_turn})",
|
||||
|
|
@ -295,7 +292,7 @@ class Runner:
|
|||
last_agent=current_agent,
|
||||
context_wrapper=context_wrapper,
|
||||
input_guardrail_results=input_guardrail_results,
|
||||
output_guardrail_results=[]
|
||||
output_guardrail_results=[],
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
|
|
@ -528,6 +525,8 @@ class Runner:
|
|||
if streamed_result.is_complete:
|
||||
break
|
||||
|
||||
all_tools = await cls._get_all_tools(current_agent)
|
||||
|
||||
# Start an agent span if we don't have one. This span is ended if the current
|
||||
# agent changes, or if the agent loop ends.
|
||||
if current_span is None:
|
||||
|
|
@ -543,8 +542,6 @@ class Runner:
|
|||
output_type=output_type_name,
|
||||
)
|
||||
current_span.start(mark_as_current=True)
|
||||
|
||||
all_tools = await cls._get_all_tools(current_agent)
|
||||
tool_names = [t.name for t in all_tools]
|
||||
current_span.span_data.tools = tool_names
|
||||
current_turn += 1
|
||||
|
|
|
|||
|
|
@ -17,9 +17,11 @@ DEFAULT_TTS_BUFFER_SIZE = 120
|
|||
TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
|
||||
"""Exportable type for the TTSModelSettings voice enum"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSModelSettings:
|
||||
"""Settings for a TTS model."""
|
||||
|
||||
voice: TTSVoice | None = None
|
||||
"""
|
||||
The voice to use for the TTS model. If not provided, the default voice for the respective model
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ async def test_mcp_tracing():
|
|||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
|
|
@ -53,10 +57,6 @@ async def test_mcp_tracing():
|
|||
"output_type": "str",
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"data": {
|
||||
|
|
@ -66,8 +66,12 @@ async def test_mcp_tracing():
|
|||
"mcp_data": {"server": "fake_mcp_server"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
|
@ -100,6 +104,13 @@ async def test_mcp_tracing():
|
|||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {
|
||||
"server": "fake_mcp_server",
|
||||
"result": ["test_tool_1", "test_tool_2"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
|
|
@ -109,13 +120,6 @@ async def test_mcp_tracing():
|
|||
"output_type": "str",
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {
|
||||
"server": "fake_mcp_server",
|
||||
"result": ["test_tool_1", "test_tool_2"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"data": {
|
||||
|
|
@ -133,8 +137,15 @@ async def test_mcp_tracing():
|
|||
"mcp_data": {"server": "fake_mcp_server"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {
|
||||
"server": "fake_mcp_server",
|
||||
"result": ["test_tool_1", "test_tool_2"],
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
|
@ -165,6 +176,13 @@ async def test_mcp_tracing():
|
|||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {
|
||||
"server": "fake_mcp_server",
|
||||
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
|
|
@ -174,13 +192,6 @@ async def test_mcp_tracing():
|
|||
"output_type": "str",
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {
|
||||
"server": "fake_mcp_server",
|
||||
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"data": {
|
||||
|
|
@ -190,8 +201,15 @@ async def test_mcp_tracing():
|
|||
"mcp_data": {"server": "fake_mcp_server"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "mcp_tools",
|
||||
"data": {
|
||||
"server": "fake_mcp_server",
|
||||
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ async def test_extra_body_is_forwarded(monkeypatch):
|
|||
|
||||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
|
||||
settings = ModelSettings(
|
||||
temperature=0.1,
|
||||
extra_body={"cached_content": "some_cache", "foo": 123}
|
||||
temperature=0.1, extra_body={"cached_content": "some_cache", "foo": 123}
|
||||
)
|
||||
model = LitellmModel(model="test-model")
|
||||
|
||||
|
|
|
|||
|
|
@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn():
|
|||
pass
|
||||
|
||||
assert model.last_turn_args.get("previous_response_id") == "resp-stream-test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_tool_addition_run() -> None:
|
||||
"""Test that tools can be added to an agent during a run."""
|
||||
model = FakeModel()
|
||||
|
||||
executed: dict[str, bool] = {"called": False}
|
||||
|
||||
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
|
||||
|
||||
@function_tool(name_override="tool2")
|
||||
def tool2() -> str:
|
||||
executed["called"] = True
|
||||
return "result2"
|
||||
|
||||
@function_tool(name_override="add_tool")
|
||||
async def add_tool() -> str:
|
||||
agent.tools.append(tool2)
|
||||
return "added"
|
||||
|
||||
agent.tools.append(add_tool)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
[get_function_tool_call("add_tool", json.dumps({}))],
|
||||
[get_function_tool_call("tool2", json.dumps({}))],
|
||||
[get_text_message("done")],
|
||||
]
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="start")
|
||||
|
||||
assert executed["called"] is True
|
||||
assert result.final_output == "done"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from agents import (
|
|||
RunContextWrapper,
|
||||
Runner,
|
||||
UserError,
|
||||
function_tool,
|
||||
handoff,
|
||||
)
|
||||
from agents.items import RunItem
|
||||
|
|
@ -684,3 +685,39 @@ async def test_streaming_events():
|
|||
assert len(agent_data) == 2, "should have 2 agent updated events"
|
||||
assert agent_data[0].new_agent == agent_2, "should have started with agent_2"
|
||||
assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_tool_addition_run_streamed() -> None:
|
||||
model = FakeModel()
|
||||
|
||||
executed: dict[str, bool] = {"called": False}
|
||||
|
||||
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
|
||||
|
||||
@function_tool(name_override="tool2")
|
||||
def tool2() -> str:
|
||||
executed["called"] = True
|
||||
return "result2"
|
||||
|
||||
@function_tool(name_override="add_tool")
|
||||
async def add_tool() -> str:
|
||||
agent.tools.append(tool2)
|
||||
return "added"
|
||||
|
||||
agent.tools.append(add_tool)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
[get_function_tool_call("add_tool", json.dumps({}))],
|
||||
[get_function_tool_call("tool2", json.dumps({}))],
|
||||
[get_text_message("done")],
|
||||
]
|
||||
)
|
||||
|
||||
result = Runner.run_streamed(agent, input="start")
|
||||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
assert executed["called"] is True
|
||||
assert result.final_output == "done"
|
||||
|
|
|
|||
|
|
@ -12,10 +12,12 @@ from .test_responses import get_function_tool, get_function_tool_call, get_text_
|
|||
async def test_run_error_includes_data():
|
||||
model = FakeModel()
|
||||
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
|
||||
model.add_multiple_turn_outputs([
|
||||
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
|
||||
[get_text_message("done")],
|
||||
])
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
|
||||
[get_text_message("done")],
|
||||
]
|
||||
)
|
||||
with pytest.raises(MaxTurnsExceeded) as exc:
|
||||
await Runner.run(agent, input="hello", max_turns=1)
|
||||
data = exc.value.run_data
|
||||
|
|
@ -29,10 +31,12 @@ async def test_run_error_includes_data():
|
|||
async def test_streamed_run_error_includes_data():
|
||||
model = FakeModel()
|
||||
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
|
||||
model.add_multiple_turn_outputs([
|
||||
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
|
||||
[get_text_message("done")],
|
||||
])
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
|
||||
[get_text_message("done")],
|
||||
]
|
||||
)
|
||||
result = Runner.run_streamed(agent, input="hello", max_turns=1)
|
||||
with pytest.raises(MaxTurnsExceeded) as exc:
|
||||
async for _ in result.stream_events():
|
||||
|
|
|
|||
Loading…
Reference in a new issue