[2/n] Add MCP support to Runner
### Summary: This enables users to **use** MCP inside the SDK. 1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]` 2. When an agent runs, we look up its MCP tools and add them to the list of tools. 3. When a tool call occurs, we call the relevant MCP server. Notes: 1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc. 2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup. ### Test Plan: See unit tests. Also has an end to end example next PR.
This commit is contained in:
parent
300e12c198
commit
68c800d2a3
14 changed files with 662 additions and 35 deletions
|
|
@ -50,7 +50,7 @@ from .logger import logger
|
||||||
from .models.interface import ModelTracing
|
from .models.interface import ModelTracing
|
||||||
from .run_context import RunContextWrapper, TContext
|
from .run_context import RunContextWrapper, TContext
|
||||||
from .stream_events import RunItemStreamEvent, StreamEvent
|
from .stream_events import RunItemStreamEvent, StreamEvent
|
||||||
from .tool import ComputerTool, FunctionTool, FunctionToolResult
|
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
|
||||||
from .tracing import (
|
from .tracing import (
|
||||||
SpanError,
|
SpanError,
|
||||||
Trace,
|
Trace,
|
||||||
|
|
@ -301,6 +301,7 @@ class RunImpl:
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
agent: Agent[Any],
|
agent: Agent[Any],
|
||||||
|
all_tools: list[Tool],
|
||||||
response: ModelResponse,
|
response: ModelResponse,
|
||||||
output_schema: AgentOutputSchema | None,
|
output_schema: AgentOutputSchema | None,
|
||||||
handoffs: list[Handoff],
|
handoffs: list[Handoff],
|
||||||
|
|
@ -312,8 +313,8 @@ class RunImpl:
|
||||||
computer_actions = []
|
computer_actions = []
|
||||||
|
|
||||||
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
||||||
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
|
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
|
||||||
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
|
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
|
||||||
|
|
||||||
for output in response.output:
|
for output in response.output:
|
||||||
if isinstance(output, ResponseOutputMessage):
|
if isinstance(output, ResponseOutputMessage):
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from .guardrail import InputGuardrail, OutputGuardrail
|
||||||
from .handoffs import Handoff
|
from .handoffs import Handoff
|
||||||
from .items import ItemHelpers
|
from .items import ItemHelpers
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
|
from .mcp import MCPUtil
|
||||||
from .model_settings import ModelSettings
|
from .model_settings import ModelSettings
|
||||||
from .models.interface import Model
|
from .models.interface import Model
|
||||||
from .run_context import RunContextWrapper, TContext
|
from .run_context import RunContextWrapper, TContext
|
||||||
|
|
@ -21,6 +22,7 @@ from .util._types import MaybeAwaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .lifecycle import AgentHooks
|
from .lifecycle import AgentHooks
|
||||||
|
from .mcp import MCPServer
|
||||||
from .result import RunResult
|
from .result import RunResult
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -107,6 +109,16 @@ class Agent(Generic[TContext]):
|
||||||
tools: list[Tool] = field(default_factory=list)
|
tools: list[Tool] = field(default_factory=list)
|
||||||
"""A list of tools that the agent can use."""
|
"""A list of tools that the agent can use."""
|
||||||
|
|
||||||
|
mcp_servers: list[MCPServer] = field(default_factory=list)
|
||||||
|
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
|
||||||
|
the agent can use. Every time the agent runs, it will include tools from these servers in the
|
||||||
|
list of available tools.
|
||||||
|
|
||||||
|
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
|
||||||
|
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
|
||||||
|
longer needed.
|
||||||
|
"""
|
||||||
|
|
||||||
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
||||||
"""A list of checks that run in parallel to the agent's execution, before generating a
|
"""A list of checks that run in parallel to the agent's execution, before generating a
|
||||||
response. Runs only if the agent is the first agent in the chain.
|
response. Runs only if the agent is the first agent in the chain.
|
||||||
|
|
@ -205,3 +217,11 @@ class Agent(Generic[TContext]):
|
||||||
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
|
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_mcp_tools(self) -> list[Tool]:
|
||||||
|
"""Fetches the available tools from the MCP servers."""
|
||||||
|
return await MCPUtil.get_all_function_tools(self.mcp_servers)
|
||||||
|
|
||||||
|
async def get_all_tools(self) -> list[Tool]:
|
||||||
|
"""All agent tools, including MCP tools and function tools."""
|
||||||
|
return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from typing import Any, cast
|
||||||
|
|
||||||
from openai.types.responses import ResponseCompletedEvent
|
from openai.types.responses import ResponseCompletedEvent
|
||||||
|
|
||||||
|
from agents.tool import Tool
|
||||||
|
|
||||||
from ._run_impl import (
|
from ._run_impl import (
|
||||||
NextStepFinalOutput,
|
NextStepFinalOutput,
|
||||||
NextStepHandoff,
|
NextStepHandoff,
|
||||||
|
|
@ -177,7 +179,8 @@ class Runner:
|
||||||
# agent changes, or if the agent loop ends.
|
# agent changes, or if the agent loop ends.
|
||||||
if current_span is None:
|
if current_span is None:
|
||||||
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
||||||
tool_names = [t.name for t in current_agent.tools]
|
all_tools = await cls._get_all_tools(current_agent)
|
||||||
|
tool_names = [t.name for t in all_tools]
|
||||||
if output_schema := cls._get_output_schema(current_agent):
|
if output_schema := cls._get_output_schema(current_agent):
|
||||||
output_type_name = output_schema.output_type_name()
|
output_type_name = output_schema.output_type_name()
|
||||||
else:
|
else:
|
||||||
|
|
@ -217,6 +220,7 @@ class Runner:
|
||||||
),
|
),
|
||||||
cls._run_single_turn(
|
cls._run_single_turn(
|
||||||
agent=current_agent,
|
agent=current_agent,
|
||||||
|
all_tools=all_tools,
|
||||||
original_input=original_input,
|
original_input=original_input,
|
||||||
generated_items=generated_items,
|
generated_items=generated_items,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
|
|
@ -228,6 +232,7 @@ class Runner:
|
||||||
else:
|
else:
|
||||||
turn_result = await cls._run_single_turn(
|
turn_result = await cls._run_single_turn(
|
||||||
agent=current_agent,
|
agent=current_agent,
|
||||||
|
all_tools=all_tools,
|
||||||
original_input=original_input,
|
original_input=original_input,
|
||||||
generated_items=generated_items,
|
generated_items=generated_items,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
|
|
@ -627,7 +632,7 @@ class Runner:
|
||||||
system_prompt = await agent.get_system_prompt(context_wrapper)
|
system_prompt = await agent.get_system_prompt(context_wrapper)
|
||||||
|
|
||||||
handoffs = cls._get_handoffs(agent)
|
handoffs = cls._get_handoffs(agent)
|
||||||
|
all_tools = await cls._get_all_tools(agent)
|
||||||
model = cls._get_model(agent, run_config)
|
model = cls._get_model(agent, run_config)
|
||||||
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
||||||
final_response: ModelResponse | None = None
|
final_response: ModelResponse | None = None
|
||||||
|
|
@ -640,7 +645,7 @@ class Runner:
|
||||||
system_prompt,
|
system_prompt,
|
||||||
input,
|
input,
|
||||||
model_settings,
|
model_settings,
|
||||||
agent.tools,
|
all_tools,
|
||||||
output_schema,
|
output_schema,
|
||||||
handoffs,
|
handoffs,
|
||||||
get_model_tracing_impl(
|
get_model_tracing_impl(
|
||||||
|
|
@ -677,6 +682,7 @@ class Runner:
|
||||||
pre_step_items=streamed_result.new_items,
|
pre_step_items=streamed_result.new_items,
|
||||||
new_response=final_response,
|
new_response=final_response,
|
||||||
output_schema=output_schema,
|
output_schema=output_schema,
|
||||||
|
all_tools=all_tools,
|
||||||
handoffs=handoffs,
|
handoffs=handoffs,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
context_wrapper=context_wrapper,
|
context_wrapper=context_wrapper,
|
||||||
|
|
@ -691,6 +697,7 @@ class Runner:
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
agent: Agent[TContext],
|
agent: Agent[TContext],
|
||||||
|
all_tools: list[Tool],
|
||||||
original_input: str | list[TResponseInputItem],
|
original_input: str | list[TResponseInputItem],
|
||||||
generated_items: list[RunItem],
|
generated_items: list[RunItem],
|
||||||
hooks: RunHooks[TContext],
|
hooks: RunHooks[TContext],
|
||||||
|
|
@ -721,6 +728,7 @@ class Runner:
|
||||||
system_prompt,
|
system_prompt,
|
||||||
input,
|
input,
|
||||||
output_schema,
|
output_schema,
|
||||||
|
all_tools,
|
||||||
handoffs,
|
handoffs,
|
||||||
context_wrapper,
|
context_wrapper,
|
||||||
run_config,
|
run_config,
|
||||||
|
|
@ -732,6 +740,7 @@ class Runner:
|
||||||
pre_step_items=generated_items,
|
pre_step_items=generated_items,
|
||||||
new_response=new_response,
|
new_response=new_response,
|
||||||
output_schema=output_schema,
|
output_schema=output_schema,
|
||||||
|
all_tools=all_tools,
|
||||||
handoffs=handoffs,
|
handoffs=handoffs,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
context_wrapper=context_wrapper,
|
context_wrapper=context_wrapper,
|
||||||
|
|
@ -743,6 +752,7 @@ class Runner:
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
agent: Agent[TContext],
|
agent: Agent[TContext],
|
||||||
|
all_tools: list[Tool],
|
||||||
original_input: str | list[TResponseInputItem],
|
original_input: str | list[TResponseInputItem],
|
||||||
pre_step_items: list[RunItem],
|
pre_step_items: list[RunItem],
|
||||||
new_response: ModelResponse,
|
new_response: ModelResponse,
|
||||||
|
|
@ -754,6 +764,7 @@ class Runner:
|
||||||
) -> SingleStepResult:
|
) -> SingleStepResult:
|
||||||
processed_response = RunImpl.process_model_response(
|
processed_response = RunImpl.process_model_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
all_tools=all_tools,
|
||||||
response=new_response,
|
response=new_response,
|
||||||
output_schema=output_schema,
|
output_schema=output_schema,
|
||||||
handoffs=handoffs,
|
handoffs=handoffs,
|
||||||
|
|
@ -853,6 +864,7 @@ class Runner:
|
||||||
system_prompt: str | None,
|
system_prompt: str | None,
|
||||||
input: list[TResponseInputItem],
|
input: list[TResponseInputItem],
|
||||||
output_schema: AgentOutputSchema | None,
|
output_schema: AgentOutputSchema | None,
|
||||||
|
all_tools: list[Tool],
|
||||||
handoffs: list[Handoff],
|
handoffs: list[Handoff],
|
||||||
context_wrapper: RunContextWrapper[TContext],
|
context_wrapper: RunContextWrapper[TContext],
|
||||||
run_config: RunConfig,
|
run_config: RunConfig,
|
||||||
|
|
@ -863,7 +875,7 @@ class Runner:
|
||||||
system_instructions=system_prompt,
|
system_instructions=system_prompt,
|
||||||
input=input,
|
input=input,
|
||||||
model_settings=model_settings,
|
model_settings=model_settings,
|
||||||
tools=agent.tools,
|
tools=all_tools,
|
||||||
output_schema=output_schema,
|
output_schema=output_schema,
|
||||||
handoffs=handoffs,
|
handoffs=handoffs,
|
||||||
tracing=get_model_tracing_impl(
|
tracing=get_model_tracing_impl(
|
||||||
|
|
@ -892,6 +904,10 @@ class Runner:
|
||||||
handoffs.append(handoff(handoff_item))
|
handoffs.append(handoff(handoff_item))
|
||||||
return handoffs
|
return handoffs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
|
||||||
|
return await agent.get_all_tools()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
|
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
|
||||||
if isinstance(run_config.model, Model):
|
if isinstance(run_config.model, Model):
|
||||||
|
|
|
||||||
0
tests/mcp/__init__.py
Normal file
0
tests/mcp/__init__.py
Normal file
11
tests/mcp/conftest.py
Normal file
11
tests/mcp/conftest.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
# Skip MCP tests on Python 3.9
|
||||||
|
def pytest_ignore_collect(collection_path, config):
|
||||||
|
if sys.version_info[:2] == (3, 9):
|
||||||
|
this_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
if str(collection_path).startswith(this_dir):
|
||||||
|
return True
|
||||||
54
tests/mcp/helpers.py
Normal file
54
tests/mcp/helpers.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from mcp import Tool as MCPTool
|
||||||
|
from mcp.types import CallToolResult, TextContent
|
||||||
|
|
||||||
|
from agents.mcp import MCPServer
|
||||||
|
|
||||||
|
tee = shutil.which("tee") or ""
|
||||||
|
assert tee, "tee not found"
|
||||||
|
|
||||||
|
|
||||||
|
# Added dummy stream classes for patching stdio_client to avoid real I/O during tests
|
||||||
|
class DummyStream:
|
||||||
|
async def send(self, msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def receive(self):
|
||||||
|
raise Exception("Dummy receive not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
class DummyStreamsContextManager:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return (DummyStream(), DummyStream())
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMCPServer(MCPServer):
|
||||||
|
def __init__(self, tools: list[MCPTool] | None = None):
|
||||||
|
self.tools: list[MCPTool] = tools or []
|
||||||
|
self.tool_calls: list[str] = []
|
||||||
|
self.tool_results: list[str] = []
|
||||||
|
|
||||||
|
def add_tool(self, name: str, input_schema: dict[str, Any]):
|
||||||
|
self.tools.append(MCPTool(name=name, inputSchema=input_schema))
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_tools(self):
|
||||||
|
return self.tools
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
|
||||||
|
self.tool_calls.append(tool_name)
|
||||||
|
self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}")
|
||||||
|
return CallToolResult(
|
||||||
|
content=[TextContent(text=self.tool_results[-1], type="text")],
|
||||||
|
)
|
||||||
57
tests/mcp/test_caching.py
Normal file
57
tests/mcp/test_caching.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mcp.types import ListToolsResult, Tool as MCPTool
|
||||||
|
|
||||||
|
from agents.mcp import MCPServerStdio
|
||||||
|
|
||||||
|
from .helpers import DummyStreamsContextManager, tee
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
|
||||||
|
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
|
||||||
|
@patch("mcp.client.session.ClientSession.list_tools")
|
||||||
|
async def test_server_caching_works(
|
||||||
|
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
|
||||||
|
):
|
||||||
|
"""Test that if we turn caching on, the list of tools is cached and not fetched from the server
|
||||||
|
on each call to `list_tools()`.
|
||||||
|
"""
|
||||||
|
server = MCPServerStdio(
|
||||||
|
params={
|
||||||
|
"command": tee,
|
||||||
|
},
|
||||||
|
cache_tools_list=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
MCPTool(name="tool1", inputSchema={}),
|
||||||
|
MCPTool(name="tool2", inputSchema={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_list_tools.return_value = ListToolsResult(tools=tools)
|
||||||
|
|
||||||
|
async with server:
|
||||||
|
# Call list_tools() multiple times
|
||||||
|
tools = await server.list_tools()
|
||||||
|
assert tools == tools
|
||||||
|
|
||||||
|
assert mock_list_tools.call_count == 1, "list_tools() should have been called once"
|
||||||
|
|
||||||
|
# Call list_tools() again, should return the cached value
|
||||||
|
tools = await server.list_tools()
|
||||||
|
assert tools == tools
|
||||||
|
|
||||||
|
assert mock_list_tools.call_count == 1, "list_tools() should not have been called again"
|
||||||
|
|
||||||
|
# Invalidate the cache and call list_tools() again
|
||||||
|
server.invalidate_tools_cache()
|
||||||
|
tools = await server.list_tools()
|
||||||
|
assert tools == tools
|
||||||
|
|
||||||
|
assert mock_list_tools.call_count == 2, "list_tools() should be called again"
|
||||||
|
|
||||||
|
# Without invalidating the cache, calling list_tools() again should return the cached value
|
||||||
|
tools = await server.list_tools()
|
||||||
|
assert tools == tools
|
||||||
69
tests/mcp/test_connect_disconnect.py
Normal file
69
tests/mcp/test_connect_disconnect.py
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mcp.types import ListToolsResult, Tool as MCPTool
|
||||||
|
|
||||||
|
from agents.mcp import MCPServerStdio
|
||||||
|
|
||||||
|
from .helpers import DummyStreamsContextManager, tee
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
|
||||||
|
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
|
||||||
|
@patch("mcp.client.session.ClientSession.list_tools")
|
||||||
|
async def test_async_ctx_manager_works(
|
||||||
|
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
|
||||||
|
):
|
||||||
|
"""Test that the async context manager works."""
|
||||||
|
server = MCPServerStdio(
|
||||||
|
params={
|
||||||
|
"command": tee,
|
||||||
|
},
|
||||||
|
cache_tools_list=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
MCPTool(name="tool1", inputSchema={}),
|
||||||
|
MCPTool(name="tool2", inputSchema={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_list_tools.return_value = ListToolsResult(tools=tools)
|
||||||
|
|
||||||
|
assert server.session is None, "Server should not be connected"
|
||||||
|
|
||||||
|
async with server:
|
||||||
|
assert server.session is not None, "Server should be connected"
|
||||||
|
|
||||||
|
assert server.session is None, "Server should be disconnected"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
|
||||||
|
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
|
||||||
|
@patch("mcp.client.session.ClientSession.list_tools")
|
||||||
|
async def test_manual_connect_disconnect_works(
|
||||||
|
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
|
||||||
|
):
|
||||||
|
"""Test that the async context manager works."""
|
||||||
|
server = MCPServerStdio(
|
||||||
|
params={
|
||||||
|
"command": tee,
|
||||||
|
},
|
||||||
|
cache_tools_list=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
MCPTool(name="tool1", inputSchema={}),
|
||||||
|
MCPTool(name="tool2", inputSchema={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_list_tools.return_value = ListToolsResult(tools=tools)
|
||||||
|
|
||||||
|
assert server.session is None, "Server should not be connected"
|
||||||
|
|
||||||
|
await server.connect()
|
||||||
|
assert server.session is not None, "Server should be connected"
|
||||||
|
|
||||||
|
await server.cleanup()
|
||||||
|
assert server.session is None, "Server should be disconnected"
|
||||||
109
tests/mcp/test_mcp_util.py
Normal file
109
tests/mcp/test_mcp_util.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mcp.types import Tool as MCPTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agents import FunctionTool, RunContextWrapper
|
||||||
|
from agents.exceptions import AgentsException, ModelBehaviorError
|
||||||
|
from agents.mcp import MCPServer, MCPUtil
|
||||||
|
|
||||||
|
from .helpers import FakeMCPServer
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
bar: str
|
||||||
|
baz: int
|
||||||
|
|
||||||
|
|
||||||
|
class Bar(BaseModel):
|
||||||
|
qux: str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_function_tools():
|
||||||
|
"""Test that the get_all_function_tools function returns all function tools from a list of MCP
|
||||||
|
servers.
|
||||||
|
"""
|
||||||
|
names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"]
|
||||||
|
schemas = [
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
Foo.model_json_schema(),
|
||||||
|
Bar.model_json_schema(),
|
||||||
|
]
|
||||||
|
|
||||||
|
server1 = FakeMCPServer()
|
||||||
|
server1.add_tool(names[0], schemas[0])
|
||||||
|
server1.add_tool(names[1], schemas[1])
|
||||||
|
|
||||||
|
server2 = FakeMCPServer()
|
||||||
|
server2.add_tool(names[2], schemas[2])
|
||||||
|
server2.add_tool(names[3], schemas[3])
|
||||||
|
|
||||||
|
server3 = FakeMCPServer()
|
||||||
|
server3.add_tool(names[4], schemas[4])
|
||||||
|
|
||||||
|
servers: list[MCPServer] = [server1, server2, server3]
|
||||||
|
tools = await MCPUtil.get_all_function_tools(servers)
|
||||||
|
assert len(tools) == 5
|
||||||
|
assert all(tool.name in names for tool in tools)
|
||||||
|
|
||||||
|
for idx, tool in enumerate(tools):
|
||||||
|
assert isinstance(tool, FunctionTool)
|
||||||
|
assert tool.params_json_schema == schemas[idx]
|
||||||
|
assert tool.name == names[idx]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_mcp_tool():
|
||||||
|
"""Test that the invoke_mcp_tool function invokes an MCP tool and returns the result."""
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("test_tool_1", {})
|
||||||
|
|
||||||
|
ctx = RunContextWrapper(context=None)
|
||||||
|
tool = MCPTool(name="test_tool_1", inputSchema={})
|
||||||
|
|
||||||
|
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
|
||||||
|
# Just making sure it doesn't crash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture):
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
|
||||||
|
"""Test that bad JSON input errors are logged and re-raised."""
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("test_tool_1", {})
|
||||||
|
|
||||||
|
ctx = RunContextWrapper(context=None)
|
||||||
|
tool = MCPTool(name="test_tool_1", inputSchema={})
|
||||||
|
|
||||||
|
with pytest.raises(ModelBehaviorError):
|
||||||
|
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json")
|
||||||
|
|
||||||
|
assert "Invalid JSON input for tool test_tool_1" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
class CrashingFakeMCPServer(FakeMCPServer):
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None):
|
||||||
|
raise Exception("Crash!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture):
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
|
||||||
|
"""Test that bad JSON input errors are logged and re-raised."""
|
||||||
|
server = CrashingFakeMCPServer()
|
||||||
|
server.add_tool("test_tool_1", {})
|
||||||
|
|
||||||
|
ctx = RunContextWrapper(context=None)
|
||||||
|
tool = MCPTool(name="test_tool_1", inputSchema={})
|
||||||
|
|
||||||
|
with pytest.raises(AgentsException):
|
||||||
|
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
|
||||||
|
|
||||||
|
assert "Error invoking MCP tool test_tool_1" in caplog.text
|
||||||
197
tests/mcp/test_runner_calls_mcp.py
Normal file
197
tests/mcp/test_runner_calls_mcp.py
Normal file
|
|
@ -0,0 +1,197 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agents import Agent, ModelBehaviorError, Runner, UserError
|
||||||
|
|
||||||
|
from ..fake_model import FakeModel
|
||||||
|
from ..test_responses import get_function_tool_call, get_text_message
|
||||||
|
from .helpers import FakeMCPServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("streaming", [False, True])
|
||||||
|
async def test_runner_calls_mcp_tool(streaming: bool):
|
||||||
|
"""Test that the runner calls an MCP tool when the model produces a tool call."""
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("test_tool_1", {})
|
||||||
|
server.add_tool("test_tool_2", {})
|
||||||
|
server.add_tool("test_tool_3", {})
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
mcp_servers=[server],
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[get_text_message("a_message"), get_function_tool_call("test_tool_2", "")],
|
||||||
|
# Second turn: text message
|
||||||
|
[get_text_message("done")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
result = Runner.run_streamed(agent, input="user_message")
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
assert server.tool_calls == ["test_tool_2"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("streaming", [False, True])
|
||||||
|
async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool):
|
||||||
|
"""Test that the runner asserts when an MCP tool is not found."""
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("test_tool_1", {})
|
||||||
|
server.add_tool("test_tool_2", {})
|
||||||
|
server.add_tool("test_tool_3", {})
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
mcp_servers=[server],
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")],
|
||||||
|
# Second turn: text message
|
||||||
|
[get_text_message("done")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ModelBehaviorError):
|
||||||
|
if streaming:
|
||||||
|
result = Runner.run_streamed(agent, input="user_message")
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("streaming", [False, True])
|
||||||
|
async def test_runner_works_with_multiple_mcp_servers(streaming: bool):
|
||||||
|
"""Test that the runner works with multiple MCP servers."""
|
||||||
|
server1 = FakeMCPServer()
|
||||||
|
server1.add_tool("test_tool_1", {})
|
||||||
|
|
||||||
|
server2 = FakeMCPServer()
|
||||||
|
server2.add_tool("test_tool_2", {})
|
||||||
|
server2.add_tool("test_tool_3", {})
|
||||||
|
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
mcp_servers=[server1, server2],
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[get_text_message("a_message"), get_function_tool_call("test_tool_2", "")],
|
||||||
|
# Second turn: text message
|
||||||
|
[get_text_message("done")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
result = Runner.run_streamed(agent, input="user_message")
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
assert server1.tool_calls == []
|
||||||
|
assert server2.tool_calls == ["test_tool_2"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("streaming", [False, True])
|
||||||
|
async def test_runner_errors_when_mcp_tools_clash(streaming: bool):
|
||||||
|
"""Test that the runner errors when multiple servers have the same tool name."""
|
||||||
|
server1 = FakeMCPServer()
|
||||||
|
server1.add_tool("test_tool_1", {})
|
||||||
|
server1.add_tool("test_tool_2", {})
|
||||||
|
|
||||||
|
server2 = FakeMCPServer()
|
||||||
|
server2.add_tool("test_tool_2", {})
|
||||||
|
server2.add_tool("test_tool_3", {})
|
||||||
|
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
mcp_servers=[server1, server2],
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[get_text_message("a_message"), get_function_tool_call("test_tool_3", "")],
|
||||||
|
# Second turn: text message
|
||||||
|
[get_text_message("done")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(UserError):
|
||||||
|
if streaming:
|
||||||
|
result = Runner.run_streamed(agent, input="user_message")
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
bar: str
|
||||||
|
baz: int
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("streaming", [False, True])
|
||||||
|
async def test_runner_calls_mcp_tool_with_args(streaming: bool):
|
||||||
|
"""Test that the runner calls an MCP tool when the model produces a tool call."""
|
||||||
|
server = FakeMCPServer()
|
||||||
|
await server.connect()
|
||||||
|
server.add_tool("test_tool_1", {})
|
||||||
|
server.add_tool("test_tool_2", Foo.model_json_schema())
|
||||||
|
server.add_tool("test_tool_3", {})
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
mcp_servers=[server],
|
||||||
|
)
|
||||||
|
|
||||||
|
json_args = json.dumps(Foo(bar="baz", baz=1).model_dump())
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)],
|
||||||
|
# Second turn: text message
|
||||||
|
[get_text_message("done")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
result = Runner.run_streamed(agent, input="user_message")
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
assert server.tool_calls == ["test_tool_2"]
|
||||||
|
assert server.tool_results == [f"result_test_tool_2_{json_args}"]
|
||||||
|
|
||||||
|
await server.cleanup()
|
||||||
38
tests/mcp/test_server_errors.py
Normal file
38
tests/mcp/test_server_errors.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agents.exceptions import UserError
|
||||||
|
from agents.mcp.server import _MCPServerWithClientSession
|
||||||
|
|
||||||
|
|
||||||
|
class CrashingClientSessionServer(_MCPServerWithClientSession):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(cache_tools_list=False)
|
||||||
|
self.cleanup_called = False
|
||||||
|
|
||||||
|
def create_streams(self):
|
||||||
|
raise ValueError("Crash!")
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
self.cleanup_called = True
|
||||||
|
await super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_errors_cause_error_and_cleanup_called():
|
||||||
|
server = CrashingClientSessionServer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await server.connect()
|
||||||
|
|
||||||
|
assert server.cleanup_called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_calling_connect_causes_error():
|
||||||
|
server = CrashingClientSessionServer()
|
||||||
|
|
||||||
|
with pytest.raises(UserError):
|
||||||
|
await server.list_tools()
|
||||||
|
|
||||||
|
with pytest.raises(UserError):
|
||||||
|
await server.call_tool("foo", {})
|
||||||
|
|
@ -290,6 +290,7 @@ async def get_execute_result(
|
||||||
|
|
||||||
processed_response = RunImpl.process_model_response(
|
processed_response = RunImpl.process_model_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=output_schema,
|
output_schema=output_schema,
|
||||||
handoffs=handoffs,
|
handoffs=handoffs,
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,11 @@ def test_empty_response():
|
||||||
)
|
)
|
||||||
|
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=[],
|
||||||
)
|
)
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
assert not result.functions
|
assert not result.functions
|
||||||
|
|
@ -57,13 +61,14 @@ def test_no_tool_calls():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[]
|
||||||
)
|
)
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
assert not result.functions
|
assert not result.functions
|
||||||
|
|
||||||
|
|
||||||
def test_single_tool_call():
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_tool_call():
|
||||||
agent = Agent(name="test", tools=[get_function_tool(name="test")])
|
agent = Agent(name="test", tools=[get_function_tool(name="test")])
|
||||||
response = ModelResponse(
|
response = ModelResponse(
|
||||||
output=[
|
output=[
|
||||||
|
|
@ -74,7 +79,11 @@ def test_single_tool_call():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
assert result.functions and len(result.functions) == 1
|
assert result.functions and len(result.functions) == 1
|
||||||
|
|
@ -84,7 +93,8 @@ def test_single_tool_call():
|
||||||
assert func.tool_call.arguments == ""
|
assert func.tool_call.arguments == ""
|
||||||
|
|
||||||
|
|
||||||
def test_missing_tool_call_raises_error():
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_tool_call_raises_error():
|
||||||
agent = Agent(name="test", tools=[get_function_tool(name="test")])
|
agent = Agent(name="test", tools=[get_function_tool(name="test")])
|
||||||
response = ModelResponse(
|
response = ModelResponse(
|
||||||
output=[
|
output=[
|
||||||
|
|
@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error():
|
||||||
|
|
||||||
with pytest.raises(ModelBehaviorError):
|
with pytest.raises(ModelBehaviorError):
|
||||||
RunImpl.process_model_response(
|
RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_tool_calls():
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_tool_calls():
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name="test",
|
name="test",
|
||||||
tools=[
|
tools=[
|
||||||
|
|
@ -121,7 +136,11 @@ def test_multiple_tool_calls():
|
||||||
)
|
)
|
||||||
|
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
assert result.functions and len(result.functions) == 2
|
assert result.functions and len(result.functions) == 2
|
||||||
|
|
@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent_3, response=response, output_schema=None, handoffs=[]
|
agent=agent_3,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent_3.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert not result.handoffs, "Shouldn't have a handoff here"
|
assert not result.handoffs, "Shouldn't have a handoff here"
|
||||||
|
|
||||||
|
|
@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
|
all_tools=await agent_3.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||||
handoff = result.handoffs[0]
|
handoff = result.handoffs[0]
|
||||||
|
|
@ -189,10 +213,12 @@ async def test_missing_handoff_fails():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
|
all_tools=await agent_3.get_all_tools(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_handoffs_doesnt_error():
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_handoffs_doesnt_error():
|
||||||
agent_1 = Agent(name="test_1")
|
agent_1 = Agent(name="test_1")
|
||||||
agent_2 = Agent(name="test_2")
|
agent_2 = Agent(name="test_2")
|
||||||
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
|
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
|
||||||
|
|
@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
|
all_tools=await agent_3.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
||||||
|
|
||||||
|
|
@ -218,7 +245,8 @@ class Foo(BaseModel):
|
||||||
bar: str
|
bar: str
|
||||||
|
|
||||||
|
|
||||||
def test_final_output_parsed_correctly():
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_output_parsed_correctly():
|
||||||
agent = Agent(name="test", output_type=Foo)
|
agent = Agent(name="test", output_type=Foo)
|
||||||
response = ModelResponse(
|
response = ModelResponse(
|
||||||
output=[
|
output=[
|
||||||
|
|
@ -234,10 +262,12 @@ def test_final_output_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=Runner._get_output_schema(agent),
|
output_schema=Runner._get_output_schema(agent),
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_file_search_tool_call_parsed_correctly():
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_search_tool_call_parsed_correctly():
|
||||||
# Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool
|
# Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool
|
||||||
# runs are scheduled.
|
# runs are scheduled.
|
||||||
|
|
||||||
|
|
@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
# The final item should be a ToolCallItem for the file search call
|
# The final item should be a ToolCallItem for the file search call
|
||||||
assert any(
|
assert any(
|
||||||
|
|
@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly():
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
|
|
||||||
|
|
||||||
def test_function_web_search_tool_call_parsed_correctly():
|
@pytest.mark.asyncio
|
||||||
|
async def test_function_web_search_tool_call_parsed_correctly():
|
||||||
agent = Agent(name="test")
|
agent = Agent(name="test")
|
||||||
web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call")
|
web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call")
|
||||||
response = ModelResponse(
|
response = ModelResponse(
|
||||||
|
|
@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
|
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
|
||||||
|
|
@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly():
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
|
|
||||||
|
|
||||||
def test_reasoning_item_parsed_correctly():
|
@pytest.mark.asyncio
|
||||||
|
async def test_reasoning_item_parsed_correctly():
|
||||||
# Verify that a Reasoning output item is converted into a ReasoningItem.
|
# Verify that a Reasoning output item is converted into a ReasoningItem.
|
||||||
|
|
||||||
reasoning = ResponseReasoningItem(
|
reasoning = ResponseReasoningItem(
|
||||||
|
|
@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=Agent(name="test"), response=response, output_schema=None, handoffs=[]
|
agent=Agent(name="test"),
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await Agent(name="test").get_all_tools(),
|
||||||
)
|
)
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
|
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
|
||||||
|
|
@ -342,7 +386,8 @@ class DummyComputer(Computer):
|
||||||
return None # pragma: no cover
|
return None # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
def test_computer_tool_call_without_computer_tool_raises_error():
|
@pytest.mark.asyncio
|
||||||
|
async def test_computer_tool_call_without_computer_tool_raises_error():
|
||||||
# If the agent has no ComputerTool in its tools, process_model_response should raise a
|
# If the agent has no ComputerTool in its tools, process_model_response should raise a
|
||||||
# ModelBehaviorError when encountering a ResponseComputerToolCall.
|
# ModelBehaviorError when encountering a ResponseComputerToolCall.
|
||||||
computer_call = ResponseComputerToolCall(
|
computer_call = ResponseComputerToolCall(
|
||||||
|
|
@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error():
|
||||||
)
|
)
|
||||||
with pytest.raises(ModelBehaviorError):
|
with pytest.raises(ModelBehaviorError):
|
||||||
RunImpl.process_model_response(
|
RunImpl.process_model_response(
|
||||||
agent=Agent(name="test"), response=response, output_schema=None, handoffs=[]
|
agent=Agent(name="test"),
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await Agent(name="test").get_all_tools(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_computer_tool_call_with_computer_tool_parsed_correctly():
|
@pytest.mark.asyncio
|
||||||
|
async def test_computer_tool_call_with_computer_tool_parsed_correctly():
|
||||||
# If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a
|
# If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a
|
||||||
# ToolCallItem and scheduled to run in computer_actions.
|
# ToolCallItem and scheduled to run in computer_actions.
|
||||||
dummy_computer = DummyComputer()
|
dummy_computer = DummyComputer()
|
||||||
|
|
@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly():
|
||||||
referenceable_id=None,
|
referenceable_id=None,
|
||||||
)
|
)
|
||||||
result = RunImpl.process_model_response(
|
result = RunImpl.process_model_response(
|
||||||
agent=agent, response=response, output_schema=None, handoffs=[]
|
agent=agent,
|
||||||
|
response=response,
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
all_tools=await agent.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(item, ToolCallItem) and item.raw_item is computer_call
|
isinstance(item, ToolCallItem) and item.raw_item is computer_call
|
||||||
|
|
@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly():
|
||||||
assert result.computer_actions and result.computer_actions[0].tool_call == computer_call
|
assert result.computer_actions and result.computer_actions[0].tool_call == computer_call
|
||||||
|
|
||||||
|
|
||||||
def test_tool_and_handoff_parsed_correctly():
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_and_handoff_parsed_correctly():
|
||||||
agent_1 = Agent(name="test_1")
|
agent_1 = Agent(name="test_1")
|
||||||
agent_2 = Agent(name="test_2")
|
agent_2 = Agent(name="test_2")
|
||||||
agent_3 = Agent(
|
agent_3 = Agent(
|
||||||
|
|
@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
|
all_tools=await agent_3.get_all_tools(),
|
||||||
)
|
)
|
||||||
assert result.functions and len(result.functions) == 1
|
assert result.functions and len(result.functions) == 1
|
||||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||||
|
|
|
||||||
|
|
@ -269,7 +269,7 @@ async def test_timeout_waiting_for_created_event(monkeypatch):
|
||||||
async for _ in turns:
|
async for _ in turns:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert "Timeout waiting for transcription_session.created event" in str(exc_info.value)
|
assert "Timeout waiting for transcription_session.created event" in str(exc_info.value)
|
||||||
|
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
@ -302,13 +302,11 @@ async def test_session_error_event():
|
||||||
trace_include_sensitive_audio_data=False,
|
trace_include_sensitive_audio_data=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(STTWebsocketConnectionError) as exc_info:
|
with pytest.raises(STTWebsocketConnectionError):
|
||||||
turns = session.transcribe_turns()
|
turns = session.transcribe_turns()
|
||||||
async for _ in turns:
|
async for _ in turns:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert "Simulated server error!" in str(exc_info.value)
|
|
||||||
|
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -362,8 +360,8 @@ async def test_inactivity_timeout():
|
||||||
async for turn in session.transcribe_turns():
|
async for turn in session.transcribe_turns():
|
||||||
collected_turns.append(turn)
|
collected_turns.append(turn)
|
||||||
|
|
||||||
assert "Timeout waiting for transcription_session" in str(exc_info.value)
|
assert "Timeout waiting for transcription_session" in str(exc_info.value)
|
||||||
|
|
||||||
assert len(collected_turns) == 0, "No transcripts expected, but we got something?"
|
assert len(collected_turns) == 0, "No transcripts expected, but we got something?"
|
||||||
|
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue