Prevent MCP ClientSession hang (#580)

Per
https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts

"Implementations SHOULD establish timeouts for all sent requests, to
prevent hung connections and resource exhaustion. When the request has
not received a success or error response within the timeout period, the
sender SHOULD issue a cancellation notification for that request and
stop waiting for a response.

SDKs and other middleware SHOULD allow these timeouts to be configured
on a per-request basis."

I picked 5 seconds since that's the default for SSE
This commit is contained in:
Nathan Brake 2025-04-24 12:12:46 -04:00 committed by GitHub
parent 3755ea8658
commit af80e3a971
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 5 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import abc import abc
import asyncio import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack from contextlib import AbstractAsyncContextManager, AsyncExitStack
from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
@ -54,7 +55,7 @@ class MCPServer(abc.ABC):
class _MCPServerWithClientSession(MCPServer, abc.ABC): class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server.""" """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
def __init__(self, cache_tools_list: bool): def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
""" """
Args: Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@ -63,12 +64,16 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time). (by avoiding a round-trip to the server every time).
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
""" """
self.session: ClientSession | None = None self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack() self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock() self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list self.cache_tools_list = cache_tools_list
self.client_session_timeout_seconds = client_session_timeout_seconds
# The cache is always dirty at startup, so that we fetch tools at least once # The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None self._tools_list: list[MCPTool] | None = None
@ -101,7 +106,15 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
try: try:
transport = await self.exit_stack.enter_async_context(self.create_streams()) transport = await self.exit_stack.enter_async_context(self.create_streams())
read, write = transport read, write = transport
session = await self.exit_stack.enter_async_context(ClientSession(read, write)) session = await self.exit_stack.enter_async_context(
ClientSession(
read,
write,
timedelta(seconds=self.client_session_timeout_seconds)
if self.client_session_timeout_seconds
else None,
)
)
await session.initialize() await session.initialize()
self.session = session self.session = session
except Exception as e: except Exception as e:
@ -183,6 +196,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
params: MCPServerStdioParams, params: MCPServerStdioParams,
cache_tools_list: bool = False, cache_tools_list: bool = False,
name: str | None = None, name: str | None = None,
client_session_timeout_seconds: float | None = 5,
): ):
"""Create a new MCP server based on the stdio transport. """Create a new MCP server based on the stdio transport.
@ -199,8 +213,9 @@ class MCPServerStdio(_MCPServerWithClientSession):
improve latency (by avoiding a round-trip to the server every time). improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the name: A readable name for the server. If not provided, we'll create one from the
command. command.
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
""" """
super().__init__(cache_tools_list) super().__init__(cache_tools_list, client_session_timeout_seconds)
self.params = StdioServerParameters( self.params = StdioServerParameters(
command=params["command"], command=params["command"],
@ -257,6 +272,7 @@ class MCPServerSse(_MCPServerWithClientSession):
params: MCPServerSseParams, params: MCPServerSseParams,
cache_tools_list: bool = False, cache_tools_list: bool = False,
name: str | None = None, name: str | None = None,
client_session_timeout_seconds: float | None = 5,
): ):
"""Create a new MCP server based on the HTTP with SSE transport. """Create a new MCP server based on the HTTP with SSE transport.
@ -274,8 +290,10 @@ class MCPServerSse(_MCPServerWithClientSession):
name: A readable name for the server. If not provided, we'll create one from the name: A readable name for the server. If not provided, we'll create one from the
URL. URL.
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
""" """
super().__init__(cache_tools_list) super().__init__(cache_tools_list, client_session_timeout_seconds)
self.params = params self.params = params
self._name = name or f"sse: {self.params['url']}" self._name = name or f"sse: {self.params['url']}"

View file

@ -6,7 +6,7 @@ from agents.mcp.server import _MCPServerWithClientSession
class CrashingClientSessionServer(_MCPServerWithClientSession): class CrashingClientSessionServer(_MCPServerWithClientSession):
def __init__(self): def __init__(self):
super().__init__(cache_tools_list=False) super().__init__(cache_tools_list=False, client_session_timeout_seconds=5)
self.cleanup_called = False self.cleanup_called = False
def create_streams(self): def create_streams(self):