### Summary: This enables users to **use** MCP inside the SDK. 1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]` 2. When an agent runs, we look up its MCP tools and add them to the list of tools. 3. When a tool call occurs, we call the relevant MCP server. Notes: 1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc. 2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup. ### Test Plan: See unit tests. Also has an end to end example next PR.
57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from mcp.types import ListToolsResult, Tool as MCPTool
|
|
|
|
from agents.mcp import MCPServerStdio
|
|
|
|
from .helpers import DummyStreamsContextManager, tee
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
|
|
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
|
|
@patch("mcp.client.session.ClientSession.list_tools")
|
|
async def test_server_caching_works(
|
|
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
|
|
):
|
|
"""Test that if we turn caching on, the list of tools is cached and not fetched from the server
|
|
on each call to `list_tools()`.
|
|
"""
|
|
server = MCPServerStdio(
|
|
params={
|
|
"command": tee,
|
|
},
|
|
cache_tools_list=True,
|
|
)
|
|
|
|
tools = [
|
|
MCPTool(name="tool1", inputSchema={}),
|
|
MCPTool(name="tool2", inputSchema={}),
|
|
]
|
|
|
|
mock_list_tools.return_value = ListToolsResult(tools=tools)
|
|
|
|
async with server:
|
|
# Call list_tools() multiple times
|
|
tools = await server.list_tools()
|
|
assert tools == tools
|
|
|
|
assert mock_list_tools.call_count == 1, "list_tools() should have been called once"
|
|
|
|
# Call list_tools() again, should return the cached value
|
|
tools = await server.list_tools()
|
|
assert tools == tools
|
|
|
|
assert mock_list_tools.call_count == 1, "list_tools() should not have been called again"
|
|
|
|
# Invalidate the cache and call list_tools() again
|
|
server.invalidate_tools_cache()
|
|
tools = await server.list_tools()
|
|
assert tools == tools
|
|
|
|
assert mock_list_tools.call_count == 2, "list_tools() should be called again"
|
|
|
|
# Without invalidating the cache, calling list_tools() again should return the cached value
|
|
tools = await server.list_tools()
|
|
assert tools == tools
|