[2/n] Add MCP support to Runner

### Summary:
This enables users to **use** MCP inside the SDK.
1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]`
2. When an agent runs, we look up its MCP tools and add them to the list of tools.
3. When a tool call occurs, we call the relevant MCP server.

Notes:
1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc.
2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup.

### Test Plan:
See unit tests. Also has an end to end example next PR.
This commit is contained in:
Rohan Mehta 2025-03-24 15:08:02 -04:00
parent 300e12c198
commit 68c800d2a3
14 changed files with 662 additions and 35 deletions

View file

@ -50,7 +50,7 @@ from .logger import logger
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool, FunctionToolResult
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
from .tracing import (
SpanError,
Trace,
@ -301,6 +301,7 @@ class RunImpl:
cls,
*,
agent: Agent[Any],
all_tools: list[Tool],
response: ModelResponse,
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
@ -312,8 +313,8 @@ class RunImpl:
computer_actions = []
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
for output in response.output:
if isinstance(output, ResponseOutputMessage):

View file

@ -12,6 +12,7 @@ from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
from .logger import logger
from .mcp import MCPUtil
from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
@ -21,6 +22,7 @@ from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .lifecycle import AgentHooks
from .mcp import MCPServer
from .result import RunResult
@ -107,6 +109,16 @@ class Agent(Generic[TContext]):
tools: list[Tool] = field(default_factory=list)
"""A list of tools that the agent can use."""
mcp_servers: list[MCPServer] = field(default_factory=list)
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
the agent can use. Every time the agent runs, it will include tools from these servers in the
list of available tools.
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
longer needed.
"""
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
"""A list of checks that run in parallel to the agent's execution, before generating a
response. Runs only if the agent is the first agent in the chain.
@ -205,3 +217,11 @@ class Agent(Generic[TContext]):
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
return None
async def get_mcp_tools(self) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
return await MCPUtil.get_all_function_tools(self.mcp_servers)
async def get_all_tools(self) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools

View file

@ -7,6 +7,8 @@ from typing import Any, cast
from openai.types.responses import ResponseCompletedEvent
from agents.tool import Tool
from ._run_impl import (
NextStepFinalOutput,
NextStepHandoff,
@ -177,7 +179,8 @@ class Runner:
# agent changes, or if the agent loop ends.
if current_span is None:
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
tool_names = [t.name for t in current_agent.tools]
all_tools = await cls._get_all_tools(current_agent)
tool_names = [t.name for t in all_tools]
if output_schema := cls._get_output_schema(current_agent):
output_type_name = output_schema.output_type_name()
else:
@ -217,6 +220,7 @@ class Runner:
),
cls._run_single_turn(
agent=current_agent,
all_tools=all_tools,
original_input=original_input,
generated_items=generated_items,
hooks=hooks,
@ -228,6 +232,7 @@ class Runner:
else:
turn_result = await cls._run_single_turn(
agent=current_agent,
all_tools=all_tools,
original_input=original_input,
generated_items=generated_items,
hooks=hooks,
@ -627,7 +632,7 @@ class Runner:
system_prompt = await agent.get_system_prompt(context_wrapper)
handoffs = cls._get_handoffs(agent)
all_tools = await cls._get_all_tools(agent)
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
final_response: ModelResponse | None = None
@ -640,7 +645,7 @@ class Runner:
system_prompt,
input,
model_settings,
agent.tools,
all_tools,
output_schema,
handoffs,
get_model_tracing_impl(
@ -677,6 +682,7 @@ class Runner:
pre_step_items=streamed_result.new_items,
new_response=final_response,
output_schema=output_schema,
all_tools=all_tools,
handoffs=handoffs,
hooks=hooks,
context_wrapper=context_wrapper,
@ -691,6 +697,7 @@ class Runner:
cls,
*,
agent: Agent[TContext],
all_tools: list[Tool],
original_input: str | list[TResponseInputItem],
generated_items: list[RunItem],
hooks: RunHooks[TContext],
@ -721,6 +728,7 @@ class Runner:
system_prompt,
input,
output_schema,
all_tools,
handoffs,
context_wrapper,
run_config,
@ -732,6 +740,7 @@ class Runner:
pre_step_items=generated_items,
new_response=new_response,
output_schema=output_schema,
all_tools=all_tools,
handoffs=handoffs,
hooks=hooks,
context_wrapper=context_wrapper,
@ -743,6 +752,7 @@ class Runner:
cls,
*,
agent: Agent[TContext],
all_tools: list[Tool],
original_input: str | list[TResponseInputItem],
pre_step_items: list[RunItem],
new_response: ModelResponse,
@ -754,6 +764,7 @@ class Runner:
) -> SingleStepResult:
processed_response = RunImpl.process_model_response(
agent=agent,
all_tools=all_tools,
response=new_response,
output_schema=output_schema,
handoffs=handoffs,
@ -853,6 +864,7 @@ class Runner:
system_prompt: str | None,
input: list[TResponseInputItem],
output_schema: AgentOutputSchema | None,
all_tools: list[Tool],
handoffs: list[Handoff],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
@ -863,7 +875,7 @@ class Runner:
system_instructions=system_prompt,
input=input,
model_settings=model_settings,
tools=agent.tools,
tools=all_tools,
output_schema=output_schema,
handoffs=handoffs,
tracing=get_model_tracing_impl(
@ -892,6 +904,10 @@ class Runner:
handoffs.append(handoff(handoff_item))
return handoffs
@classmethod
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
return await agent.get_all_tools()
@classmethod
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
if isinstance(run_config.model, Model):

0
tests/mcp/__init__.py Normal file
View file

11
tests/mcp/conftest.py Normal file
View file

@ -0,0 +1,11 @@
import os
import sys
# Skip MCP tests on Python 3.9
def pytest_ignore_collect(collection_path, config):
if sys.version_info[:2] == (3, 9):
this_dir = os.path.dirname(__file__)
if str(collection_path).startswith(this_dir):
return True

54
tests/mcp/helpers.py Normal file
View file

@ -0,0 +1,54 @@
import json
import shutil
from typing import Any
from mcp import Tool as MCPTool
from mcp.types import CallToolResult, TextContent
from agents.mcp import MCPServer
tee = shutil.which("tee") or ""
assert tee, "tee not found"
# Added dummy stream classes for patching stdio_client to avoid real I/O during tests
class DummyStream:
async def send(self, msg):
pass
async def receive(self):
raise Exception("Dummy receive not implemented")
class DummyStreamsContextManager:
async def __aenter__(self):
return (DummyStream(), DummyStream())
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
class FakeMCPServer(MCPServer):
def __init__(self, tools: list[MCPTool] | None = None):
self.tools: list[MCPTool] = tools or []
self.tool_calls: list[str] = []
self.tool_results: list[str] = []
def add_tool(self, name: str, input_schema: dict[str, Any]):
self.tools.append(MCPTool(name=name, inputSchema=input_schema))
async def connect(self):
pass
async def cleanup(self):
pass
async def list_tools(self):
return self.tools
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
self.tool_calls.append(tool_name)
self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}")
return CallToolResult(
content=[TextContent(text=self.tool_results[-1], type="text")],
)

57
tests/mcp/test_caching.py Normal file
View file

@ -0,0 +1,57 @@
from unittest.mock import AsyncMock, patch
import pytest
from mcp.types import ListToolsResult, Tool as MCPTool
from agents.mcp import MCPServerStdio
from .helpers import DummyStreamsContextManager, tee
@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("mcp.client.session.ClientSession.list_tools")
async def test_server_caching_works(
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
):
"""Test that if we turn caching on, the list of tools is cached and not fetched from the server
on each call to `list_tools()`.
"""
server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
)
tools = [
MCPTool(name="tool1", inputSchema={}),
MCPTool(name="tool2", inputSchema={}),
]
mock_list_tools.return_value = ListToolsResult(tools=tools)
async with server:
# Call list_tools() multiple times
tools = await server.list_tools()
assert tools == tools
assert mock_list_tools.call_count == 1, "list_tools() should have been called once"
# Call list_tools() again, should return the cached value
tools = await server.list_tools()
assert tools == tools
assert mock_list_tools.call_count == 1, "list_tools() should not have been called again"
# Invalidate the cache and call list_tools() again
server.invalidate_tools_cache()
tools = await server.list_tools()
assert tools == tools
assert mock_list_tools.call_count == 2, "list_tools() should be called again"
# Without invalidating the cache, calling list_tools() again should return the cached value
tools = await server.list_tools()
assert tools == tools

View file

@ -0,0 +1,69 @@
from unittest.mock import AsyncMock, patch
import pytest
from mcp.types import ListToolsResult, Tool as MCPTool
from agents.mcp import MCPServerStdio
from .helpers import DummyStreamsContextManager, tee
@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("mcp.client.session.ClientSession.list_tools")
async def test_async_ctx_manager_works(
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
):
"""Test that the async context manager works."""
server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
)
tools = [
MCPTool(name="tool1", inputSchema={}),
MCPTool(name="tool2", inputSchema={}),
]
mock_list_tools.return_value = ListToolsResult(tools=tools)
assert server.session is None, "Server should not be connected"
async with server:
assert server.session is not None, "Server should be connected"
assert server.session is None, "Server should be disconnected"
@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("mcp.client.session.ClientSession.list_tools")
async def test_manual_connect_disconnect_works(
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
):
"""Test that the async context manager works."""
server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
)
tools = [
MCPTool(name="tool1", inputSchema={}),
MCPTool(name="tool2", inputSchema={}),
]
mock_list_tools.return_value = ListToolsResult(tools=tools)
assert server.session is None, "Server should not be connected"
await server.connect()
assert server.session is not None, "Server should be connected"
await server.cleanup()
assert server.session is None, "Server should be disconnected"

109
tests/mcp/test_mcp_util.py Normal file
View file

@ -0,0 +1,109 @@
import logging
from typing import Any
import pytest
from mcp.types import Tool as MCPTool
from pydantic import BaseModel
from agents import FunctionTool, RunContextWrapper
from agents.exceptions import AgentsException, ModelBehaviorError
from agents.mcp import MCPServer, MCPUtil
from .helpers import FakeMCPServer
class Foo(BaseModel):
bar: str
baz: int
class Bar(BaseModel):
qux: str
@pytest.mark.asyncio
async def test_get_all_function_tools():
"""Test that the get_all_function_tools function returns all function tools from a list of MCP
servers.
"""
names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"]
schemas = [
{},
{},
{},
Foo.model_json_schema(),
Bar.model_json_schema(),
]
server1 = FakeMCPServer()
server1.add_tool(names[0], schemas[0])
server1.add_tool(names[1], schemas[1])
server2 = FakeMCPServer()
server2.add_tool(names[2], schemas[2])
server2.add_tool(names[3], schemas[3])
server3 = FakeMCPServer()
server3.add_tool(names[4], schemas[4])
servers: list[MCPServer] = [server1, server2, server3]
tools = await MCPUtil.get_all_function_tools(servers)
assert len(tools) == 5
assert all(tool.name in names for tool in tools)
for idx, tool in enumerate(tools):
assert isinstance(tool, FunctionTool)
assert tool.params_json_schema == schemas[idx]
assert tool.name == names[idx]
@pytest.mark.asyncio
async def test_invoke_mcp_tool():
"""Test that the invoke_mcp_tool function invokes an MCP tool and returns the result."""
server = FakeMCPServer()
server.add_tool("test_tool_1", {})
ctx = RunContextWrapper(context=None)
tool = MCPTool(name="test_tool_1", inputSchema={})
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
# Just making sure it doesn't crash
@pytest.mark.asyncio
async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture):
caplog.set_level(logging.DEBUG)
"""Test that bad JSON input errors are logged and re-raised."""
server = FakeMCPServer()
server.add_tool("test_tool_1", {})
ctx = RunContextWrapper(context=None)
tool = MCPTool(name="test_tool_1", inputSchema={})
with pytest.raises(ModelBehaviorError):
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json")
assert "Invalid JSON input for tool test_tool_1" in caplog.text
class CrashingFakeMCPServer(FakeMCPServer):
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None):
raise Exception("Crash!")
@pytest.mark.asyncio
async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture):
caplog.set_level(logging.DEBUG)
"""Test that bad JSON input errors are logged and re-raised."""
server = CrashingFakeMCPServer()
server.add_tool("test_tool_1", {})
ctx = RunContextWrapper(context=None)
tool = MCPTool(name="test_tool_1", inputSchema={})
with pytest.raises(AgentsException):
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
assert "Error invoking MCP tool test_tool_1" in caplog.text

View file

@ -0,0 +1,197 @@
import json
import pytest
from pydantic import BaseModel
from agents import Agent, ModelBehaviorError, Runner, UserError
from ..fake_model import FakeModel
from ..test_responses import get_function_tool_call, get_text_message
from .helpers import FakeMCPServer
@pytest.mark.asyncio
@pytest.mark.parametrize("streaming", [False, True])
async def test_runner_calls_mcp_tool(streaming: bool):
"""Test that the runner calls an MCP tool when the model produces a tool call."""
server = FakeMCPServer()
server.add_tool("test_tool_1", {})
server.add_tool("test_tool_2", {})
server.add_tool("test_tool_3", {})
model = FakeModel()
agent = Agent(
name="test",
model=model,
mcp_servers=[server],
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[get_text_message("a_message"), get_function_tool_call("test_tool_2", "")],
# Second turn: text message
[get_text_message("done")],
]
)
if streaming:
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass
else:
await Runner.run(agent, input="user_message")
assert server.tool_calls == ["test_tool_2"]
@pytest.mark.asyncio
@pytest.mark.parametrize("streaming", [False, True])
async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool):
"""Test that the runner asserts when an MCP tool is not found."""
server = FakeMCPServer()
server.add_tool("test_tool_1", {})
server.add_tool("test_tool_2", {})
server.add_tool("test_tool_3", {})
model = FakeModel()
agent = Agent(
name="test",
model=model,
mcp_servers=[server],
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")],
# Second turn: text message
[get_text_message("done")],
]
)
with pytest.raises(ModelBehaviorError):
if streaming:
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass
else:
await Runner.run(agent, input="user_message")
@pytest.mark.asyncio
@pytest.mark.parametrize("streaming", [False, True])
async def test_runner_works_with_multiple_mcp_servers(streaming: bool):
"""Test that the runner works with multiple MCP servers."""
server1 = FakeMCPServer()
server1.add_tool("test_tool_1", {})
server2 = FakeMCPServer()
server2.add_tool("test_tool_2", {})
server2.add_tool("test_tool_3", {})
model = FakeModel()
agent = Agent(
name="test",
model=model,
mcp_servers=[server1, server2],
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[get_text_message("a_message"), get_function_tool_call("test_tool_2", "")],
# Second turn: text message
[get_text_message("done")],
]
)
if streaming:
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass
else:
await Runner.run(agent, input="user_message")
assert server1.tool_calls == []
assert server2.tool_calls == ["test_tool_2"]
@pytest.mark.asyncio
@pytest.mark.parametrize("streaming", [False, True])
async def test_runner_errors_when_mcp_tools_clash(streaming: bool):
"""Test that the runner errors when multiple servers have the same tool name."""
server1 = FakeMCPServer()
server1.add_tool("test_tool_1", {})
server1.add_tool("test_tool_2", {})
server2 = FakeMCPServer()
server2.add_tool("test_tool_2", {})
server2.add_tool("test_tool_3", {})
model = FakeModel()
agent = Agent(
name="test",
model=model,
mcp_servers=[server1, server2],
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[get_text_message("a_message"), get_function_tool_call("test_tool_3", "")],
# Second turn: text message
[get_text_message("done")],
]
)
with pytest.raises(UserError):
if streaming:
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass
else:
await Runner.run(agent, input="user_message")
class Foo(BaseModel):
bar: str
baz: int
@pytest.mark.asyncio
@pytest.mark.parametrize("streaming", [False, True])
async def test_runner_calls_mcp_tool_with_args(streaming: bool):
"""Test that the runner calls an MCP tool when the model produces a tool call."""
server = FakeMCPServer()
await server.connect()
server.add_tool("test_tool_1", {})
server.add_tool("test_tool_2", Foo.model_json_schema())
server.add_tool("test_tool_3", {})
model = FakeModel()
agent = Agent(
name="test",
model=model,
mcp_servers=[server],
)
json_args = json.dumps(Foo(bar="baz", baz=1).model_dump())
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)],
# Second turn: text message
[get_text_message("done")],
]
)
if streaming:
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass
else:
await Runner.run(agent, input="user_message")
assert server.tool_calls == ["test_tool_2"]
assert server.tool_results == [f"result_test_tool_2_{json_args}"]
await server.cleanup()

View file

@ -0,0 +1,38 @@
import pytest
from agents.exceptions import UserError
from agents.mcp.server import _MCPServerWithClientSession
class CrashingClientSessionServer(_MCPServerWithClientSession):
def __init__(self):
super().__init__(cache_tools_list=False)
self.cleanup_called = False
def create_streams(self):
raise ValueError("Crash!")
async def cleanup(self):
self.cleanup_called = True
await super().cleanup()
@pytest.mark.asyncio
async def test_server_errors_cause_error_and_cleanup_called():
server = CrashingClientSessionServer()
with pytest.raises(ValueError):
await server.connect()
assert server.cleanup_called
@pytest.mark.asyncio
async def test_not_calling_connect_causes_error():
server = CrashingClientSessionServer()
with pytest.raises(UserError):
await server.list_tools()
with pytest.raises(UserError):
await server.call_tool("foo", {})

View file

@ -290,6 +290,7 @@ async def get_execute_result(
processed_response = RunImpl.process_model_response(
agent=agent,
all_tools=await agent.get_all_tools(),
response=response,
output_schema=output_schema,
handoffs=handoffs,

View file

@ -43,7 +43,11 @@ def test_empty_response():
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=[],
)
assert not result.handoffs
assert not result.functions
@ -57,13 +61,14 @@ def test_no_tool_calls():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[]
)
assert not result.handoffs
assert not result.functions
def test_single_tool_call():
@pytest.mark.asyncio
async def test_single_tool_call():
agent = Agent(name="test", tools=[get_function_tool(name="test")])
response = ModelResponse(
output=[
@ -74,7 +79,11 @@ def test_single_tool_call():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
)
assert not result.handoffs
assert result.functions and len(result.functions) == 1
@ -84,7 +93,8 @@ def test_single_tool_call():
assert func.tool_call.arguments == ""
def test_missing_tool_call_raises_error():
@pytest.mark.asyncio
async def test_missing_tool_call_raises_error():
agent = Agent(name="test", tools=[get_function_tool(name="test")])
response = ModelResponse(
output=[
@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error():
with pytest.raises(ModelBehaviorError):
RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
)
def test_multiple_tool_calls():
@pytest.mark.asyncio
async def test_multiple_tool_calls():
agent = Agent(
name="test",
tools=[
@ -121,7 +136,11 @@ def test_multiple_tool_calls():
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
)
assert not result.handoffs
assert result.functions and len(result.functions) == 2
@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=agent_3, response=response, output_schema=None, handoffs=[]
agent=agent_3,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent_3.get_all_tools(),
)
assert not result.handoffs, "Shouldn't have a handoff here"
@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
)
assert len(result.handoffs) == 1, "Should have a handoff here"
handoff = result.handoffs[0]
@ -189,10 +213,12 @@ async def test_missing_handoff_fails():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
)
def test_multiple_handoffs_doesnt_error():
@pytest.mark.asyncio
async def test_multiple_handoffs_doesnt_error():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2")
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
)
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@ -218,7 +245,8 @@ class Foo(BaseModel):
bar: str
def test_final_output_parsed_correctly():
@pytest.mark.asyncio
async def test_final_output_parsed_correctly():
agent = Agent(name="test", output_type=Foo)
response = ModelResponse(
output=[
@ -234,10 +262,12 @@ def test_final_output_parsed_correctly():
response=response,
output_schema=Runner._get_output_schema(agent),
handoffs=[],
all_tools=await agent.get_all_tools(),
)
def test_file_search_tool_call_parsed_correctly():
@pytest.mark.asyncio
async def test_file_search_tool_call_parsed_correctly():
# Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool
# runs are scheduled.
@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
)
# The final item should be a ToolCallItem for the file search call
assert any(
@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly():
assert not result.handoffs
def test_function_web_search_tool_call_parsed_correctly():
@pytest.mark.asyncio
async def test_function_web_search_tool_call_parsed_correctly():
agent = Agent(name="test")
web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call")
response = ModelResponse(
@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
)
assert any(
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly():
assert not result.handoffs
def test_reasoning_item_parsed_correctly():
@pytest.mark.asyncio
async def test_reasoning_item_parsed_correctly():
# Verify that a Reasoning output item is converted into a ReasoningItem.
reasoning = ResponseReasoningItem(
@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=Agent(name="test"), response=response, output_schema=None, handoffs=[]
agent=Agent(name="test"),
response=response,
output_schema=None,
handoffs=[],
all_tools=await Agent(name="test").get_all_tools(),
)
assert any(
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
@ -342,7 +386,8 @@ class DummyComputer(Computer):
return None # pragma: no cover
def test_computer_tool_call_without_computer_tool_raises_error():
@pytest.mark.asyncio
async def test_computer_tool_call_without_computer_tool_raises_error():
# If the agent has no ComputerTool in its tools, process_model_response should raise a
# ModelBehaviorError when encountering a ResponseComputerToolCall.
computer_call = ResponseComputerToolCall(
@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error():
)
with pytest.raises(ModelBehaviorError):
RunImpl.process_model_response(
agent=Agent(name="test"), response=response, output_schema=None, handoffs=[]
agent=Agent(name="test"),
response=response,
output_schema=None,
handoffs=[],
all_tools=await Agent(name="test").get_all_tools(),
)
def test_computer_tool_call_with_computer_tool_parsed_correctly():
@pytest.mark.asyncio
async def test_computer_tool_call_with_computer_tool_parsed_correctly():
# If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a
# ToolCallItem and scheduled to run in computer_actions.
dummy_computer = DummyComputer()
@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly():
referenceable_id=None,
)
result = RunImpl.process_model_response(
agent=agent, response=response, output_schema=None, handoffs=[]
agent=agent,
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
)
assert any(
isinstance(item, ToolCallItem) and item.raw_item is computer_call
@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly():
assert result.computer_actions and result.computer_actions[0].tool_call == computer_call
def test_tool_and_handoff_parsed_correctly():
@pytest.mark.asyncio
async def test_tool_and_handoff_parsed_correctly():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2")
agent_3 = Agent(
@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
)
assert result.functions and len(result.functions) == 1
assert len(result.handoffs) == 1, "Should have a handoff here"

View file

@ -302,13 +302,11 @@ async def test_session_error_event():
trace_include_sensitive_audio_data=False,
)
with pytest.raises(STTWebsocketConnectionError) as exc_info:
with pytest.raises(STTWebsocketConnectionError):
turns = session.transcribe_turns()
async for _ in turns:
pass
assert "Simulated server error!" in str(exc_info.value)
await session.close()