# PR Description
Consider this PR the result of a full pass through of this repository.
## Add helper for adding tools to an `MCPApp`
You can now add all of the tools in a module to an `MCPApp` via
`app.add_tools_from_module(...)`
## Edit what `arcade new` generates
First, I updated the backend to use hatchling.
Second, the structure generated before this PR was simple, but did not
create a proper Python module.
This hindered developers in the following ways:
1. Difficult to add the tools in your server to an evaluation suite
2. Difficult to add more than one tool to an MCPApp at a time
3. All other niceties that come with being able to import modules
```
# Before
server/
├── .env.example
├── server.py
└── pyproject.toml
```
This PR updates the structure generated such that a valid Python module
is generated:
```
# After
server/
├── pyproject.toml
└── src/
└── server/
├── __init__.py
├── .env.example
└── server.py
```
## Fix Tool Chaining
`self._ctx.server.executor.run(...)` was being called, but `MCPServer`
does not have an instance of `ToolExecutor` (and it's not intended to be
an instance anyways). I updated `Tool.call_raw` to pass the programmatic
tool call through the `MCPServer._handle_call_tool`. This means that the
programmatic tool calls now go through the same steps that a typical
tool call (initiated by the MCP client) would.
This means that **toolA**, which specifies **requirementsA**, is
permitted to call **toolB**, which specifies **requirementsB**, without
needing to explicitly declare or satisfy **requirementsB**. I believe
this is acceptable because the secrets and/or auth token associated with
**toolB's** `Context` are not exposed to **toolA**, and the secrets
and/or auth token associated with **toolA's** `Context` are not exposed
to **toolB**.
## Fix User Elicitation
1. The read & write streams were created with a maximum queue size of 0.
I increased this to 100.
2. I updated `ServerSession`'s run loop to both read messages from the
stream & process them concurrently. This enables server initiated
requests (like user elicitation and progress reporting) to be handled
while tools are being executed. Otherwise, the server initiated requests
would wait for the tool to finish executing and the tool execution would
wait for the server initiated request to finish.
3.
## Fix Progress Reporting
Progress tokens sent by the client were not being stored. Therefore
there was no way to notify a client with progress updates. I am now
storing the `progressToken`, along with other `_meta` sent from the
client, in the `ServerSession`'s `_request_meta`. I am setting
`_request_meta` whenever the `MCPServer` is handling an incoming message
from a client.
## Fix handling of server names with spaces
Before:
Server name: "The simple server name"
Tool name: whisper_secret
Name seen by client: "The_simple_server_name_WhisperSecret"
After
Server name: "The simple server name"
Tool name: whisper_secret
Name seen by client: "TheSimpleServerName_WhisperSecret"
## Add Integration Tests
The stdio integration test is much more comprehensive than the http
integration test. These tests will let me sleep a bit more at night
## Add Example MCP Servers
Example servers for sampling, user-elicitation, progress reporting,
logging, tool chaining, combining prebuilt tools with custom tools, tool
secrets, tool auth, evaluations, and more!
## Add Docker template
Added a Docker template for running an MCP server in Docker (and removed
the old docker stuff)
834 lines
33 KiB
Python
834 lines
33 KiB
Python
"""HTTP Streamable Transport for MCP servers.
|
|
|
|
This module implements HTTP transport with Server-Sent Events (SSE) streaming support,
|
|
following the patterns from the sample library.
|
|
|
|
Design overview
|
|
- The transport provides a duplex, in-process message channel between the HTTP layer
|
|
and the MCP session using anyio memory streams:
|
|
- read side (transport -> session):
|
|
- `_read_stream_writer` (SendStream[SessionMessage | Exception])
|
|
- `_read_stream` (ReceiveStream[SessionMessage | Exception])
|
|
- write side (session -> transport):
|
|
- `_write_stream` (SendStream[SessionMessage])
|
|
- `_write_stream_reader` (ReceiveStream[SessionMessage])
|
|
|
|
- The transport writes inbound client messages (parsed from HTTP requests) to
|
|
`_read_stream_writer`; the session consumes them from `_read_stream`.
|
|
|
|
- The session writes outbound server messages to `_write_stream`; the transport's
|
|
`message_router` task consumes them from `_write_stream_reader` and fans them out
|
|
to the correct per-request stream maintained in `_request_streams[request_id]`.
|
|
|
|
- Response modes:
|
|
- JSON response mode: a single HTTP JSON response is returned by awaiting the
|
|
first terminal message (JSONRPCResponse or JSONRPCError) for the request.
|
|
- SSE response mode: a long-lived stream of events is sent as SSE; the stream
|
|
is closed when a terminal message is observed for the request.
|
|
|
|
- A standalone GET SSE stream uses the special key `GET_STREAM_KEY` to deliver
|
|
server-initiated events without a preceding POST.
|
|
|
|
- Optional resumability can be enabled by providing an `EventStore` implementation.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from http import HTTPStatus
|
|
from typing import cast
|
|
|
|
import anyio
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from pydantic import BaseModel, TypeAdapter
|
|
from sse_starlette import EventSourceResponse
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.types import Receive, Scope, Send
|
|
|
|
from arcade_mcp_server.session import ServerSession
|
|
from arcade_mcp_server.types import (
|
|
INTERNAL_ERROR,
|
|
INVALID_REQUEST,
|
|
PARSE_ERROR,
|
|
ErrorData,
|
|
JSONRPCError,
|
|
JSONRPCMessage,
|
|
JSONRPCRequest,
|
|
JSONRPCResponse,
|
|
MCPMessage,
|
|
RequestId,
|
|
SessionMessage,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Header names
|
|
MCP_SESSION_ID_HEADER = "Mcp-Session-Id"
|
|
MCP_PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version"
|
|
LAST_EVENT_ID_HEADER = "Last-Event-ID"
|
|
|
|
# Content types
|
|
CONTENT_TYPE_JSON = "application/json"
|
|
CONTENT_TYPE_SSE = "text/event-stream"
|
|
|
|
# Special key for the standalone GET stream
|
|
GET_STREAM_KEY = "_GET_stream"
|
|
|
|
# Session ID validation pattern (visible ASCII characters)
|
|
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
|
|
|
# Type aliases
|
|
StreamId = str
|
|
EventId = str
|
|
|
|
|
|
@dataclass
|
|
class EventMessage:
|
|
"""A JSONRPCMessage with an optional event ID for stream resumability."""
|
|
|
|
message: MCPMessage
|
|
event_id: str | None = None
|
|
|
|
|
|
EventCallback = Callable[[EventMessage], Awaitable[None]]
|
|
|
|
|
|
class EventStore:
|
|
"""Interface for resumability support via event storage."""
|
|
|
|
async def store_event(self, stream_id: StreamId, message: MCPMessage) -> EventId:
|
|
"""Store an event for later retrieval."""
|
|
raise NotImplementedError
|
|
|
|
async def replay_events_after(
|
|
self,
|
|
last_event_id: EventId,
|
|
send_callback: EventCallback,
|
|
) -> StreamId | None:
|
|
"""Replay events after the specified event ID."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class HTTPStreamableTransport:
|
|
"""HTTP transport with SSE streaming support for MCP.
|
|
|
|
Responsibilities
|
|
- Parse HTTP requests into JSON-RPC messages and enqueue them on the
|
|
transport→session read stream (via `_read_stream_writer`).
|
|
- Consume session→transport messages from `_write_stream_reader` in a
|
|
background `message_router`, routing them to per-request streams in
|
|
`_request_streams` keyed by the JSON-RPC request id (or `GET_STREAM_KEY`
|
|
for the standalone GET SSE stream).
|
|
- Serve responses back to the HTTP client:
|
|
- JSON response mode: wait for the first terminal response and return a
|
|
single `application/json` body.
|
|
- SSE mode: stream each outbound `SessionMessage` as an SSE event with
|
|
appropriate headers and close on terminal response.
|
|
|
|
Streams created in `connect()`
|
|
- `_read_stream_writer` / `_read_stream`: transport→session channel for inbound
|
|
client messages.
|
|
- `_write_stream` / `_write_stream_reader`: session→transport channel for outbound
|
|
server messages, consumed by the `message_router`.
|
|
|
|
These in-memory channels provide backpressure and decouple HTTP from the session
|
|
loop while keeping the implementation fully async.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_session_id: str | None,
|
|
session: ServerSession | None = None,
|
|
is_json_response_enabled: bool = False,
|
|
event_store: EventStore | None = None,
|
|
):
|
|
"""Initialize HTTP streamable transport.
|
|
|
|
Args:
|
|
mcp_session_id: Session identifier (must be visible ASCII)
|
|
session: Server session for handling requests
|
|
is_json_response_enabled: If True, return JSON responses instead of SSE
|
|
event_store: Optional event store for resumability
|
|
"""
|
|
if mcp_session_id and not SESSION_ID_PATTERN.fullmatch(mcp_session_id):
|
|
raise ValueError("Session ID must only contain visible ASCII characters")
|
|
|
|
self.mcp_session_id = mcp_session_id
|
|
self.session = session
|
|
self.is_json_response_enabled = is_json_response_enabled
|
|
self._event_store = event_store
|
|
self._request_streams: dict[
|
|
RequestId,
|
|
tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]],
|
|
] = {}
|
|
self._terminated = False
|
|
|
|
# Streams for connection
|
|
self._read_stream_writer: MemoryObjectSendStream[str | Exception] | None = None
|
|
self._read_stream: MemoryObjectReceiveStream[str | Exception] | None = None
|
|
self._write_stream: MemoryObjectSendStream[str | SessionMessage] | None = None
|
|
self._write_stream_reader: MemoryObjectReceiveStream[str | SessionMessage] | None = None
|
|
|
|
def _parse_mcp_message(self, obj: str | dict[str, object] | MCPMessage) -> MCPMessage:
|
|
"""Parse incoming data into a typed MCPMessage.
|
|
|
|
Accepts a raw JSON string, already-parsed dict, or an existing MCPMessage.
|
|
"""
|
|
if isinstance(obj, BaseModel):
|
|
# Already a pydantic model; trust caller and cast to MCPMessage
|
|
return cast(MCPMessage, obj)
|
|
|
|
parsed: dict[str, object]
|
|
if isinstance(obj, str):
|
|
try:
|
|
maybe = json.loads(obj)
|
|
except Exception as exc: # parse error - treat as invalid request
|
|
raise ValueError(f"Invalid JSON: {exc}")
|
|
if not isinstance(maybe, dict):
|
|
raise TypeError("JSON must be an object")
|
|
parsed = maybe
|
|
elif isinstance(obj, dict):
|
|
parsed = obj
|
|
else:
|
|
raise TypeError("Unsupported message type")
|
|
|
|
try:
|
|
return TypeAdapter(MCPMessage).validate_python(parsed)
|
|
except Exception:
|
|
# Fallback: treat as error
|
|
return JSONRPCError(
|
|
id=str(parsed.get("id", "null")),
|
|
error={"code": -32600, "message": "Invalid message"},
|
|
)
|
|
|
|
@property
|
|
def is_terminated(self) -> bool:
|
|
"""Check if transport has been terminated."""
|
|
return self._terminated
|
|
|
|
def _create_error_response(
|
|
self,
|
|
error_message: str,
|
|
status_code: HTTPStatus,
|
|
error_code: int = INVALID_REQUEST,
|
|
headers: dict[str, str] | None = None,
|
|
) -> Response:
|
|
"""Create an error response."""
|
|
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
|
|
if headers:
|
|
response_headers.update(headers)
|
|
|
|
if self.mcp_session_id:
|
|
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
error_response = JSONRPCError(
|
|
jsonrpc="2.0",
|
|
id="server-error",
|
|
error=ErrorData(code=error_code, message=error_message).model_dump(exclude_none=True),
|
|
)
|
|
|
|
return Response(
|
|
error_response.model_dump_json(by_alias=True, exclude_none=True),
|
|
status_code=status_code,
|
|
headers=response_headers,
|
|
)
|
|
|
|
def _create_json_response(
|
|
self,
|
|
response_message: JSONRPCMessage | None,
|
|
status_code: HTTPStatus = HTTPStatus.OK,
|
|
headers: dict[str, str] | None = None,
|
|
) -> Response:
|
|
"""Create a JSON response."""
|
|
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
|
|
if headers:
|
|
response_headers.update(headers)
|
|
|
|
if self.mcp_session_id:
|
|
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
return Response(
|
|
response_message.model_dump_json(by_alias=True, exclude_none=True)
|
|
if response_message
|
|
else None,
|
|
status_code=status_code,
|
|
headers=response_headers,
|
|
)
|
|
|
|
def _get_session_id(self, request: Request) -> str | None:
|
|
"""Extract session ID from request headers."""
|
|
return request.headers.get(MCP_SESSION_ID_HEADER)
|
|
|
|
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
|
|
"""Create event data dictionary from EventMessage."""
|
|
event_data = {
|
|
"event": "message",
|
|
"data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
|
|
}
|
|
|
|
if event_message.event_id:
|
|
event_data["id"] = event_message.event_id
|
|
|
|
return event_data
|
|
|
|
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
|
|
"""Clean up memory streams for a request."""
|
|
if request_id in self._request_streams:
|
|
try:
|
|
await self._request_streams[request_id][0].aclose()
|
|
await self._request_streams[request_id][1].aclose()
|
|
except Exception:
|
|
logger.debug("Error closing memory streams - may already be closed")
|
|
finally:
|
|
self._request_streams.pop(request_id, None)
|
|
|
|
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
"""Handle incoming HTTP requests."""
|
|
request = Request(scope, receive)
|
|
|
|
if self._terminated:
|
|
response = self._create_error_response(
|
|
"Not Found: Session has been terminated",
|
|
HTTPStatus.NOT_FOUND,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
if request.method == "POST":
|
|
await self._handle_post_request(scope, request, receive, send)
|
|
elif request.method == "GET":
|
|
await self._handle_get_request(request, send)
|
|
elif request.method == "DELETE":
|
|
await self._handle_delete_request(request, send)
|
|
else:
|
|
await self._handle_unsupported_request(request, send)
|
|
|
|
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
|
|
"""Check if request accepts required media types."""
|
|
accept_header = request.headers.get("accept", "")
|
|
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
|
|
|
|
has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
|
|
has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
|
|
|
|
return has_json, has_sse
|
|
|
|
def _check_content_type(self, request: Request) -> bool:
|
|
"""Check if request has correct Content-Type."""
|
|
content_type = request.headers.get("content-type", "")
|
|
content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")]
|
|
|
|
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
|
|
|
|
async def _handle_post_request(
|
|
self, scope: Scope, request: Request, receive: Receive, send: Send
|
|
) -> None:
|
|
"""Handle POST requests containing JSON-RPC messages."""
|
|
writer = self._read_stream_writer
|
|
if writer is None:
|
|
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
|
|
|
try:
|
|
# Check Accept headers
|
|
has_json, has_sse = self._check_accept_headers(request)
|
|
if self.is_json_response_enabled:
|
|
if not has_json:
|
|
response = self._create_error_response(
|
|
"Not Acceptable: Client must accept application/json",
|
|
HTTPStatus.NOT_ACCEPTABLE,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
else:
|
|
if not has_sse:
|
|
response = self._create_error_response(
|
|
"Not Acceptable: Client must accept text/event-stream",
|
|
HTTPStatus.NOT_ACCEPTABLE,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Validate Content-Type for POST payloads only when JSON mode
|
|
if self.is_json_response_enabled and not self._check_content_type(request):
|
|
response = self._create_error_response(
|
|
"Unsupported Media Type: Content-Type must be application/json",
|
|
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Parse the body
|
|
body = await request.body()
|
|
body_str = body.decode("utf-8") if isinstance(body, (bytes, bytearray)) else str(body)
|
|
|
|
try:
|
|
raw_message = json.loads(body)
|
|
except json.JSONDecodeError as e:
|
|
response = self._create_error_response(
|
|
f"Parse error: {e!s}", HTTPStatus.BAD_REQUEST, PARSE_ERROR
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Accept either well-typed messages or raw dicts
|
|
message_dict = raw_message if isinstance(raw_message, dict) else {}
|
|
try:
|
|
message = self._parse_mcp_message(message_dict or body_str)
|
|
except Exception as exc:
|
|
response = self._create_error_response(
|
|
f"Invalid request: {exc}",
|
|
HTTPStatus.BAD_REQUEST,
|
|
INVALID_REQUEST,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Check if this is an initialization request
|
|
# Determine initialization by dict method when validation fallback used
|
|
is_initialization_request = (
|
|
isinstance(message, JSONRPCRequest) and message.method == "initialize"
|
|
)
|
|
|
|
if is_initialization_request:
|
|
if self.mcp_session_id:
|
|
request_session_id = self._get_session_id(request)
|
|
if request_session_id and request_session_id != self.mcp_session_id:
|
|
response = self._create_error_response(
|
|
"Not Found: Invalid or expired session ID",
|
|
HTTPStatus.NOT_FOUND,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
elif not await self._validate_request_headers(request, send):
|
|
return
|
|
|
|
# For notifications and responses, return 202 Accepted
|
|
if not isinstance(message, JSONRPCRequest):
|
|
response = self._create_json_response(None, HTTPStatus.ACCEPTED)
|
|
await response(scope, receive, send)
|
|
|
|
# Process the message
|
|
await writer.send(body_str if body_str.endswith("\n") else body_str + "\n")
|
|
return
|
|
|
|
# Handle requests
|
|
request_id = str(message.id)
|
|
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
|
|
request_stream_reader = self._request_streams[request_id][1]
|
|
|
|
if self.is_json_response_enabled:
|
|
# JSON response mode
|
|
await writer.send(body_str if body_str.endswith("\n") else body_str + "\n")
|
|
|
|
try:
|
|
response_message = None
|
|
async for event_message in request_stream_reader:
|
|
if isinstance(event_message.message, (JSONRPCResponse, JSONRPCError)):
|
|
response_message = event_message.message
|
|
break
|
|
|
|
if response_message:
|
|
response = self._create_json_response(response_message)
|
|
await response(scope, receive, send)
|
|
else:
|
|
logger.error("No response received before stream closed")
|
|
response = self._create_error_response(
|
|
"Error processing request: No response received",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
)
|
|
await response(scope, receive, send)
|
|
except Exception:
|
|
logger.exception("Error processing JSON response")
|
|
response = self._create_error_response(
|
|
"Error processing request",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
INTERNAL_ERROR,
|
|
)
|
|
await response(scope, receive, send)
|
|
finally:
|
|
await self._clean_up_memory_streams(request_id)
|
|
else:
|
|
# SSE response mode
|
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
|
dict[str, str]
|
|
](0)
|
|
|
|
async def sse_writer() -> None:
|
|
try:
|
|
async with sse_stream_writer, request_stream_reader:
|
|
async for event_message in request_stream_reader:
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
|
|
if isinstance(
|
|
event_message.message, (JSONRPCResponse, JSONRPCError)
|
|
):
|
|
break
|
|
except Exception:
|
|
logger.exception("Error in SSE writer")
|
|
finally:
|
|
logger.debug("Closing SSE writer")
|
|
await self._clean_up_memory_streams(request_id)
|
|
|
|
headers = {
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": CONTENT_TYPE_SSE,
|
|
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
|
|
}
|
|
|
|
response = EventSourceResponse(
|
|
content=sse_stream_reader,
|
|
data_sender_callable=sse_writer,
|
|
headers=headers,
|
|
)
|
|
|
|
try:
|
|
async with anyio.create_task_group() as tg:
|
|
tg.start_soon(response, scope, receive, send)
|
|
await writer.send(body_str if body_str.endswith("\n") else body_str + "\n")
|
|
except Exception:
|
|
logger.exception("SSE response error")
|
|
await sse_stream_writer.aclose()
|
|
await sse_stream_reader.aclose()
|
|
await self._clean_up_memory_streams(request_id)
|
|
|
|
except Exception as err:
|
|
logger.exception("Error handling POST request")
|
|
response = self._create_error_response(
|
|
f"Error handling POST request: {err}",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
INTERNAL_ERROR,
|
|
)
|
|
await response(scope, receive, send)
|
|
if writer:
|
|
await writer.send(Exception(err))
|
|
|
|
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
|
"""Handle GET request to establish SSE."""
|
|
writer = self._read_stream_writer
|
|
if writer is None:
|
|
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
|
|
|
# Validate Accept header
|
|
_, has_sse = self._check_accept_headers(request)
|
|
|
|
if not has_sse:
|
|
error_response = self._create_error_response(
|
|
"Not Acceptable: Client must accept text/event-stream",
|
|
HTTPStatus.NOT_ACCEPTABLE,
|
|
)
|
|
await error_response(request.scope, request.receive, send)
|
|
return
|
|
|
|
if not await self._validate_request_headers(request, send):
|
|
return
|
|
|
|
# Handle resumability
|
|
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
|
await self._replay_events(last_event_id, request, send)
|
|
return
|
|
|
|
headers = {
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": CONTENT_TYPE_SSE,
|
|
}
|
|
|
|
if self.mcp_session_id:
|
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
# Check if we already have an active GET stream
|
|
if GET_STREAM_KEY in self._request_streams:
|
|
error_response = self._create_error_response(
|
|
"Conflict: Only one SSE stream is allowed per session",
|
|
HTTPStatus.CONFLICT,
|
|
)
|
|
await error_response(request.scope, request.receive, send)
|
|
return
|
|
|
|
# Create SSE stream
|
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
|
|
|
async def standalone_sse_writer() -> None:
|
|
try:
|
|
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[
|
|
EventMessage
|
|
](0)
|
|
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
|
|
|
async with sse_stream_writer, standalone_stream_reader:
|
|
async for event_message in standalone_stream_reader:
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
except Exception:
|
|
logger.exception("Error in standalone SSE writer")
|
|
finally:
|
|
logger.debug("Closing standalone SSE writer")
|
|
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
|
|
|
sse_response: EventSourceResponse = EventSourceResponse(
|
|
content=sse_stream_reader,
|
|
data_sender_callable=standalone_sse_writer,
|
|
headers=headers,
|
|
)
|
|
|
|
try:
|
|
await sse_response(request.scope, request.receive, send)
|
|
except Exception:
|
|
logger.exception("Error in standalone SSE response")
|
|
await sse_stream_writer.aclose()
|
|
await sse_stream_reader.aclose()
|
|
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
|
|
|
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
|
"""Handle DELETE requests for session termination."""
|
|
if not self.mcp_session_id:
|
|
response = self._create_error_response(
|
|
"Method Not Allowed: Session termination not supported",
|
|
HTTPStatus.METHOD_NOT_ALLOWED,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return
|
|
|
|
if not await self._validate_request_headers(request, send):
|
|
return
|
|
|
|
await self.terminate()
|
|
|
|
response = self._create_json_response(None, HTTPStatus.OK)
|
|
await response(request.scope, request.receive, send)
|
|
|
|
async def terminate(self) -> None:
|
|
"""Terminate the current session."""
|
|
self._terminated = True
|
|
logger.info(f"Terminating session: {self.mcp_session_id}")
|
|
|
|
# Close all request streams
|
|
request_stream_keys = list(self._request_streams.keys())
|
|
for key in request_stream_keys:
|
|
await self._clean_up_memory_streams(key)
|
|
self._request_streams.clear()
|
|
|
|
try:
|
|
if self._read_stream_writer:
|
|
await self._read_stream_writer.aclose()
|
|
if self._read_stream:
|
|
await self._read_stream.aclose()
|
|
if self._write_stream_reader:
|
|
await self._write_stream_reader.aclose()
|
|
if self._write_stream:
|
|
await self._write_stream.aclose()
|
|
except Exception as e:
|
|
logger.debug(f"Error closing streams: {e}")
|
|
|
|
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
|
|
"""Handle unsupported HTTP methods."""
|
|
headers = {
|
|
"Content-Type": CONTENT_TYPE_JSON,
|
|
"Allow": "GET, POST, DELETE",
|
|
}
|
|
if self.mcp_session_id:
|
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
response = self._create_error_response(
|
|
"Method Not Allowed",
|
|
HTTPStatus.METHOD_NOT_ALLOWED,
|
|
headers=headers,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
|
|
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
|
|
"""Validate request headers."""
|
|
return await self._validate_session(request, send)
|
|
|
|
async def _validate_session(self, request: Request, send: Send) -> bool:
|
|
"""Validate session ID in request."""
|
|
if not self.mcp_session_id:
|
|
return True
|
|
|
|
request_session_id = self._get_session_id(request)
|
|
|
|
if not request_session_id:
|
|
response = self._create_error_response(
|
|
"Bad Request: Missing session ID",
|
|
HTTPStatus.BAD_REQUEST,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return False
|
|
|
|
if request_session_id != self.mcp_session_id:
|
|
response = self._create_error_response(
|
|
"Not Found: Invalid or expired session ID",
|
|
HTTPStatus.NOT_FOUND,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
|
"""Replay events after the specified event ID."""
|
|
event_store = self._event_store
|
|
if not event_store:
|
|
return
|
|
|
|
try:
|
|
headers = {
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": CONTENT_TYPE_SSE,
|
|
}
|
|
|
|
if self.mcp_session_id:
|
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
|
dict[str, str]
|
|
](0)
|
|
|
|
async def replay_sender() -> None:
|
|
try:
|
|
async with sse_stream_writer:
|
|
|
|
async def send_event(event_message: EventMessage) -> None:
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
|
|
stream_id = await event_store.replay_events_after(last_event_id, send_event)
|
|
|
|
if stream_id and stream_id not in self._request_streams:
|
|
self._request_streams[stream_id] = anyio.create_memory_object_stream[
|
|
EventMessage
|
|
](0)
|
|
msg_reader = self._request_streams[stream_id][1]
|
|
|
|
async with msg_reader:
|
|
async for event_message in msg_reader:
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
except Exception:
|
|
logger.exception("Error in replay sender")
|
|
|
|
sse_response: EventSourceResponse = EventSourceResponse(
|
|
content=sse_stream_reader,
|
|
data_sender_callable=replay_sender,
|
|
headers=headers,
|
|
)
|
|
|
|
try:
|
|
await sse_response(request.scope, request.receive, send)
|
|
except Exception:
|
|
logger.exception("Error in replay response")
|
|
finally:
|
|
await sse_stream_writer.aclose()
|
|
await sse_stream_reader.aclose()
|
|
|
|
except Exception:
|
|
logger.exception("Error replaying events")
|
|
error_response = self._create_error_response(
|
|
"Error replaying events",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
INTERNAL_ERROR,
|
|
)
|
|
await error_response(request.scope, request.receive, send)
|
|
|
|
@asynccontextmanager
|
|
async def connect(
|
|
self,
|
|
) -> AsyncIterator[
|
|
tuple[
|
|
MemoryObjectReceiveStream[str | Exception],
|
|
MemoryObjectSendStream[str | SessionMessage],
|
|
]
|
|
]:
|
|
"""Context manager providing read and write streams for connection.
|
|
|
|
Creates the in-memory channels used by the transport and starts the
|
|
`message_router` task responsible for routing outbound messages from
|
|
the session to the correct per-request stream (or the standalone GET
|
|
stream identified by `GET_STREAM_KEY`).
|
|
"""
|
|
# Create memory streams with buffer
|
|
read_stream_writer, read_stream = anyio.create_memory_object_stream[str | Exception](100)
|
|
write_stream, write_stream_reader = anyio.create_memory_object_stream[str | SessionMessage](
|
|
100
|
|
)
|
|
|
|
# Store the streams
|
|
self._read_stream_writer = read_stream_writer
|
|
self._read_stream = read_stream
|
|
self._write_stream_reader = write_stream_reader
|
|
self._write_stream = write_stream
|
|
|
|
# Start message router
|
|
async with anyio.create_task_group() as tg:
|
|
|
|
async def message_router() -> None:
|
|
try:
|
|
async for session_message in write_stream_reader:
|
|
# Accept either a SessionMessage wrapper or a raw JSON string
|
|
try:
|
|
if isinstance(session_message, SessionMessage):
|
|
message = session_message.message
|
|
elif isinstance(session_message, str):
|
|
message = self._parse_mcp_message(session_message)
|
|
elif isinstance(session_message, BaseModel):
|
|
message = cast(JSONRPCMessage, session_message)
|
|
else:
|
|
logger.error(
|
|
f"Unsupported outbound message type: {type(session_message)}"
|
|
)
|
|
continue
|
|
except Exception:
|
|
logger.exception("Failed to parse outbound message from session")
|
|
continue
|
|
target_request_id = None
|
|
|
|
# Check if this is a response
|
|
if isinstance(message, (JSONRPCResponse, JSONRPCError)):
|
|
target_request_id = str(message.id)
|
|
|
|
request_stream_id = (
|
|
target_request_id if target_request_id else GET_STREAM_KEY
|
|
)
|
|
|
|
# Store event if we have an event store
|
|
event_id = None
|
|
if self._event_store:
|
|
event_id = await self._event_store.store_event(
|
|
request_stream_id,
|
|
message, # type: ignore[arg-type]
|
|
)
|
|
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
|
|
|
if request_stream_id in self._request_streams:
|
|
try:
|
|
await self._request_streams[request_stream_id][0].send(
|
|
EventMessage(message, event_id) # type: ignore[arg-type]
|
|
)
|
|
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
|
|
self._request_streams.pop(request_stream_id, None)
|
|
except Exception:
|
|
logger.exception("Error in message router")
|
|
|
|
tg.start_soon(message_router)
|
|
|
|
try:
|
|
yield read_stream, write_stream
|
|
finally:
|
|
for stream_id in list(self._request_streams.keys()):
|
|
await self._clean_up_memory_streams(stream_id)
|
|
self._request_streams.clear()
|
|
|
|
try:
|
|
await read_stream_writer.aclose()
|
|
await read_stream.aclose()
|
|
await write_stream_reader.aclose()
|
|
await write_stream.aclose()
|
|
except Exception as e:
|
|
logger.debug(f"Error closing streams: {e}")
|