Prevent MCP ClientSession hang (#580)

Per
https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts

"Implementations SHOULD establish timeouts for all sent requests, to
prevent hung connections and resource exhaustion. When the request has
not received a success or error response within the timeout period, the
sender SHOULD issue a cancellation notification for that request and
stop waiting for a response.

SDKs and other middleware SHOULD allow these timeouts to be configured
on a per-request basis."

I picked 5 seconds since that's the default for SSE
This commit is contained in:
Nathan Brake 2025-04-24 12:12:46 -04:00 committed by GitHub
parent 3755ea8658
commit af80e3a971
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 5 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import abc
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from datetime import timedelta
from pathlib import Path
from typing import Any, Literal
@ -54,7 +55,7 @@ class MCPServer(abc.ABC):
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
def __init__(self, cache_tools_list: bool):
def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@ -63,12 +64,16 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
self.client_session_timeout_seconds = client_session_timeout_seconds
# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None
@ -101,7 +106,15 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
try:
transport = await self.exit_stack.enter_async_context(self.create_streams())
read, write = transport
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
session = await self.exit_stack.enter_async_context(
ClientSession(
read,
write,
timedelta(seconds=self.client_session_timeout_seconds)
if self.client_session_timeout_seconds
else None,
)
)
await session.initialize()
self.session = session
except Exception as e:
@ -183,6 +196,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
):
"""Create a new MCP server based on the stdio transport.
@ -199,8 +213,9 @@ class MCPServerStdio(_MCPServerWithClientSession):
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
command.
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
"""
super().__init__(cache_tools_list)
super().__init__(cache_tools_list, client_session_timeout_seconds)
self.params = StdioServerParameters(
command=params["command"],
@ -257,6 +272,7 @@ class MCPServerSse(_MCPServerWithClientSession):
params: MCPServerSseParams,
cache_tools_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
):
"""Create a new MCP server based on the HTTP with SSE transport.
@ -274,8 +290,10 @@ class MCPServerSse(_MCPServerWithClientSession):
name: A readable name for the server. If not provided, we'll create one from the
URL.
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
"""
super().__init__(cache_tools_list)
super().__init__(cache_tools_list, client_session_timeout_seconds)
self.params = params
self._name = name or f"sse: {self.params['url']}"

View file

@ -6,7 +6,7 @@ from agents.mcp.server import _MCPServerWithClientSession
class CrashingClientSessionServer(_MCPServerWithClientSession):
def __init__(self):
super().__init__(cache_tools_list=False)
super().__init__(cache_tools_list=False, client_session_timeout_seconds=5)
self.cleanup_called = False
def create_streams(self):