Prevent MCP ClientSession hang (#580)
Per https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts "Implementations SHOULD establish timeouts for all sent requests, to prevent hung connections and resource exhaustion. When the request has not received a success or error response within the timeout period, the sender SHOULD issue a cancellation notification for that request and stop waiting for a response. SDKs and other middleware SHOULD allow these timeouts to be configured on a per-request basis." I picked 5 seconds since that's the default for SSE
This commit is contained in:
parent
3755ea8658
commit
af80e3a971
2 changed files with 23 additions and 5 deletions
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||||
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
@ -54,7 +55,7 @@ class MCPServer(abc.ABC):
|
||||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||||
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
||||||
|
|
||||||
def __init__(self, cache_tools_list: bool):
|
def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||||
|
|
@ -63,12 +64,16 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||||
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
||||||
server will not change its tools list, because it can drastically improve latency
|
server will not change its tools list, because it can drastically improve latency
|
||||||
(by avoiding a round-trip to the server every time).
|
(by avoiding a round-trip to the server every time).
|
||||||
|
|
||||||
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
||||||
"""
|
"""
|
||||||
self.session: ClientSession | None = None
|
self.session: ClientSession | None = None
|
||||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||||
self.cache_tools_list = cache_tools_list
|
self.cache_tools_list = cache_tools_list
|
||||||
|
|
||||||
|
self.client_session_timeout_seconds = client_session_timeout_seconds
|
||||||
|
|
||||||
# The cache is always dirty at startup, so that we fetch tools at least once
|
# The cache is always dirty at startup, so that we fetch tools at least once
|
||||||
self._cache_dirty = True
|
self._cache_dirty = True
|
||||||
self._tools_list: list[MCPTool] | None = None
|
self._tools_list: list[MCPTool] | None = None
|
||||||
|
|
@ -101,7 +106,15 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||||
try:
|
try:
|
||||||
transport = await self.exit_stack.enter_async_context(self.create_streams())
|
transport = await self.exit_stack.enter_async_context(self.create_streams())
|
||||||
read, write = transport
|
read, write = transport
|
||||||
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
session = await self.exit_stack.enter_async_context(
|
||||||
|
ClientSession(
|
||||||
|
read,
|
||||||
|
write,
|
||||||
|
timedelta(seconds=self.client_session_timeout_seconds)
|
||||||
|
if self.client_session_timeout_seconds
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
self.session = session
|
self.session = session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -183,6 +196,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
||||||
params: MCPServerStdioParams,
|
params: MCPServerStdioParams,
|
||||||
cache_tools_list: bool = False,
|
cache_tools_list: bool = False,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
|
client_session_timeout_seconds: float | None = 5,
|
||||||
):
|
):
|
||||||
"""Create a new MCP server based on the stdio transport.
|
"""Create a new MCP server based on the stdio transport.
|
||||||
|
|
||||||
|
|
@ -199,8 +213,9 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
||||||
improve latency (by avoiding a round-trip to the server every time).
|
improve latency (by avoiding a round-trip to the server every time).
|
||||||
name: A readable name for the server. If not provided, we'll create one from the
|
name: A readable name for the server. If not provided, we'll create one from the
|
||||||
command.
|
command.
|
||||||
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
||||||
"""
|
"""
|
||||||
super().__init__(cache_tools_list)
|
super().__init__(cache_tools_list, client_session_timeout_seconds)
|
||||||
|
|
||||||
self.params = StdioServerParameters(
|
self.params = StdioServerParameters(
|
||||||
command=params["command"],
|
command=params["command"],
|
||||||
|
|
@ -257,6 +272,7 @@ class MCPServerSse(_MCPServerWithClientSession):
|
||||||
params: MCPServerSseParams,
|
params: MCPServerSseParams,
|
||||||
cache_tools_list: bool = False,
|
cache_tools_list: bool = False,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
|
client_session_timeout_seconds: float | None = 5,
|
||||||
):
|
):
|
||||||
"""Create a new MCP server based on the HTTP with SSE transport.
|
"""Create a new MCP server based on the HTTP with SSE transport.
|
||||||
|
|
||||||
|
|
@ -274,8 +290,10 @@ class MCPServerSse(_MCPServerWithClientSession):
|
||||||
|
|
||||||
name: A readable name for the server. If not provided, we'll create one from the
|
name: A readable name for the server. If not provided, we'll create one from the
|
||||||
URL.
|
URL.
|
||||||
|
|
||||||
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
||||||
"""
|
"""
|
||||||
super().__init__(cache_tools_list)
|
super().__init__(cache_tools_list, client_session_timeout_seconds)
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self._name = name or f"sse: {self.params['url']}"
|
self._name = name or f"sse: {self.params['url']}"
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from agents.mcp.server import _MCPServerWithClientSession
|
||||||
|
|
||||||
class CrashingClientSessionServer(_MCPServerWithClientSession):
|
class CrashingClientSessionServer(_MCPServerWithClientSession):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(cache_tools_list=False)
|
super().__init__(cache_tools_list=False, client_session_timeout_seconds=5)
|
||||||
self.cleanup_called = False
|
self.cleanup_called = False
|
||||||
|
|
||||||
def create_streams(self):
|
def create_streams(self):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue