# PR Description
Consider this PR the result of a full pass through of this repository.
## Add helper for adding tools to an `MCPApp`
You can now add all of the tools in a module to an `MCPApp` via
`app.add_tools_from_module(...)`
## Edit what `arcade new` generates
First, I updated the backend to use hatchling.
Second, the structure generated before this PR was simple, but did not
create a proper Python module.
This hindered developers in the following ways:
1. Difficult to add the tools in your server to an evaluation suite
2. Difficult to add more than one tool to an MCPApp at a time
3. All other niceties that come with being able to import modules
```
# Before
server/
├── .env.example
├── server.py
└── pyproject.toml
```
This PR updates the structure generated such that a valid Python module
is generated:
```
# After
server/
├── pyproject.toml
└── src/
└── server/
├── __init__.py
├── .env.example
└── server.py
```
## Fix Tool Chaining
`self._ctx.server.executor.run(...)` was being called, but `MCPServer`
does not have an instance of `ToolExecutor` (and it's not intended to be
an instance anyways). I updated `Tool.call_raw` to pass the programmatic
tool call through the `MCPServer._handle_call_tool`. This means that the
programmatic tool calls now go through the same steps that a typical
tool call (initiated by the MCP client) would.
This means that **toolA**, which specifies **requirementsA**, is
permitted to call **toolB**, which specifies **requirementsB**, without
needing to explicitly declare or satisfy **requirementsB**. I believe
this is acceptable because the secrets and/or auth token associated with
**toolB's** `Context` are not exposed to **toolA**, and the secrets
and/or auth token associated with **toolA's** `Context` are not exposed
to **toolB**.
## Fix User Elicitation
1. The read & write streams were created with a maximum queue size of 0.
I increased this to 100.
2. I updated `ServerSession`'s run loop to both read messages from the
stream & process them concurrently. This enables server initiated
requests (like user elicitation and progress reporting) to be handled
while tools are being executed. Otherwise, the server initiated requests
would wait for the tool to finish executing and the tool execution would
wait for the server initiated request to finish.
3.
## Fix Progress Reporting
Progress tokens sent by the client were not being stored. Therefore
there was no way to notify a client with progress updates. I am now
storing the `progressToken`, along with other `_meta` sent from the
client, in the `ServerSession`'s `_request_meta`. I am setting
`_request_meta` whenever the `MCPServer` is handling an incoming message
from a client.
## Fix handling of server names with spaces
Before:
Server name: "The simple server name"
Tool name: whisper_secret
Name seen by client: "The_simple_server_name_WhisperSecret"
After
Server name: "The simple server name"
Tool name: whisper_secret
Name seen by client: "TheSimpleServerName_WhisperSecret"
## Add Integration Tests
The stdio integration test is much more comprehensive than the http
integration test. These tests will let me sleep a bit more at night
## Add Example MCP Servers
Example servers for sampling, user-elicitation, progress reporting,
logging, tool chaining, combining prebuilt tools with custom tools, tool
secrets, tool auth, evaluations, and more!
## Add Docker template
Added a Docker template for running an MCP server in Docker (and removed
the old docker stuff)
652 lines
21 KiB
Python
652 lines
21 KiB
Python
"""
|
|
MCP Server Session
|
|
|
|
Manages per-session state and provides session-level operations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from enum import Enum
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
|
|
import anyio
|
|
|
|
from arcade_mcp_server.context import Context
|
|
from arcade_mcp_server.exceptions import RequestError, SessionError
|
|
from arcade_mcp_server.types import (
|
|
CancelledNotification,
|
|
CancelledParams,
|
|
ClientCapabilities,
|
|
CompleteResult,
|
|
CreateMessageResult,
|
|
ElicitResult,
|
|
InitializeParams,
|
|
JSONRPCError,
|
|
JSONRPCMessage,
|
|
JSONRPCRequest,
|
|
ListRootsResult,
|
|
LoggingLevel,
|
|
LoggingMessageNotification,
|
|
LoggingMessageParams,
|
|
ProgressNotification,
|
|
ProgressNotificationParams,
|
|
PromptListChangedNotification,
|
|
ResourceListChangedNotification,
|
|
ToolListChangedNotification,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InitializationState(Enum):
|
|
"""Session initialization states."""
|
|
|
|
NOT_INITIALIZED = 1
|
|
INITIALIZING = 2
|
|
INITIALIZED = 3
|
|
|
|
|
|
class RequestManager:
|
|
"""
|
|
Manages server-initiated requests to the client.
|
|
|
|
Handles request/response correlation for bidirectional communication.
|
|
"""
|
|
|
|
def __init__(self, write_stream: Any):
|
|
"""Initialize request manager."""
|
|
self._write_stream = write_stream
|
|
self._pending_requests: dict[str, asyncio.Future[Any]] = {}
|
|
self._lock = asyncio.Lock()
|
|
self._closed = asyncio.Event()
|
|
|
|
def is_closed(self) -> bool:
|
|
"""Return True if the manager has been closed/cancelled."""
|
|
return self._closed.is_set()
|
|
|
|
async def send_request(
|
|
self,
|
|
method: str,
|
|
params: dict[str, Any] | None = None,
|
|
timeout: float = 400.0,
|
|
) -> Any:
|
|
"""
|
|
Send a request to the client and wait for response.
|
|
|
|
Args:
|
|
method: Request method
|
|
params: Request parameters
|
|
timeout: Request timeout in seconds
|
|
|
|
Returns:
|
|
Response result
|
|
|
|
Raises:
|
|
MCPTimeoutError: If request times out
|
|
ProtocolError: If response is an error
|
|
"""
|
|
if self._closed.is_set():
|
|
raise SessionError("Session closed")
|
|
request_id = str(uuid.uuid4())
|
|
|
|
# Create request
|
|
request = JSONRPCRequest(
|
|
id=request_id,
|
|
method=method,
|
|
params=params or {},
|
|
)
|
|
|
|
# Create future for response
|
|
future: asyncio.Future[Any] = asyncio.Future()
|
|
async with self._lock:
|
|
if self._closed.is_set():
|
|
raise SessionError("Session closed")
|
|
self._pending_requests[request_id] = future
|
|
|
|
try:
|
|
# Send request
|
|
message = request.model_dump_json(exclude_none=True) + "\n"
|
|
logger.debug(f"Sending server->client request method={method} id={request_id}")
|
|
await self._write_stream.send(message)
|
|
|
|
# Wait for response
|
|
result = await asyncio.wait_for(future, timeout=timeout)
|
|
logger.debug(f"Received response for id={request_id} method={method}")
|
|
return result
|
|
|
|
finally:
|
|
# Clean up
|
|
async with self._lock:
|
|
self._pending_requests.pop(request_id, None)
|
|
|
|
async def handle_response(self, message: dict[str, Any]) -> None:
|
|
"""
|
|
Handle a response message from the client.
|
|
|
|
Args:
|
|
message: Response message
|
|
"""
|
|
if self._closed.is_set():
|
|
# Drop any late responses after closure
|
|
return
|
|
request_id = message.get("id")
|
|
if not request_id:
|
|
logger.debug("Received response without id; ignoring")
|
|
return
|
|
|
|
async with self._lock:
|
|
future = self._pending_requests.get(str(request_id))
|
|
if future and not future.done():
|
|
if "error" in message:
|
|
logger.debug(f"Response id={request_id} contains error; propagating")
|
|
future.set_exception(RequestError(f"Request failed: {message['error']}"))
|
|
else:
|
|
logger.debug(f"Correlated response id={request_id} -> completing future")
|
|
future.set_result(message.get("result"))
|
|
else:
|
|
logger.debug(
|
|
f"No pending future for response id={request_id}; possibly late or mismatched"
|
|
)
|
|
|
|
async def cancel_all(self, reason: str | None = None) -> None:
|
|
"""Cancel all pending requests and notify the client.
|
|
|
|
Sends a CancelledNotification for each in-flight request and
|
|
completes their futures with SessionError so awaiters unblock.
|
|
"""
|
|
# Mark closed first to prevent new requests
|
|
if not self._closed.is_set():
|
|
self._closed.set()
|
|
# Snapshot current pending ids and futures
|
|
async with self._lock:
|
|
pending_items = list(self._pending_requests.items())
|
|
# Clear the map eagerly to prevent races with late responses
|
|
self._pending_requests.clear()
|
|
|
|
if not pending_items:
|
|
return
|
|
|
|
# Best-effort notify client of cancellations
|
|
notifications = []
|
|
for request_id, _future in pending_items:
|
|
notification = CancelledNotification(
|
|
params=CancelledParams(requestId=request_id, reason=reason)
|
|
)
|
|
notifications.append(notification)
|
|
|
|
try:
|
|
for note in notifications:
|
|
message = note.model_dump_json(exclude_none=True) + "\n"
|
|
await self._write_stream.send(message)
|
|
except Exception:
|
|
# Swallow transport errors during shutdown; proceed to cancel futures
|
|
logging.debug(
|
|
"Failed to send cancellation notifications during shutdown", exc_info=True
|
|
)
|
|
|
|
# Cancel futures so any waiters are released
|
|
for _request_id, future in pending_items:
|
|
if not future.done():
|
|
future.set_exception(SessionError("Session closed"))
|
|
|
|
|
|
class NotificationManager:
|
|
"""Broadcasts server-initiated listChanged notifications to sessions."""
|
|
|
|
def __init__(self, server: Any):
|
|
self._server = server
|
|
|
|
async def _broadcast(
|
|
self, notification: JSONRPCMessage, session_ids: list[str] | None = None
|
|
) -> None:
|
|
# Do not broadcast before server is started
|
|
if not getattr(self._server, "_started", False):
|
|
return
|
|
async with self._server._sessions_lock:
|
|
if session_ids is None:
|
|
sessions = list(self._server._sessions.values())
|
|
else:
|
|
sessions = [
|
|
self._server._sessions.get(sid)
|
|
for sid in session_ids
|
|
if sid in self._server._sessions
|
|
]
|
|
for s in sessions:
|
|
if s is None:
|
|
continue
|
|
try:
|
|
await s.send_notification(notification)
|
|
except Exception:
|
|
logger.debug("Failed to notify a session", exc_info=True)
|
|
|
|
async def notify_tool_list_changed(self, session_ids: list[str] | None = None) -> None:
|
|
await self._broadcast(ToolListChangedNotification(), session_ids)
|
|
|
|
async def notify_resource_list_changed(self, session_ids: list[str] | None = None) -> None:
|
|
await self._broadcast(ResourceListChangedNotification(), session_ids)
|
|
|
|
async def notify_prompt_list_changed(self, session_ids: list[str] | None = None) -> None:
|
|
await self._broadcast(PromptListChangedNotification(), session_ids)
|
|
|
|
|
|
class ServerSession:
|
|
"""
|
|
MCP server session handling a single client connection.
|
|
|
|
Manages:
|
|
- Session state and lifecycle
|
|
- Client capabilities
|
|
- Request/response handling
|
|
- Notification sending
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server: Any,
|
|
session_id: str | None = None,
|
|
read_stream: Any | None = None,
|
|
write_stream: Any | None = None,
|
|
init_options: Any | None = None,
|
|
stateless: bool = False,
|
|
):
|
|
"""
|
|
Initialize server session.
|
|
|
|
Args:
|
|
server: Parent server instance
|
|
session_id: Session identifier (generated if not provided)
|
|
read_stream: Stream for reading messages
|
|
write_stream: Stream for writing messages
|
|
init_options: Initialization options
|
|
stateless: Whether session is stateless
|
|
"""
|
|
self.server = server
|
|
self.session_id = session_id or str(uuid.uuid4())
|
|
self.read_stream = read_stream
|
|
self.write_stream = write_stream
|
|
self.init_options = init_options or {}
|
|
self.stateless = stateless
|
|
|
|
# Session state
|
|
self.initialization_state = InitializationState.NOT_INITIALIZED
|
|
self.client_params: InitializeParams | None = None
|
|
self._session_data: dict[str, Any] = {}
|
|
self._request_meta: Any = None
|
|
|
|
# Request management
|
|
self._request_manager = RequestManager(write_stream) if write_stream else None
|
|
|
|
# Context for current request
|
|
self._current_context: Context | None = None
|
|
|
|
def set_client_params(self, params: InitializeParams) -> None:
|
|
"""Set client initialization parameters."""
|
|
self.client_params = params
|
|
self.initialization_state = InitializationState.INITIALIZING
|
|
|
|
def mark_initialized(self) -> None:
|
|
"""Mark session as initialized."""
|
|
self.initialization_state = InitializationState.INITIALIZED
|
|
|
|
def check_client_capability(self, capability: ClientCapabilities) -> bool:
|
|
"""
|
|
Check if client has a specific capability.
|
|
|
|
Args:
|
|
capability: Capability to check
|
|
|
|
Returns:
|
|
True if client has capability
|
|
"""
|
|
if not self.client_params or not self.client_params.capabilities:
|
|
return False
|
|
|
|
client_caps = self.client_params.capabilities
|
|
|
|
# Check specific capabilities
|
|
# Use hasattr to check for attributes that might be in extra fields
|
|
if (
|
|
hasattr(capability, "tools")
|
|
and capability.tools
|
|
and not (hasattr(client_caps, "tools") and client_caps.tools)
|
|
):
|
|
return False
|
|
if (
|
|
hasattr(capability, "resources")
|
|
and capability.resources
|
|
and not (hasattr(client_caps, "resources") and client_caps.resources)
|
|
):
|
|
return False
|
|
if (
|
|
hasattr(capability, "prompts")
|
|
and capability.prompts
|
|
and not (hasattr(client_caps, "prompts") and client_caps.prompts)
|
|
):
|
|
return False
|
|
return not (
|
|
hasattr(capability, "logging")
|
|
and capability.logging
|
|
and not (hasattr(client_caps, "logging") and client_caps.logging)
|
|
)
|
|
|
|
async def run(self) -> None:
|
|
"""
|
|
Run the session message loop.
|
|
|
|
Reads messages from the stream and processes them concurrently
|
|
to allow server-initiated requests to be handled while tools execute.
|
|
"""
|
|
if not self.read_stream:
|
|
raise SessionError("No read stream available")
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
try:
|
|
async for message in self.read_stream:
|
|
if message:
|
|
# Process messages concurrently so the loop can continue reading
|
|
tg.start_soon(self._process_message, message)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
await self.server.logger.exception("Session error")
|
|
raise SessionError(f"Session error: {e}") from e
|
|
finally:
|
|
# Cleanup
|
|
if self._request_manager:
|
|
# Cancel any pending requests
|
|
await self._cleanup_pending_requests()
|
|
|
|
async def _process_message(self, message: str) -> None:
|
|
"""Process a single message."""
|
|
try:
|
|
# Parse message
|
|
data = json.loads(message)
|
|
|
|
# Check if it's a response to our request
|
|
if "id" in data and "method" not in data:
|
|
if self._request_manager:
|
|
logger.debug(
|
|
f"Session received response message id={data.get('id')} -> routing to RequestManager"
|
|
)
|
|
await self._request_manager.handle_response(data)
|
|
return
|
|
|
|
# Otherwise, process as incoming request
|
|
response = await self.server.handle_message(data, self)
|
|
|
|
# Send response if any
|
|
if response and self.write_stream:
|
|
if hasattr(response, "model_dump_json"):
|
|
response_data = response.model_dump_json(exclude_none=True, by_alias=True)
|
|
else:
|
|
response_data = json.dumps(response)
|
|
|
|
if not response_data.endswith("\n"):
|
|
response_data += "\n"
|
|
|
|
await self.write_stream.send(response_data)
|
|
|
|
except json.JSONDecodeError:
|
|
await self._send_error_response(
|
|
None,
|
|
-32700,
|
|
"Parse error",
|
|
)
|
|
except Exception as e:
|
|
await self._send_error_response(
|
|
None,
|
|
-32603,
|
|
f"Internal error: {e!s}",
|
|
)
|
|
|
|
async def _send_error_response(
|
|
self,
|
|
request_id: Any,
|
|
code: int,
|
|
message: str,
|
|
) -> None:
|
|
"""Send an error response."""
|
|
if not self.write_stream:
|
|
return
|
|
|
|
error_response = JSONRPCError(
|
|
id=str(request_id) if request_id else "null",
|
|
error={"code": code, "message": message},
|
|
)
|
|
|
|
response_data = error_response.model_dump_json() + "\n"
|
|
await self.write_stream.send(response_data)
|
|
|
|
async def _cleanup_pending_requests(self) -> None:
|
|
"""Clean up any pending requests."""
|
|
if self._request_manager:
|
|
# Cancel all pending futures and notify client
|
|
await self._request_manager.cancel_all(reason="Session closed")
|
|
|
|
# Notification methods
|
|
async def send_notification(self, notification: JSONRPCMessage) -> None:
|
|
"""Send a notification to the client."""
|
|
if not self.write_stream:
|
|
return
|
|
|
|
message = notification.model_dump_json(exclude_none=True) + "\n"
|
|
await self.write_stream.send(message)
|
|
|
|
async def send_progress_notification(
|
|
self,
|
|
progress_token: str | int,
|
|
progress: float,
|
|
total: float | None = None,
|
|
message: str | None = None,
|
|
) -> None:
|
|
"""Send a progress notification."""
|
|
notification = ProgressNotification(
|
|
params=ProgressNotificationParams(
|
|
progressToken=progress_token,
|
|
progress=progress,
|
|
total=total,
|
|
message=message,
|
|
)
|
|
)
|
|
await self.send_notification(notification)
|
|
|
|
async def send_log_message(
|
|
self,
|
|
level: LoggingLevel,
|
|
data: Any,
|
|
logger: str | None = None,
|
|
) -> None:
|
|
"""Send a log message notification."""
|
|
notification = LoggingMessageNotification(
|
|
params=LoggingMessageParams(
|
|
level=level,
|
|
data=data,
|
|
logger=logger,
|
|
)
|
|
)
|
|
await self.send_notification(notification)
|
|
|
|
async def send_tool_list_changed(self) -> None:
|
|
"""Send tool list changed notification."""
|
|
await self.send_notification(ToolListChangedNotification())
|
|
|
|
async def send_resource_list_changed(self) -> None:
|
|
"""Send resource list changed notification."""
|
|
await self.send_notification(ResourceListChangedNotification())
|
|
|
|
async def send_prompt_list_changed(self) -> None:
|
|
"""Send prompt list changed notification."""
|
|
await self.send_notification(PromptListChangedNotification())
|
|
|
|
# Server-initiated requests
|
|
async def create_message(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
max_tokens: int,
|
|
system_prompt: str | None = None,
|
|
include_context: str | None = None,
|
|
temperature: float | None = None,
|
|
model_preferences: dict[str, Any] | None = None,
|
|
stop_sequences: list[str] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
timeout: float = 60.0,
|
|
) -> CreateMessageResult:
|
|
"""
|
|
Send a sampling request to the client.
|
|
|
|
Args:
|
|
messages: Messages to sample
|
|
max_tokens: Maximum tokens to generate
|
|
system_prompt: System prompt
|
|
include_context: Context to include
|
|
temperature: Sampling temperature
|
|
model_preferences: Model preferences
|
|
stop_sequences: Stop sequences
|
|
metadata: Request metadata
|
|
timeout: Request timeout
|
|
|
|
Returns:
|
|
Sampling result
|
|
"""
|
|
if not self._request_manager:
|
|
raise SessionError("Cannot send requests without request manager")
|
|
|
|
params = {
|
|
"messages": messages,
|
|
"maxTokens": max_tokens,
|
|
}
|
|
|
|
# Add optional parameters
|
|
if system_prompt is not None:
|
|
params["systemPrompt"] = system_prompt
|
|
if include_context is not None:
|
|
params["includeContext"] = include_context
|
|
if temperature is not None:
|
|
params["temperature"] = temperature
|
|
if model_preferences is not None:
|
|
params["modelPreferences"] = model_preferences
|
|
if stop_sequences is not None:
|
|
params["stopSequences"] = stop_sequences
|
|
if metadata is not None:
|
|
params["metadata"] = metadata
|
|
|
|
result = await self._request_manager.send_request(
|
|
"sampling/createMessage",
|
|
params,
|
|
timeout,
|
|
)
|
|
|
|
return CreateMessageResult(**result)
|
|
|
|
async def list_roots(self, timeout: float = 60.0) -> ListRootsResult:
|
|
"""
|
|
Request roots list from the client.
|
|
|
|
Args:
|
|
timeout: Request timeout
|
|
|
|
Returns:
|
|
Roots list result
|
|
"""
|
|
if not self._request_manager:
|
|
raise SessionError("Cannot send requests without request manager")
|
|
|
|
result = await self._request_manager.send_request(
|
|
"roots/list",
|
|
None,
|
|
timeout,
|
|
)
|
|
|
|
return ListRootsResult(**result)
|
|
|
|
async def complete(
|
|
self,
|
|
ref: dict[str, Any],
|
|
argument: dict[str, Any],
|
|
timeout: float = 60.0,
|
|
) -> CompleteResult:
|
|
"""
|
|
Request completion from the client.
|
|
|
|
Args:
|
|
ref: Completion reference
|
|
argument: Completion argument
|
|
timeout: Request timeout
|
|
|
|
Returns:
|
|
Completion result
|
|
"""
|
|
if not self._request_manager:
|
|
raise SessionError("Cannot send requests without request manager")
|
|
|
|
result = await self._request_manager.send_request(
|
|
"completion/complete",
|
|
{"ref": ref, "argument": argument},
|
|
timeout,
|
|
)
|
|
|
|
return CompleteResult(**result)
|
|
|
|
async def elicit(
|
|
self,
|
|
message: str,
|
|
requested_schema: dict[str, Any] | None = None,
|
|
timeout: float = 300.0,
|
|
) -> ElicitResult:
|
|
"""
|
|
Send an elicitation request to the client.
|
|
|
|
Args:
|
|
message: Elicitation message to display
|
|
requested_schema: JSON schema for the requested response
|
|
timeout: Request timeout
|
|
|
|
Returns:
|
|
Elicitation result
|
|
"""
|
|
if not self._request_manager:
|
|
raise SessionError("Cannot send requests without request manager")
|
|
|
|
params: dict[str, Any] = {
|
|
"message": message,
|
|
}
|
|
|
|
# Add schema if provided
|
|
if requested_schema is not None:
|
|
params["requestedSchema"] = requested_schema
|
|
|
|
result = await self._request_manager.send_request(
|
|
"elicitation/create",
|
|
params,
|
|
timeout,
|
|
)
|
|
|
|
return ElicitResult(**result)
|
|
|
|
# Request metadata management
|
|
def set_request_meta(self, meta: dict[str, Any] | None) -> None:
|
|
"""Store meta for the current request"""
|
|
self._request_meta = SimpleNamespace(**meta) if meta else None
|
|
|
|
def clear_request_meta(self) -> None:
|
|
"""Clear the request's meta after the request is complete"""
|
|
self._request_meta = None
|
|
|
|
# Context management
|
|
async def create_request_context(self) -> Context:
|
|
"""Create a context for the current request."""
|
|
context = Context(self.server)
|
|
context.set_session(self)
|
|
self._current_context = context
|
|
return context
|
|
|
|
async def cleanup_request_context(self, context: Context) -> None:
|
|
"""Clean up request context."""
|
|
# Flush any pending notifications
|
|
await context._flush_notifications()
|
|
self._current_context = None
|