From 68c800d2a3a7ea101d144a69b53a0b6658291e16 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Mon, 24 Mar 2025 15:08:02 -0400 Subject: [PATCH] [2/n] Add MCP support to Runner ### Summary: This enables users to **use** MCP inside the SDK. 1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]` 2. When an agent runs, we look up its MCP tools and add them to the list of tools. 3. When a tool call occurs, we call the relevant MCP server. Notes: 1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc. 2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup. ### Test Plan: See unit tests. Also has an end to end example next PR. --- src/agents/_run_impl.py | 7 +- src/agents/agent.py | 20 +++ src/agents/run.py | 24 +++- tests/mcp/__init__.py | 0 tests/mcp/conftest.py | 11 ++ tests/mcp/helpers.py | 54 ++++++++ tests/mcp/test_caching.py | 57 ++++++++ tests/mcp/test_connect_disconnect.py | 69 ++++++++++ tests/mcp/test_mcp_util.py | 109 +++++++++++++++ tests/mcp/test_runner_calls_mcp.py | 197 +++++++++++++++++++++++++++ tests/mcp/test_server_errors.py | 38 ++++++ tests/test_run_step_execution.py | 1 + tests/test_run_step_processing.py | 100 +++++++++++--- tests/voice/test_openai_stt.py | 10 +- 14 files changed, 662 insertions(+), 35 deletions(-) create mode 100644 tests/mcp/__init__.py create mode 100644 tests/mcp/conftest.py create mode 100644 tests/mcp/helpers.py create mode 100644 tests/mcp/test_caching.py create mode 100644 tests/mcp/test_connect_disconnect.py create mode 100644 tests/mcp/test_mcp_util.py create mode 100644 tests/mcp/test_runner_calls_mcp.py create mode 100644 tests/mcp/test_server_errors.py diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2849538..02e3bf5 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -50,7 +50,7 @@ from .logger import logger from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool from .tracing import ( SpanError, Trace, @@ -301,6 +301,7 @@ class RunImpl: cls, *, agent: Agent[Any], + all_tools: list[Tool], response: ModelResponse, output_schema: AgentOutputSchema | None, handoffs: list[Handoff], @@ -312,8 +313,8 @@ class RunImpl: computer_actions = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} - function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)} - computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) for output in response.output: if isinstance(output, ResponseOutputMessage): diff --git a/src/agents/agent.py b/src/agents/agent.py index 2723e67..3258e15 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -12,6 +12,7 @@ from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers from .logger import logger +from .mcp import MCPUtil from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext @@ -21,6 +22,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: from .lifecycle import AgentHooks + from .mcp import MCPServer from .result import RunResult @@ -107,6 +109,16 @@ class Agent(Generic[TContext]): tools: list[Tool] = field(default_factory=list) """A list of tools that the agent can use.""" + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -205,3 +217,11 @@ class Agent(Generic[TContext]): logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None + + async def get_mcp_tools(self) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + + async def get_all_tools(self) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools diff --git a/src/agents/run.py b/src/agents/run.py index 934400f..b7ac85f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,6 +7,8 @@ from typing import Any, cast from openai.types.responses import ResponseCompletedEvent +from agents.tool import Tool + from ._run_impl import ( NextStepFinalOutput, NextStepHandoff, @@ -177,7 +179,8 @@ class Runner: # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - tool_names = [t.name for t in current_agent.tools] + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.output_type_name() else: @@ -217,6 +220,7 @@ class Runner: ), cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -228,6 +232,7 @@ class Runner: else: turn_result = await cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -627,7 +632,7 @@ class Runner: system_prompt = await agent.get_system_prompt(context_wrapper) handoffs = cls._get_handoffs(agent) - + all_tools = await cls._get_all_tools(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) final_response: ModelResponse | None = None @@ -640,7 +645,7 @@ class Runner: system_prompt, input, model_settings, - agent.tools, + all_tools, output_schema, handoffs, get_model_tracing_impl( @@ -677,6 +682,7 @@ class Runner: pre_step_items=streamed_result.new_items, new_response=final_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -691,6 +697,7 @@ class Runner: cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], generated_items: list[RunItem], hooks: RunHooks[TContext], @@ -721,6 +728,7 @@ class Runner: system_prompt, input, output_schema, + all_tools, handoffs, context_wrapper, run_config, @@ -732,6 +740,7 @@ class Runner: pre_step_items=generated_items, new_response=new_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -743,6 +752,7 @@ class Runner: cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, @@ -754,6 +764,7 @@ class Runner: ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, + all_tools=all_tools, response=new_response, output_schema=output_schema, handoffs=handoffs, @@ -853,6 +864,7 @@ class Runner: system_prompt: str | None, input: list[TResponseInputItem], output_schema: AgentOutputSchema | None, + all_tools: list[Tool], handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -863,7 +875,7 @@ class Runner: system_instructions=system_prompt, input=input, model_settings=model_settings, - tools=agent.tools, + tools=all_tools, output_schema=output_schema, handoffs=handoffs, tracing=get_model_tracing_impl( @@ -892,6 +904,10 @@ class Runner: handoffs.append(handoff(handoff_item)) return handoffs + @classmethod + async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: + return await agent.get_all_tools() + @classmethod def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: if isinstance(run_config.model, Model): diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/mcp/conftest.py b/tests/mcp/conftest.py new file mode 100644 index 0000000..80fd15e --- /dev/null +++ b/tests/mcp/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + + +# Skip MCP tests on Python 3.9 +def pytest_ignore_collect(collection_path, config): + if sys.version_info[:2] == (3, 9): + this_dir = os.path.dirname(__file__) + + if str(collection_path).startswith(this_dir): + return True diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py new file mode 100644 index 0000000..952b3ea --- /dev/null +++ b/tests/mcp/helpers.py @@ -0,0 +1,54 @@ +import json +import shutil +from typing import Any + +from mcp import Tool as MCPTool +from mcp.types import CallToolResult, TextContent + +from agents.mcp import MCPServer + +tee = shutil.which("tee") or "" +assert tee, "tee not found" + + +# Added dummy stream classes for patching stdio_client to avoid real I/O during tests +class DummyStream: + async def send(self, msg): + pass + + async def receive(self): + raise Exception("Dummy receive not implemented") + + +class DummyStreamsContextManager: + async def __aenter__(self): + return (DummyStream(), DummyStream()) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class FakeMCPServer(MCPServer): + def __init__(self, tools: list[MCPTool] | None = None): + self.tools: list[MCPTool] = tools or [] + self.tool_calls: list[str] = [] + self.tool_results: list[str] = [] + + def add_tool(self, name: str, input_schema: dict[str, Any]): + self.tools.append(MCPTool(name=name, inputSchema=input_schema)) + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_tools(self): + return self.tools + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + self.tool_calls.append(tool_name) + self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}") + return CallToolResult( + content=[TextContent(text=self.tool_results[-1], type="text")], + ) diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py new file mode 100644 index 0000000..cac409e --- /dev/null +++ b/tests/mcp/test_caching.py @@ -0,0 +1,57 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_server_caching_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that if we turn caching on, the list of tools is cached and not fetched from the server + on each call to `list_tools()`. + """ + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + async with server: + # Call list_tools() multiple times + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should have been called once" + + # Call list_tools() again, should return the cached value + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" + + # Invalidate the cache and call list_tools() again + server.invalidate_tools_cache() + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 2, "list_tools() should be called again" + + # Without invalidating the cache, calling list_tools() again should return the cached value + tools = await server.list_tools() + assert tools == tools diff --git a/tests/mcp/test_connect_disconnect.py b/tests/mcp/test_connect_disconnect.py new file mode 100644 index 0000000..b001303 --- /dev/null +++ b/tests/mcp/test_connect_disconnect.py @@ -0,0 +1,69 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_async_ctx_manager_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + async with server: + assert server.session is not None, "Server should be connected" + + assert server.session is None, "Server should be disconnected" + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_manual_connect_disconnect_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + await server.connect() + assert server.session is not None, "Server should be connected" + + await server.cleanup() + assert server.session is None, "Server should be disconnected" diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py new file mode 100644 index 0000000..345df99 --- /dev/null +++ b/tests/mcp/test_mcp_util.py @@ -0,0 +1,109 @@ +import logging +from typing import Any + +import pytest +from mcp.types import Tool as MCPTool +from pydantic import BaseModel + +from agents import FunctionTool, RunContextWrapper +from agents.exceptions import AgentsException, ModelBehaviorError +from agents.mcp import MCPServer, MCPUtil + +from .helpers import FakeMCPServer + + +class Foo(BaseModel): + bar: str + baz: int + + +class Bar(BaseModel): + qux: str + + +@pytest.mark.asyncio +async def test_get_all_function_tools(): + """Test that the get_all_function_tools function returns all function tools from a list of MCP + servers. + """ + names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"] + schemas = [ + {}, + {}, + {}, + Foo.model_json_schema(), + Bar.model_json_schema(), + ] + + server1 = FakeMCPServer() + server1.add_tool(names[0], schemas[0]) + server1.add_tool(names[1], schemas[1]) + + server2 = FakeMCPServer() + server2.add_tool(names[2], schemas[2]) + server2.add_tool(names[3], schemas[3]) + + server3 = FakeMCPServer() + server3.add_tool(names[4], schemas[4]) + + servers: list[MCPServer] = [server1, server2, server3] + tools = await MCPUtil.get_all_function_tools(servers) + assert len(tools) == 5 + assert all(tool.name in names for tool in tools) + + for idx, tool in enumerate(tools): + assert isinstance(tool, FunctionTool) + assert tool.params_json_schema == schemas[idx] + assert tool.name == names[idx] + + +@pytest.mark.asyncio +async def test_invoke_mcp_tool(): + """Test that the invoke_mcp_tool function invokes an MCP tool and returns the result.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + # Just making sure it doesn't crash + + +@pytest.mark.asyncio +async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(ModelBehaviorError): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json") + + assert "Invalid JSON input for tool test_tool_1" in caplog.text + + +class CrashingFakeMCPServer(FakeMCPServer): + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None): + raise Exception("Crash!") + + +@pytest.mark.asyncio +async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = CrashingFakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(AgentsException): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + + assert "Error invoking MCP tool test_tool_1" in caplog.text diff --git a/tests/mcp/test_runner_calls_mcp.py b/tests/mcp/test_runner_calls_mcp.py new file mode 100644 index 0000000..3319c09 --- /dev/null +++ b/tests/mcp/test_runner_calls_mcp.py @@ -0,0 +1,197 @@ +import json + +import pytest +from pydantic import BaseModel + +from agents import Agent, ModelBehaviorError, Runner, UserError + +from ..fake_model import FakeModel +from ..test_responses import get_function_tool_call, get_text_message +from .helpers import FakeMCPServer + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool): + """Test that the runner asserts when an MCP tool is not found.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(ModelBehaviorError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_works_with_multiple_mcp_servers(streaming: bool): + """Test that the runner works with multiple MCP servers.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server1.tool_calls == [] + assert server2.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_errors_when_mcp_tools_clash(streaming: bool): + """Test that the runner errors when multiple servers have the same tool name.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + server1.add_tool("test_tool_2", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_3", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(UserError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +class Foo(BaseModel): + bar: str + baz: int + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool_with_args(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + await server.connect() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", Foo.model_json_schema()) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + json_args = json.dumps(Foo(bar="baz", baz=1).model_dump()) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + assert server.tool_results == [f"result_test_tool_2_{json_args}"] + + await server.cleanup() diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py new file mode 100644 index 0000000..5c6432b --- /dev/null +++ b/tests/mcp/test_server_errors.py @@ -0,0 +1,38 @@ +import pytest + +from agents.exceptions import UserError +from agents.mcp.server import _MCPServerWithClientSession + + +class CrashingClientSessionServer(_MCPServerWithClientSession): + def __init__(self): + super().__init__(cache_tools_list=False) + self.cleanup_called = False + + def create_streams(self): + raise ValueError("Crash!") + + async def cleanup(self): + self.cleanup_called = True + await super().cleanup() + + +@pytest.mark.asyncio +async def test_server_errors_cause_error_and_cleanup_called(): + server = CrashingClientSessionServer() + + with pytest.raises(ValueError): + await server.connect() + + assert server.cleanup_called + + +@pytest.mark.asyncio +async def test_not_calling_connect_causes_error(): + server = CrashingClientSessionServer() + + with pytest.raises(UserError): + await server.list_tools() + + with pytest.raises(UserError): + await server.call_tool("foo", {}) diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2d581bf..16c62c8 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -290,6 +290,7 @@ async def get_execute_result( processed_response = RunImpl.process_model_response( agent=agent, + all_tools=await agent.get_all_tools(), response=response, output_schema=output_schema, handoffs=handoffs, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 24f9e8e..2a6634a 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -43,7 +43,11 @@ def test_empty_response(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=[], ) assert not result.handoffs assert not result.functions @@ -57,13 +61,14 @@ def test_no_tool_calls(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[] ) assert not result.handoffs assert not result.functions -def test_single_tool_call(): +@pytest.mark.asyncio +async def test_single_tool_call(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -74,7 +79,11 @@ def test_single_tool_call(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 1 @@ -84,7 +93,8 @@ def test_single_tool_call(): assert func.tool_call.arguments == "" -def test_missing_tool_call_raises_error(): +@pytest.mark.asyncio +async def test_missing_tool_call_raises_error(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error(): with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_multiple_tool_calls(): +@pytest.mark.asyncio +async def test_multiple_tool_calls(): agent = Agent( name="test", tools=[ @@ -121,7 +136,11 @@ def test_multiple_tool_calls(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 2 @@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent_3, response=response, output_schema=None, handoffs=[] + agent=agent_3, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent_3.get_all_tools(), ) assert not result.handoffs, "Shouldn't have a handoff here" @@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 1, "Should have a handoff here" handoff = result.handoffs[0] @@ -189,10 +213,12 @@ async def test_missing_handoff_fails(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) -def test_multiple_handoffs_doesnt_error(): +@pytest.mark.asyncio +async def test_multiple_handoffs_doesnt_error(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -218,7 +245,8 @@ class Foo(BaseModel): bar: str -def test_final_output_parsed_correctly(): +@pytest.mark.asyncio +async def test_final_output_parsed_correctly(): agent = Agent(name="test", output_type=Foo) response = ModelResponse( output=[ @@ -234,10 +262,12 @@ def test_final_output_parsed_correctly(): response=response, output_schema=Runner._get_output_schema(agent), handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_file_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_file_search_tool_call_parsed_correctly(): # Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool # runs are scheduled. @@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) # The final item should be a ToolCallItem for the file search call assert any( @@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_function_web_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_function_web_search_tool_call_parsed_correctly(): agent = Agent(name="test") web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call") response = ModelResponse( @@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is web_search_call @@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_reasoning_item_parsed_correctly(): +@pytest.mark.asyncio +async def test_reasoning_item_parsed_correctly(): # Verify that a Reasoning output item is converted into a ReasoningItem. reasoning = ResponseReasoningItem( @@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) assert any( isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items @@ -342,7 +386,8 @@ class DummyComputer(Computer): return None # pragma: no cover -def test_computer_tool_call_without_computer_tool_raises_error(): +@pytest.mark.asyncio +async def test_computer_tool_call_without_computer_tool_raises_error(): # If the agent has no ComputerTool in its tools, process_model_response should raise a # ModelBehaviorError when encountering a ResponseComputerToolCall. computer_call = ResponseComputerToolCall( @@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error(): ) with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) -def test_computer_tool_call_with_computer_tool_parsed_correctly(): +@pytest.mark.asyncio +async def test_computer_tool_call_with_computer_tool_parsed_correctly(): # If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a # ToolCallItem and scheduled to run in computer_actions. dummy_computer = DummyComputer() @@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is computer_call @@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): assert result.computer_actions and result.computer_actions[0].tool_call == computer_call -def test_tool_and_handoff_parsed_correctly(): +@pytest.mark.asyncio +async def test_tool_and_handoff_parsed_correctly(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert result.functions and len(result.functions) == 1 assert len(result.handoffs) == 1, "Should have a handoff here" diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index 7555923..89b5cca 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -269,7 +269,7 @@ async def test_timeout_waiting_for_created_event(monkeypatch): async for _ in turns: pass - assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) + assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) await session.close() @@ -302,13 +302,11 @@ async def test_session_error_event(): trace_include_sensitive_audio_data=False, ) - with pytest.raises(STTWebsocketConnectionError) as exc_info: + with pytest.raises(STTWebsocketConnectionError): turns = session.transcribe_turns() async for _ in turns: pass - assert "Simulated server error!" in str(exc_info.value) - await session.close() @@ -362,8 +360,8 @@ async def test_inactivity_timeout(): async for turn in session.transcribe_turns(): collected_turns.append(turn) - assert "Timeout waiting for transcription_session" in str(exc_info.value) + assert "Timeout waiting for transcription_session" in str(exc_info.value) - assert len(collected_turns) == 0, "No transcripts expected, but we got something?" + assert len(collected_turns) == 0, "No transcripts expected, but we got something?" await session.close()