Prevent MCP ClientSession hang (#580)
Per https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts "Implementations SHOULD establish timeouts for all sent requests, to prevent hung connections and resource exhaustion. When the request has not received a success or error response within the timeout period, the sender SHOULD issue a cancellation notification for that request and stop waiting for a response. SDKs and other middleware SHOULD allow these timeouts to be configured on a per-request basis." I picked 5 seconds since that's the default for SSE
This commit is contained in:
parent
3755ea8658
commit
af80e3a971
2 changed files with 23 additions and 5 deletions
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import abc
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
|
|
@ -54,7 +55,7 @@ class MCPServer(abc.ABC):
|
|||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
||||
|
||||
def __init__(self, cache_tools_list: bool):
|
||||
def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
|
||||
"""
|
||||
Args:
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
|
|
@ -63,12 +64,16 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|||
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
||||
server will not change its tools list, because it can drastically improve latency
|
||||
(by avoiding a round-trip to the server every time).
|
||||
|
||||
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
||||
"""
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||
self.cache_tools_list = cache_tools_list
|
||||
|
||||
self.client_session_timeout_seconds = client_session_timeout_seconds
|
||||
|
||||
# The cache is always dirty at startup, so that we fetch tools at least once
|
||||
self._cache_dirty = True
|
||||
self._tools_list: list[MCPTool] | None = None
|
||||
|
|
@ -101,7 +106,15 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|||
try:
|
||||
transport = await self.exit_stack.enter_async_context(self.create_streams())
|
||||
read, write = transport
|
||||
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(
|
||||
read,
|
||||
write,
|
||||
timedelta(seconds=self.client_session_timeout_seconds)
|
||||
if self.client_session_timeout_seconds
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
except Exception as e:
|
||||
|
|
@ -183,6 +196,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|||
params: MCPServerStdioParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: str | None = None,
|
||||
client_session_timeout_seconds: float | None = 5,
|
||||
):
|
||||
"""Create a new MCP server based on the stdio transport.
|
||||
|
||||
|
|
@ -199,8 +213,9 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|||
improve latency (by avoiding a round-trip to the server every time).
|
||||
name: A readable name for the server. If not provided, we'll create one from the
|
||||
command.
|
||||
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
||||
"""
|
||||
super().__init__(cache_tools_list)
|
||||
super().__init__(cache_tools_list, client_session_timeout_seconds)
|
||||
|
||||
self.params = StdioServerParameters(
|
||||
command=params["command"],
|
||||
|
|
@ -257,6 +272,7 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|||
params: MCPServerSseParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: str | None = None,
|
||||
client_session_timeout_seconds: float | None = 5,
|
||||
):
|
||||
"""Create a new MCP server based on the HTTP with SSE transport.
|
||||
|
||||
|
|
@ -274,8 +290,10 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|||
|
||||
name: A readable name for the server. If not provided, we'll create one from the
|
||||
URL.
|
||||
|
||||
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
||||
"""
|
||||
super().__init__(cache_tools_list)
|
||||
super().__init__(cache_tools_list, client_session_timeout_seconds)
|
||||
|
||||
self.params = params
|
||||
self._name = name or f"sse: {self.params['url']}"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from agents.mcp.server import _MCPServerWithClientSession
|
|||
|
||||
class CrashingClientSessionServer(_MCPServerWithClientSession):
|
||||
def __init__(self):
|
||||
super().__init__(cache_tools_list=False)
|
||||
super().__init__(cache_tools_list=False, client_session_timeout_seconds=5)
|
||||
self.cleanup_called = False
|
||||
|
||||
def create_streams(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue