From af80e3a97123a5a0ad0fba695fbc257163c23224 Mon Sep 17 00:00:00 2001 From: Nathan Brake <33383515+njbrake@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:12:46 -0400 Subject: [PATCH] Prevent MCP ClientSession hang (#580) Per https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts "Implementations SHOULD establish timeouts for all sent requests, to prevent hung connections and resource exhaustion. When the request has not received a success or error response within the timeout period, the sender SHOULD issue a cancellation notification for that request and stop waiting for a response. SDKs and other middleware SHOULD allow these timeouts to be configured on a per-request basis." I picked 5 seconds since that's the default for SSE --- src/agents/mcp/server.py | 26 ++++++++++++++++++++++---- tests/mcp/test_server_errors.py | 2 +- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 9a137bb..9916c92 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc import asyncio from contextlib import AbstractAsyncContextManager, AsyncExitStack +from datetime import timedelta from pathlib import Path from typing import Any, Literal @@ -54,7 +55,7 @@ class MCPServer(abc.ABC): class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool): + def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -63,12 +64,16 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC): by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.client_session_timeout_seconds = client_session_timeout_seconds + # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True self._tools_list: list[MCPTool] | None = None @@ -101,7 +106,15 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC): try: transport = await self.exit_stack.enter_async_context(self.create_streams()) read, write = transport - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + session = await self.exit_stack.enter_async_context( + ClientSession( + read, + write, + timedelta(seconds=self.client_session_timeout_seconds) + if self.client_session_timeout_seconds + else None, + ) + ) await session.initialize() self.session = session except Exception as e: @@ -183,6 +196,7 @@ class MCPServerStdio(_MCPServerWithClientSession): params: MCPServerStdioParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the stdio transport. @@ -199,8 +213,9 @@ class MCPServerStdio(_MCPServerWithClientSession): improve latency (by avoiding a round-trip to the server every time). name: A readable name for the server. If not provided, we'll create one from the command. + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = StdioServerParameters( command=params["command"], @@ -257,6 +272,7 @@ class MCPServerSse(_MCPServerWithClientSession): params: MCPServerSseParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -274,8 +290,10 @@ class MCPServerSse(_MCPServerWithClientSession): name: A readable name for the server. If not provided, we'll create one from the URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = params self._name = name or f"sse: {self.params['url']}" diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index bdca7ce..fbd8db1 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -6,7 +6,7 @@ from agents.mcp.server import _MCPServerWithClientSession class CrashingClientSessionServer(_MCPServerWithClientSession): def __init__(self): - super().__init__(cache_tools_list=False) + super().__init__(cache_tools_list=False, client_session_timeout_seconds=5) self.cleanup_called = False def create_streams(self):