diff --git a/.vscode/settings.json b/.vscode/settings.json
index 791aba76..4e27f70a 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -10,7 +10,6 @@
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
- "source.fixAll": "explicit",
"source.organizeImports": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff"
diff --git a/arcade/arcade/worker/config/deployment.py b/arcade/arcade/cli/deployment.py
similarity index 100%
rename from arcade/arcade/worker/config/deployment.py
rename to arcade/arcade/cli/deployment.py
diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py
index 3375385b..65a8906a 100644
--- a/arcade/arcade/cli/main.py
+++ b/arcade/arcade/cli/main.py
@@ -24,6 +24,7 @@ from arcade.cli.constants import (
PROD_CLOUD_HOST,
PROD_ENGINE_HOST,
)
+from arcade.cli.deployment import Deployment
from arcade.cli.display import (
display_arcade_chat_header,
display_eval_results,
@@ -48,7 +49,6 @@ from arcade.cli.utils import (
version_callback,
)
from arcade.cli.worker import parse_deployment_response
-from arcade.worker.config.deployment import Deployment
cli = typer.Typer(
cls=OrderCommands,
@@ -61,7 +61,12 @@ cli = typer.Typer(
)
-cli.add_typer(worker.app, name="worker", help="Manage workers")
+cli.add_typer(
+ worker.app,
+ name="worker",
+ help="Manage deployments of tool servers (logs, list, etc)",
+ rich_help_panel="Deployment",
+)
console = Console()
@@ -192,7 +197,10 @@ def show(
show_logic(toolkit, tool, host, local, port, force_tls, force_no_tls, debug)
-@cli.command(help="Start Arcade Chat in the terminal", rich_help_panel="Launch")
+@cli.command(
+ help="Start a chat with a model in the terminal to test tools",
+ rich_help_panel="Tool Development",
+)
def chat(
model: str = typer.Option("gpt-4o", "-m", "--model", help="The model to use for prediction."),
stream: bool = typer.Option(
@@ -439,7 +447,7 @@ def evals(
@cli.command(
- help="Start Arcade Worker serving tools installed in the current python environment",
+ help="Start tool server worker with locally installed tools",
rich_help_panel="Launch",
)
def serve(
@@ -460,19 +468,38 @@ def serve(
otel_enable: bool = typer.Option(
False, "--otel-enable", help="Send logs to OpenTelemetry", show_default=True
),
+ mcp: bool = typer.Option(
+ False, "--mcp", help="Run as a local MCP server over stdio", show_default=True
+ ),
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
) -> None:
"""
Start a local Arcade Worker server.
"""
- workerup(host, port, disable_auth, otel_enable, debug)
+ from arcade.cli.serve import serve_default_worker
+
+ try:
+ serve_default_worker(
+ host,
+ port,
+ disable_auth=disable_auth,
+ enable_otel=otel_enable,
+ debug=debug,
+ mcp=mcp,
+ )
+ except KeyboardInterrupt:
+ typer.Exit()
+ except Exception as e:
+ error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}"
+ console.print(error_message, style="bold red")
+ typer.Exit(code=1)
-@cli.command(help="Launch Arcade Worker and Engine locally", rich_help_panel="Launch")
+@cli.command(help="Launch Arcade - requires 'arcade-engine'", rich_help_panel="Launch")
def dev(
- host: str = typer.Option("127.0.0.1", help="Host for the worker server.", show_default=True),
+ host: str = typer.Option("127.0.0.1", help="Host for the toolkit server.", show_default=True),
port: int = typer.Option(
- 8002, "-p", "--port", help="Port for the worker server.", show_default=True
+ 8002, "-p", "--port", help="Port for the toolkit server.", show_default=True
),
engine_config: str = typer.Option(
None, "-c", "--config", help="Path to the engine configuration file."
@@ -483,7 +510,7 @@ def dev(
debug: bool = typer.Option(False, "-d", "--debug", help="Show debug information"),
) -> None:
"""
- Start both the worker and engine servers.
+ Start both the toolkit server and engine servers.
"""
try:
start_servers(host, port, engine_config, engine_env=env_file, debug=debug)
@@ -493,8 +520,9 @@ def dev(
typer.Exit(code=1)
-# TODO: deprecate this next major version
-@cli.command(help="Start a local Arcade Worker server", rich_help_panel="Launch", hidden=True)
+@cli.command(
+ help="Start a server with locally installed Arcade tools", rich_help_panel="Launch", hidden=True
+)
def workerup(
host: str = typer.Option(
"127.0.0.1",
@@ -523,17 +551,21 @@ def workerup(
try:
serve_default_worker(
- host, port, disable_auth=disable_auth, enable_otel=otel_enable, debug=debug
+ host,
+ port,
+ disable_auth=disable_auth,
+ enable_otel=otel_enable,
+ debug=debug,
)
except KeyboardInterrupt:
typer.Exit()
except Exception as e:
- error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}"
+ error_message = f"❌ Failed to start Arcade Toolkit Server: {escape(str(e))}"
console.print(error_message, style="bold red")
typer.Exit(code=1)
-@cli.command(help="Deploy worker to Arcade Cloud", rich_help_panel="Deployment")
+@cli.command(help="Deploy toolkits to Arcade Cloud", rich_help_panel="Deployment")
def deploy(
deployment_file: str = typer.Option(
"worker.toml", "--deployment-file", "-d", help="The deployment file to deploy."
diff --git a/arcade/arcade/cli/serve.py b/arcade/arcade/cli/serve.py
index 8100a4a9..4c378b2b 100644
--- a/arcade/arcade/cli/serve.py
+++ b/arcade/arcade/cli/serve.py
@@ -1,70 +1,209 @@
import asyncio
import logging
import os
+import signal
import sys
+from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
+from importlib.metadata import version as get_pkg_version
+from pathlib import Path
from typing import Any
import fastapi
import uvicorn
from loguru import logger
+from rich.console import Console
+from arcade.cli.constants import ARCADE_CONFIG_PATH
+from arcade.cli.utils import (
+ build_tool_catalog,
+ discover_toolkits,
+ load_dotenv,
+ validate_and_get_config,
+)
from arcade.core.telemetry import OTELHandler
from arcade.sdk import Toolkit
from arcade.worker.fastapi.worker import FastAPIWorker
+console = Console(width=70, color_system="auto")
-class InterceptHandler(logging.Handler):
+
+def _run_mcp_stdio(
+ toolkits: list[Toolkit], *, logging_enabled: bool, env_file: str | None = None
+) -> None:
+ """Launch an MCP stdio server; blocks until it exits."""
+
+ from arcade.worker.mcp.stdio import StdioServer
+
+ # Load env vars before launching server (explicit path, config path, cwd)
+ if env_file:
+ load_dotenv(env_file, override=False)
+ else:
+ for candidate in [Path(ARCADE_CONFIG_PATH) / "arcade.env", Path.cwd() / "arcade.env"]:
+ if candidate.is_file():
+ load_dotenv(candidate, override=False)
+ break
+
+ # Set up middleware configuration for stdio mode
+ middleware_config = {
+ "stdio_mode": True, # Ensure logs go to stderr
+ }
+
+ catalog = build_tool_catalog(toolkits)
+ server = StdioServer(
+ catalog,
+ enable_logging=logging_enabled,
+ middleware_config=middleware_config,
+ )
+
+ try:
+ asyncio.run(server.run())
+ except KeyboardInterrupt:
+ logger.info("MCP server stopped by user.")
+ except Exception as exc:
+ logger.exception("Error while running MCP server: %s", exc)
+ raise
+
+
+def _run_fastapi_server(
+ app: fastapi.FastAPI,
+ *,
+ host: str,
+ port: int,
+ workers: int,
+ timeout_keep_alive: int,
+ enable_otel: bool,
+ otel_handler: OTELHandler,
+ **uvicorn_kwargs: Any,
+) -> None:
+ """Run a FastAPI application via Uvicorn with graceful shutdown."""
+
+ class CustomUvicornServer(uvicorn.Server):
+ def install_signal_handlers(self) -> None:
+ # Disable Uvicorn's default signal handling; we manage it manually
+ pass
+
+ async def shutdown(self, sockets: Any = None) -> None:
+ logger.info("Initiating graceful shutdown...")
+ await super().shutdown(sockets=sockets)
+
+ config = uvicorn.Config(
+ app=app,
+ host=host,
+ port=port,
+ workers=workers,
+ timeout_keep_alive=timeout_keep_alive,
+ log_config=None,
+ **uvicorn_kwargs,
+ )
+
+ server = CustomUvicornServer(config=config)
+
+ async def _serve() -> None:
+ await server.serve()
+
+ async def _graceful_shutdown() -> None:
+ try:
+ logger.info("Shutting down server ...")
+ await server.shutdown()
+
+ # brief pause for connections to close gracefully
+ await asyncio.sleep(0.5)
+ finally:
+ if enable_otel:
+ otel_handler.shutdown()
+ logger.debug("Server shutdown complete.")
+
+ # Map signals to our graceful shutdown
+ loop = asyncio.get_event_loop()
+ for sig_name in (
+ "SIGINT",
+ "SIGTERM",
+ "SIGHUP",
+ "SIGUSR1",
+ "SIGUSR2",
+ "SIGWINCH",
+ "SIGBREAK",
+ ):
+ if hasattr(signal, sig_name):
+ loop.add_signal_handler(
+ getattr(signal, sig_name), lambda: asyncio.create_task(_graceful_shutdown())
+ )
+
+ try:
+ asyncio.run(_serve())
+ except KeyboardInterrupt:
+ logger.info("Server stopped by user.")
+ finally:
+ if enable_otel:
+ otel_handler.shutdown()
+
+
+class RichInterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
- # Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
- level = record.levelno # type: ignore[assignment]
+ level = str(record.levelno)
- # Find caller from where originated the logged message
- frame, depth = sys._getframe(6), 6
- while frame and frame.f_code.co_filename == logging.__file__:
- frame = frame.f_back # type: ignore[assignment]
- depth += 1
-
- logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
+ # Let Loguru handle caller info; don't do stack inspection here
+ logger.opt(exception=record.exc_info).log(level, record.getMessage())
-def setup_logging(log_level: int = logging.INFO) -> None:
+def setup_logging(log_level: int = logging.INFO, mcp_mode: bool = False) -> None:
# Intercept everything at the root logger
- logging.root.handlers = [InterceptHandler()]
+ logging.root.handlers = [RichInterceptHandler()]
logging.root.setLevel(log_level)
- # Remove every other logger's handlers
- # and propagate to root logger
+ # Remove every other logger's handlers and propagate to root logger
for name in logging.root.manager.loggerDict:
+ # Keep handlers for MCP logger if middleware handles it separately
+ if mcp_mode and name == "arcade.mcp":
+ continue
logging.getLogger(name).handlers = []
logging.getLogger(name).propagate = True
- # Configure loguru with custom format, no colors
+ # Remove default handlers from Loguru
+ logger.remove()
+
+ # Configure main Loguru sink
+ # In MCP mode, all general console logs go to stderr to keep stdout clean
+ sink_destination = sys.stderr if mcp_mode else sys.stdout
+
+ # Configure loguru with a cleaner format and colors
+ if log_level == logging.DEBUG:
+ format_string = "{level} | {time:HH:mm:ss} | {name}:{file}:{line: <4} | {message}"
+ else:
+ format_string = (
+ "{level} | {time:HH:mm:ss} | {message}"
+ )
logger.configure(
handlers=[
{
- "sink": sys.stdout,
- "serialize": False,
+ "sink": sink_destination, # Redirect sink based on mcp_mode
+ "colorize": True,
"level": log_level,
- "format": "{level} [{time:HH:mm:ss.SSS}] {message}"
- + (" {name}:{function}:{line}" if log_level <= logging.DEBUG else "")
- + ("\n{exception}" if "{exception}" in "{message}" else ""),
+ # Format that ensures timestamp on every line and better alignment
+ "format": format_string,
+ # Make sure multiline messages are handled properly
+ "enqueue": True,
+ "diagnose": True, # Disable traceback framing which adds noise
}
]
)
+ if mcp_mode:
+ logger.debug("Loguru sink configured for stderr in MCP mode.")
@asynccontextmanager
-async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def]
+async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
try:
yield
- except asyncio.CancelledError:
+ except (asyncio.CancelledError, KeyboardInterrupt):
# This is necessary to prevent an unhandled error
# when the user presses Ctrl+C
logger.debug("Lifespan cancelled.")
+ raise
def serve_default_worker(
@@ -75,76 +214,78 @@ def serve_default_worker(
timeout_keep_alive: int = 5,
enable_otel: bool = False,
debug: bool = False,
+ mcp: bool = False,
**kwargs: Any,
) -> None:
"""
- Get an instance of a FastAPI server with the Arcade Worker.
+ Get a default instance of a FastAPI server with the Arcade Worker
+ serving tools installed in the current Python environment.
+
+ Args:
+ host: The host to run the server on.
+ port: The port to run the server on.
+ disable_auth: Whether to disable authentication.
+ workers: The number of workers to run.
+ timeout_keep_alive: The timeout for keep-alive connections.
+ enable_otel: Whether to enable OpenTelemetry.
+ debug: Whether to enable debug logging.
+ mcp: Whether to run worker as MCP server over stdio.
"""
- # Setup unified logging
- setup_logging(log_level=logging.DEBUG if debug else logging.INFO)
- toolkits = Toolkit.find_all_arcade_toolkits()
- if not toolkits:
- raise RuntimeError("No toolkits found in Python environment.")
+ # Setup unified logging first
+ version = get_pkg_version("arcade-ai")
+ validate_and_get_config()
+ setup_logging(log_level=logging.DEBUG if debug else logging.INFO, mcp_mode=mcp)
- worker_secret = os.environ.get("ARCADE_WORKER_SECRET")
- if not disable_auth and not worker_secret:
- logger.warning(
- "Warning: ARCADE_WORKER_SECRET environment variable is not set. Using 'dev' as the worker secret.",
- )
- worker_secret = worker_secret or "dev"
+ toolkits = discover_toolkits()
+ logger.info("Serving the following toolkits:")
+ for toolkit in toolkits:
+ if debug:
+ for name, tools in toolkit.tools.items():
+ for tool in tools:
+ logger.info(f" - {name}: {tool}")
+ else:
+ logger.info(f" - {toolkit.name}: {len(toolkit.tools)} tools")
+ # --- MCP stdio --------------------------------------------------
+ if mcp:
+ env_file = kwargs.pop("env_file", None)
+ _run_mcp_stdio(toolkits, logging_enabled=not debug, env_file=env_file)
+ return
+
+ # --- FastAPI HTTP --------------------------------------------------
app = fastapi.FastAPI(
title="Arcade Worker",
- description="Arcade default Worker implementation using FastAPI.",
- version="0.1.0",
- lifespan=lifespan, # Use custom lifespan to catch errors, notably KeyboardInterrupt (Ctrl+C)
+ description="A worker for the Arcade platform",
+ version=version,
+ docs_url="/docs" if debug else None,
+ redoc_url="/redoc" if debug else None,
+ openapi_url="/openapi.json" if debug else None,
+ lifespan=lifespan,
)
- otel_handler = OTELHandler(app, enable=enable_otel)
+ secret = os.getenv("ARCADE_WORKER_SECRET", None)
+ if secret is None:
+ logger.warning("No secret found for Arcade Worker")
+ logger.info("Setting secret to 'dev'. Set this in production")
+ secret = "dev" # noqa: S105
- worker = FastAPIWorker(
- app, secret=worker_secret, disable_auth=disable_auth, otel_meter=otel_handler.get_meter()
+ otel_handler = OTELHandler(
+ app, enable=enable_otel, log_level=logging.DEBUG if debug else logging.INFO
)
-
- toolkit_tool_counts = {}
- for toolkit in toolkits:
- prev_tool_count = worker.catalog.get_tool_count()
- worker.register_toolkit(toolkit)
- new_tool_count = worker.catalog.get_tool_count()
- toolkit_tool_counts[f"{toolkit.name} ({toolkit.package_name})"] = (
- new_tool_count - prev_tool_count
- )
-
- logger.info("Serving the following toolkits:")
- for name, tool_count in toolkit_tool_counts.items():
- logger.info(f" - {name}: {tool_count} tools")
-
- logger.info("Starting FastAPI server...")
-
- class CustomUvicornServer(uvicorn.Server):
- def install_signal_handlers(self) -> None:
- pass # Disable Uvicorn's default signal handlers
-
- config = uvicorn.Config(
+ _ = FastAPIWorker(
app=app,
+ secret=secret,
+ disable_auth=disable_auth,
+ otel_meter=otel_handler.get_meter(),
+ )
+ _run_fastapi_server(
+ app,
host=host,
port=port,
workers=workers,
timeout_keep_alive=timeout_keep_alive,
- log_config=None,
+ enable_otel=enable_otel,
+ otel_handler=otel_handler,
**kwargs,
)
- server = CustomUvicornServer(config=config)
-
- async def serve() -> None:
- await server.serve()
-
- try:
- asyncio.run(serve())
- except KeyboardInterrupt:
- logger.info("Server stopped by user.")
- finally:
- if enable_otel:
- otel_handler.shutdown()
- logger.debug("Server shutdown complete.")
diff --git a/arcade/arcade/cli/utils.py b/arcade/arcade/cli/utils.py
index 524d220a..44dbdb41 100644
--- a/arcade/arcade/cli/utils.py
+++ b/arcade/arcade/cli/utils.py
@@ -1,5 +1,7 @@
import importlib.util
import ipaddress
+import os
+import shlex
import webbrowser
from dataclasses import dataclass
from datetime import datetime
@@ -34,6 +36,11 @@ from arcade.sdk import ToolCatalog, Toolkit
console = Console()
+# -----------------------------------------------------------------------------
+# Shared helpers for the CLI
+# -----------------------------------------------------------------------------
+
+
class OrderCommands(TyperGroup):
def list_commands(self, ctx: Context) -> list[str]: # type: ignore[override]
"""Return list of commands in the order appear."""
@@ -666,3 +673,81 @@ def get_today_context() -> str:
today = datetime.now().strftime("%Y-%m-%d")
day_of_week = datetime.now().strftime("%A")
return f"Today is {today}, {day_of_week}."
+
+
+def discover_toolkits() -> list[Toolkit]:
+ """Return all Arcade toolkits installed in the active Python environment.
+
+ Raises:
+ RuntimeError: If no toolkits are found, mirroring the behaviour of Toolkit discovery elsewhere.
+ """
+ toolkits = Toolkit.find_all_arcade_toolkits()
+ if not toolkits:
+ raise RuntimeError("No toolkits found in Python environment.")
+ return toolkits
+
+
+def build_tool_catalog(toolkits: list[Toolkit]) -> ToolCatalog:
+ """Construct a ``ToolCatalog`` populated with *toolkits*.
+
+
+ Args:
+ toolkits: Toolkits to register in the catalog.
+
+ Returns:
+ ToolCatalog
+ """
+ catalog = ToolCatalog()
+ for tk in toolkits:
+ catalog.add_toolkit(tk)
+ return catalog
+
+
+def _parse_line(line: str) -> tuple[str, str] | None:
+ """
+ Return (key, value) if the line looks like KEY=VALUE, else None.
+ Handles quotes and escaped chars via shlex.
+ """
+ if not line or line.startswith("#") or "=" not in line:
+ return None
+ key, raw_val = line.split("=", 1)
+ key = key.strip()
+ raw_val = raw_val.strip()
+
+ # Use shlex to handle "quoted strings with # hash" etc.
+ try:
+ value = shlex.split(raw_val)[0] if raw_val else ""
+ except ValueError:
+ # Fallback: naked value without shlex parsing
+ value = raw_val
+
+ return key, value
+
+
+def load_dotenv(path: str | Path, *, override: bool = False) -> dict[str, str]:
+ """
+ Load variables from *path* into os.environ.
+
+ Args:
+ path: .env file path
+ override: replace existing env vars if True
+
+ Returns:
+ The mapping of vars that were added/updated.
+ """
+ path = Path(path).expanduser()
+ if not path.is_file():
+ return {}
+
+ loaded: dict[str, str] = {}
+
+ for raw in path.read_text().splitlines():
+ parsed = _parse_line(raw.strip())
+ if parsed is None:
+ continue
+ k, v = parsed
+ if override or k not in os.environ:
+ os.environ[k] = v
+ loaded[k] = v
+
+ return loaded
diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py
index af21b3ba..5a27ee0e 100644
--- a/arcade/arcade/core/schema.py
+++ b/arcade/arcade/core/schema.py
@@ -326,8 +326,16 @@ class ToolContext(BaseModel):
for item in items:
if item.key.lower() == normalized_key:
return item.value
+
raise ValueError(f"{item_name.capitalize()} {key} not found in context.")
+ def set_secret(self, key: str, value: str) -> None:
+ """Set a secret for the tool invocation."""
+ if not self.secrets:
+ self.secrets = []
+ secret = ToolSecretItem(key=str(key), value=str(value))
+ self.secrets.append(secret)
+
class ToolCallRequest(BaseModel):
"""The request to call (invoke) a tool."""
diff --git a/arcade/arcade/core/utils.py b/arcade/arcade/core/utils.py
index a199f93d..d9c58696 100644
--- a/arcade/arcade/core/utils.py
+++ b/arcade/arcade/core/utils.py
@@ -1,14 +1,16 @@
+from __future__ import annotations
+
import ast
import inspect
import re
from collections.abc import Iterable
from types import UnionType
-from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin
+from typing import Any, Callable, Literal, TypeVar, Union, get_args, get_origin
T = TypeVar("T")
-def first_or_none(_type: type[T], iterable: Iterable[Any]) -> Optional[T]:
+def first_or_none(_type: type[T], iterable: Iterable[Any]) -> T | None:
"""
Returns the first item in the iterable that is an instance of the given type, or None if no such item is found.
"""
@@ -65,7 +67,7 @@ def does_function_return_value(func: Callable) -> bool:
Returns True if the given function returns a value, i.e. if it has a return statement with a value.
"""
try:
- source: Optional[str] = inspect.getsource(func)
+ source: str | None = inspect.getsource(func)
except OSError:
# Workaround for parameterized unit tests that use a dynamically-generated function
source = getattr(func, "__source__", None)
diff --git a/arcade/arcade/worker/core/base.py b/arcade/arcade/worker/core/base.py
index 5c307350..4416c24b 100644
--- a/arcade/arcade/worker/core/base.py
+++ b/arcade/arcade/worker/core/base.py
@@ -175,7 +175,7 @@ class BaseWorker(Worker):
"""
Provide a health check that serves as a heartbeat of worker health.
"""
- return {"status": "ok", "tool_count": len(self.catalog)}
+ return {"status": "ok", "tool_count": str(len(self.catalog))}
def register_routes(self, router: Router) -> None:
"""
diff --git a/arcade/arcade/worker/core/common.py b/arcade/arcade/worker/core/common.py
index 328b1d70..85a73b48 100644
--- a/arcade/arcade/worker/core/common.py
+++ b/arcade/arcade/worker/core/common.py
@@ -5,6 +5,11 @@ from pydantic import BaseModel
from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
+CatalogResponse = list[ToolDefinition]
+HealthCheckResponse = dict[str, str]
+JSONResponse = dict[str, Any]
+ResponseData = CatalogResponse | ToolCallResponse | HealthCheckResponse
+
class RequestData(BaseModel):
"""
@@ -17,7 +22,7 @@ class RequestData(BaseModel):
"""The path of the request."""
method: str
"""The method of the request."""
- body_json: dict | None = None
+ body_json: JSONResponse | None = None
"""The deserialized body of the request (e.g. JSON)"""
@@ -28,7 +33,13 @@ class Router(ABC):
@abstractmethod
def add_route(
- self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True
+ self,
+ endpoint_path: str,
+ handler: Callable,
+ method: str,
+ require_auth: bool = True,
+ response_type: type[ResponseData] | None = None,
+ **kwargs: Any,
) -> None:
"""
Add a route to the router.
@@ -43,7 +54,7 @@ class Worker(ABC):
"""
@abstractmethod
- def get_catalog(self) -> list[ToolDefinition]:
+ def get_catalog(self) -> CatalogResponse:
"""
Get the catalog of tools available in the worker.
"""
@@ -57,7 +68,7 @@ class Worker(ABC):
pass
@abstractmethod
- def health_check(self) -> dict[str, Any]:
+ def health_check(self) -> HealthCheckResponse:
"""
Perform a health check of the worker
"""
@@ -76,7 +87,7 @@ class WorkerComponent(ABC):
pass
@abstractmethod
- async def __call__(self, request: RequestData) -> Any:
+ async def __call__(self, request: RequestData) -> ResponseData:
"""
Handle the request.
"""
diff --git a/arcade/arcade/worker/core/components.py b/arcade/arcade/worker/core/components.py
index a9590f9e..490954a6 100644
--- a/arcade/arcade/worker/core/components.py
+++ b/arcade/arcade/worker/core/components.py
@@ -1,9 +1,15 @@
-from typing import Any
-
from opentelemetry import trace
-from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
-from arcade.worker.core.common import RequestData, Router, Worker, WorkerComponent
+from arcade.worker.core.common import (
+ CatalogResponse,
+ HealthCheckResponse,
+ RequestData,
+ Router,
+ ToolCallRequest,
+ ToolCallResponse,
+ Worker,
+ WorkerComponent,
+)
class CatalogComponent(WorkerComponent):
@@ -14,9 +20,18 @@ class CatalogComponent(WorkerComponent):
"""
Register the catalog route with the router.
"""
- router.add_route("tools", self, method="GET")
+ router.add_route(
+ "tools",
+ self,
+ method="GET",
+ response_type=CatalogResponse,
+ operation_id="get_catalog",
+ description="Get the catalog of tools",
+ summary="Get the catalog of tools",
+ tags=["Arcade"],
+ )
- async def __call__(self, request: RequestData) -> list[ToolDefinition]:
+ async def __call__(self, request: RequestData) -> CatalogResponse:
"""
Handle the request to get the catalog.
"""
@@ -33,7 +48,16 @@ class CallToolComponent(WorkerComponent):
"""
Register the call tool route with the router.
"""
- router.add_route("tools/invoke", self, method="POST")
+ router.add_route(
+ "tools/invoke",
+ self,
+ method="POST",
+ response_type=ToolCallResponse,
+ operation_id="call_tool",
+ description="Call a tool",
+ summary="Call a tool",
+ tags=["Arcade"],
+ )
async def __call__(self, request: RequestData) -> ToolCallResponse:
"""
@@ -54,9 +78,19 @@ class HealthCheckComponent(WorkerComponent):
"""
Register the health check route with the router.
"""
- router.add_route("health", self, method="GET", require_auth=False)
+ router.add_route(
+ "health",
+ self,
+ method="GET",
+ require_auth=False,
+ response_type=HealthCheckResponse,
+ operation_id="health_check",
+ description="Check the health of the worker",
+ summary="Check the health of the worker",
+ tags=["Arcade"],
+ )
- async def __call__(self, request: RequestData) -> dict[str, Any]:
+ async def __call__(self, request: RequestData) -> HealthCheckResponse:
"""
Handle the request for a health check.
"""
diff --git a/arcade/arcade/worker/fastapi/__init__.py b/arcade/arcade/worker/fastapi/__init__.py
index e69de29b..f3c85996 100644
--- a/arcade/arcade/worker/fastapi/__init__.py
+++ b/arcade/arcade/worker/fastapi/__init__.py
@@ -0,0 +1,3 @@
+from .worker import FastAPIWorker
+
+__all__ = ["FastAPIWorker"]
diff --git a/arcade/arcade/worker/fastapi/worker.py b/arcade/arcade/worker/fastapi/worker.py
index 17e92cf3..c455941e 100644
--- a/arcade/arcade/worker/fastapi/worker.py
+++ b/arcade/arcade/worker/fastapi/worker.py
@@ -9,7 +9,7 @@ from arcade.worker.core.base import (
BaseWorker,
Router,
)
-from arcade.worker.core.common import RequestData
+from arcade.worker.core.common import RequestData, ResponseData, WorkerComponent
from arcade.worker.fastapi.auth import validate_engine_request
from arcade.worker.utils import is_async_callable
@@ -30,12 +30,21 @@ class FastAPIWorker(BaseWorker):
"""
Initialize the FastAPIWorker with a FastAPI app instance.
If no secret is provided, the worker will use the ARCADE_WORKER_SECRET environment variable.
+
+ Args:
+ app: The FastAPI app to host the worker in
+ secret: Optional secret for authorization
+ disable_auth: Whether to disable authorization
+ otel_meter: Optional OpenTelemetry meter
"""
super().__init__(secret, disable_auth, otel_meter)
self.app = app
self.router = FastAPIRouter(app, self)
self.register_routes(self.router)
+ # Initialize components
+ self.components: list[WorkerComponent] = []
+
security = HTTPBearer() # Authorization: Bearer
@@ -81,7 +90,13 @@ class FastAPIRouter(Router):
return wrapped_handler
def add_route(
- self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True
+ self,
+ endpoint_path: str,
+ handler: Callable,
+ method: str,
+ require_auth: bool = True,
+ response_type: type[ResponseData] | None = None,
+ **kwargs: Any,
) -> None:
"""
Add a route to the FastAPI application.
@@ -90,4 +105,7 @@ class FastAPIRouter(Router):
f"{self.worker.base_path}/{endpoint_path}",
self._wrap_handler(handler, require_auth),
methods=[method],
+ response_model=response_type,
+ # **kwargs to pass to FastAPI
+ **kwargs,
)
diff --git a/arcade/arcade/worker/mcp/__init__.py b/arcade/arcade/worker/mcp/__init__.py
new file mode 100644
index 00000000..9761f4f2
--- /dev/null
+++ b/arcade/arcade/worker/mcp/__init__.py
@@ -0,0 +1,7 @@
+"""
+MCP (Model Context Protocol) support for Arcade workers.
+"""
+
+from arcade.worker.mcp.stdio import StdioServer
+
+__all__ = ["StdioServer"]
diff --git a/arcade/arcade/worker/mcp/convert.py b/arcade/arcade/worker/mcp/convert.py
new file mode 100644
index 00000000..59e10642
--- /dev/null
+++ b/arcade/arcade/worker/mcp/convert.py
@@ -0,0 +1,188 @@
+import json
+import logging
+from enum import Enum
+from typing import Any
+
+from arcade.core.catalog import MaterializedTool
+
+# Type aliases for MCP types
+MCPTool = dict[str, Any]
+MCPTextContent = dict[str, Any]
+MCPImageContent = dict[str, Any]
+MCPEmbeddedResource = dict[str, Any]
+MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource
+
+logger = logging.getLogger("arcade.mcp")
+
+
+def create_mcp_tool(tool: MaterializedTool) -> dict[str, Any] | None: # noqa: C901
+ """
+ Create an MCP-compatible tool definition from an Arcade tool.
+
+ Args:
+ tool: An Arcade tool object
+
+ Returns:
+ An MCP tool definition or None if the tool cannot be converted
+ """
+ try:
+ name = getattr(tool.definition, "fully_qualified_name", None) or getattr(
+ tool.definition, "name", "unknown"
+ )
+ description = getattr(tool.definition, "description", "No description available")
+
+ # Extract parameters from the input model
+ parameters = {}
+ required = []
+
+ if (
+ hasattr(tool, "input_model")
+ and tool.input_model is not None
+ and hasattr(tool.input_model, "model_fields")
+ ):
+ for field_name, field in tool.input_model.model_fields.items():
+ # Skip internal tool context parameters
+ if field_name == getattr(
+ tool.definition.input, "tool_context_parameter_name", None
+ ):
+ continue
+
+ # Get field type information
+ field_type = getattr(field, "annotation", None)
+ field_type_name = "string" # default
+
+ # Safety check for field_type
+ if field_type is int:
+ field_type_name = "integer"
+ elif field_type is float:
+ field_type_name = "number"
+ elif field_type is bool:
+ field_type_name = "boolean"
+ elif field_type is list or str(field_type).startswith("list["):
+ field_type_name = "array"
+ elif field_type is dict or str(field_type).startswith("dict["):
+ field_type_name = "object"
+
+ # Get description with fallback
+ field_description = getattr(field, "description", None)
+ if not field_description:
+ field_description = f"Parameter: {field_name}"
+
+ # Create parameter definition
+ param_def = {
+ "type": field_type_name,
+ "description": field_description,
+ }
+
+ # Enum support: if the field annotation is an Enum, add allowed values
+ enum_type = None
+ if hasattr(field, "annotation"):
+ ann = field.annotation
+ # Handle typing.Annotated[Enum, ...]
+ if getattr(ann, "__origin__", None) is not None and hasattr(ann, "__args__"):
+ for arg in ann.__args__: # type: ignore[union-attr]
+ if isinstance(arg, type) and issubclass(arg, Enum):
+ enum_type = arg
+ break
+ elif isinstance(ann, type) and issubclass(ann, Enum):
+ enum_type = ann
+ if enum_type is not None:
+ param_def["enum"] = [e.value for e in enum_type]
+
+ parameters[field_name] = param_def
+
+ # In Pydantic v2, check if field is required based on default value
+ try:
+ if field.is_required():
+ required.append(field_name)
+ except (AttributeError, TypeError):
+ # Fallback if is_required() doesn't exist or fails
+ try:
+ has_default = getattr(field, "default", None) is not None
+ has_factory = getattr(field, "default_factory", None) is not None
+ if not (has_default or has_factory):
+ required.append(field_name)
+ except Exception:
+ # Ultimate fallback - assume required if we can't determine
+ logger.debug(
+ f"Could not determine if field {field_name} is required, assuming optional"
+ )
+
+ # Create the input schema with explicit properties and required fields
+ input_schema = {
+ "type": "object",
+ "properties": parameters,
+ }
+
+ # Only include required field if we have required parameters
+ if required:
+ input_schema["required"] = required
+
+ # Add annotations based on tool metadata
+ annotations = {}
+
+ # Use tool name as title if available
+ annotations["title"] = getattr(tool.definition, "title", str(name).replace(".", "_"))
+
+ # Determine hints based on tool properties
+ if hasattr(tool.definition, "metadata"):
+ metadata = tool.definition.metadata or {}
+ annotations["readOnlyHint"] = metadata.get("read_only", False)
+ annotations["destructiveHint"] = metadata.get("destructive", False)
+ annotations["idempotentHint"] = metadata.get("idempotent", True)
+ annotations["openWorldHint"] = metadata.get("open_world", False)
+
+ # Create the final tool definition
+ tool_def: MCPTool = {
+ "name": str(name).replace(".", "_"),
+ "description": str(description),
+ "inputSchema": input_schema,
+ "annotations": annotations,
+ }
+
+ logger.debug(f"Created tool definition for {name}")
+
+ except Exception:
+ logger.exception(
+ f"Error creating MCP tool definition for {getattr(tool, 'name', str(tool))}"
+ )
+ return None
+ return tool_def
+
+
+def convert_to_mcp_content(value: Any) -> list[dict[str, Any]]:
+ """
+ Convert a Python value to MCP-compatible content.
+ """
+ if value is None:
+ return []
+
+ if isinstance(value, (str, bool, int, float)):
+ return [{"type": "text", "text": str(value)}]
+
+ if isinstance(value, (dict, list)):
+ return [{"type": "text", "text": json.dumps(value)}]
+
+ # Default fallback
+ return [{"type": "text", "text": str(value)}]
+
+
+def _map_type_to_json_schema_type(val_type: str) -> str:
+ """
+ Map Arcade value types to JSON schema types.
+
+ Args:
+ val_type: The Arcade value type as a string.
+
+ Returns:
+ The corresponding JSON schema type as a string.
+ """
+ mapping: dict[str, str] = {
+ "string": "string",
+ "integer": "integer",
+ "number": "number",
+ "boolean": "boolean",
+ "json": "object",
+ "array": "array",
+ }
+ return mapping.get(val_type, "string")
diff --git a/arcade/arcade/worker/mcp/logging.py b/arcade/arcade/worker/mcp/logging.py
new file mode 100644
index 00000000..0299d5e1
--- /dev/null
+++ b/arcade/arcade/worker/mcp/logging.py
@@ -0,0 +1,215 @@
+import json
+import logging
+import sys
+import time
+from typing import Any
+
+from arcade.worker.mcp.types import (
+ JSONRPCError,
+ JSONRPCRequest,
+ JSONRPCResponse,
+ MCPMessage,
+)
+
+logger = logging.getLogger("arcade.mcp")
+
+
+class MCPLoggingMiddleware:
+ """
+ Middleware for logging MCP requests and responses.
+ Logs request and response details, including timing and errors.
+ """
+
+ def __init__(
+ self,
+ log_level: str = "INFO",
+ log_request_body: bool = False,
+ log_response_body: bool = False,
+ log_errors: bool = True,
+ min_duration_to_log_ms: int = 0,
+ stdio_mode: bool = False,
+ ) -> None:
+ """
+ Initialize the MCP logging middleware.
+
+ Args:
+ log_level: Logging level (default: "INFO").
+ log_request_body: Whether to log full request bodies (default: False).
+ log_response_body: Whether to log full response bodies (default: False).
+ log_errors: Whether to log errors at ERROR level (default: True).
+ min_duration_to_log_ms: Minimum duration in ms to log (0 logs all).
+ stdio_mode: Whether running in stdio mode (redirects logs to stderr).
+ """
+ self.log_level = getattr(logging, log_level.upper())
+ self.log_request_body = log_request_body
+ self.log_response_body = log_response_body
+ self.log_errors = log_errors
+ self.min_duration_to_log_ms = min_duration_to_log_ms
+ self.request_log_format = "[MCP>] {method}{params_str} (id: {id})"
+ self.response_log_format = "[MCP<] {method} completed in {duration:.2f}ms (id: {id})"
+ self.error_log_format = "[MCP!] {method} error: {error} (id: {id})"
+
+ # If in stdio mode, ensure MCP logs go to stderr
+ if stdio_mode:
+ self._redirect_logs_to_stderr()
+
+ # Log that middleware is initialized
+ logger.debug(f"MCP logging middleware initialized (level: {log_level})")
+
+ def _redirect_logs_to_stderr(self) -> None:
+ """Redirect MCP logs to stderr to avoid interfering with stdio communication."""
+ # Remove any existing handlers
+ for handler in logger.handlers[:]:
+ logger.removeHandler(handler)
+
+ # Add a stderr handler
+ stderr_handler = logging.StreamHandler(sys.stderr)
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+ stderr_handler.setFormatter(formatter)
+ logger.addHandler(stderr_handler)
+
+ # Ensure we're not propagating to root logger which might log to stdout
+ logger.propagate = False
+
+ logger.debug("MCP logs redirected to stderr for stdio mode")
+
+ def __call__(self, message: MCPMessage, direction: str) -> MCPMessage:
+ """
+ Process and log an MCP message.
+
+ Args:
+ message: The MCP message to process.
+ direction: The message direction ("request" or "response").
+
+ Returns:
+ The original message (unmodified).
+ """
+ if direction == "request":
+ self._log_request(message)
+ else:
+ self._log_response(message)
+ return message
+
+ def _log_request(self, message: MCPMessage) -> None:
+ """
+ Log an MCP request message.
+ """
+ if not isinstance(message, JSONRPCRequest):
+ logger.debug(f"Ignoring non-request message: {type(message).__name__}")
+ return
+
+ try:
+ # Store request start time for duration calculation
+ message._mcp_start_time = time.time() # type: ignore[attr-defined]
+
+ # Format parameters for logging
+ params_str = ""
+ if self.log_request_body and hasattr(message, "params") and message.params is not None:
+ params_str = f": {self._format_params(message.params)}"
+
+ log_msg = self.request_log_format.format(
+ method=message.method, params_str=params_str, id=getattr(message, "id", "none")
+ )
+
+ logger.log(self.log_level, log_msg)
+ except Exception:
+ logger.exception("Error logging request")
+
+ def _log_response(self, message: MCPMessage) -> None:
+ """
+ Log an MCP response message.
+ """
+ if not isinstance(message, (JSONRPCResponse, JSONRPCError)):
+ logger.debug(f"Ignoring non-response message: {type(message).__name__}")
+ return
+
+ try:
+ # Calculate request duration if we have the start time
+ duration_ms = 0
+ request = getattr(message, "_request", None)
+ if request:
+ start_time = getattr(request, "_mcp_start_time", None)
+ if start_time:
+ duration_ms = (time.time() - start_time) * 1000
+ else:
+ start_time = getattr(message, "_mcp_start_time", None)
+ if start_time:
+ duration_ms = (time.time() - start_time) * 1000
+
+ # Skip if below minimum duration threshold
+ if self.min_duration_to_log_ms > 0 and duration_ms < self.min_duration_to_log_ms:
+ return
+
+ # Handle error responses
+ if hasattr(message, "error") and message.error is not None:
+ if self.log_errors:
+ error_msg = self.error_log_format.format(
+ method=getattr(message, "method", "unknown"),
+ error=getattr(message.error, "message", str(message.error)),
+ id=getattr(message, "id", "none"),
+ )
+ logger.error(error_msg)
+ return
+
+ # Log successful response
+ result_str = ""
+ if self.log_response_body and hasattr(message, "result"):
+ result_str = f": {self._format_result(message.result)}"
+
+ log_msg = self.response_log_format.format(
+ method=getattr(message, "method", "unknown"),
+ duration=duration_ms,
+ id=getattr(message, "id", "none"),
+ result_str=result_str,
+ )
+
+ logger.log(self.log_level, log_msg)
+ except Exception:
+ logger.exception("Error logging response")
+
+ def _format_params(self, params: Any) -> str:
+ """
+ Format parameters for logging.
+ """
+ try:
+ if isinstance(params, dict):
+ # Handle common MCP params specially
+ if "name" in params and "arguments" in params:
+ return f"{params['name']}({json.dumps(params.get('arguments', {}))})"
+ return json.dumps(params)
+ return str(params)
+ except Exception:
+ logger.debug(f"Error formatting params {params!s}")
+ return str(params)
+
+ def _format_result(self, result: Any) -> str:
+ """
+ Format result for logging.
+ """
+ try:
+ if isinstance(result, dict):
+ return json.dumps(result)
+ return str(result)
+ except Exception as e:
+ logger.debug(f"Error formatting result {e!s}")
+ return str(result)
+
+
+def create_mcp_logging_middleware(**config: Any) -> MCPLoggingMiddleware:
+ """
+ Create an MCP logging middleware with the given configuration.
+
+ Args:
+ **config: Configuration options.
+
+ Returns:
+ An MCPLoggingMiddleware instance.
+ """
+ return MCPLoggingMiddleware(
+ log_level=config.get("log_level", "INFO"),
+ log_request_body=config.get("log_request_body", False),
+ log_response_body=config.get("log_response_body", False),
+ log_errors=config.get("log_errors", True),
+ min_duration_to_log_ms=config.get("min_duration_to_log_ms", 0),
+ stdio_mode=config.get("stdio_mode", False),
+ )
diff --git a/arcade/arcade/worker/mcp/message_processor.py b/arcade/arcade/worker/mcp/message_processor.py
new file mode 100644
index 00000000..a2199bee
--- /dev/null
+++ b/arcade/arcade/worker/mcp/message_processor.py
@@ -0,0 +1,83 @@
+import inspect
+import json
+import logging
+from typing import Any, Callable, TypeVar
+
+from arcade.worker.mcp.types import InitializeRequest, JSONRPCRequest, MCPMessage
+
+logger = logging.getLogger("arcade.mcp")
+
+T = TypeVar("T")
+
+# Type definition for middleware functions
+MessageProcessor = Callable[[Any, str], Any]
+
+
+class MCPMessageProcessor:
+ """
+ Processes MCP messages through a chain of middleware.
+ Supports both synchronous and asynchronous middleware.
+ """
+
+ def __init__(self) -> None:
+ self.middleware: list[Callable[[MCPMessage, str], Any]] = []
+
+ def add_middleware(self, mw: Callable[[MCPMessage, str], Any]) -> None:
+ self.middleware.append(mw)
+
+ async def process(self, message: Any, direction: str) -> Any: # noqa: C901
+ # First, try to parse the message if it's a string
+ if isinstance(message, str):
+ # Strip any whitespace including newlines
+ message = message.strip()
+ if not message:
+ return None
+
+ try:
+ parsed = json.loads(message)
+ if isinstance(parsed, dict):
+ method = parsed.get("method")
+ # Convert to appropriate message type
+ if method == "initialize" and "id" in parsed:
+ logger.debug(f"Parsed initialize request: {parsed}")
+ message = InitializeRequest(**parsed)
+ elif method and method.startswith("notifications/"):
+ # It's a notification, log it but pass through as dict
+ logger.debug(f"Received notification: {method}")
+ # Keep as parsed dict to avoid validation errors on unknown notifications
+ message = parsed
+ elif "method" in parsed and "id" in parsed:
+ # Regular method request
+ logger.debug(f"Parsed method request: {method}")
+ message = JSONRPCRequest(**parsed)
+ # Other message types can be handled similarly
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse message as JSON: {message[:100]}...")
+ except Exception:
+ logger.exception("Error processing message")
+
+ # Process through middleware chain
+ result = message
+ for mw in self.middleware:
+ try:
+ if inspect.iscoroutinefunction(mw):
+ result = await mw(result, direction)
+ else:
+ result = mw(result, direction)
+ except Exception:
+ logger.exception(f"Error in middleware {mw}")
+ return result
+
+ async def process_request(self, message: Any) -> Any:
+ return await self.process(message, "request")
+
+ async def process_response(self, message: Any) -> Any:
+ return await self.process(message, "response")
+
+
+def create_message_processor(*middleware: MessageProcessor) -> MCPMessageProcessor:
+ processor = MCPMessageProcessor()
+ for m in middleware:
+ if m is not None:
+ processor.add_middleware(m)
+ return processor
diff --git a/arcade/arcade/worker/mcp/server.py b/arcade/arcade/worker/mcp/server.py
new file mode 100644
index 00000000..ae255b14
--- /dev/null
+++ b/arcade/arcade/worker/mcp/server.py
@@ -0,0 +1,601 @@
+import asyncio
+import logging
+import os
+import uuid
+from enum import Enum
+from typing import Any, Callable, Union
+
+from arcadepy import ArcadeError, AsyncArcade
+from arcadepy.types.auth_authorize_params import AuthRequirement, AuthRequirementOauth2
+from arcadepy.types.shared import AuthorizationResponse
+
+from arcade.core.catalog import MaterializedTool, ToolCatalog
+from arcade.core.executor import ToolExecutor
+from arcade.core.schema import ToolAuthorizationContext, ToolContext
+from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool
+from arcade.worker.mcp.logging import create_mcp_logging_middleware
+from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor
+from arcade.worker.mcp.types import (
+ CallToolRequest,
+ CallToolResponse,
+ CallToolResult,
+ CancelRequest,
+ Implementation,
+ InitializeRequest,
+ InitializeResponse,
+ InitializeResult,
+ JSONRPCError,
+ JSONRPCResponse,
+ ListPromptsRequest,
+ ListPromptsResponse,
+ ListResourcesRequest,
+ ListResourcesResponse,
+ ListToolsRequest,
+ ListToolsResponse,
+ ListToolsResult,
+ PingRequest,
+ PingResponse,
+ ProgressNotification,
+ ServerCapabilities,
+ ShutdownRequest,
+ ShutdownResponse,
+ Tool,
+)
+
+logger = logging.getLogger("arcade.mcp")
+
+MCP_PROTOCOL_VERSION = "2024-11-05"
+
+
+class MessageMethod(str, Enum):
+ """Enumeration of supported MCP message methods"""
+
+ PING = "ping"
+ INITIALIZE = "initialize"
+ LIST_TOOLS = "tools/list"
+ CALL_TOOL = "tools/call"
+ PROGRESS = "progress"
+ CANCEL = "$/cancelRequest"
+ SHUTDOWN = "shutdown"
+ LIST_RESOURCES = "resources/list"
+ LIST_PROMPTS = "prompts/list"
+
+
+class MCPServer:
+ """
+ Unified async MCP server that manages connections, middleware, and tool invocation.
+ Handles protocol-level messages (ping, initialize, list_tools, call_tool, etc.).
+ """
+
+ def __init__(
+ self,
+ tool_catalog: Any,
+ enable_logging: bool = True,
+ **client_kwargs: dict[str, Any],
+ ) -> None:
+ """
+ Initialize the MCP server.
+
+ Args:
+ tool_catalog: Catalog of available tools
+ **client_kwargs: Additional arguments to pass to the AsyncArcade client
+ """
+ self.tool_catalog: ToolCatalog = tool_catalog
+ self.message_processor: MCPMessageProcessor = create_message_processor()
+
+ # Pop middleware_config from client_kwargs regardless of logging state,
+ # as it's internal config not meant for AsyncArcade.
+ middleware_config = client_kwargs.pop("middleware_config", {})
+
+ if enable_logging:
+ # Create and add the logging middleware if logging is enabled.
+ # Note: enable_logging must be True for this middleware (and its stdio_mode behavior)
+ # to be activated.
+ self.message_processor.add_middleware(
+ create_mcp_logging_middleware(**middleware_config)
+ )
+
+ self._shutdown: bool = False
+ # Initialize AsyncArcade with the *remaining* client_kwargs
+ self.arcade = AsyncArcade(**client_kwargs) # type: ignore[arg-type]
+
+ # Initialize handler dispatch table
+ self._method_handlers: dict[str, Callable] = {
+ MessageMethod.PING: self._handle_ping,
+ MessageMethod.INITIALIZE: self._handle_initialize,
+ MessageMethod.LIST_TOOLS: self._handle_list_tools,
+ MessageMethod.CALL_TOOL: self._handle_call_tool,
+ MessageMethod.PROGRESS: self._handle_progress,
+ MessageMethod.CANCEL: self._handle_cancel,
+ MessageMethod.SHUTDOWN: self._handle_shutdown,
+ MessageMethod.LIST_RESOURCES: self._handle_list_resources,
+ MessageMethod.LIST_PROMPTS: self._handle_list_prompts,
+ }
+
+ async def run_connection(
+ self,
+ read_stream: Any,
+ write_stream: Any,
+ init_options: Any,
+ ) -> None:
+ """
+ Handle a single MCP connection (SSE or stdio).
+
+ Args:
+ read_stream: Async iterable yielding incoming messages.
+ write_stream: Object with an async send(message) method.
+ init_options: Initialization options for the connection.
+ """
+ # Generate a user ID if possible
+ user_id = self._get_user_id(init_options)
+
+ try:
+ logger.info(f"Starting MCP connection for user {user_id}")
+
+ async for message in read_stream:
+ # Process the message
+ response = await self.handle_message(message, user_id=user_id)
+
+ # Skip sending responses for None (e.g., notifications)
+ if response is None:
+ continue
+
+ await self._send_response(write_stream, response)
+
+ except asyncio.CancelledError:
+ logger.info("Connection cancelled")
+ except Exception:
+ logger.exception("Error in connection")
+
+ def _get_user_id(self, init_options: Any) -> str:
+ """
+ Get the user ID for a connection.
+
+ Args:
+ init_options: Initialization options for the connection
+
+ Returns:
+ A user ID string
+ """
+ try:
+ from arcade.core.config import config
+
+ # Prefer config.user.email if available
+ if config.user and config.user.email:
+ return config.user.email
+ except ValueError:
+ logger.debug("No logged in user for MCP Server")
+
+ fallback = str(uuid.uuid4())
+ if os.environ.get("ARCADE_USER_ID", None):
+ return os.environ.get("ARCADE_USER_ID", fallback)
+ elif isinstance(init_options, dict):
+ user_id = init_options.get("user_id")
+ if user_id:
+ return str(user_id)
+ # Fallback to random UUID
+ return str(fallback)
+
+ async def _send_response(self, write_stream: Any, response: Any) -> None:
+ """
+ Send a response to the client.
+
+ Args:
+ write_stream: Stream to write the response to
+ response: Response object to send
+ """
+ # Ensure the response is properly serialized to JSON
+ if hasattr(response, "model_dump_json"):
+ # It's a Pydantic model, serialize it
+ json_response = response.model_dump_json()
+ # Ensure it ends with a newline for JSON-RPC-over-stdio
+ if not json_response.endswith("\n"):
+ json_response += "\n"
+ logger.debug(f"Sending response: {json_response[:200]}...")
+ await write_stream.send(json_response)
+ elif isinstance(response, dict):
+ # It's a dict, convert to JSON
+ import json
+
+ json_response = json.dumps(response)
+ # Ensure it ends with a newline for JSON-RPC-over-stdio
+ if not json_response.endswith("\n"):
+ json_response += "\n"
+ logger.debug(f"Sending response: {json_response[:200]}...")
+ await write_stream.send(json_response)
+ else:
+ # It's already a string or something else
+ response_str = str(response)
+ # Ensure it ends with a newline for JSON-RPC-over-stdio
+ if not response_str.endswith("\n"):
+ response_str += "\n"
+ logger.debug(f"Sending raw response type: {type(response)}")
+ await write_stream.send(response_str)
+
+ async def handle_message(self, message: Any, user_id: str | None = None) -> Any:
+ """
+ Handle an incoming MCP message. Processes it through middleware and dispatches
+ to the appropriate handler based on the message method.
+
+ Args:
+ message: The raw incoming message
+ user_id: Optional user ID for authentication
+
+ Returns:
+ A properly formatted response message
+ """
+ # Pre-process message through middleware
+ processed = await self.message_processor.process_request(message)
+
+ # Handle special case for JSON string initialize requests
+ if isinstance(processed, str):
+ try:
+ import json
+
+ parsed = json.loads(processed)
+ if (
+ isinstance(parsed, dict)
+ and parsed.get("method") == MessageMethod.INITIALIZE
+ and "id" in parsed
+ ):
+ # This is an initialize request
+ init_response = await self._handle_initialize(InitializeRequest(**parsed))
+ return init_response
+ except Exception:
+ logger.exception("Error processing JSON string")
+ # Not parseable JSON, continue with normal processing
+ pass
+
+ # Check if it's a notification
+ if hasattr(processed, "method"):
+ method = getattr(processed, "method", None)
+
+ # Handle notifications (methods starting with "notifications/")
+ if method and method.startswith("notifications/"):
+ await self._handle_notification(method, processed)
+ return None
+
+ # Handle regular methods using the dispatch table
+ if method in self._method_handlers:
+ # If it's a call_tool request, we need to pass the user_id
+ if method == MessageMethod.CALL_TOOL:
+ return await self._method_handlers[method](processed, user_id=user_id)
+ # For other methods, just pass the processed message
+ return await self._method_handlers[method](processed)
+
+ # Unknown method
+ return JSONRPCError(
+ id=getattr(processed, "id", None),
+ error={
+ "code": -32601,
+ "message": f"Method not found: {method}",
+ },
+ )
+
+ # If it's not a method request, just pass it through
+ return processed
+
+ async def _handle_notification(self, method: str, message: Any) -> None:
+ """
+ Handle notification messages.
+
+ Args:
+ method: The notification method
+ message: The notification message
+ """
+ if method == "notifications/cancelled":
+ logger.info(f"Request cancelled: {getattr(message, 'params', {})}")
+ else:
+ logger.debug(f"Received notification: {method}")
+
+ async def _handle_ping(self, message: PingRequest) -> PingResponse:
+ """
+ Handle a ping request and return a pong response.
+
+ Args:
+ message: The ping request
+
+ Returns:
+ A properly formatted pong response
+ """
+ return PingResponse(id=message.id)
+
+ async def _handle_initialize(self, message: InitializeRequest) -> InitializeResponse:
+ """
+ Handle an initialize request and return a proper initialize response.
+
+ Args:
+ message: The initialize request
+
+ Returns:
+ A properly formatted initialize response
+ """
+ # Create the result data
+ result = InitializeResult(
+ protocolVersion=MCP_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="Arcade MCP Worker", version="0.1.0"),
+ instructions="Arcade MCP Worker initialized.",
+ )
+
+ # Construct proper response with result field
+ response = InitializeResponse(id=message.id, result=result)
+
+ logger.debug(f"Initialize response: {response.model_dump_json()}")
+ return response
+
+ async def _handle_list_tools(
+ self, message: ListToolsRequest
+ ) -> Union[ListToolsResponse, JSONRPCError]:
+ """
+ Handle a tools/list request and return a list of available tools.
+
+ Args:
+ message: The tools/list request
+
+ Returns:
+ A properly formatted tools/list response or error
+ """
+ try:
+ # Get all tools from the catalog
+ tools = []
+ tool_conversion_errors = []
+
+ for tool in self.tool_catalog:
+ try:
+ mcp_tool = create_mcp_tool(tool)
+ if mcp_tool:
+ tools.append(mcp_tool)
+ except Exception:
+ tool_name = getattr(tool, "name", str(tool))
+ logger.exception(f"Error converting tool: {tool_name}")
+ tool_conversion_errors.append(tool_name)
+
+ # Log summary if we had errors
+ if tool_conversion_errors:
+ logger.warning(
+ f"Failed to convert {len(tool_conversion_errors)} tools: {tool_conversion_errors}"
+ )
+
+ # Create tool objects with exception handling for each one
+ tool_objects = []
+ for t in tools:
+ try:
+ # Make input schema optional if missing
+ tool_dict = dict(t)
+ if "inputSchema" not in tool_dict:
+ tool_dict["inputSchema"] = {"type": "object", "properties": {}}
+
+ tool_objects.append(Tool(**tool_dict))
+ except Exception:
+ logger.exception(f"Error creating Tool object for {t.get('name', 'unknown')}")
+
+ # Return successful response with the tools we were able to convert
+ result = ListToolsResult(tools=tool_objects)
+ response = ListToolsResponse(id=message.id, result=result)
+
+ except Exception:
+ logger.exception("Error listing tools")
+ return JSONRPCError(
+ id=message.id,
+ error={
+ "code": -32603,
+ "message": "Internal error listing tools",
+ },
+ )
+ return response
+
+ async def _handle_call_tool(
+ self, message: CallToolRequest, user_id: str | None = None
+ ) -> CallToolResponse:
+ """
+ Handle a tools/call request to execute a tool.
+
+ Args:
+ message: The tools/call request
+ user_id: Optional user ID for authentication
+
+ Returns:
+ A properly formatted tools/call response
+ """
+ tool_name: str = message.params["name"]
+ # Extract input from the correct field
+ input_params: dict[str, Any] = message.params.get("input", {})
+ if not input_params:
+ input_params = message.params.get("arguments", {})
+
+ logger.info(f"Handling tool call for {tool_name}")
+
+ try:
+ tool = self.tool_catalog.get_tool_by_name(tool_name, separator="_")
+ tool_context = ToolContext()
+
+ # Set up context with secrets
+ if tool.definition.requirements and tool.definition.requirements.secrets:
+ self._setup_tool_secrets(tool, tool_context)
+
+ # Handle authorization if needed
+ requirement = self._get_auth_requirement(tool)
+ if requirement:
+ auth_result = await self._check_authorization(requirement, user_id=user_id)
+ if auth_result.status != "completed":
+ return CallToolResponse(
+ id=message.id,
+ result=CallToolResult(content=[{"type": "text", "text": auth_result.url}]),
+ )
+ else:
+ tool_context.authorization = ToolAuthorizationContext(
+ token=auth_result.context.token if auth_result.context else None,
+ user_info={"user_id": user_id} if user_id else {},
+ )
+
+ # Execute the tool
+ logger.debug(f"Executing tool {tool_name} with input: {input_params}")
+ result = await ToolExecutor.run(
+ func=tool.tool,
+ definition=tool.definition,
+ input_model=tool.input_model,
+ output_model=tool.output_model,
+ context=tool_context,
+ **input_params,
+ )
+ logger.debug(f"Tool result: {result}")
+ if result.value:
+ return CallToolResponse(
+ id=message.id,
+ result=CallToolResult(content=convert_to_mcp_content(result.value)),
+ )
+ else:
+ error = result.error or "Error calling tool"
+ logger.error(f"Tool {tool_name} returned error: {error}")
+ return CallToolResponse(
+ id=message.id,
+ result=CallToolResult(
+ content=[{"type": "text", "text": convert_to_mcp_content(error)}]
+ ),
+ )
+ except Exception as e:
+ logger.exception(f"Error calling tool {tool_name}")
+ error = f"Error calling tool {tool_name}: {e!s}"
+ return CallToolResponse(
+ id=message.id,
+ result=CallToolResult(
+ content=[{"type": "text", "text": convert_to_mcp_content(error)}]
+ ),
+ )
+
+ def _setup_tool_secrets(self, tool: Any, tool_context: ToolContext) -> None:
+ """
+ Set up tool secrets in the tool context.
+
+ Args:
+ tool: The tool to set up secrets for
+ tool_context: The tool context to update
+ """
+ for secret in tool.definition.requirements.secrets:
+ value = os.environ.get(secret.key)
+ if value is not None:
+ tool_context.set_secret(secret.key, value)
+
+ async def _handle_progress(self, message: ProgressNotification) -> JSONRPCResponse:
+ """
+ Handle a progress notification.
+
+ Args:
+ message: The progress notification
+
+ Returns:
+ A response acknowledging the notification
+ """
+ return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True})
+
+ async def _handle_cancel(self, message: CancelRequest) -> JSONRPCResponse:
+ """
+ Handle a cancel request.
+
+ Args:
+ message: The cancel request
+
+ Returns:
+ A response acknowledging the cancellation
+ """
+ return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True})
+
+ async def _handle_shutdown(self, message: ShutdownRequest) -> ShutdownResponse:
+ """
+ Handle a shutdown request.
+
+ Args:
+ message: The shutdown request
+
+ Returns:
+ A response acknowledging the shutdown request
+ """
+ # Schedule a task to shutdown the server after sending the response
+ proc = asyncio.create_task(self.shutdown())
+ proc.add_done_callback(lambda _: logger.info("MCP server shutdown complete"))
+ return ShutdownResponse(id=message.id, result={"ok": True})
+
+ async def _handle_list_resources(self, message: ListResourcesRequest) -> ListResourcesResponse:
+ """
+ Handle a resources/list request.
+
+ Args:
+ message: The resources/list request
+
+ Returns:
+ A properly formatted resources/list response
+ """
+ return ListResourcesResponse(id=message.id, result={"resources": []})
+
+ async def _handle_list_prompts(self, message: ListPromptsRequest) -> ListPromptsResponse:
+ """
+ Handle a prompts/list request.
+
+ Args:
+ message: The prompts/list request
+
+ Returns:
+ A properly formatted prompts/list response
+ """
+ return ListPromptsResponse(id=message.id, result={"prompts": []})
+
+ def _get_auth_requirement(self, tool: MaterializedTool) -> AuthRequirement | None:
+ """
+ Get the authentication requirement for a tool.
+
+ Args:
+ tool: The tool to get the requirement for
+
+ Returns:
+ An authentication requirement or None if not required
+ """
+ req = tool.definition.requirements.authorization
+ if not req:
+ return None
+ if not req.provider_id and not req.provider_type:
+ return None
+ if hasattr(req, "oauth2") and req.oauth2:
+ return AuthRequirement(
+ provider_id=str(req.provider_id),
+ provider_type=str(req.provider_type),
+ oauth2=AuthRequirementOauth2(scopes=req.oauth2.scopes or []),
+ )
+ return AuthRequirement(
+ provider_id=str(req.provider_id),
+ provider_type=str(req.provider_type),
+ )
+
+ async def _check_authorization(
+ self, auth_requirement: AuthRequirement, user_id: str | None = None
+ ) -> AuthorizationResponse:
+ """
+ Check if a tool is authorized for a user.
+
+ Args:
+ tool: The tool to check authorization for
+ user_id: The user ID to check authorization for
+
+ Returns:
+ An authorization response
+
+ Raises:
+ RuntimeError: If the tool has no authorization requirement
+ Exception: If authorization fails
+ """
+ try:
+ response = await self.arcade.auth.authorize(
+ auth_requirement=auth_requirement,
+ user_id=user_id or "anonymous",
+ )
+ logger.debug(f"Authorization response: {response}")
+
+ except ArcadeError:
+ logger.exception("Error authorizing tool")
+ raise
+ return response
+
+ async def shutdown(self) -> None:
+ """Shutdown the server."""
+ self._shutdown = True
+ logger.info("MCP server shutdown complete")
diff --git a/arcade/arcade/worker/mcp/stdio.py b/arcade/arcade/worker/mcp/stdio.py
new file mode 100644
index 00000000..4c637005
--- /dev/null
+++ b/arcade/arcade/worker/mcp/stdio.py
@@ -0,0 +1,185 @@
+import asyncio
+import logging
+import queue
+import signal
+import sys
+import threading
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING, Any, TypeVar
+
+if TYPE_CHECKING:
+ pass
+
+from arcade.worker.mcp.server import MCPServer
+
+logger = logging.getLogger("arcade.mcp")
+
+T = TypeVar("T")
+
+
+def stdio_reader(stdin: object, q: queue.Queue[str | None]) -> None:
+ """Read lines from stdin and put them into a queue."""
+ for line in stdin: # type: ignore[attr-defined]
+ q.put(line)
+ q.put(None)
+
+
+def stdio_writer(stdout: object, q: queue.Queue[str | None]) -> None:
+ """Write messages from a queue to stdout."""
+ try:
+ while True:
+ msg = q.get()
+ if msg is None:
+ break
+
+ # Ensure message ends with a newline for proper JSON-RPC-over-stdio
+ if not msg.endswith("\n"):
+ msg += "\n"
+
+ stdout.write(msg) # type: ignore[attr-defined]
+ stdout.flush() # type: ignore[attr-defined]
+ except Exception:
+ logger.exception("Error in stdio writer")
+
+
+class StdioServer(MCPServer):
+ """
+ Stdio server that handles signals and cleanup.
+ """
+
+ def __init__(
+ self,
+ tool_catalog: Any,
+ enable_logging: bool = True,
+ **client_kwargs: dict[str, Any],
+ ):
+ # Set up stdio-specific middleware configuration
+ middleware_config = client_kwargs.get("middleware_config", {})
+ middleware_config["stdio_mode"] = True
+ client_kwargs["middleware_config"] = middleware_config
+
+ super().__init__(tool_catalog, enable_logging, **client_kwargs)
+ self.read_q: queue.Queue[str | None] = queue.Queue()
+ self.write_q: queue.Queue[str | None] = queue.Queue()
+ self.reader_thread: threading.Thread | None = None
+ self.writer_thread: threading.Thread | None = None
+ self.running = False
+ self.shutdown_event = asyncio.Event()
+
+ def start_io_threads(self) -> None:
+ """Start stdio reader and writer threads."""
+ self.reader_thread = threading.Thread(
+ target=self._stdio_reader, args=(sys.stdin, self.read_q), daemon=True
+ )
+ self.writer_thread = threading.Thread(
+ target=self._stdio_writer, args=(sys.stdout, self.write_q), daemon=True
+ )
+ self.reader_thread.start()
+ self.writer_thread.start()
+
+ def _stdio_reader(self, stdin: object, q: queue.Queue[str | None]) -> None:
+ """Read lines from stdin and put them into a queue."""
+ try:
+ for line in stdin: # type: ignore[attr-defined]
+ if not self.running:
+ break
+ q.put(line)
+ except Exception:
+ logger.exception("Error in stdio reader")
+ finally:
+ q.put(None) # Signal EOF
+
+ def _stdio_writer(self, stdout: object, q: queue.Queue[str | None]) -> None:
+ """Write messages from a queue to stdout."""
+ try:
+ while self.running:
+ msg = q.get()
+ if msg is None:
+ break
+ stdout.write(msg) # type: ignore[attr-defined]
+ stdout.flush() # type: ignore[attr-defined]
+ except Exception:
+ logger.exception("Error in stdio writer")
+
+ async def _read_stream(self) -> AsyncGenerator[str, None]:
+ """Async generator that yields lines from the read queue."""
+ while self.running:
+ try:
+ line = await asyncio.to_thread(self.read_q.get)
+ if line is None:
+ break
+ yield line
+ except asyncio.CancelledError:
+ break
+ except Exception:
+ logger.exception("Error reading from stdin")
+ break
+
+ async def shutdown(self) -> None:
+ """Gracefully shut down the server."""
+ if not self.running:
+ return
+
+ logger.info("Shutting down stdio server...")
+ self.running = False
+
+ # Signal shutdown to MCP server
+ await self.shutdown()
+
+ # Clean up IO queues and threads
+ try:
+ if self.read_q:
+ self.read_q.put(None)
+ if self.write_q:
+ self.write_q.put(None)
+ except Exception:
+ logger.exception("Error during shutdown")
+
+ # Signal completion
+ self.shutdown_event.set()
+ logger.info("Stdio server shutdown complete")
+
+ async def run(self) -> None:
+ """Run the stdio server with signal handling."""
+ self.running = True
+
+ # Set up signal handlers
+ loop = asyncio.get_running_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ try:
+ loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown()))
+ except NotImplementedError:
+ # Windows doesn't support POSIX signals
+ if sys.platform == "win32":
+ logger.warning("Signal handling not fully supported on Windows")
+ else:
+ logger.warning(f"Failed to set up signal handler for {sig}")
+
+ # Start IO threads
+ self.start_io_threads()
+
+ logger.info("Starting MCP server with stdio transport")
+
+ # Create WriteStream class for MCP server
+ class WriteStream:
+ async def send(self_, message: str) -> None:
+ if self.running:
+ await asyncio.to_thread(self.write_q.put, message)
+
+ try:
+ # Run MCP server connection
+ await self.run_connection(self._read_stream(), WriteStream(), None)
+ except asyncio.CancelledError:
+ # Handle cancellation
+ logger.info("Server operation cancelled")
+ except KeyboardInterrupt:
+ # Handle keyboard interrupt
+ logger.info("Keyboard interrupt received")
+ except Exception:
+ # Handle unexpected errors
+ logger.exception("Unexpected error")
+ finally:
+ # Ensure we clean up
+ await self.shutdown()
+ # Wait for shutdown to complete
+ await self.shutdown_event.wait()
diff --git a/arcade/arcade/worker/mcp/types.py b/arcade/arcade/worker/mcp/types.py
new file mode 100644
index 00000000..4354241b
--- /dev/null
+++ b/arcade/arcade/worker/mcp/types.py
@@ -0,0 +1,383 @@
+import json
+from collections.abc import Callable
+from typing import (
+ Any,
+ Generic,
+ Literal,
+ TypeAlias,
+ TypeVar,
+ Union,
+)
+
+from pydantic import BaseModel, ConfigDict, Field
+
+ProgressToken = str | int
+Cursor = str
+Role = Literal["user", "assistant"]
+RequestId = str | int
+AnyFunction: TypeAlias = Callable[..., Any]
+
+
+class RequestParams(BaseModel):
+ class Meta(BaseModel):
+ progressToken: ProgressToken | None = None
+ model_config = ConfigDict(extra="allow")
+
+ meta: Meta | None = Field(alias="_meta", default=None)
+
+ model_config = ConfigDict(extra="allow")
+
+
+class NotificationParams(BaseModel):
+ class Meta(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ meta: Meta | None = Field(alias="_meta", default=None)
+ model_config = ConfigDict(extra="allow")
+
+
+RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
+NotificationParamsT = TypeVar(
+ "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
+)
+MethodT = TypeVar("MethodT", bound=str)
+
+
+class Request(BaseModel, Generic[RequestParamsT, MethodT]):
+ method: MethodT
+ params: RequestParamsT
+ model_config = ConfigDict(extra="allow")
+
+
+class PaginatedRequest(Request[RequestParamsT, MethodT]):
+ cursor: Cursor | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
+ method: MethodT
+ params: NotificationParamsT
+ model_config = ConfigDict(extra="allow")
+
+
+class Result(BaseModel):
+ meta: dict[str, Any] | None = Field(alias="_meta", default=None)
+ model_config = ConfigDict(extra="allow")
+
+
+class PaginatedResult(Result):
+ nextCursor: Cursor | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class JSONRPCMessage(BaseModel):
+ """Base class for all JSON-RPC messages."""
+
+ model_config = ConfigDict(extra="allow")
+ jsonrpc: str = Field(default="2.0", frozen=True)
+
+
+class JSONRPCRequest(JSONRPCMessage):
+ """A JSON-RPC request message."""
+
+ id: str | int | None = None
+ method: str
+ params: dict[str, Any] | None = None
+
+
+class JSONRPCResponse(JSONRPCMessage):
+ """A JSON-RPC response message."""
+
+ id: str | int | None
+ result: Any = None
+ error: dict[str, Any] | None = None
+
+ def model_dump_json(self, **kwargs: Any) -> str:
+ """Convert to JSON string with proper formatting."""
+
+ # Convert to dict
+ data = {
+ "jsonrpc": self.jsonrpc,
+ "id": self.id,
+ }
+
+ # Add result if present
+ if self.result is not None:
+ # Check if result is a Pydantic model
+ if hasattr(self.result, "model_dump"):
+ data["result"] = self.result.model_dump(exclude_none=True)
+ # Check if result is already a dict/list/primitive
+ elif (
+ isinstance(self.result, (dict, list, str, int, float, bool)) or self.result is None
+ ):
+ data["result"] = self.result # type: ignore[assignment]
+ else:
+ # Try to convert using str() as a fallback
+ data["result"] = str(self.result)
+
+ # Add error if present
+ if self.error is not None:
+ data["error"] = self.error # type: ignore[assignment]
+
+ return json.dumps(data, ensure_ascii=False)
+
+
+class JSONRPCError(JSONRPCMessage):
+ """A JSON-RPC error message."""
+
+ id: str | int | None
+ error: dict[str, Any]
+
+
+PARSE_ERROR = -32700
+INVALID_REQUEST = -32600
+METHOD_NOT_FOUND = -32601
+INVALID_PARAMS = -32602
+INTERNAL_ERROR = -32603
+
+
+class ErrorData(BaseModel):
+ code: int
+ message: str
+ data: Any | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+JSONRPCMessageBaseModel = BaseModel | JSONRPCRequest | JSONRPCResponse | JSONRPCError
+
+
+class EmptyResult(Result):
+ pass
+
+
+class Implementation(BaseModel):
+ """Describes the server or client implementation."""
+
+ name: str
+ version: str
+ model_config = ConfigDict(extra="allow")
+
+
+class RootsCapability(BaseModel):
+ listChanged: bool | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class SamplingCapability(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+
+class ClientCapabilities(BaseModel):
+ experimental: dict[str, dict[str, Any]] | None = None
+ sampling: SamplingCapability | None = None
+ roots: RootsCapability | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class PromptsCapability(BaseModel):
+ listChanged: bool | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class ResourcesCapability(BaseModel):
+ subscribe: bool | None = None
+ listChanged: bool | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class ToolsCapability(BaseModel):
+ listChanged: bool | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class LoggingCapability(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+
+class ServerCapabilities(BaseModel):
+ """Describes the server's capabilities."""
+
+ model_config = ConfigDict(extra="allow")
+ tools: dict[str, Any] | None = None
+ resources: dict[str, Any] | None = None
+ prompts: dict[str, Any] | None = None
+
+
+class InitializeRequestParams(RequestParams):
+ protocolVersion: str | int
+ capabilities: ClientCapabilities
+ clientInfo: Implementation
+ model_config = ConfigDict(extra="allow")
+
+
+class InitializeRequest(JSONRPCRequest):
+ method: str = Field(default="initialize", frozen=True)
+ params: dict[str, Any] | None = None
+
+
+class InitializeResult(BaseModel):
+ protocolVersion: str
+ capabilities: ServerCapabilities
+ serverInfo: Implementation
+ instructions: str | None = None
+
+
+class InitializedNotification(
+ Notification[NotificationParams | None, Literal["notifications/initialized"]]
+):
+ method: Literal["notifications/initialized"]
+ params: NotificationParams | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class PingRequest(JSONRPCRequest):
+ method: str = Field(default="ping", frozen=True)
+ params: dict[str, Any] | None = None
+
+
+class ProgressNotificationParams(NotificationParams):
+ progressToken: ProgressToken
+ progress: float
+ total: float | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class ProgressNotification(JSONRPCMessage):
+ method: str = Field(default="progress", frozen=True)
+ params: dict[str, Any]
+
+
+class PingResponse(JSONRPCResponse):
+ result: dict[str, Any] = Field(default_factory=lambda: {"pong": True})
+
+
+class ShutdownRequest(JSONRPCRequest):
+ method: str = Field(default="shutdown", frozen=True)
+ params: dict[str, Any] | None = None
+
+
+class ShutdownResponse(JSONRPCResponse):
+ result: dict[str, Any] = Field(default_factory=lambda: {"ok": True})
+
+
+class CancelRequest(JSONRPCRequest):
+ method: str = Field(default="$/cancelRequest", frozen=True)
+ params: dict[str, Any]
+
+
+class InitializeResponse(JSONRPCResponse):
+ """
+ Response to an initialize request.
+
+ Note: This must be a properly formatted JSON-RPC response with a `result` field
+ containing the initialization data, not another request.
+ """
+
+ result: InitializeResult
+
+ def model_dump_json(self, **kwargs: Any) -> str:
+ """Convert to JSON string with proper formatting."""
+ # Convert to dict
+ data = {
+ "jsonrpc": self.jsonrpc,
+ "id": self.id,
+ "result": self.result.model_dump(exclude_none=True),
+ }
+
+ # Return JSON string
+ return json.dumps(data, ensure_ascii=False)
+
+
+class ListToolsRequest(JSONRPCRequest):
+ method: str = Field(default="tools/list", frozen=True)
+ params: dict[str, Any] | None = None
+
+
+class ToolAnnotations(BaseModel):
+ """
+ Represents tool annotations for hints about behavior.
+ """
+
+ title: str | None = None
+ readOnlyHint: bool | None = None
+ destructiveHint: bool | None = None
+ idempotentHint: bool | None = None
+ openWorldHint: bool | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class Tool(BaseModel):
+ """
+ Represents an MCP tool definition.
+ """
+
+ name: str
+ description: str
+ inputSchema: dict[str, Any] | None = None
+ annotations: ToolAnnotations | None = None
+
+ model_config = ConfigDict(extra="allow")
+
+
+class ListToolsResult(BaseModel):
+ tools: list[Tool]
+
+
+class ListToolsResponse(JSONRPCResponse):
+ result: ListToolsResult
+
+
+class CallToolRequest(JSONRPCRequest):
+ method: str = Field(default="tools/call", frozen=True)
+ params: dict[str, Any]
+
+
+class CallToolResult(BaseModel):
+ content: Any
+
+
+class CallToolResponse(JSONRPCResponse):
+ result: CallToolResult
+
+
+# Resource and Prompt protocol stubs (expand as needed)
+class ListResourcesRequest(JSONRPCRequest):
+ method: str = Field(default="resources/list", frozen=True)
+ params: dict[str, Any] | None = None
+
+
+class ListResourcesResponse(JSONRPCResponse):
+ result: dict[str, Any]
+
+
+class ListPromptsRequest(JSONRPCRequest):
+ method: str = Field(default="prompts/list", frozen=True)
+ params: dict[str, Any] | None = None
+
+
+class ListPromptsResponse(JSONRPCResponse):
+ result: dict[str, Any]
+
+
+# Utility type alias for all MCP protocol messages
+MCPMessage = Union[
+ JSONRPCRequest,
+ JSONRPCResponse,
+ JSONRPCError,
+ PingRequest,
+ PingResponse,
+ InitializeRequest,
+ InitializeResponse,
+ ListToolsRequest,
+ ListToolsResponse,
+ CallToolRequest,
+ CallToolResponse,
+ ProgressNotification,
+ CancelRequest,
+ ShutdownRequest,
+ ShutdownResponse,
+ ListResourcesRequest,
+ ListResourcesResponse,
+ ListPromptsRequest,
+ ListPromptsResponse,
+]
diff --git a/arcade/tests/deployment/test_config.py b/arcade/tests/deployment/test_config.py
index 531bfaac..7d62c461 100644
--- a/arcade/tests/deployment/test_config.py
+++ b/arcade/tests/deployment/test_config.py
@@ -7,7 +7,7 @@ from pathlib import Path
import pytest
-from arcade.worker.config.deployment import (
+from arcade.cli.deployment import (
Config,
Deployment,
LocalPackages,
diff --git a/arcade/tests/mcp/test_convert.py b/arcade/tests/mcp/test_convert.py
new file mode 100644
index 00000000..a4bcf0c2
--- /dev/null
+++ b/arcade/tests/mcp/test_convert.py
@@ -0,0 +1,44 @@
+import json
+from typing import Annotated
+
+from arcade.core.catalog import ToolCatalog
+from arcade.sdk import tool
+from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool
+
+
+@tool
+def sample_tool(x: Annotated[int, "first"], y: Annotated[int, "second"]) -> int:
+ """Return x+y"""
+
+ return x + y
+
+
+def test_convert_to_mcp_content_primitives():
+ assert convert_to_mcp_content(42) == [{"type": "text", "text": "42"}]
+ assert convert_to_mcp_content("hello") == [{"type": "text", "text": "hello"}]
+ assert convert_to_mcp_content(True) == [{"type": "text", "text": "True"}]
+
+
+def test_convert_to_mcp_content_complex():
+ data = {"a": 1}
+ expected_json = json.dumps(data)
+ assert convert_to_mcp_content(data) == [{"type": "text", "text": expected_json}]
+
+
+def test_create_mcp_tool():
+ # Materialize a tool via catalog then feed it to create_mcp_tool
+ catalog = ToolCatalog()
+ catalog.add_tool(sample_tool, "convert_toolkit")
+ mat_tool = next(iter(catalog)) # only tool
+ mcp_tool = create_mcp_tool(mat_tool)
+
+ assert mcp_tool is not None
+ assert mcp_tool["name"] == "ConvertToolkit_SampleTool"
+ assert mcp_tool["description"]
+ # Ensure input schema contains both parameters and marks them required
+ props = mcp_tool["inputSchema"]["properties"]
+ assert set(props.keys()) == {"x", "y"}
+
+ required_fields = set(mcp_tool["inputSchema"].get("required", []))
+ # Ensure no unexpected required fields and that declared ones are subset of expected
+ assert required_fields.issubset({"x", "y"})
diff --git a/arcade/tests/mcp/test_message_processor.py b/arcade/tests/mcp/test_message_processor.py
new file mode 100644
index 00000000..fac0d91a
--- /dev/null
+++ b/arcade/tests/mcp/test_message_processor.py
@@ -0,0 +1,57 @@
+import asyncio
+
+import pytest
+
+from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor
+from arcade.worker.mcp.types import InitializeRequest, PingRequest
+
+
+@pytest.mark.asyncio
+async def test_message_processor_parses_initialize_json():
+ """Ensure JSON initialize strings are converted into InitializeRequest objects."""
+ json_init = '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}\n'
+ processor = MCPMessageProcessor()
+
+ result = await processor.process_request(json_init)
+
+ assert isinstance(result, InitializeRequest)
+ assert result.id == 1
+ assert result.method == "initialize"
+
+
+@pytest.mark.asyncio
+async def test_message_processor_passes_notifications_unchanged():
+ """Unknown notifications should be passed through as parsed dictionaries without errors."""
+ json_notification = '{"jsonrpc":"2.0","id":null,"method":"notifications/custom","params":{}}\n'
+ processor = MCPMessageProcessor()
+
+ result = await processor.process_request(json_notification)
+
+ # The MCPMessageProcessor keeps unknown notifications as simple dicts
+ assert isinstance(result, dict)
+ assert result["method"] == "notifications/custom"
+
+
+@pytest.mark.asyncio
+async def test_message_processor_middleware_execution_order(monkeypatch):
+ """Middleware (sync + async) should be executed in the order they were added."""
+
+ order: list[str] = []
+
+ def mw_sync(msg, direction): # type: ignore[return-value]
+ order.append("sync")
+ return msg
+
+ async def mw_async(msg, direction): # type: ignore[return-value]
+ await asyncio.sleep(0) # ensure it is truly async
+ order.append("async")
+ return msg
+
+ processor = create_message_processor(mw_sync, mw_async)
+
+ # Use a pre-parsed PingRequest instance so we don't test parsing again here
+ ping = PingRequest(id=42)
+
+ _ = await processor.process_request(ping)
+
+ assert order == ["sync", "async"]
diff --git a/arcade/tests/mcp/test_server.py b/arcade/tests/mcp/test_server.py
new file mode 100644
index 00000000..bfb75eeb
--- /dev/null
+++ b/arcade/tests/mcp/test_server.py
@@ -0,0 +1,153 @@
+import sys
+import types
+from typing import Annotated, Any
+
+import pytest
+
+from arcade.core.catalog import ToolCatalog
+from arcade.sdk import tool
+from arcade.worker.mcp import server as mcp_server
+from arcade.worker.mcp.types import (
+ CallToolRequest,
+ CancelRequest,
+ InitializeRequest,
+ ListToolsRequest,
+ PingRequest,
+)
+
+# ---------------------------------------------------------------------------
+# Test helpers / stubs
+# ---------------------------------------------------------------------------
+
+
+class _FakeAuth:
+ async def authorize(self, auth_requirement: Any, user_id: str):
+ """Return an object that mimics AuthorizationResponse with completed status."""
+
+ class _Ctx: # minimal stub
+ token = "dummy-token" # noqa: S105
+
+ class _Resp: # pylint: disable=too-few-public-methods
+ status = "completed"
+ url = ""
+ context = _Ctx()
+
+ return _Resp()
+
+
+class _FakeArcade: # pylint: disable=too-few-public-methods
+ def __init__(self, **_: Any):
+ self.auth = _FakeAuth()
+
+
+# Ensure that the AsyncArcade & ArcadeError symbols inside server.py point to our stubs.
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture(autouse=True)
+def _patch_arcadepy(monkeypatch):
+ """Patch the external `arcadepy` dependency used by mcp.server."""
+
+ # Patch the imported symbols on the already-imported server module
+ monkeypatch.setattr(mcp_server, "AsyncArcade", _FakeArcade, raising=True)
+ monkeypatch.setattr(mcp_server, "ArcadeError", Exception, raising=True)
+
+ # Provide a dummy `arcadepy` module in sys.modules for any other importers
+ fake_arcadepy = types.ModuleType("arcadepy")
+ fake_arcadepy.AsyncArcade = _FakeArcade # type: ignore[attr-defined]
+ fake_arcadepy.ArcadeError = Exception # type: ignore[attr-defined]
+ sys.modules["arcadepy"] = fake_arcadepy
+
+ yield
+
+ # Cleanup
+ sys.modules.pop("arcadepy", None)
+
+
+# ---------------------------------------------------------------------------
+# Fixtures for a sample tool / catalog / server
+# ---------------------------------------------------------------------------
+
+
+@tool
+def multiply(a: Annotated[int, "a"], b: Annotated[int, "b"]) -> Annotated[int, "result"]:
+ """Return the product of *a* and *b*."""
+
+ return a * b
+
+
+@pytest.fixture(scope="module")
+def sample_catalog():
+ catalog = ToolCatalog()
+ catalog.add_tool(multiply, "test_toolkit")
+ return catalog
+
+
+@pytest.fixture()
+def server(sample_catalog):
+ # MCPServer constructor is synchronous, so fixture need not be async
+ return mcp_server.MCPServer(sample_catalog, enable_logging=False)
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+async def test_handle_ping(server):
+ req = PingRequest(id=123)
+ resp = await server._handle_ping(req) # pylint: disable=protected-access
+ assert resp.id == 123
+ assert resp.result == {"pong": True}
+
+
+async def test_handle_initialize(server):
+ req = InitializeRequest(id=1)
+ resp = await server._handle_initialize(req) # pylint: disable=protected-access
+ assert resp.id == 1
+ assert resp.result.protocolVersion == mcp_server.MCP_PROTOCOL_VERSION
+ assert resp.result.serverInfo.name.startswith("Arcade")
+
+
+async def test_handle_list_tools(server):
+ req = ListToolsRequest(id=99)
+ resp = await server._handle_list_tools(req) # pylint: disable=protected-access
+ assert resp.id == 99
+ # Should list our sample tool only
+ tool_names = [t.name for t in resp.result.tools]
+ assert "TestToolkit_Multiply" in tool_names # toolkit + "_" + tool
+
+
+async def test_handle_call_tool_success(server):
+ req = CallToolRequest(
+ id="call-1",
+ params={
+ "name": "TestToolkit_Multiply",
+ "input": {"a": 6, "b": 7},
+ },
+ )
+ resp = await server._handle_call_tool(req, user_id="tester@example.com") # pylint: disable=protected-access
+
+ assert resp.id == "call-1"
+ # convert_to_mcp_content wraps primitives in list-of-dicts
+ assert resp.result.content == [{"type": "text", "text": "42"}]
+
+
+async def test_send_response_dict(server, monkeypatch):
+ """_send_response should JSON-serialize plain dictionaries."""
+
+ sent: list[str] = []
+
+ class _Write:
+ async def send(self, msg):
+ sent.append(msg)
+
+ await server._send_response(_Write(), {"foo": "bar"}) # pylint: disable=protected-access
+
+ assert sent and sent[0].strip() == '{"foo": "bar"}'
+
+
+async def test_handle_cancel(server):
+ req = CancelRequest(id=77, params={"id": "abc"})
+ resp = await server._handle_cancel(req) # pylint: disable=protected-access
+ assert resp.result == {"ok": True}
diff --git a/arcade/tests/mcp/test_stdio.py b/arcade/tests/mcp/test_stdio.py
new file mode 100644
index 00000000..6d512157
--- /dev/null
+++ b/arcade/tests/mcp/test_stdio.py
@@ -0,0 +1,32 @@
+import io
+import queue
+
+from arcade.worker.mcp.stdio import stdio_reader, stdio_writer
+
+
+def test_stdio_reader_puts_lines_and_none():
+ q: queue.Queue[str | None] = queue.Queue()
+ test_input = io.StringIO("line1\nline2\n")
+
+ stdio_reader(test_input, q)
+
+ # We should get the two lines followed by None sentinel
+ assert q.get_nowait() == "line1\n"
+ assert q.get_nowait() == "line2\n"
+ assert q.get_nowait() is None
+
+
+def test_stdio_writer_reads_until_none():
+ q: queue.Queue[str | None] = queue.Queue()
+ output_stream = io.StringIO()
+
+ # preload queue with two messages and sentinel
+ q.put("msg1")
+ q.put("msg2\n")
+ q.put(None)
+
+ stdio_writer(output_stream, q)
+
+ # Ensure writer appended newlines when missing
+ output_stream.seek(0)
+ assert output_stream.read() == "msg1\nmsg2\n"
diff --git a/arcade/tests/worker/test_worker_base.py b/arcade/tests/worker/test_worker_base.py
new file mode 100644
index 00000000..f50eb871
--- /dev/null
+++ b/arcade/tests/worker/test_worker_base.py
@@ -0,0 +1,237 @@
+import os
+from typing import Annotated
+from unittest.mock import MagicMock
+
+import pytest
+
+from arcade.core.errors import ToolDefinitionError
+from arcade.core.schema import (
+ ToolCallRequest,
+ ToolCallResponse,
+ ToolContext,
+ ToolReference,
+)
+from arcade.sdk import tool
+from arcade.worker.core.base import BaseWorker
+from arcade.worker.core.common import RequestData, Router
+from arcade.worker.core.components import (
+ CallToolComponent,
+ CatalogComponent,
+ HealthCheckComponent,
+)
+
+
+@tool()
+def sample_tool(
+ context: ToolContext, a: Annotated[int, "a"], b: Annotated[int, "b"]
+) -> Annotated[int, "output"]:
+ """Sample tool for testing."""
+ return a + b
+
+
+# Define error tool at module level to avoid indentation issues with getsource
+@tool()
+def error_tool(context: ToolContext) -> int:
+ """This tool always raises an error."""
+ raise ValueError("Something went wrong")
+
+
+@pytest.fixture
+def mock_router():
+ router = MagicMock(spec=Router)
+ router.add_route = MagicMock()
+ return router
+
+
+@pytest.fixture
+def base_worker(mock_router):
+ # Set env var temporarily for testing secret loading
+ os.environ["ARCADE_WORKER_SECRET"] = "test_secret_env" # noqa: S105
+ worker = BaseWorker()
+ worker.register_routes(mock_router) # Register routes using the mock router
+ # Clean up env var
+ del os.environ["ARCADE_WORKER_SECRET"]
+ return worker
+
+
+@pytest.fixture
+def base_worker_no_auth():
+ return BaseWorker(disable_auth=True)
+
+
+# --- BaseWorker Tests ---
+
+
+def test_base_worker_init_with_secret():
+ worker = BaseWorker(secret="explicit_secret") # noqa: S106
+ assert worker.secret == "explicit_secret" # noqa: S105
+ assert not worker.disable_auth
+
+
+def test_base_worker_init_with_env_secret():
+ os.environ["ARCADE_WORKER_SECRET"] = "env_secret_value" # noqa: S105
+ worker = BaseWorker()
+ assert worker.secret == "env_secret_value" # noqa: S105
+ assert not worker.disable_auth
+ del os.environ["ARCADE_WORKER_SECRET"]
+
+
+def test_base_worker_init_no_secret_raises_error():
+ # Ensure env var is not set
+ if "ARCADE_WORKER_SECRET" in os.environ:
+ del os.environ["ARCADE_WORKER_SECRET"]
+ with pytest.raises(ValueError, match="No secret provided for worker"):
+ BaseWorker()
+
+
+def test_base_worker_init_disable_auth():
+ worker = BaseWorker(disable_auth=True)
+ assert worker.secret == ""
+ assert worker.disable_auth
+
+
+def test_register_tool(base_worker_no_auth):
+ assert len(base_worker_no_auth.catalog) == 0
+ base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
+ assert len(base_worker_no_auth.catalog) == 1
+ tool_def = base_worker_no_auth.get_catalog()[0]
+ assert tool_def.name == "SampleTool"
+ assert tool_def.toolkit.name == "TestKit"
+
+
+def test_get_catalog(base_worker_no_auth):
+ base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
+ catalog = base_worker_no_auth.get_catalog()
+ assert isinstance(catalog, list)
+ assert len(catalog) == 1
+ assert catalog[0].name == "SampleTool"
+
+
+def test_health_check(base_worker_no_auth):
+ base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
+ health = base_worker_no_auth.health_check()
+ assert health == {"status": "ok", "tool_count": "1"}
+
+
+@pytest.mark.asyncio
+async def test_call_tool_success(base_worker_no_auth):
+ base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
+ # Create ToolReference WITHOUT version, as register_tool doesn't seem to set it
+ tool_ref = ToolReference(toolkit="TestKit", name="SampleTool")
+ tool_request = ToolCallRequest(
+ execution_id="test_exec_id",
+ tool=tool_ref,
+ inputs={"a": 5, "b": 3},
+ )
+
+ response = await base_worker_no_auth.call_tool(tool_request)
+
+ assert response.success is True
+ assert response.output.value == 8
+ assert response.output.error is None
+ assert response.execution_id == "test_exec_id"
+ assert response.duration > 0
+
+
+@pytest.mark.asyncio
+async def test_call_tool_execution_error(base_worker_no_auth):
+ # Tool is now defined at module level
+ try:
+ base_worker_no_auth.register_tool(error_tool, toolkit_name="error_kit")
+ except ToolDefinitionError as e:
+ pytest.fail(f"Failed to register error_tool: {e}")
+
+ # Create ToolReference WITHOUT version
+ tool_ref = ToolReference(toolkit="ErrorKit", name="ErrorTool")
+ tool_request = ToolCallRequest(
+ execution_id="test_exec_error",
+ tool=tool_ref,
+ inputs={},
+ )
+
+ response = await base_worker_no_auth.call_tool(tool_request)
+
+ assert response.success is False
+ assert response.output.value is None
+ assert response.output.error is not None
+
+
+@pytest.mark.asyncio
+async def test_call_tool_not_found(base_worker_no_auth):
+ # Use ToolReference without version for lookup consistency
+ tool_ref = ToolReference(toolkit="nonexistent", name="nosuchtool")
+ tool_request = ToolCallRequest(
+ execution_id="test_exec_notfound",
+ tool=tool_ref,
+ inputs={},
+ )
+
+ # Update regex to match actual error format
+ with pytest.raises(ValueError):
+ await base_worker_no_auth.call_tool(tool_request)
+
+
+# --- Component Tests (tested via BaseWorker registration) ---
+
+
+def test_register_routes_registers_default_components(base_worker, mock_router):
+ # BaseWorker calls register_routes in its init via the fixture
+ assert mock_router.add_route.call_count == len(BaseWorker.default_components)
+
+ calls = mock_router.add_route.call_args_list
+ expected_paths = ["tools", "tools/invoke", "health"]
+ registered_paths = [
+ call[0][0] for call in calls
+ ] # call[0] are positional args, call[0][0] is endpoint_path
+
+ assert sorted(registered_paths) == sorted(expected_paths)
+
+ # Check if components were instantiated and passed to add_route
+ assert any(isinstance(call[0][1], CatalogComponent) for call in calls)
+ assert any(isinstance(call[0][1], CallToolComponent) for call in calls)
+ assert any(isinstance(call[0][1], HealthCheckComponent) for call in calls)
+
+
+@pytest.mark.asyncio
+async def test_catalog_component_call(base_worker_no_auth):
+ base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
+ component = CatalogComponent(base_worker_no_auth)
+ # Mock request data - not actually used by this component's __call__
+ mock_request = MagicMock(spec=RequestData)
+ catalog_response = await component(mock_request)
+
+ assert isinstance(catalog_response, list)
+ assert len(catalog_response) == 1
+ assert catalog_response[0].name == "SampleTool"
+
+
+@pytest.mark.asyncio
+async def test_call_tool_component_call(base_worker_no_auth):
+ base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
+ component = CallToolComponent(base_worker_no_auth)
+
+ # Create ToolReference WITHOUT version
+ tool_ref = ToolReference(toolkit="TestKit", name="SampleTool")
+ request_body = {
+ "execution_id": "comp_test_exec",
+ "tool": tool_ref.model_dump(),
+ "inputs": {"a": 10, "b": 5},
+ }
+ mock_request = MagicMock(spec=RequestData)
+ mock_request.body_json = request_body
+
+ response = await component(mock_request)
+
+ assert isinstance(response, ToolCallResponse)
+ assert response.success is True
+ assert response.output.value == 15
+ assert response.execution_id == "comp_test_exec"
+
+
+@pytest.mark.asyncio
+async def test_health_check_component_call(base_worker_no_auth):
+ component = HealthCheckComponent(base_worker_no_auth)
+ mock_request = MagicMock(spec=RequestData)
+ health_response = await component(mock_request)
+
+ assert health_response == {"status": "ok", "tool_count": "0"}
diff --git a/arcade/tests/worker/test_worker_fastapi.py b/arcade/tests/worker/test_worker_fastapi.py
new file mode 100644
index 00000000..83c6f105
--- /dev/null
+++ b/arcade/tests/worker/test_worker_fastapi.py
@@ -0,0 +1,167 @@
+from typing import Annotated
+
+import pytest
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+
+from arcade.core.schema import ToolCallRequest, ToolContext, ToolReference
+from arcade.sdk import tool
+from arcade.worker.fastapi.worker import FastAPIWorker
+
+
+@tool()
+def sample_tool_fastapi(
+ context: ToolContext, x: Annotated[int, "x"], y: Annotated[str, "y"]
+) -> Annotated[str, "output"]:
+ """A sample tool for FastAPI tests."""
+ return f"{y}-{x}"
+
+
+# Define tool at module level to avoid indentation issues with getsource
+@tool()
+def error_throwing_tool(
+ context: ToolContext,
+ a: Annotated[int, "a", "Input integer a"], # Added description for parameter
+) -> int:
+ """This tool throws a ValueError.""" # Added description for tool
+ raise ValueError("Test execution error")
+
+
+@pytest.fixture
+def test_app():
+ return FastAPI()
+
+
+@pytest.fixture
+def worker_secret():
+ return "test-secret-fastapi"
+
+
+@pytest.fixture
+def fastapi_worker(test_app, worker_secret):
+ worker = FastAPIWorker(app=test_app, secret=worker_secret)
+ worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit")
+ return worker
+
+
+@pytest.fixture
+def fastapi_worker_no_auth(test_app):
+ worker = FastAPIWorker(app=test_app, disable_auth=True)
+ worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit")
+ return worker
+
+
+@pytest.fixture
+def client(test_app, fastapi_worker): # Use the worker fixture to ensure routes are registered
+ return TestClient(test_app)
+
+
+@pytest.fixture
+def client_no_auth(test_app, fastapi_worker_no_auth):
+ return TestClient(test_app)
+
+
+# --- FastAPIWorker Tests ---
+
+
+def test_fastapi_worker_registers_routes(client, fastapi_worker):
+ # Check if routes exist by trying to access them (even if auth fails)
+ response = client.get(f"{fastapi_worker.base_path}/health")
+ assert response.status_code != 404 # Should be 200
+
+ response = client.get(f"{fastapi_worker.base_path}/tools")
+ assert response.status_code != 404 # Should be 403 without auth
+
+ # Prepare a dummy request body for invoke
+ tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi")
+ request_body = ToolCallRequest(
+ execution_id="test", tool=tool_ref, inputs={"x": 1, "y": "test"}
+ ).model_dump()
+
+ response = client.post(f"{fastapi_worker.base_path}/tools/invoke", json=request_body)
+ assert response.status_code != 404 # Should be 403 without auth
+
+
+# --- Route Tests (using TestClient) ---
+
+
+# Health Check
+def test_health_check_route(client, worker_secret):
+ response = client.get("/worker/health")
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok", "tool_count": "1"}
+
+
+def test_health_check_route_no_auth(client_no_auth):
+ response = client_no_auth.get("/worker/health")
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok", "tool_count": "1"}
+
+
+# Catalog
+def test_get_catalog_route_no_auth_header(client):
+ response = client.get("/worker/tools")
+ assert response.status_code == 403
+ assert "Not authenticated" in response.text
+
+
+def test_get_catalog_route_invalid_auth_header(client, worker_secret):
+ response = client.get("/worker/tools", headers={"Authorization": "Bearer invalid-token"})
+ assert response.status_code == 401 # Unauthorized
+ # Updated expected error message based on last run
+ assert "Invalid token. Error: Not enough segments" in response.text
+
+
+def test_get_catalog_route_no_auth_worker(client_no_auth):
+ response = client_no_auth.get("/worker/tools")
+ assert response.status_code == 200
+ catalog = response.json()
+ assert isinstance(catalog, list)
+ assert len(catalog) == 1
+ assert catalog[0]["name"] == "SampleToolFastapi"
+
+
+# Call Tool
+@pytest.fixture
+def call_tool_payload():
+ tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi")
+ return ToolCallRequest(
+ execution_id="fastapi-test-exec", tool=tool_ref, inputs={"x": 123, "y": "hello"}
+ ).model_dump()
+
+
+def test_call_tool_route_no_auth_header(client, call_tool_payload):
+ response = client.post("/worker/tools/invoke", json=call_tool_payload)
+ assert response.status_code == 403
+
+
+def test_call_tool_route_invalid_auth_header(client, worker_secret, call_tool_payload):
+ response = client.post(
+ "/worker/tools/invoke",
+ json=call_tool_payload,
+ headers={"Authorization": "Bearer invalid-token"},
+ )
+ assert response.status_code == 401
+
+
+def test_call_tool_route_no_auth_worker(client_no_auth, call_tool_payload):
+ response = client_no_auth.post("/worker/tools/invoke", json=call_tool_payload)
+ assert response.status_code == 200
+ result = response.json()
+ assert result["success"] is True
+ assert result["output"]["value"] == "hello-123"
+
+
+def test_call_tool_route_tool_not_found(client_no_auth, call_tool_payload):
+ call_tool_payload["tool"]["name"] = "NonExistentTool"
+ call_tool_payload["tool"]["toolkit"] = "FastapiKit"
+
+ with pytest.raises(ValueError):
+ _ = client_no_auth.post(
+ "/worker/tools/invoke",
+ json=call_tool_payload,
+ )
+ # The handler catches the ValueError and returns a 500 internal server error
+ # Ideally, this might be a 404 or 400, but BaseWorker.call_tool raises ValueError
+ # which isn't automatically mapped to a 4xx by FastAPI unless handled explicitly.
+ # TODO fix this.
diff --git a/examples/mcp/claude.json b/examples/mcp/claude.json
new file mode 100644
index 00000000..d3f9832f
--- /dev/null
+++ b/examples/mcp/claude.json
@@ -0,0 +1,8 @@
+{
+ "mcpServers": {
+ "arcade": {
+ "command": "bash",
+ "args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade mcp"]
+ }
+ }
+}
diff --git a/examples/mcp/run_stdio.py b/examples/mcp/run_stdio.py
new file mode 100644
index 00000000..3011640e
--- /dev/null
+++ b/examples/mcp/run_stdio.py
@@ -0,0 +1,25 @@
+import arcade_google # pip install arcade_google
+import arcade_search # pip install arcade_search
+
+from arcade.core.catalog import ToolCatalog
+from arcade.worker.mcp.stdio import StdioServer
+
+# 2. Create and populate the tool catalog
+catalog = ToolCatalog()
+catalog.add_module(arcade_google) # Registers all tools in the package
+catalog.add_module(arcade_search)
+
+
+# 3. Main entrypoint
+async def main():
+ # Create the worker with the tool catalog
+ worker = StdioServer(catalog)
+
+ # Run the worker
+ await worker.run()
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ asyncio.run(main())
diff --git a/examples/simple_chatbot.py b/examples/simple_chatbot.py
deleted file mode 100644
index 859d914e..00000000
--- a/examples/simple_chatbot.py
+++ /dev/null
@@ -1,76 +0,0 @@
-"""
-Example script demonstrating how to build a simple chatbot with Arcade.
-
-For this example, we are using the prebuilt Google Docs toolkit to create and edit documents.
-
-Try asking questions like:
-- "Create a document with the title 'My New Document' and content 'Hello, World!'"
-- "List my 2 most recently modified documents and tell me the title, document id, and document URL of each one and summarize them."
-- "Edit the second document from the list you just returned and add the text 'Hello, World!' to the end of it."
-"""
-
-import os
-
-from openai import OpenAI
-
-
-def chat(openai_client: OpenAI, tool_names: list[str], user_id: str) -> None:
- history = []
-
- print("Hello! How can I help you today?")
- while True:
- message = {"role": "user", "content": input(">")}
- history.append(message)
- chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history)
-
- # If the tool call requires authorization, then wait for the user to authorize and then call the tool again
- if (
- chat_result.choices[0].tool_authorizations
- and chat_result.choices[0].tool_authorizations[0].get("status") == "pending"
- ):
- print("\n" + chat_result.choices[0].message.content)
- input("\nAfter you have authorized, press Enter to continue...")
- chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history)
-
- history.append({"role": "assistant", "content": chat_result.choices[0].message.content})
-
- print(chat_result.choices[0].message.content)
-
-
-def call_tool_with_openai(
- client: OpenAI, tool_names: list[str], user_id: str, messages: list[dict]
-) -> dict:
- response = client.chat.completions.create(
- messages=messages,
- model="gpt-4o-mini",
- user=user_id,
- tools=tool_names,
- tool_choice="generate",
- )
-
- return response
-
-
-if __name__ == "__main__":
- arcade_api_key = os.environ.get(
- "ARCADE_API_KEY"
- ) # If you forget your Arcade API key, it is stored at ~/.arcade/credentials.yaml on `arcade login`
- cloud_host = "https://api.arcade.dev/v1"
- user_id = "user@example.com"
-
- openai_client = OpenAI(
- api_key=arcade_api_key,
- base_url=cloud_host,
- )
-
- tool_names = [
- "Google.SendEmail",
- "Google.SendDraftEmail",
- "Google.WriteDraftEmail",
- "Google.UpdateDraftEmail",
- "Google.ListDraftEmails",
- "Google.ListEmailsByHeader",
- "Google.ListEmails",
- ]
-
- chat(openai_client, tool_names, user_id)