Support for MCP stdio transport (#368)

MCP stdio Implementation:

The PR adds support for standard input/output (stdio) as a transport
mechanism for the Message Control Protocol. This is a replacement to the
SSE (Server-Sent Events) transport that was worked on in PR #359 but
will not be merged as it's not deprecated.

This will allow developers to use Arcade tools (written by the dev or
Arcade) in Claude, Cursor, windsurf, etc.

The engine Gateway already supports adding HTTPS streamable (replacement
for SSE) MCP servers as tool servers, and will soon support full gateway
capability in the client API as well.

To use any existing Toolkit just 

## Examples

### Quickstart setup with existing toolkits

```bash
pip install arcade-ai
pip install <name of toolkit> # ex. arcade-google
arcade serve --mcp
```

### Run with Claude

Just add the following to the Claude config

```json
{
  "mcpServers": {
    "arcade": {
      "command": "bash",
      "args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade serve --mcp"]
    }
  }
}
```

### Customizing the Tool Server

Developers can customize their served tools and server furthermore by
importing the worker sdk

```python

import arcade_google  # pip install arcade_google
import arcade_search  # pip install arcade_search

from arcade.core.catalog import ToolCatalog
from arcade.worker.mcp.stdio import StdioServer

# 2. Create and populate the tool catalog
catalog = ToolCatalog()
catalog.add_module(arcade_google)  # Registers all tools in the package
catalog.add_module(arcade_search)


# 3. Main entrypoint
async def main():
    # Create the worker with the tool catalog
    worker = StdioServer(catalog)

    # Run the worker
    await worker.run()


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())
    
 ```
 
 Then to run with claude, just run this python file instead of the prebuilt server used in ``arcade serve --mcp``
This commit is contained in:
Sam Partee 2025-05-02 06:27:43 -07:00 committed by GitHub
parent f5774fd4ae
commit 9bc1cd4a12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2829 additions and 187 deletions

View file

@ -10,7 +10,6 @@
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff"

View file

@ -24,6 +24,7 @@ from arcade.cli.constants import (
PROD_CLOUD_HOST,
PROD_ENGINE_HOST,
)
from arcade.cli.deployment import Deployment
from arcade.cli.display import (
display_arcade_chat_header,
display_eval_results,
@ -48,7 +49,6 @@ from arcade.cli.utils import (
version_callback,
)
from arcade.cli.worker import parse_deployment_response
from arcade.worker.config.deployment import Deployment
cli = typer.Typer(
cls=OrderCommands,
@ -61,7 +61,12 @@ cli = typer.Typer(
)
cli.add_typer(worker.app, name="worker", help="Manage workers")
cli.add_typer(
worker.app,
name="worker",
help="Manage deployments of tool servers (logs, list, etc)",
rich_help_panel="Deployment",
)
console = Console()
@ -192,7 +197,10 @@ def show(
show_logic(toolkit, tool, host, local, port, force_tls, force_no_tls, debug)
@cli.command(help="Start Arcade Chat in the terminal", rich_help_panel="Launch")
@cli.command(
help="Start a chat with a model in the terminal to test tools",
rich_help_panel="Tool Development",
)
def chat(
model: str = typer.Option("gpt-4o", "-m", "--model", help="The model to use for prediction."),
stream: bool = typer.Option(
@ -439,7 +447,7 @@ def evals(
@cli.command(
help="Start Arcade Worker serving tools installed in the current python environment",
help="Start tool server worker with locally installed tools",
rich_help_panel="Launch",
)
def serve(
@ -460,19 +468,38 @@ def serve(
otel_enable: bool = typer.Option(
False, "--otel-enable", help="Send logs to OpenTelemetry", show_default=True
),
mcp: bool = typer.Option(
False, "--mcp", help="Run as a local MCP server over stdio", show_default=True
),
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
) -> None:
"""
Start a local Arcade Worker server.
"""
workerup(host, port, disable_auth, otel_enable, debug)
from arcade.cli.serve import serve_default_worker
try:
serve_default_worker(
host,
port,
disable_auth=disable_auth,
enable_otel=otel_enable,
debug=debug,
mcp=mcp,
)
except KeyboardInterrupt:
typer.Exit()
except Exception as e:
error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}"
console.print(error_message, style="bold red")
typer.Exit(code=1)
@cli.command(help="Launch Arcade Worker and Engine locally", rich_help_panel="Launch")
@cli.command(help="Launch Arcade - requires 'arcade-engine'", rich_help_panel="Launch")
def dev(
host: str = typer.Option("127.0.0.1", help="Host for the worker server.", show_default=True),
host: str = typer.Option("127.0.0.1", help="Host for the toolkit server.", show_default=True),
port: int = typer.Option(
8002, "-p", "--port", help="Port for the worker server.", show_default=True
8002, "-p", "--port", help="Port for the toolkit server.", show_default=True
),
engine_config: str = typer.Option(
None, "-c", "--config", help="Path to the engine configuration file."
@ -483,7 +510,7 @@ def dev(
debug: bool = typer.Option(False, "-d", "--debug", help="Show debug information"),
) -> None:
"""
Start both the worker and engine servers.
Start both the toolkit server and engine servers.
"""
try:
start_servers(host, port, engine_config, engine_env=env_file, debug=debug)
@ -493,8 +520,9 @@ def dev(
typer.Exit(code=1)
# TODO: deprecate this next major version
@cli.command(help="Start a local Arcade Worker server", rich_help_panel="Launch", hidden=True)
@cli.command(
help="Start a server with locally installed Arcade tools", rich_help_panel="Launch", hidden=True
)
def workerup(
host: str = typer.Option(
"127.0.0.1",
@ -523,17 +551,21 @@ def workerup(
try:
serve_default_worker(
host, port, disable_auth=disable_auth, enable_otel=otel_enable, debug=debug
host,
port,
disable_auth=disable_auth,
enable_otel=otel_enable,
debug=debug,
)
except KeyboardInterrupt:
typer.Exit()
except Exception as e:
error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}"
error_message = f"❌ Failed to start Arcade Toolkit Server: {escape(str(e))}"
console.print(error_message, style="bold red")
typer.Exit(code=1)
@cli.command(help="Deploy worker to Arcade Cloud", rich_help_panel="Deployment")
@cli.command(help="Deploy toolkits to Arcade Cloud", rich_help_panel="Deployment")
def deploy(
deployment_file: str = typer.Option(
"worker.toml", "--deployment-file", "-d", help="The deployment file to deploy."

View file

@ -1,70 +1,209 @@
import asyncio
import logging
import os
import signal
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from importlib.metadata import version as get_pkg_version
from pathlib import Path
from typing import Any
import fastapi
import uvicorn
from loguru import logger
from rich.console import Console
from arcade.cli.constants import ARCADE_CONFIG_PATH
from arcade.cli.utils import (
build_tool_catalog,
discover_toolkits,
load_dotenv,
validate_and_get_config,
)
from arcade.core.telemetry import OTELHandler
from arcade.sdk import Toolkit
from arcade.worker.fastapi.worker import FastAPIWorker
console = Console(width=70, color_system="auto")
class InterceptHandler(logging.Handler):
def _run_mcp_stdio(
toolkits: list[Toolkit], *, logging_enabled: bool, env_file: str | None = None
) -> None:
"""Launch an MCP stdio server; blocks until it exits."""
from arcade.worker.mcp.stdio import StdioServer
# Load env vars before launching server (explicit path, config path, cwd)
if env_file:
load_dotenv(env_file, override=False)
else:
for candidate in [Path(ARCADE_CONFIG_PATH) / "arcade.env", Path.cwd() / "arcade.env"]:
if candidate.is_file():
load_dotenv(candidate, override=False)
break
# Set up middleware configuration for stdio mode
middleware_config = {
"stdio_mode": True, # Ensure logs go to stderr
}
catalog = build_tool_catalog(toolkits)
server = StdioServer(
catalog,
enable_logging=logging_enabled,
middleware_config=middleware_config,
)
try:
asyncio.run(server.run())
except KeyboardInterrupt:
logger.info("MCP server stopped by user.")
except Exception as exc:
logger.exception("Error while running MCP server: %s", exc)
raise
def _run_fastapi_server(
app: fastapi.FastAPI,
*,
host: str,
port: int,
workers: int,
timeout_keep_alive: int,
enable_otel: bool,
otel_handler: OTELHandler,
**uvicorn_kwargs: Any,
) -> None:
"""Run a FastAPI application via Uvicorn with graceful shutdown."""
class CustomUvicornServer(uvicorn.Server):
def install_signal_handlers(self) -> None:
# Disable Uvicorn's default signal handling; we manage it manually
pass
async def shutdown(self, sockets: Any = None) -> None:
logger.info("Initiating graceful shutdown...")
await super().shutdown(sockets=sockets)
config = uvicorn.Config(
app=app,
host=host,
port=port,
workers=workers,
timeout_keep_alive=timeout_keep_alive,
log_config=None,
**uvicorn_kwargs,
)
server = CustomUvicornServer(config=config)
async def _serve() -> None:
await server.serve()
async def _graceful_shutdown() -> None:
try:
logger.info("Shutting down server ...")
await server.shutdown()
# brief pause for connections to close gracefully
await asyncio.sleep(0.5)
finally:
if enable_otel:
otel_handler.shutdown()
logger.debug("Server shutdown complete.")
# Map signals to our graceful shutdown
loop = asyncio.get_event_loop()
for sig_name in (
"SIGINT",
"SIGTERM",
"SIGHUP",
"SIGUSR1",
"SIGUSR2",
"SIGWINCH",
"SIGBREAK",
):
if hasattr(signal, sig_name):
loop.add_signal_handler(
getattr(signal, sig_name), lambda: asyncio.create_task(_graceful_shutdown())
)
try:
asyncio.run(_serve())
except KeyboardInterrupt:
logger.info("Server stopped by user.")
finally:
if enable_otel:
otel_handler.shutdown()
class RichInterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno # type: ignore[assignment]
level = str(record.levelno)
# Find caller from where originated the logged message
frame, depth = sys._getframe(6), 6
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back # type: ignore[assignment]
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
# Let Loguru handle caller info; don't do stack inspection here
logger.opt(exception=record.exc_info).log(level, record.getMessage())
def setup_logging(log_level: int = logging.INFO) -> None:
def setup_logging(log_level: int = logging.INFO, mcp_mode: bool = False) -> None:
# Intercept everything at the root logger
logging.root.handlers = [InterceptHandler()]
logging.root.handlers = [RichInterceptHandler()]
logging.root.setLevel(log_level)
# Remove every other logger's handlers
# and propagate to root logger
# Remove every other logger's handlers and propagate to root logger
for name in logging.root.manager.loggerDict:
# Keep handlers for MCP logger if middleware handles it separately
if mcp_mode and name == "arcade.mcp":
continue
logging.getLogger(name).handlers = []
logging.getLogger(name).propagate = True
# Configure loguru with custom format, no colors
# Remove default handlers from Loguru
logger.remove()
# Configure main Loguru sink
# In MCP mode, all general console logs go to stderr to keep stdout clean
sink_destination = sys.stderr if mcp_mode else sys.stdout
# Configure loguru with a cleaner format and colors
if log_level == logging.DEBUG:
format_string = "<level>{level}</level> | <green>{time:HH:mm:ss}</green> | <cyan>{name}:{file}:{line: <4}</cyan> | <level>{message}</level>"
else:
format_string = (
"<level>{level}</level> | <green>{time:HH:mm:ss}</green> | <level>{message}</level>"
)
logger.configure(
handlers=[
{
"sink": sys.stdout,
"serialize": False,
"sink": sink_destination, # Redirect sink based on mcp_mode
"colorize": True,
"level": log_level,
"format": "{level} [{time:HH:mm:ss.SSS}] {message}"
+ (" {name}:{function}:{line}" if log_level <= logging.DEBUG else "")
+ ("\n{exception}" if "{exception}" in "{message}" else ""),
# Format that ensures timestamp on every line and better alignment
"format": format_string,
# Make sure multiline messages are handled properly
"enqueue": True,
"diagnose": True, # Disable traceback framing which adds noise
}
]
)
if mcp_mode:
logger.debug("Loguru sink configured for stderr in MCP mode.")
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def]
async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
try:
yield
except asyncio.CancelledError:
except (asyncio.CancelledError, KeyboardInterrupt):
# This is necessary to prevent an unhandled error
# when the user presses Ctrl+C
logger.debug("Lifespan cancelled.")
raise
def serve_default_worker(
@ -75,76 +214,78 @@ def serve_default_worker(
timeout_keep_alive: int = 5,
enable_otel: bool = False,
debug: bool = False,
mcp: bool = False,
**kwargs: Any,
) -> None:
"""
Get an instance of a FastAPI server with the Arcade Worker.
Get a default instance of a FastAPI server with the Arcade Worker
serving tools installed in the current Python environment.
Args:
host: The host to run the server on.
port: The port to run the server on.
disable_auth: Whether to disable authentication.
workers: The number of workers to run.
timeout_keep_alive: The timeout for keep-alive connections.
enable_otel: Whether to enable OpenTelemetry.
debug: Whether to enable debug logging.
mcp: Whether to run worker as MCP server over stdio.
"""
# Setup unified logging
setup_logging(log_level=logging.DEBUG if debug else logging.INFO)
toolkits = Toolkit.find_all_arcade_toolkits()
if not toolkits:
raise RuntimeError("No toolkits found in Python environment.")
# Setup unified logging first
version = get_pkg_version("arcade-ai")
validate_and_get_config()
setup_logging(log_level=logging.DEBUG if debug else logging.INFO, mcp_mode=mcp)
worker_secret = os.environ.get("ARCADE_WORKER_SECRET")
if not disable_auth and not worker_secret:
logger.warning(
"Warning: ARCADE_WORKER_SECRET environment variable is not set. Using 'dev' as the worker secret.",
)
worker_secret = worker_secret or "dev"
toolkits = discover_toolkits()
logger.info("Serving the following toolkits:")
for toolkit in toolkits:
if debug:
for name, tools in toolkit.tools.items():
for tool in tools:
logger.info(f" - {name}: {tool}")
else:
logger.info(f" - {toolkit.name}: {len(toolkit.tools)} tools")
# --- MCP stdio --------------------------------------------------
if mcp:
env_file = kwargs.pop("env_file", None)
_run_mcp_stdio(toolkits, logging_enabled=not debug, env_file=env_file)
return
# --- FastAPI HTTP --------------------------------------------------
app = fastapi.FastAPI(
title="Arcade Worker",
description="Arcade default Worker implementation using FastAPI.",
version="0.1.0",
lifespan=lifespan, # Use custom lifespan to catch errors, notably KeyboardInterrupt (Ctrl+C)
description="A worker for the Arcade platform",
version=version,
docs_url="/docs" if debug else None,
redoc_url="/redoc" if debug else None,
openapi_url="/openapi.json" if debug else None,
lifespan=lifespan,
)
otel_handler = OTELHandler(app, enable=enable_otel)
secret = os.getenv("ARCADE_WORKER_SECRET", None)
if secret is None:
logger.warning("No secret found for Arcade Worker")
logger.info("Setting secret to 'dev'. Set this in production")
secret = "dev" # noqa: S105
worker = FastAPIWorker(
app, secret=worker_secret, disable_auth=disable_auth, otel_meter=otel_handler.get_meter()
otel_handler = OTELHandler(
app, enable=enable_otel, log_level=logging.DEBUG if debug else logging.INFO
)
toolkit_tool_counts = {}
for toolkit in toolkits:
prev_tool_count = worker.catalog.get_tool_count()
worker.register_toolkit(toolkit)
new_tool_count = worker.catalog.get_tool_count()
toolkit_tool_counts[f"{toolkit.name} ({toolkit.package_name})"] = (
new_tool_count - prev_tool_count
)
logger.info("Serving the following toolkits:")
for name, tool_count in toolkit_tool_counts.items():
logger.info(f" - {name}: {tool_count} tools")
logger.info("Starting FastAPI server...")
class CustomUvicornServer(uvicorn.Server):
def install_signal_handlers(self) -> None:
pass # Disable Uvicorn's default signal handlers
config = uvicorn.Config(
_ = FastAPIWorker(
app=app,
secret=secret,
disable_auth=disable_auth,
otel_meter=otel_handler.get_meter(),
)
_run_fastapi_server(
app,
host=host,
port=port,
workers=workers,
timeout_keep_alive=timeout_keep_alive,
log_config=None,
enable_otel=enable_otel,
otel_handler=otel_handler,
**kwargs,
)
server = CustomUvicornServer(config=config)
async def serve() -> None:
await server.serve()
try:
asyncio.run(serve())
except KeyboardInterrupt:
logger.info("Server stopped by user.")
finally:
if enable_otel:
otel_handler.shutdown()
logger.debug("Server shutdown complete.")

View file

@ -1,5 +1,7 @@
import importlib.util
import ipaddress
import os
import shlex
import webbrowser
from dataclasses import dataclass
from datetime import datetime
@ -34,6 +36,11 @@ from arcade.sdk import ToolCatalog, Toolkit
console = Console()
# -----------------------------------------------------------------------------
# Shared helpers for the CLI
# -----------------------------------------------------------------------------
class OrderCommands(TyperGroup):
def list_commands(self, ctx: Context) -> list[str]: # type: ignore[override]
"""Return list of commands in the order appear."""
@ -666,3 +673,81 @@ def get_today_context() -> str:
today = datetime.now().strftime("%Y-%m-%d")
day_of_week = datetime.now().strftime("%A")
return f"Today is {today}, {day_of_week}."
def discover_toolkits() -> list[Toolkit]:
"""Return all Arcade toolkits installed in the active Python environment.
Raises:
RuntimeError: If no toolkits are found, mirroring the behaviour of Toolkit discovery elsewhere.
"""
toolkits = Toolkit.find_all_arcade_toolkits()
if not toolkits:
raise RuntimeError("No toolkits found in Python environment.")
return toolkits
def build_tool_catalog(toolkits: list[Toolkit]) -> ToolCatalog:
"""Construct a ``ToolCatalog`` populated with *toolkits*.
Args:
toolkits: Toolkits to register in the catalog.
Returns:
ToolCatalog
"""
catalog = ToolCatalog()
for tk in toolkits:
catalog.add_toolkit(tk)
return catalog
def _parse_line(line: str) -> tuple[str, str] | None:
"""
Return (key, value) if the line looks like KEY=VALUE, else None.
Handles quotes and escaped chars via shlex.
"""
if not line or line.startswith("#") or "=" not in line:
return None
key, raw_val = line.split("=", 1)
key = key.strip()
raw_val = raw_val.strip()
# Use shlex to handle "quoted strings with # hash" etc.
try:
value = shlex.split(raw_val)[0] if raw_val else ""
except ValueError:
# Fallback: naked value without shlex parsing
value = raw_val
return key, value
def load_dotenv(path: str | Path, *, override: bool = False) -> dict[str, str]:
"""
Load variables from *path* into os.environ.
Args:
path: .env file path
override: replace existing env vars if True
Returns:
The mapping of vars that were added/updated.
"""
path = Path(path).expanduser()
if not path.is_file():
return {}
loaded: dict[str, str] = {}
for raw in path.read_text().splitlines():
parsed = _parse_line(raw.strip())
if parsed is None:
continue
k, v = parsed
if override or k not in os.environ:
os.environ[k] = v
loaded[k] = v
return loaded

View file

@ -326,8 +326,16 @@ class ToolContext(BaseModel):
for item in items:
if item.key.lower() == normalized_key:
return item.value
raise ValueError(f"{item_name.capitalize()} {key} not found in context.")
def set_secret(self, key: str, value: str) -> None:
"""Set a secret for the tool invocation."""
if not self.secrets:
self.secrets = []
secret = ToolSecretItem(key=str(key), value=str(value))
self.secrets.append(secret)
class ToolCallRequest(BaseModel):
"""The request to call (invoke) a tool."""

View file

@ -1,14 +1,16 @@
from __future__ import annotations
import ast
import inspect
import re
from collections.abc import Iterable
from types import UnionType
from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, Literal, TypeVar, Union, get_args, get_origin
T = TypeVar("T")
def first_or_none(_type: type[T], iterable: Iterable[Any]) -> Optional[T]:
def first_or_none(_type: type[T], iterable: Iterable[Any]) -> T | None:
"""
Returns the first item in the iterable that is an instance of the given type, or None if no such item is found.
"""
@ -65,7 +67,7 @@ def does_function_return_value(func: Callable) -> bool:
Returns True if the given function returns a value, i.e. if it has a return statement with a value.
"""
try:
source: Optional[str] = inspect.getsource(func)
source: str | None = inspect.getsource(func)
except OSError:
# Workaround for parameterized unit tests that use a dynamically-generated function
source = getattr(func, "__source__", None)

View file

@ -175,7 +175,7 @@ class BaseWorker(Worker):
"""
Provide a health check that serves as a heartbeat of worker health.
"""
return {"status": "ok", "tool_count": len(self.catalog)}
return {"status": "ok", "tool_count": str(len(self.catalog))}
def register_routes(self, router: Router) -> None:
"""

View file

@ -5,6 +5,11 @@ from pydantic import BaseModel
from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
CatalogResponse = list[ToolDefinition]
HealthCheckResponse = dict[str, str]
JSONResponse = dict[str, Any]
ResponseData = CatalogResponse | ToolCallResponse | HealthCheckResponse
class RequestData(BaseModel):
"""
@ -17,7 +22,7 @@ class RequestData(BaseModel):
"""The path of the request."""
method: str
"""The method of the request."""
body_json: dict | None = None
body_json: JSONResponse | None = None
"""The deserialized body of the request (e.g. JSON)"""
@ -28,7 +33,13 @@ class Router(ABC):
@abstractmethod
def add_route(
self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True
self,
endpoint_path: str,
handler: Callable,
method: str,
require_auth: bool = True,
response_type: type[ResponseData] | None = None,
**kwargs: Any,
) -> None:
"""
Add a route to the router.
@ -43,7 +54,7 @@ class Worker(ABC):
"""
@abstractmethod
def get_catalog(self) -> list[ToolDefinition]:
def get_catalog(self) -> CatalogResponse:
"""
Get the catalog of tools available in the worker.
"""
@ -57,7 +68,7 @@ class Worker(ABC):
pass
@abstractmethod
def health_check(self) -> dict[str, Any]:
def health_check(self) -> HealthCheckResponse:
"""
Perform a health check of the worker
"""
@ -76,7 +87,7 @@ class WorkerComponent(ABC):
pass
@abstractmethod
async def __call__(self, request: RequestData) -> Any:
async def __call__(self, request: RequestData) -> ResponseData:
"""
Handle the request.
"""

View file

@ -1,9 +1,15 @@
from typing import Any
from opentelemetry import trace
from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
from arcade.worker.core.common import RequestData, Router, Worker, WorkerComponent
from arcade.worker.core.common import (
CatalogResponse,
HealthCheckResponse,
RequestData,
Router,
ToolCallRequest,
ToolCallResponse,
Worker,
WorkerComponent,
)
class CatalogComponent(WorkerComponent):
@ -14,9 +20,18 @@ class CatalogComponent(WorkerComponent):
"""
Register the catalog route with the router.
"""
router.add_route("tools", self, method="GET")
router.add_route(
"tools",
self,
method="GET",
response_type=CatalogResponse,
operation_id="get_catalog",
description="Get the catalog of tools",
summary="Get the catalog of tools",
tags=["Arcade"],
)
async def __call__(self, request: RequestData) -> list[ToolDefinition]:
async def __call__(self, request: RequestData) -> CatalogResponse:
"""
Handle the request to get the catalog.
"""
@ -33,7 +48,16 @@ class CallToolComponent(WorkerComponent):
"""
Register the call tool route with the router.
"""
router.add_route("tools/invoke", self, method="POST")
router.add_route(
"tools/invoke",
self,
method="POST",
response_type=ToolCallResponse,
operation_id="call_tool",
description="Call a tool",
summary="Call a tool",
tags=["Arcade"],
)
async def __call__(self, request: RequestData) -> ToolCallResponse:
"""
@ -54,9 +78,19 @@ class HealthCheckComponent(WorkerComponent):
"""
Register the health check route with the router.
"""
router.add_route("health", self, method="GET", require_auth=False)
router.add_route(
"health",
self,
method="GET",
require_auth=False,
response_type=HealthCheckResponse,
operation_id="health_check",
description="Check the health of the worker",
summary="Check the health of the worker",
tags=["Arcade"],
)
async def __call__(self, request: RequestData) -> dict[str, Any]:
async def __call__(self, request: RequestData) -> HealthCheckResponse:
"""
Handle the request for a health check.
"""

View file

@ -0,0 +1,3 @@
from .worker import FastAPIWorker
__all__ = ["FastAPIWorker"]

View file

@ -9,7 +9,7 @@ from arcade.worker.core.base import (
BaseWorker,
Router,
)
from arcade.worker.core.common import RequestData
from arcade.worker.core.common import RequestData, ResponseData, WorkerComponent
from arcade.worker.fastapi.auth import validate_engine_request
from arcade.worker.utils import is_async_callable
@ -30,12 +30,21 @@ class FastAPIWorker(BaseWorker):
"""
Initialize the FastAPIWorker with a FastAPI app instance.
If no secret is provided, the worker will use the ARCADE_WORKER_SECRET environment variable.
Args:
app: The FastAPI app to host the worker in
secret: Optional secret for authorization
disable_auth: Whether to disable authorization
otel_meter: Optional OpenTelemetry meter
"""
super().__init__(secret, disable_auth, otel_meter)
self.app = app
self.router = FastAPIRouter(app, self)
self.register_routes(self.router)
# Initialize components
self.components: list[WorkerComponent] = []
security = HTTPBearer() # Authorization: Bearer <xxx>
@ -81,7 +90,13 @@ class FastAPIRouter(Router):
return wrapped_handler
def add_route(
self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True
self,
endpoint_path: str,
handler: Callable,
method: str,
require_auth: bool = True,
response_type: type[ResponseData] | None = None,
**kwargs: Any,
) -> None:
"""
Add a route to the FastAPI application.
@ -90,4 +105,7 @@ class FastAPIRouter(Router):
f"{self.worker.base_path}/{endpoint_path}",
self._wrap_handler(handler, require_auth),
methods=[method],
response_model=response_type,
# **kwargs to pass to FastAPI
**kwargs,
)

View file

@ -0,0 +1,7 @@
"""
MCP (Model Context Protocol) support for Arcade workers.
"""
from arcade.worker.mcp.stdio import StdioServer
__all__ = ["StdioServer"]

View file

@ -0,0 +1,188 @@
import json
import logging
from enum import Enum
from typing import Any
from arcade.core.catalog import MaterializedTool
# Type aliases for MCP types
MCPTool = dict[str, Any]
MCPTextContent = dict[str, Any]
MCPImageContent = dict[str, Any]
MCPEmbeddedResource = dict[str, Any]
MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource
logger = logging.getLogger("arcade.mcp")
def create_mcp_tool(tool: MaterializedTool) -> dict[str, Any] | None: # noqa: C901
"""
Create an MCP-compatible tool definition from an Arcade tool.
Args:
tool: An Arcade tool object
Returns:
An MCP tool definition or None if the tool cannot be converted
"""
try:
name = getattr(tool.definition, "fully_qualified_name", None) or getattr(
tool.definition, "name", "unknown"
)
description = getattr(tool.definition, "description", "No description available")
# Extract parameters from the input model
parameters = {}
required = []
if (
hasattr(tool, "input_model")
and tool.input_model is not None
and hasattr(tool.input_model, "model_fields")
):
for field_name, field in tool.input_model.model_fields.items():
# Skip internal tool context parameters
if field_name == getattr(
tool.definition.input, "tool_context_parameter_name", None
):
continue
# Get field type information
field_type = getattr(field, "annotation", None)
field_type_name = "string" # default
# Safety check for field_type
if field_type is int:
field_type_name = "integer"
elif field_type is float:
field_type_name = "number"
elif field_type is bool:
field_type_name = "boolean"
elif field_type is list or str(field_type).startswith("list["):
field_type_name = "array"
elif field_type is dict or str(field_type).startswith("dict["):
field_type_name = "object"
# Get description with fallback
field_description = getattr(field, "description", None)
if not field_description:
field_description = f"Parameter: {field_name}"
# Create parameter definition
param_def = {
"type": field_type_name,
"description": field_description,
}
# Enum support: if the field annotation is an Enum, add allowed values
enum_type = None
if hasattr(field, "annotation"):
ann = field.annotation
# Handle typing.Annotated[Enum, ...]
if getattr(ann, "__origin__", None) is not None and hasattr(ann, "__args__"):
for arg in ann.__args__: # type: ignore[union-attr]
if isinstance(arg, type) and issubclass(arg, Enum):
enum_type = arg
break
elif isinstance(ann, type) and issubclass(ann, Enum):
enum_type = ann
if enum_type is not None:
param_def["enum"] = [e.value for e in enum_type]
parameters[field_name] = param_def
# In Pydantic v2, check if field is required based on default value
try:
if field.is_required():
required.append(field_name)
except (AttributeError, TypeError):
# Fallback if is_required() doesn't exist or fails
try:
has_default = getattr(field, "default", None) is not None
has_factory = getattr(field, "default_factory", None) is not None
if not (has_default or has_factory):
required.append(field_name)
except Exception:
# Ultimate fallback - assume required if we can't determine
logger.debug(
f"Could not determine if field {field_name} is required, assuming optional"
)
# Create the input schema with explicit properties and required fields
input_schema = {
"type": "object",
"properties": parameters,
}
# Only include required field if we have required parameters
if required:
input_schema["required"] = required
# Add annotations based on tool metadata
annotations = {}
# Use tool name as title if available
annotations["title"] = getattr(tool.definition, "title", str(name).replace(".", "_"))
# Determine hints based on tool properties
if hasattr(tool.definition, "metadata"):
metadata = tool.definition.metadata or {}
annotations["readOnlyHint"] = metadata.get("read_only", False)
annotations["destructiveHint"] = metadata.get("destructive", False)
annotations["idempotentHint"] = metadata.get("idempotent", True)
annotations["openWorldHint"] = metadata.get("open_world", False)
# Create the final tool definition
tool_def: MCPTool = {
"name": str(name).replace(".", "_"),
"description": str(description),
"inputSchema": input_schema,
"annotations": annotations,
}
logger.debug(f"Created tool definition for {name}")
except Exception:
logger.exception(
f"Error creating MCP tool definition for {getattr(tool, 'name', str(tool))}"
)
return None
return tool_def
def convert_to_mcp_content(value: Any) -> list[dict[str, Any]]:
"""
Convert a Python value to MCP-compatible content.
"""
if value is None:
return []
if isinstance(value, (str, bool, int, float)):
return [{"type": "text", "text": str(value)}]
if isinstance(value, (dict, list)):
return [{"type": "text", "text": json.dumps(value)}]
# Default fallback
return [{"type": "text", "text": str(value)}]
def _map_type_to_json_schema_type(val_type: str) -> str:
"""
Map Arcade value types to JSON schema types.
Args:
val_type: The Arcade value type as a string.
Returns:
The corresponding JSON schema type as a string.
"""
mapping: dict[str, str] = {
"string": "string",
"integer": "integer",
"number": "number",
"boolean": "boolean",
"json": "object",
"array": "array",
}
return mapping.get(val_type, "string")

View file

@ -0,0 +1,215 @@
import json
import logging
import sys
import time
from typing import Any
from arcade.worker.mcp.types import (
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
MCPMessage,
)
logger = logging.getLogger("arcade.mcp")
class MCPLoggingMiddleware:
"""
Middleware for logging MCP requests and responses.
Logs request and response details, including timing and errors.
"""
def __init__(
self,
log_level: str = "INFO",
log_request_body: bool = False,
log_response_body: bool = False,
log_errors: bool = True,
min_duration_to_log_ms: int = 0,
stdio_mode: bool = False,
) -> None:
"""
Initialize the MCP logging middleware.
Args:
log_level: Logging level (default: "INFO").
log_request_body: Whether to log full request bodies (default: False).
log_response_body: Whether to log full response bodies (default: False).
log_errors: Whether to log errors at ERROR level (default: True).
min_duration_to_log_ms: Minimum duration in ms to log (0 logs all).
stdio_mode: Whether running in stdio mode (redirects logs to stderr).
"""
self.log_level = getattr(logging, log_level.upper())
self.log_request_body = log_request_body
self.log_response_body = log_response_body
self.log_errors = log_errors
self.min_duration_to_log_ms = min_duration_to_log_ms
self.request_log_format = "[MCP>] {method}{params_str} (id: {id})"
self.response_log_format = "[MCP<] {method} completed in {duration:.2f}ms (id: {id})"
self.error_log_format = "[MCP!] {method} error: {error} (id: {id})"
# If in stdio mode, ensure MCP logs go to stderr
if stdio_mode:
self._redirect_logs_to_stderr()
# Log that middleware is initialized
logger.debug(f"MCP logging middleware initialized (level: {log_level})")
def _redirect_logs_to_stderr(self) -> None:
"""Redirect MCP logs to stderr to avoid interfering with stdio communication."""
# Remove any existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Add a stderr handler
stderr_handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stderr_handler.setFormatter(formatter)
logger.addHandler(stderr_handler)
# Ensure we're not propagating to root logger which might log to stdout
logger.propagate = False
logger.debug("MCP logs redirected to stderr for stdio mode")
def __call__(self, message: MCPMessage, direction: str) -> MCPMessage:
"""
Process and log an MCP message.
Args:
message: The MCP message to process.
direction: The message direction ("request" or "response").
Returns:
The original message (unmodified).
"""
if direction == "request":
self._log_request(message)
else:
self._log_response(message)
return message
def _log_request(self, message: MCPMessage) -> None:
"""
Log an MCP request message.
"""
if not isinstance(message, JSONRPCRequest):
logger.debug(f"Ignoring non-request message: {type(message).__name__}")
return
try:
# Store request start time for duration calculation
message._mcp_start_time = time.time() # type: ignore[attr-defined]
# Format parameters for logging
params_str = ""
if self.log_request_body and hasattr(message, "params") and message.params is not None:
params_str = f": {self._format_params(message.params)}"
log_msg = self.request_log_format.format(
method=message.method, params_str=params_str, id=getattr(message, "id", "none")
)
logger.log(self.log_level, log_msg)
except Exception:
logger.exception("Error logging request")
def _log_response(self, message: MCPMessage) -> None:
"""
Log an MCP response message.
"""
if not isinstance(message, (JSONRPCResponse, JSONRPCError)):
logger.debug(f"Ignoring non-response message: {type(message).__name__}")
return
try:
# Calculate request duration if we have the start time
duration_ms = 0
request = getattr(message, "_request", None)
if request:
start_time = getattr(request, "_mcp_start_time", None)
if start_time:
duration_ms = (time.time() - start_time) * 1000
else:
start_time = getattr(message, "_mcp_start_time", None)
if start_time:
duration_ms = (time.time() - start_time) * 1000
# Skip if below minimum duration threshold
if self.min_duration_to_log_ms > 0 and duration_ms < self.min_duration_to_log_ms:
return
# Handle error responses
if hasattr(message, "error") and message.error is not None:
if self.log_errors:
error_msg = self.error_log_format.format(
method=getattr(message, "method", "unknown"),
error=getattr(message.error, "message", str(message.error)),
id=getattr(message, "id", "none"),
)
logger.error(error_msg)
return
# Log successful response
result_str = ""
if self.log_response_body and hasattr(message, "result"):
result_str = f": {self._format_result(message.result)}"
log_msg = self.response_log_format.format(
method=getattr(message, "method", "unknown"),
duration=duration_ms,
id=getattr(message, "id", "none"),
result_str=result_str,
)
logger.log(self.log_level, log_msg)
except Exception:
logger.exception("Error logging response")
def _format_params(self, params: Any) -> str:
"""
Format parameters for logging.
"""
try:
if isinstance(params, dict):
# Handle common MCP params specially
if "name" in params and "arguments" in params:
return f"{params['name']}({json.dumps(params.get('arguments', {}))})"
return json.dumps(params)
return str(params)
except Exception:
logger.debug(f"Error formatting params {params!s}")
return str(params)
def _format_result(self, result: Any) -> str:
"""
Format result for logging.
"""
try:
if isinstance(result, dict):
return json.dumps(result)
return str(result)
except Exception as e:
logger.debug(f"Error formatting result {e!s}")
return str(result)
def create_mcp_logging_middleware(**config: Any) -> MCPLoggingMiddleware:
"""
Create an MCP logging middleware with the given configuration.
Args:
**config: Configuration options.
Returns:
An MCPLoggingMiddleware instance.
"""
return MCPLoggingMiddleware(
log_level=config.get("log_level", "INFO"),
log_request_body=config.get("log_request_body", False),
log_response_body=config.get("log_response_body", False),
log_errors=config.get("log_errors", True),
min_duration_to_log_ms=config.get("min_duration_to_log_ms", 0),
stdio_mode=config.get("stdio_mode", False),
)

View file

@ -0,0 +1,83 @@
import inspect
import json
import logging
from typing import Any, Callable, TypeVar
from arcade.worker.mcp.types import InitializeRequest, JSONRPCRequest, MCPMessage
logger = logging.getLogger("arcade.mcp")
T = TypeVar("T")
# Type definition for middleware functions
MessageProcessor = Callable[[Any, str], Any]
class MCPMessageProcessor:
"""
Processes MCP messages through a chain of middleware.
Supports both synchronous and asynchronous middleware.
"""
def __init__(self) -> None:
self.middleware: list[Callable[[MCPMessage, str], Any]] = []
def add_middleware(self, mw: Callable[[MCPMessage, str], Any]) -> None:
self.middleware.append(mw)
async def process(self, message: Any, direction: str) -> Any: # noqa: C901
# First, try to parse the message if it's a string
if isinstance(message, str):
# Strip any whitespace including newlines
message = message.strip()
if not message:
return None
try:
parsed = json.loads(message)
if isinstance(parsed, dict):
method = parsed.get("method")
# Convert to appropriate message type
if method == "initialize" and "id" in parsed:
logger.debug(f"Parsed initialize request: {parsed}")
message = InitializeRequest(**parsed)
elif method and method.startswith("notifications/"):
# It's a notification, log it but pass through as dict
logger.debug(f"Received notification: {method}")
# Keep as parsed dict to avoid validation errors on unknown notifications
message = parsed
elif "method" in parsed and "id" in parsed:
# Regular method request
logger.debug(f"Parsed method request: {method}")
message = JSONRPCRequest(**parsed)
# Other message types can be handled similarly
except json.JSONDecodeError:
logger.warning(f"Failed to parse message as JSON: {message[:100]}...")
except Exception:
logger.exception("Error processing message")
# Process through middleware chain
result = message
for mw in self.middleware:
try:
if inspect.iscoroutinefunction(mw):
result = await mw(result, direction)
else:
result = mw(result, direction)
except Exception:
logger.exception(f"Error in middleware {mw}")
return result
async def process_request(self, message: Any) -> Any:
return await self.process(message, "request")
async def process_response(self, message: Any) -> Any:
return await self.process(message, "response")
def create_message_processor(*middleware: MessageProcessor) -> MCPMessageProcessor:
processor = MCPMessageProcessor()
for m in middleware:
if m is not None:
processor.add_middleware(m)
return processor

View file

@ -0,0 +1,601 @@
import asyncio
import logging
import os
import uuid
from enum import Enum
from typing import Any, Callable, Union
from arcadepy import ArcadeError, AsyncArcade
from arcadepy.types.auth_authorize_params import AuthRequirement, AuthRequirementOauth2
from arcadepy.types.shared import AuthorizationResponse
from arcade.core.catalog import MaterializedTool, ToolCatalog
from arcade.core.executor import ToolExecutor
from arcade.core.schema import ToolAuthorizationContext, ToolContext
from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool
from arcade.worker.mcp.logging import create_mcp_logging_middleware
from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor
from arcade.worker.mcp.types import (
CallToolRequest,
CallToolResponse,
CallToolResult,
CancelRequest,
Implementation,
InitializeRequest,
InitializeResponse,
InitializeResult,
JSONRPCError,
JSONRPCResponse,
ListPromptsRequest,
ListPromptsResponse,
ListResourcesRequest,
ListResourcesResponse,
ListToolsRequest,
ListToolsResponse,
ListToolsResult,
PingRequest,
PingResponse,
ProgressNotification,
ServerCapabilities,
ShutdownRequest,
ShutdownResponse,
Tool,
)
logger = logging.getLogger("arcade.mcp")
MCP_PROTOCOL_VERSION = "2024-11-05"
class MessageMethod(str, Enum):
"""Enumeration of supported MCP message methods"""
PING = "ping"
INITIALIZE = "initialize"
LIST_TOOLS = "tools/list"
CALL_TOOL = "tools/call"
PROGRESS = "progress"
CANCEL = "$/cancelRequest"
SHUTDOWN = "shutdown"
LIST_RESOURCES = "resources/list"
LIST_PROMPTS = "prompts/list"
class MCPServer:
"""
Unified async MCP server that manages connections, middleware, and tool invocation.
Handles protocol-level messages (ping, initialize, list_tools, call_tool, etc.).
"""
def __init__(
self,
tool_catalog: Any,
enable_logging: bool = True,
**client_kwargs: dict[str, Any],
) -> None:
"""
Initialize the MCP server.
Args:
tool_catalog: Catalog of available tools
**client_kwargs: Additional arguments to pass to the AsyncArcade client
"""
self.tool_catalog: ToolCatalog = tool_catalog
self.message_processor: MCPMessageProcessor = create_message_processor()
# Pop middleware_config from client_kwargs regardless of logging state,
# as it's internal config not meant for AsyncArcade.
middleware_config = client_kwargs.pop("middleware_config", {})
if enable_logging:
# Create and add the logging middleware if logging is enabled.
# Note: enable_logging must be True for this middleware (and its stdio_mode behavior)
# to be activated.
self.message_processor.add_middleware(
create_mcp_logging_middleware(**middleware_config)
)
self._shutdown: bool = False
# Initialize AsyncArcade with the *remaining* client_kwargs
self.arcade = AsyncArcade(**client_kwargs) # type: ignore[arg-type]
# Initialize handler dispatch table
self._method_handlers: dict[str, Callable] = {
MessageMethod.PING: self._handle_ping,
MessageMethod.INITIALIZE: self._handle_initialize,
MessageMethod.LIST_TOOLS: self._handle_list_tools,
MessageMethod.CALL_TOOL: self._handle_call_tool,
MessageMethod.PROGRESS: self._handle_progress,
MessageMethod.CANCEL: self._handle_cancel,
MessageMethod.SHUTDOWN: self._handle_shutdown,
MessageMethod.LIST_RESOURCES: self._handle_list_resources,
MessageMethod.LIST_PROMPTS: self._handle_list_prompts,
}
async def run_connection(
self,
read_stream: Any,
write_stream: Any,
init_options: Any,
) -> None:
"""
Handle a single MCP connection (SSE or stdio).
Args:
read_stream: Async iterable yielding incoming messages.
write_stream: Object with an async send(message) method.
init_options: Initialization options for the connection.
"""
# Generate a user ID if possible
user_id = self._get_user_id(init_options)
try:
logger.info(f"Starting MCP connection for user {user_id}")
async for message in read_stream:
# Process the message
response = await self.handle_message(message, user_id=user_id)
# Skip sending responses for None (e.g., notifications)
if response is None:
continue
await self._send_response(write_stream, response)
except asyncio.CancelledError:
logger.info("Connection cancelled")
except Exception:
logger.exception("Error in connection")
def _get_user_id(self, init_options: Any) -> str:
"""
Get the user ID for a connection.
Args:
init_options: Initialization options for the connection
Returns:
A user ID string
"""
try:
from arcade.core.config import config
# Prefer config.user.email if available
if config.user and config.user.email:
return config.user.email
except ValueError:
logger.debug("No logged in user for MCP Server")
fallback = str(uuid.uuid4())
if os.environ.get("ARCADE_USER_ID", None):
return os.environ.get("ARCADE_USER_ID", fallback)
elif isinstance(init_options, dict):
user_id = init_options.get("user_id")
if user_id:
return str(user_id)
# Fallback to random UUID
return str(fallback)
async def _send_response(self, write_stream: Any, response: Any) -> None:
"""
Send a response to the client.
Args:
write_stream: Stream to write the response to
response: Response object to send
"""
# Ensure the response is properly serialized to JSON
if hasattr(response, "model_dump_json"):
# It's a Pydantic model, serialize it
json_response = response.model_dump_json()
# Ensure it ends with a newline for JSON-RPC-over-stdio
if not json_response.endswith("\n"):
json_response += "\n"
logger.debug(f"Sending response: {json_response[:200]}...")
await write_stream.send(json_response)
elif isinstance(response, dict):
# It's a dict, convert to JSON
import json
json_response = json.dumps(response)
# Ensure it ends with a newline for JSON-RPC-over-stdio
if not json_response.endswith("\n"):
json_response += "\n"
logger.debug(f"Sending response: {json_response[:200]}...")
await write_stream.send(json_response)
else:
# It's already a string or something else
response_str = str(response)
# Ensure it ends with a newline for JSON-RPC-over-stdio
if not response_str.endswith("\n"):
response_str += "\n"
logger.debug(f"Sending raw response type: {type(response)}")
await write_stream.send(response_str)
async def handle_message(self, message: Any, user_id: str | None = None) -> Any:
"""
Handle an incoming MCP message. Processes it through middleware and dispatches
to the appropriate handler based on the message method.
Args:
message: The raw incoming message
user_id: Optional user ID for authentication
Returns:
A properly formatted response message
"""
# Pre-process message through middleware
processed = await self.message_processor.process_request(message)
# Handle special case for JSON string initialize requests
if isinstance(processed, str):
try:
import json
parsed = json.loads(processed)
if (
isinstance(parsed, dict)
and parsed.get("method") == MessageMethod.INITIALIZE
and "id" in parsed
):
# This is an initialize request
init_response = await self._handle_initialize(InitializeRequest(**parsed))
return init_response
except Exception:
logger.exception("Error processing JSON string")
# Not parseable JSON, continue with normal processing
pass
# Check if it's a notification
if hasattr(processed, "method"):
method = getattr(processed, "method", None)
# Handle notifications (methods starting with "notifications/")
if method and method.startswith("notifications/"):
await self._handle_notification(method, processed)
return None
# Handle regular methods using the dispatch table
if method in self._method_handlers:
# If it's a call_tool request, we need to pass the user_id
if method == MessageMethod.CALL_TOOL:
return await self._method_handlers[method](processed, user_id=user_id)
# For other methods, just pass the processed message
return await self._method_handlers[method](processed)
# Unknown method
return JSONRPCError(
id=getattr(processed, "id", None),
error={
"code": -32601,
"message": f"Method not found: {method}",
},
)
# If it's not a method request, just pass it through
return processed
async def _handle_notification(self, method: str, message: Any) -> None:
"""
Handle notification messages.
Args:
method: The notification method
message: The notification message
"""
if method == "notifications/cancelled":
logger.info(f"Request cancelled: {getattr(message, 'params', {})}")
else:
logger.debug(f"Received notification: {method}")
async def _handle_ping(self, message: PingRequest) -> PingResponse:
"""
Handle a ping request and return a pong response.
Args:
message: The ping request
Returns:
A properly formatted pong response
"""
return PingResponse(id=message.id)
async def _handle_initialize(self, message: InitializeRequest) -> InitializeResponse:
"""
Handle an initialize request and return a proper initialize response.
Args:
message: The initialize request
Returns:
A properly formatted initialize response
"""
# Create the result data
result = InitializeResult(
protocolVersion=MCP_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="Arcade MCP Worker", version="0.1.0"),
instructions="Arcade MCP Worker initialized.",
)
# Construct proper response with result field
response = InitializeResponse(id=message.id, result=result)
logger.debug(f"Initialize response: {response.model_dump_json()}")
return response
async def _handle_list_tools(
self, message: ListToolsRequest
) -> Union[ListToolsResponse, JSONRPCError]:
"""
Handle a tools/list request and return a list of available tools.
Args:
message: The tools/list request
Returns:
A properly formatted tools/list response or error
"""
try:
# Get all tools from the catalog
tools = []
tool_conversion_errors = []
for tool in self.tool_catalog:
try:
mcp_tool = create_mcp_tool(tool)
if mcp_tool:
tools.append(mcp_tool)
except Exception:
tool_name = getattr(tool, "name", str(tool))
logger.exception(f"Error converting tool: {tool_name}")
tool_conversion_errors.append(tool_name)
# Log summary if we had errors
if tool_conversion_errors:
logger.warning(
f"Failed to convert {len(tool_conversion_errors)} tools: {tool_conversion_errors}"
)
# Create tool objects with exception handling for each one
tool_objects = []
for t in tools:
try:
# Make input schema optional if missing
tool_dict = dict(t)
if "inputSchema" not in tool_dict:
tool_dict["inputSchema"] = {"type": "object", "properties": {}}
tool_objects.append(Tool(**tool_dict))
except Exception:
logger.exception(f"Error creating Tool object for {t.get('name', 'unknown')}")
# Return successful response with the tools we were able to convert
result = ListToolsResult(tools=tool_objects)
response = ListToolsResponse(id=message.id, result=result)
except Exception:
logger.exception("Error listing tools")
return JSONRPCError(
id=message.id,
error={
"code": -32603,
"message": "Internal error listing tools",
},
)
return response
async def _handle_call_tool(
self, message: CallToolRequest, user_id: str | None = None
) -> CallToolResponse:
"""
Handle a tools/call request to execute a tool.
Args:
message: The tools/call request
user_id: Optional user ID for authentication
Returns:
A properly formatted tools/call response
"""
tool_name: str = message.params["name"]
# Extract input from the correct field
input_params: dict[str, Any] = message.params.get("input", {})
if not input_params:
input_params = message.params.get("arguments", {})
logger.info(f"Handling tool call for {tool_name}")
try:
tool = self.tool_catalog.get_tool_by_name(tool_name, separator="_")
tool_context = ToolContext()
# Set up context with secrets
if tool.definition.requirements and tool.definition.requirements.secrets:
self._setup_tool_secrets(tool, tool_context)
# Handle authorization if needed
requirement = self._get_auth_requirement(tool)
if requirement:
auth_result = await self._check_authorization(requirement, user_id=user_id)
if auth_result.status != "completed":
return CallToolResponse(
id=message.id,
result=CallToolResult(content=[{"type": "text", "text": auth_result.url}]),
)
else:
tool_context.authorization = ToolAuthorizationContext(
token=auth_result.context.token if auth_result.context else None,
user_info={"user_id": user_id} if user_id else {},
)
# Execute the tool
logger.debug(f"Executing tool {tool_name} with input: {input_params}")
result = await ToolExecutor.run(
func=tool.tool,
definition=tool.definition,
input_model=tool.input_model,
output_model=tool.output_model,
context=tool_context,
**input_params,
)
logger.debug(f"Tool result: {result}")
if result.value:
return CallToolResponse(
id=message.id,
result=CallToolResult(content=convert_to_mcp_content(result.value)),
)
else:
error = result.error or "Error calling tool"
logger.error(f"Tool {tool_name} returned error: {error}")
return CallToolResponse(
id=message.id,
result=CallToolResult(
content=[{"type": "text", "text": convert_to_mcp_content(error)}]
),
)
except Exception as e:
logger.exception(f"Error calling tool {tool_name}")
error = f"Error calling tool {tool_name}: {e!s}"
return CallToolResponse(
id=message.id,
result=CallToolResult(
content=[{"type": "text", "text": convert_to_mcp_content(error)}]
),
)
def _setup_tool_secrets(self, tool: Any, tool_context: ToolContext) -> None:
"""
Set up tool secrets in the tool context.
Args:
tool: The tool to set up secrets for
tool_context: The tool context to update
"""
for secret in tool.definition.requirements.secrets:
value = os.environ.get(secret.key)
if value is not None:
tool_context.set_secret(secret.key, value)
async def _handle_progress(self, message: ProgressNotification) -> JSONRPCResponse:
"""
Handle a progress notification.
Args:
message: The progress notification
Returns:
A response acknowledging the notification
"""
return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True})
async def _handle_cancel(self, message: CancelRequest) -> JSONRPCResponse:
"""
Handle a cancel request.
Args:
message: The cancel request
Returns:
A response acknowledging the cancellation
"""
return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True})
async def _handle_shutdown(self, message: ShutdownRequest) -> ShutdownResponse:
"""
Handle a shutdown request.
Args:
message: The shutdown request
Returns:
A response acknowledging the shutdown request
"""
# Schedule a task to shutdown the server after sending the response
proc = asyncio.create_task(self.shutdown())
proc.add_done_callback(lambda _: logger.info("MCP server shutdown complete"))
return ShutdownResponse(id=message.id, result={"ok": True})
async def _handle_list_resources(self, message: ListResourcesRequest) -> ListResourcesResponse:
"""
Handle a resources/list request.
Args:
message: The resources/list request
Returns:
A properly formatted resources/list response
"""
return ListResourcesResponse(id=message.id, result={"resources": []})
async def _handle_list_prompts(self, message: ListPromptsRequest) -> ListPromptsResponse:
"""
Handle a prompts/list request.
Args:
message: The prompts/list request
Returns:
A properly formatted prompts/list response
"""
return ListPromptsResponse(id=message.id, result={"prompts": []})
def _get_auth_requirement(self, tool: MaterializedTool) -> AuthRequirement | None:
"""
Get the authentication requirement for a tool.
Args:
tool: The tool to get the requirement for
Returns:
An authentication requirement or None if not required
"""
req = tool.definition.requirements.authorization
if not req:
return None
if not req.provider_id and not req.provider_type:
return None
if hasattr(req, "oauth2") and req.oauth2:
return AuthRequirement(
provider_id=str(req.provider_id),
provider_type=str(req.provider_type),
oauth2=AuthRequirementOauth2(scopes=req.oauth2.scopes or []),
)
return AuthRequirement(
provider_id=str(req.provider_id),
provider_type=str(req.provider_type),
)
async def _check_authorization(
self, auth_requirement: AuthRequirement, user_id: str | None = None
) -> AuthorizationResponse:
"""
Check if a tool is authorized for a user.
Args:
tool: The tool to check authorization for
user_id: The user ID to check authorization for
Returns:
An authorization response
Raises:
RuntimeError: If the tool has no authorization requirement
Exception: If authorization fails
"""
try:
response = await self.arcade.auth.authorize(
auth_requirement=auth_requirement,
user_id=user_id or "anonymous",
)
logger.debug(f"Authorization response: {response}")
except ArcadeError:
logger.exception("Error authorizing tool")
raise
return response
async def shutdown(self) -> None:
"""Shutdown the server."""
self._shutdown = True
logger.info("MCP server shutdown complete")

View file

@ -0,0 +1,185 @@
import asyncio
import logging
import queue
import signal
import sys
import threading
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, TypeVar
if TYPE_CHECKING:
pass
from arcade.worker.mcp.server import MCPServer
logger = logging.getLogger("arcade.mcp")
T = TypeVar("T")
def stdio_reader(stdin: object, q: queue.Queue[str | None]) -> None:
"""Read lines from stdin and put them into a queue."""
for line in stdin: # type: ignore[attr-defined]
q.put(line)
q.put(None)
def stdio_writer(stdout: object, q: queue.Queue[str | None]) -> None:
"""Write messages from a queue to stdout."""
try:
while True:
msg = q.get()
if msg is None:
break
# Ensure message ends with a newline for proper JSON-RPC-over-stdio
if not msg.endswith("\n"):
msg += "\n"
stdout.write(msg) # type: ignore[attr-defined]
stdout.flush() # type: ignore[attr-defined]
except Exception:
logger.exception("Error in stdio writer")
class StdioServer(MCPServer):
"""
Stdio server that handles signals and cleanup.
"""
def __init__(
self,
tool_catalog: Any,
enable_logging: bool = True,
**client_kwargs: dict[str, Any],
):
# Set up stdio-specific middleware configuration
middleware_config = client_kwargs.get("middleware_config", {})
middleware_config["stdio_mode"] = True
client_kwargs["middleware_config"] = middleware_config
super().__init__(tool_catalog, enable_logging, **client_kwargs)
self.read_q: queue.Queue[str | None] = queue.Queue()
self.write_q: queue.Queue[str | None] = queue.Queue()
self.reader_thread: threading.Thread | None = None
self.writer_thread: threading.Thread | None = None
self.running = False
self.shutdown_event = asyncio.Event()
def start_io_threads(self) -> None:
"""Start stdio reader and writer threads."""
self.reader_thread = threading.Thread(
target=self._stdio_reader, args=(sys.stdin, self.read_q), daemon=True
)
self.writer_thread = threading.Thread(
target=self._stdio_writer, args=(sys.stdout, self.write_q), daemon=True
)
self.reader_thread.start()
self.writer_thread.start()
def _stdio_reader(self, stdin: object, q: queue.Queue[str | None]) -> None:
"""Read lines from stdin and put them into a queue."""
try:
for line in stdin: # type: ignore[attr-defined]
if not self.running:
break
q.put(line)
except Exception:
logger.exception("Error in stdio reader")
finally:
q.put(None) # Signal EOF
def _stdio_writer(self, stdout: object, q: queue.Queue[str | None]) -> None:
"""Write messages from a queue to stdout."""
try:
while self.running:
msg = q.get()
if msg is None:
break
stdout.write(msg) # type: ignore[attr-defined]
stdout.flush() # type: ignore[attr-defined]
except Exception:
logger.exception("Error in stdio writer")
async def _read_stream(self) -> AsyncGenerator[str, None]:
"""Async generator that yields lines from the read queue."""
while self.running:
try:
line = await asyncio.to_thread(self.read_q.get)
if line is None:
break
yield line
except asyncio.CancelledError:
break
except Exception:
logger.exception("Error reading from stdin")
break
async def shutdown(self) -> None:
"""Gracefully shut down the server."""
if not self.running:
return
logger.info("Shutting down stdio server...")
self.running = False
# Signal shutdown to MCP server
await self.shutdown()
# Clean up IO queues and threads
try:
if self.read_q:
self.read_q.put(None)
if self.write_q:
self.write_q.put(None)
except Exception:
logger.exception("Error during shutdown")
# Signal completion
self.shutdown_event.set()
logger.info("Stdio server shutdown complete")
async def run(self) -> None:
"""Run the stdio server with signal handling."""
self.running = True
# Set up signal handlers
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown()))
except NotImplementedError:
# Windows doesn't support POSIX signals
if sys.platform == "win32":
logger.warning("Signal handling not fully supported on Windows")
else:
logger.warning(f"Failed to set up signal handler for {sig}")
# Start IO threads
self.start_io_threads()
logger.info("Starting MCP server with stdio transport")
# Create WriteStream class for MCP server
class WriteStream:
async def send(self_, message: str) -> None:
if self.running:
await asyncio.to_thread(self.write_q.put, message)
try:
# Run MCP server connection
await self.run_connection(self._read_stream(), WriteStream(), None)
except asyncio.CancelledError:
# Handle cancellation
logger.info("Server operation cancelled")
except KeyboardInterrupt:
# Handle keyboard interrupt
logger.info("Keyboard interrupt received")
except Exception:
# Handle unexpected errors
logger.exception("Unexpected error")
finally:
# Ensure we clean up
await self.shutdown()
# Wait for shutdown to complete
await self.shutdown_event.wait()

View file

@ -0,0 +1,383 @@
import json
from collections.abc import Callable
from typing import (
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
Union,
)
from pydantic import BaseModel, ConfigDict, Field
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = str | int
AnyFunction: TypeAlias = Callable[..., Any]
class RequestParams(BaseModel):
class Meta(BaseModel):
progressToken: ProgressToken | None = None
model_config = ConfigDict(extra="allow")
meta: Meta | None = Field(alias="_meta", default=None)
model_config = ConfigDict(extra="allow")
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
meta: Meta | None = Field(alias="_meta", default=None)
model_config = ConfigDict(extra="allow")
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar(
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
)
MethodT = TypeVar("MethodT", bound=str)
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
method: MethodT
params: RequestParamsT
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
model_config = ConfigDict(extra="allow")
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
method: MethodT
params: NotificationParamsT
model_config = ConfigDict(extra="allow")
class Result(BaseModel):
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
model_config = ConfigDict(extra="allow")
class PaginatedResult(Result):
nextCursor: Cursor | None = None
model_config = ConfigDict(extra="allow")
class JSONRPCMessage(BaseModel):
"""Base class for all JSON-RPC messages."""
model_config = ConfigDict(extra="allow")
jsonrpc: str = Field(default="2.0", frozen=True)
class JSONRPCRequest(JSONRPCMessage):
"""A JSON-RPC request message."""
id: str | int | None = None
method: str
params: dict[str, Any] | None = None
class JSONRPCResponse(JSONRPCMessage):
"""A JSON-RPC response message."""
id: str | int | None
result: Any = None
error: dict[str, Any] | None = None
def model_dump_json(self, **kwargs: Any) -> str:
"""Convert to JSON string with proper formatting."""
# Convert to dict
data = {
"jsonrpc": self.jsonrpc,
"id": self.id,
}
# Add result if present
if self.result is not None:
# Check if result is a Pydantic model
if hasattr(self.result, "model_dump"):
data["result"] = self.result.model_dump(exclude_none=True)
# Check if result is already a dict/list/primitive
elif (
isinstance(self.result, (dict, list, str, int, float, bool)) or self.result is None
):
data["result"] = self.result # type: ignore[assignment]
else:
# Try to convert using str() as a fallback
data["result"] = str(self.result)
# Add error if present
if self.error is not None:
data["error"] = self.error # type: ignore[assignment]
return json.dumps(data, ensure_ascii=False)
class JSONRPCError(JSONRPCMessage):
"""A JSON-RPC error message."""
id: str | int | None
error: dict[str, Any]
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603
class ErrorData(BaseModel):
code: int
message: str
data: Any | None = None
model_config = ConfigDict(extra="allow")
JSONRPCMessageBaseModel = BaseModel | JSONRPCRequest | JSONRPCResponse | JSONRPCError
class EmptyResult(Result):
pass
class Implementation(BaseModel):
"""Describes the server or client implementation."""
name: str
version: str
model_config = ConfigDict(extra="allow")
class RootsCapability(BaseModel):
listChanged: bool | None = None
model_config = ConfigDict(extra="allow")
class SamplingCapability(BaseModel):
model_config = ConfigDict(extra="allow")
class ClientCapabilities(BaseModel):
experimental: dict[str, dict[str, Any]] | None = None
sampling: SamplingCapability | None = None
roots: RootsCapability | None = None
model_config = ConfigDict(extra="allow")
class PromptsCapability(BaseModel):
listChanged: bool | None = None
model_config = ConfigDict(extra="allow")
class ResourcesCapability(BaseModel):
subscribe: bool | None = None
listChanged: bool | None = None
model_config = ConfigDict(extra="allow")
class ToolsCapability(BaseModel):
listChanged: bool | None = None
model_config = ConfigDict(extra="allow")
class LoggingCapability(BaseModel):
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel):
"""Describes the server's capabilities."""
model_config = ConfigDict(extra="allow")
tools: dict[str, Any] | None = None
resources: dict[str, Any] | None = None
prompts: dict[str, Any] | None = None
class InitializeRequestParams(RequestParams):
protocolVersion: str | int
capabilities: ClientCapabilities
clientInfo: Implementation
model_config = ConfigDict(extra="allow")
class InitializeRequest(JSONRPCRequest):
method: str = Field(default="initialize", frozen=True)
params: dict[str, Any] | None = None
class InitializeResult(BaseModel):
protocolVersion: str
capabilities: ServerCapabilities
serverInfo: Implementation
instructions: str | None = None
class InitializedNotification(
Notification[NotificationParams | None, Literal["notifications/initialized"]]
):
method: Literal["notifications/initialized"]
params: NotificationParams | None = None
model_config = ConfigDict(extra="allow")
class PingRequest(JSONRPCRequest):
method: str = Field(default="ping", frozen=True)
params: dict[str, Any] | None = None
class ProgressNotificationParams(NotificationParams):
progressToken: ProgressToken
progress: float
total: float | None = None
model_config = ConfigDict(extra="allow")
class ProgressNotification(JSONRPCMessage):
method: str = Field(default="progress", frozen=True)
params: dict[str, Any]
class PingResponse(JSONRPCResponse):
result: dict[str, Any] = Field(default_factory=lambda: {"pong": True})
class ShutdownRequest(JSONRPCRequest):
method: str = Field(default="shutdown", frozen=True)
params: dict[str, Any] | None = None
class ShutdownResponse(JSONRPCResponse):
result: dict[str, Any] = Field(default_factory=lambda: {"ok": True})
class CancelRequest(JSONRPCRequest):
method: str = Field(default="$/cancelRequest", frozen=True)
params: dict[str, Any]
class InitializeResponse(JSONRPCResponse):
"""
Response to an initialize request.
Note: This must be a properly formatted JSON-RPC response with a `result` field
containing the initialization data, not another request.
"""
result: InitializeResult
def model_dump_json(self, **kwargs: Any) -> str:
"""Convert to JSON string with proper formatting."""
# Convert to dict
data = {
"jsonrpc": self.jsonrpc,
"id": self.id,
"result": self.result.model_dump(exclude_none=True),
}
# Return JSON string
return json.dumps(data, ensure_ascii=False)
class ListToolsRequest(JSONRPCRequest):
method: str = Field(default="tools/list", frozen=True)
params: dict[str, Any] | None = None
class ToolAnnotations(BaseModel):
"""
Represents tool annotations for hints about behavior.
"""
title: str | None = None
readOnlyHint: bool | None = None
destructiveHint: bool | None = None
idempotentHint: bool | None = None
openWorldHint: bool | None = None
model_config = ConfigDict(extra="allow")
class Tool(BaseModel):
"""
Represents an MCP tool definition.
"""
name: str
description: str
inputSchema: dict[str, Any] | None = None
annotations: ToolAnnotations | None = None
model_config = ConfigDict(extra="allow")
class ListToolsResult(BaseModel):
tools: list[Tool]
class ListToolsResponse(JSONRPCResponse):
result: ListToolsResult
class CallToolRequest(JSONRPCRequest):
method: str = Field(default="tools/call", frozen=True)
params: dict[str, Any]
class CallToolResult(BaseModel):
content: Any
class CallToolResponse(JSONRPCResponse):
result: CallToolResult
# Resource and Prompt protocol stubs (expand as needed)
class ListResourcesRequest(JSONRPCRequest):
method: str = Field(default="resources/list", frozen=True)
params: dict[str, Any] | None = None
class ListResourcesResponse(JSONRPCResponse):
result: dict[str, Any]
class ListPromptsRequest(JSONRPCRequest):
method: str = Field(default="prompts/list", frozen=True)
params: dict[str, Any] | None = None
class ListPromptsResponse(JSONRPCResponse):
result: dict[str, Any]
# Utility type alias for all MCP protocol messages
MCPMessage = Union[
JSONRPCRequest,
JSONRPCResponse,
JSONRPCError,
PingRequest,
PingResponse,
InitializeRequest,
InitializeResponse,
ListToolsRequest,
ListToolsResponse,
CallToolRequest,
CallToolResponse,
ProgressNotification,
CancelRequest,
ShutdownRequest,
ShutdownResponse,
ListResourcesRequest,
ListResourcesResponse,
ListPromptsRequest,
ListPromptsResponse,
]

View file

@ -7,7 +7,7 @@ from pathlib import Path
import pytest
from arcade.worker.config.deployment import (
from arcade.cli.deployment import (
Config,
Deployment,
LocalPackages,

View file

@ -0,0 +1,44 @@
import json
from typing import Annotated
from arcade.core.catalog import ToolCatalog
from arcade.sdk import tool
from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool
@tool
def sample_tool(x: Annotated[int, "first"], y: Annotated[int, "second"]) -> int:
"""Return x+y"""
return x + y
def test_convert_to_mcp_content_primitives():
assert convert_to_mcp_content(42) == [{"type": "text", "text": "42"}]
assert convert_to_mcp_content("hello") == [{"type": "text", "text": "hello"}]
assert convert_to_mcp_content(True) == [{"type": "text", "text": "True"}]
def test_convert_to_mcp_content_complex():
data = {"a": 1}
expected_json = json.dumps(data)
assert convert_to_mcp_content(data) == [{"type": "text", "text": expected_json}]
def test_create_mcp_tool():
# Materialize a tool via catalog then feed it to create_mcp_tool
catalog = ToolCatalog()
catalog.add_tool(sample_tool, "convert_toolkit")
mat_tool = next(iter(catalog)) # only tool
mcp_tool = create_mcp_tool(mat_tool)
assert mcp_tool is not None
assert mcp_tool["name"] == "ConvertToolkit_SampleTool"
assert mcp_tool["description"]
# Ensure input schema contains both parameters and marks them required
props = mcp_tool["inputSchema"]["properties"]
assert set(props.keys()) == {"x", "y"}
required_fields = set(mcp_tool["inputSchema"].get("required", []))
# Ensure no unexpected required fields and that declared ones are subset of expected
assert required_fields.issubset({"x", "y"})

View file

@ -0,0 +1,57 @@
import asyncio
import pytest
from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor
from arcade.worker.mcp.types import InitializeRequest, PingRequest
@pytest.mark.asyncio
async def test_message_processor_parses_initialize_json():
"""Ensure JSON initialize strings are converted into InitializeRequest objects."""
json_init = '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}\n'
processor = MCPMessageProcessor()
result = await processor.process_request(json_init)
assert isinstance(result, InitializeRequest)
assert result.id == 1
assert result.method == "initialize"
@pytest.mark.asyncio
async def test_message_processor_passes_notifications_unchanged():
"""Unknown notifications should be passed through as parsed dictionaries without errors."""
json_notification = '{"jsonrpc":"2.0","id":null,"method":"notifications/custom","params":{}}\n'
processor = MCPMessageProcessor()
result = await processor.process_request(json_notification)
# The MCPMessageProcessor keeps unknown notifications as simple dicts
assert isinstance(result, dict)
assert result["method"] == "notifications/custom"
@pytest.mark.asyncio
async def test_message_processor_middleware_execution_order(monkeypatch):
"""Middleware (sync + async) should be executed in the order they were added."""
order: list[str] = []
def mw_sync(msg, direction): # type: ignore[return-value]
order.append("sync")
return msg
async def mw_async(msg, direction): # type: ignore[return-value]
await asyncio.sleep(0) # ensure it is truly async
order.append("async")
return msg
processor = create_message_processor(mw_sync, mw_async)
# Use a pre-parsed PingRequest instance so we don't test parsing again here
ping = PingRequest(id=42)
_ = await processor.process_request(ping)
assert order == ["sync", "async"]

View file

@ -0,0 +1,153 @@
import sys
import types
from typing import Annotated, Any
import pytest
from arcade.core.catalog import ToolCatalog
from arcade.sdk import tool
from arcade.worker.mcp import server as mcp_server
from arcade.worker.mcp.types import (
CallToolRequest,
CancelRequest,
InitializeRequest,
ListToolsRequest,
PingRequest,
)
# ---------------------------------------------------------------------------
# Test helpers / stubs
# ---------------------------------------------------------------------------
class _FakeAuth:
async def authorize(self, auth_requirement: Any, user_id: str):
"""Return an object that mimics AuthorizationResponse with completed status."""
class _Ctx: # minimal stub
token = "dummy-token" # noqa: S105
class _Resp: # pylint: disable=too-few-public-methods
status = "completed"
url = ""
context = _Ctx()
return _Resp()
class _FakeArcade: # pylint: disable=too-few-public-methods
def __init__(self, **_: Any):
self.auth = _FakeAuth()
# Ensure that the AsyncArcade & ArcadeError symbols inside server.py point to our stubs.
pytestmark = pytest.mark.asyncio
@pytest.fixture(autouse=True)
def _patch_arcadepy(monkeypatch):
"""Patch the external `arcadepy` dependency used by mcp.server."""
# Patch the imported symbols on the already-imported server module
monkeypatch.setattr(mcp_server, "AsyncArcade", _FakeArcade, raising=True)
monkeypatch.setattr(mcp_server, "ArcadeError", Exception, raising=True)
# Provide a dummy `arcadepy` module in sys.modules for any other importers
fake_arcadepy = types.ModuleType("arcadepy")
fake_arcadepy.AsyncArcade = _FakeArcade # type: ignore[attr-defined]
fake_arcadepy.ArcadeError = Exception # type: ignore[attr-defined]
sys.modules["arcadepy"] = fake_arcadepy
yield
# Cleanup
sys.modules.pop("arcadepy", None)
# ---------------------------------------------------------------------------
# Fixtures for a sample tool / catalog / server
# ---------------------------------------------------------------------------
@tool
def multiply(a: Annotated[int, "a"], b: Annotated[int, "b"]) -> Annotated[int, "result"]:
"""Return the product of *a* and *b*."""
return a * b
@pytest.fixture(scope="module")
def sample_catalog():
catalog = ToolCatalog()
catalog.add_tool(multiply, "test_toolkit")
return catalog
@pytest.fixture()
def server(sample_catalog):
# MCPServer constructor is synchronous, so fixture need not be async
return mcp_server.MCPServer(sample_catalog, enable_logging=False)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_handle_ping(server):
req = PingRequest(id=123)
resp = await server._handle_ping(req) # pylint: disable=protected-access
assert resp.id == 123
assert resp.result == {"pong": True}
async def test_handle_initialize(server):
req = InitializeRequest(id=1)
resp = await server._handle_initialize(req) # pylint: disable=protected-access
assert resp.id == 1
assert resp.result.protocolVersion == mcp_server.MCP_PROTOCOL_VERSION
assert resp.result.serverInfo.name.startswith("Arcade")
async def test_handle_list_tools(server):
req = ListToolsRequest(id=99)
resp = await server._handle_list_tools(req) # pylint: disable=protected-access
assert resp.id == 99
# Should list our sample tool only
tool_names = [t.name for t in resp.result.tools]
assert "TestToolkit_Multiply" in tool_names # toolkit + "_" + tool
async def test_handle_call_tool_success(server):
req = CallToolRequest(
id="call-1",
params={
"name": "TestToolkit_Multiply",
"input": {"a": 6, "b": 7},
},
)
resp = await server._handle_call_tool(req, user_id="tester@example.com") # pylint: disable=protected-access
assert resp.id == "call-1"
# convert_to_mcp_content wraps primitives in list-of-dicts
assert resp.result.content == [{"type": "text", "text": "42"}]
async def test_send_response_dict(server, monkeypatch):
"""_send_response should JSON-serialize plain dictionaries."""
sent: list[str] = []
class _Write:
async def send(self, msg):
sent.append(msg)
await server._send_response(_Write(), {"foo": "bar"}) # pylint: disable=protected-access
assert sent and sent[0].strip() == '{"foo": "bar"}'
async def test_handle_cancel(server):
req = CancelRequest(id=77, params={"id": "abc"})
resp = await server._handle_cancel(req) # pylint: disable=protected-access
assert resp.result == {"ok": True}

View file

@ -0,0 +1,32 @@
import io
import queue
from arcade.worker.mcp.stdio import stdio_reader, stdio_writer
def test_stdio_reader_puts_lines_and_none():
q: queue.Queue[str | None] = queue.Queue()
test_input = io.StringIO("line1\nline2\n")
stdio_reader(test_input, q)
# We should get the two lines followed by None sentinel
assert q.get_nowait() == "line1\n"
assert q.get_nowait() == "line2\n"
assert q.get_nowait() is None
def test_stdio_writer_reads_until_none():
q: queue.Queue[str | None] = queue.Queue()
output_stream = io.StringIO()
# preload queue with two messages and sentinel
q.put("msg1")
q.put("msg2\n")
q.put(None)
stdio_writer(output_stream, q)
# Ensure writer appended newlines when missing
output_stream.seek(0)
assert output_stream.read() == "msg1\nmsg2\n"

View file

@ -0,0 +1,237 @@
import os
from typing import Annotated
from unittest.mock import MagicMock
import pytest
from arcade.core.errors import ToolDefinitionError
from arcade.core.schema import (
ToolCallRequest,
ToolCallResponse,
ToolContext,
ToolReference,
)
from arcade.sdk import tool
from arcade.worker.core.base import BaseWorker
from arcade.worker.core.common import RequestData, Router
from arcade.worker.core.components import (
CallToolComponent,
CatalogComponent,
HealthCheckComponent,
)
@tool()
def sample_tool(
context: ToolContext, a: Annotated[int, "a"], b: Annotated[int, "b"]
) -> Annotated[int, "output"]:
"""Sample tool for testing."""
return a + b
# Define error tool at module level to avoid indentation issues with getsource
@tool()
def error_tool(context: ToolContext) -> int:
"""This tool always raises an error."""
raise ValueError("Something went wrong")
@pytest.fixture
def mock_router():
router = MagicMock(spec=Router)
router.add_route = MagicMock()
return router
@pytest.fixture
def base_worker(mock_router):
# Set env var temporarily for testing secret loading
os.environ["ARCADE_WORKER_SECRET"] = "test_secret_env" # noqa: S105
worker = BaseWorker()
worker.register_routes(mock_router) # Register routes using the mock router
# Clean up env var
del os.environ["ARCADE_WORKER_SECRET"]
return worker
@pytest.fixture
def base_worker_no_auth():
return BaseWorker(disable_auth=True)
# --- BaseWorker Tests ---
def test_base_worker_init_with_secret():
worker = BaseWorker(secret="explicit_secret") # noqa: S106
assert worker.secret == "explicit_secret" # noqa: S105
assert not worker.disable_auth
def test_base_worker_init_with_env_secret():
os.environ["ARCADE_WORKER_SECRET"] = "env_secret_value" # noqa: S105
worker = BaseWorker()
assert worker.secret == "env_secret_value" # noqa: S105
assert not worker.disable_auth
del os.environ["ARCADE_WORKER_SECRET"]
def test_base_worker_init_no_secret_raises_error():
# Ensure env var is not set
if "ARCADE_WORKER_SECRET" in os.environ:
del os.environ["ARCADE_WORKER_SECRET"]
with pytest.raises(ValueError, match="No secret provided for worker"):
BaseWorker()
def test_base_worker_init_disable_auth():
worker = BaseWorker(disable_auth=True)
assert worker.secret == ""
assert worker.disable_auth
def test_register_tool(base_worker_no_auth):
assert len(base_worker_no_auth.catalog) == 0
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
assert len(base_worker_no_auth.catalog) == 1
tool_def = base_worker_no_auth.get_catalog()[0]
assert tool_def.name == "SampleTool"
assert tool_def.toolkit.name == "TestKit"
def test_get_catalog(base_worker_no_auth):
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
catalog = base_worker_no_auth.get_catalog()
assert isinstance(catalog, list)
assert len(catalog) == 1
assert catalog[0].name == "SampleTool"
def test_health_check(base_worker_no_auth):
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
health = base_worker_no_auth.health_check()
assert health == {"status": "ok", "tool_count": "1"}
@pytest.mark.asyncio
async def test_call_tool_success(base_worker_no_auth):
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
# Create ToolReference WITHOUT version, as register_tool doesn't seem to set it
tool_ref = ToolReference(toolkit="TestKit", name="SampleTool")
tool_request = ToolCallRequest(
execution_id="test_exec_id",
tool=tool_ref,
inputs={"a": 5, "b": 3},
)
response = await base_worker_no_auth.call_tool(tool_request)
assert response.success is True
assert response.output.value == 8
assert response.output.error is None
assert response.execution_id == "test_exec_id"
assert response.duration > 0
@pytest.mark.asyncio
async def test_call_tool_execution_error(base_worker_no_auth):
# Tool is now defined at module level
try:
base_worker_no_auth.register_tool(error_tool, toolkit_name="error_kit")
except ToolDefinitionError as e:
pytest.fail(f"Failed to register error_tool: {e}")
# Create ToolReference WITHOUT version
tool_ref = ToolReference(toolkit="ErrorKit", name="ErrorTool")
tool_request = ToolCallRequest(
execution_id="test_exec_error",
tool=tool_ref,
inputs={},
)
response = await base_worker_no_auth.call_tool(tool_request)
assert response.success is False
assert response.output.value is None
assert response.output.error is not None
@pytest.mark.asyncio
async def test_call_tool_not_found(base_worker_no_auth):
# Use ToolReference without version for lookup consistency
tool_ref = ToolReference(toolkit="nonexistent", name="nosuchtool")
tool_request = ToolCallRequest(
execution_id="test_exec_notfound",
tool=tool_ref,
inputs={},
)
# Update regex to match actual error format
with pytest.raises(ValueError):
await base_worker_no_auth.call_tool(tool_request)
# --- Component Tests (tested via BaseWorker registration) ---
def test_register_routes_registers_default_components(base_worker, mock_router):
# BaseWorker calls register_routes in its init via the fixture
assert mock_router.add_route.call_count == len(BaseWorker.default_components)
calls = mock_router.add_route.call_args_list
expected_paths = ["tools", "tools/invoke", "health"]
registered_paths = [
call[0][0] for call in calls
] # call[0] are positional args, call[0][0] is endpoint_path
assert sorted(registered_paths) == sorted(expected_paths)
# Check if components were instantiated and passed to add_route
assert any(isinstance(call[0][1], CatalogComponent) for call in calls)
assert any(isinstance(call[0][1], CallToolComponent) for call in calls)
assert any(isinstance(call[0][1], HealthCheckComponent) for call in calls)
@pytest.mark.asyncio
async def test_catalog_component_call(base_worker_no_auth):
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
component = CatalogComponent(base_worker_no_auth)
# Mock request data - not actually used by this component's __call__
mock_request = MagicMock(spec=RequestData)
catalog_response = await component(mock_request)
assert isinstance(catalog_response, list)
assert len(catalog_response) == 1
assert catalog_response[0].name == "SampleTool"
@pytest.mark.asyncio
async def test_call_tool_component_call(base_worker_no_auth):
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
component = CallToolComponent(base_worker_no_auth)
# Create ToolReference WITHOUT version
tool_ref = ToolReference(toolkit="TestKit", name="SampleTool")
request_body = {
"execution_id": "comp_test_exec",
"tool": tool_ref.model_dump(),
"inputs": {"a": 10, "b": 5},
}
mock_request = MagicMock(spec=RequestData)
mock_request.body_json = request_body
response = await component(mock_request)
assert isinstance(response, ToolCallResponse)
assert response.success is True
assert response.output.value == 15
assert response.execution_id == "comp_test_exec"
@pytest.mark.asyncio
async def test_health_check_component_call(base_worker_no_auth):
component = HealthCheckComponent(base_worker_no_auth)
mock_request = MagicMock(spec=RequestData)
health_response = await component(mock_request)
assert health_response == {"status": "ok", "tool_count": "0"}

View file

@ -0,0 +1,167 @@
from typing import Annotated
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from arcade.core.schema import ToolCallRequest, ToolContext, ToolReference
from arcade.sdk import tool
from arcade.worker.fastapi.worker import FastAPIWorker
@tool()
def sample_tool_fastapi(
context: ToolContext, x: Annotated[int, "x"], y: Annotated[str, "y"]
) -> Annotated[str, "output"]:
"""A sample tool for FastAPI tests."""
return f"{y}-{x}"
# Define tool at module level to avoid indentation issues with getsource
@tool()
def error_throwing_tool(
context: ToolContext,
a: Annotated[int, "a", "Input integer a"], # Added description for parameter
) -> int:
"""This tool throws a ValueError.""" # Added description for tool
raise ValueError("Test execution error")
@pytest.fixture
def test_app():
return FastAPI()
@pytest.fixture
def worker_secret():
return "test-secret-fastapi"
@pytest.fixture
def fastapi_worker(test_app, worker_secret):
worker = FastAPIWorker(app=test_app, secret=worker_secret)
worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit")
return worker
@pytest.fixture
def fastapi_worker_no_auth(test_app):
worker = FastAPIWorker(app=test_app, disable_auth=True)
worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit")
return worker
@pytest.fixture
def client(test_app, fastapi_worker): # Use the worker fixture to ensure routes are registered
return TestClient(test_app)
@pytest.fixture
def client_no_auth(test_app, fastapi_worker_no_auth):
return TestClient(test_app)
# --- FastAPIWorker Tests ---
def test_fastapi_worker_registers_routes(client, fastapi_worker):
# Check if routes exist by trying to access them (even if auth fails)
response = client.get(f"{fastapi_worker.base_path}/health")
assert response.status_code != 404 # Should be 200
response = client.get(f"{fastapi_worker.base_path}/tools")
assert response.status_code != 404 # Should be 403 without auth
# Prepare a dummy request body for invoke
tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi")
request_body = ToolCallRequest(
execution_id="test", tool=tool_ref, inputs={"x": 1, "y": "test"}
).model_dump()
response = client.post(f"{fastapi_worker.base_path}/tools/invoke", json=request_body)
assert response.status_code != 404 # Should be 403 without auth
# --- Route Tests (using TestClient) ---
# Health Check
def test_health_check_route(client, worker_secret):
response = client.get("/worker/health")
assert response.status_code == 200
assert response.json() == {"status": "ok", "tool_count": "1"}
def test_health_check_route_no_auth(client_no_auth):
response = client_no_auth.get("/worker/health")
assert response.status_code == 200
assert response.json() == {"status": "ok", "tool_count": "1"}
# Catalog
def test_get_catalog_route_no_auth_header(client):
response = client.get("/worker/tools")
assert response.status_code == 403
assert "Not authenticated" in response.text
def test_get_catalog_route_invalid_auth_header(client, worker_secret):
response = client.get("/worker/tools", headers={"Authorization": "Bearer invalid-token"})
assert response.status_code == 401 # Unauthorized
# Updated expected error message based on last run
assert "Invalid token. Error: Not enough segments" in response.text
def test_get_catalog_route_no_auth_worker(client_no_auth):
response = client_no_auth.get("/worker/tools")
assert response.status_code == 200
catalog = response.json()
assert isinstance(catalog, list)
assert len(catalog) == 1
assert catalog[0]["name"] == "SampleToolFastapi"
# Call Tool
@pytest.fixture
def call_tool_payload():
tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi")
return ToolCallRequest(
execution_id="fastapi-test-exec", tool=tool_ref, inputs={"x": 123, "y": "hello"}
).model_dump()
def test_call_tool_route_no_auth_header(client, call_tool_payload):
response = client.post("/worker/tools/invoke", json=call_tool_payload)
assert response.status_code == 403
def test_call_tool_route_invalid_auth_header(client, worker_secret, call_tool_payload):
response = client.post(
"/worker/tools/invoke",
json=call_tool_payload,
headers={"Authorization": "Bearer invalid-token"},
)
assert response.status_code == 401
def test_call_tool_route_no_auth_worker(client_no_auth, call_tool_payload):
response = client_no_auth.post("/worker/tools/invoke", json=call_tool_payload)
assert response.status_code == 200
result = response.json()
assert result["success"] is True
assert result["output"]["value"] == "hello-123"
def test_call_tool_route_tool_not_found(client_no_auth, call_tool_payload):
call_tool_payload["tool"]["name"] = "NonExistentTool"
call_tool_payload["tool"]["toolkit"] = "FastapiKit"
with pytest.raises(ValueError):
_ = client_no_auth.post(
"/worker/tools/invoke",
json=call_tool_payload,
)
# The handler catches the ValueError and returns a 500 internal server error
# Ideally, this might be a 404 or 400, but BaseWorker.call_tool raises ValueError
# which isn't automatically mapped to a 4xx by FastAPI unless handled explicitly.
# TODO fix this.

8
examples/mcp/claude.json Normal file
View file

@ -0,0 +1,8 @@
{
"mcpServers": {
"arcade": {
"command": "bash",
"args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade mcp"]
}
}
}

25
examples/mcp/run_stdio.py Normal file
View file

@ -0,0 +1,25 @@
import arcade_google # pip install arcade_google
import arcade_search # pip install arcade_search
from arcade.core.catalog import ToolCatalog
from arcade.worker.mcp.stdio import StdioServer
# 2. Create and populate the tool catalog
catalog = ToolCatalog()
catalog.add_module(arcade_google) # Registers all tools in the package
catalog.add_module(arcade_search)
# 3. Main entrypoint
async def main():
# Create the worker with the tool catalog
worker = StdioServer(catalog)
# Run the worker
await worker.run()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,76 +0,0 @@
"""
Example script demonstrating how to build a simple chatbot with Arcade.
For this example, we are using the prebuilt Google Docs toolkit to create and edit documents.
Try asking questions like:
- "Create a document with the title 'My New Document' and content 'Hello, World!'"
- "List my 2 most recently modified documents and tell me the title, document id, and document URL of each one and summarize them."
- "Edit the second document from the list you just returned and add the text 'Hello, World!' to the end of it."
"""
import os
from openai import OpenAI
def chat(openai_client: OpenAI, tool_names: list[str], user_id: str) -> None:
history = []
print("Hello! How can I help you today?")
while True:
message = {"role": "user", "content": input(">")}
history.append(message)
chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history)
# If the tool call requires authorization, then wait for the user to authorize and then call the tool again
if (
chat_result.choices[0].tool_authorizations
and chat_result.choices[0].tool_authorizations[0].get("status") == "pending"
):
print("\n" + chat_result.choices[0].message.content)
input("\nAfter you have authorized, press Enter to continue...")
chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history)
history.append({"role": "assistant", "content": chat_result.choices[0].message.content})
print(chat_result.choices[0].message.content)
def call_tool_with_openai(
client: OpenAI, tool_names: list[str], user_id: str, messages: list[dict]
) -> dict:
response = client.chat.completions.create(
messages=messages,
model="gpt-4o-mini",
user=user_id,
tools=tool_names,
tool_choice="generate",
)
return response
if __name__ == "__main__":
arcade_api_key = os.environ.get(
"ARCADE_API_KEY"
) # If you forget your Arcade API key, it is stored at ~/.arcade/credentials.yaml on `arcade login`
cloud_host = "https://api.arcade.dev/v1"
user_id = "user@example.com"
openai_client = OpenAI(
api_key=arcade_api_key,
base_url=cloud_host,
)
tool_names = [
"Google.SendEmail",
"Google.SendDraftEmail",
"Google.WriteDraftEmail",
"Google.UpdateDraftEmail",
"Google.ListDraftEmails",
"Google.ListEmailsByHeader",
"Google.ListEmails",
]
chat(openai_client, tool_names, user_id)