Support for MCP stdio transport (#368)
MCP stdio Implementation: The PR adds support for standard input/output (stdio) as a transport mechanism for the Message Control Protocol. This is a replacement to the SSE (Server-Sent Events) transport that was worked on in PR #359 but will not be merged as it's not deprecated. This will allow developers to use Arcade tools (written by the dev or Arcade) in Claude, Cursor, windsurf, etc. The engine Gateway already supports adding HTTPS streamable (replacement for SSE) MCP servers as tool servers, and will soon support full gateway capability in the client API as well. To use any existing Toolkit just ## Examples ### Quickstart setup with existing toolkits ```bash pip install arcade-ai pip install <name of toolkit> # ex. arcade-google arcade serve --mcp ``` ### Run with Claude Just add the following to the Claude config ```json { "mcpServers": { "arcade": { "command": "bash", "args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade serve --mcp"] } } } ``` ### Customizing the Tool Server Developers can customize their served tools and server furthermore by importing the worker sdk ```python import arcade_google # pip install arcade_google import arcade_search # pip install arcade_search from arcade.core.catalog import ToolCatalog from arcade.worker.mcp.stdio import StdioServer # 2. Create and populate the tool catalog catalog = ToolCatalog() catalog.add_module(arcade_google) # Registers all tools in the package catalog.add_module(arcade_search) # 3. Main entrypoint async def main(): # Create the worker with the tool catalog worker = StdioServer(catalog) # Run the worker await worker.run() if __name__ == "__main__": import asyncio asyncio.run(main()) ``` Then to run with claude, just run this python file instead of the prebuilt server used in ``arcade serve --mcp``
This commit is contained in:
parent
f5774fd4ae
commit
9bc1cd4a12
29 changed files with 2829 additions and 187 deletions
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
|
@ -10,7 +10,6 @@
|
|||
"[python]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.fixAll": "explicit",
|
||||
"source.organizeImports": "explicit"
|
||||
},
|
||||
"editor.defaultFormatter": "charliermarsh.ruff"
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from arcade.cli.constants import (
|
|||
PROD_CLOUD_HOST,
|
||||
PROD_ENGINE_HOST,
|
||||
)
|
||||
from arcade.cli.deployment import Deployment
|
||||
from arcade.cli.display import (
|
||||
display_arcade_chat_header,
|
||||
display_eval_results,
|
||||
|
|
@ -48,7 +49,6 @@ from arcade.cli.utils import (
|
|||
version_callback,
|
||||
)
|
||||
from arcade.cli.worker import parse_deployment_response
|
||||
from arcade.worker.config.deployment import Deployment
|
||||
|
||||
cli = typer.Typer(
|
||||
cls=OrderCommands,
|
||||
|
|
@ -61,7 +61,12 @@ cli = typer.Typer(
|
|||
)
|
||||
|
||||
|
||||
cli.add_typer(worker.app, name="worker", help="Manage workers")
|
||||
cli.add_typer(
|
||||
worker.app,
|
||||
name="worker",
|
||||
help="Manage deployments of tool servers (logs, list, etc)",
|
||||
rich_help_panel="Deployment",
|
||||
)
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -192,7 +197,10 @@ def show(
|
|||
show_logic(toolkit, tool, host, local, port, force_tls, force_no_tls, debug)
|
||||
|
||||
|
||||
@cli.command(help="Start Arcade Chat in the terminal", rich_help_panel="Launch")
|
||||
@cli.command(
|
||||
help="Start a chat with a model in the terminal to test tools",
|
||||
rich_help_panel="Tool Development",
|
||||
)
|
||||
def chat(
|
||||
model: str = typer.Option("gpt-4o", "-m", "--model", help="The model to use for prediction."),
|
||||
stream: bool = typer.Option(
|
||||
|
|
@ -439,7 +447,7 @@ def evals(
|
|||
|
||||
|
||||
@cli.command(
|
||||
help="Start Arcade Worker serving tools installed in the current python environment",
|
||||
help="Start tool server worker with locally installed tools",
|
||||
rich_help_panel="Launch",
|
||||
)
|
||||
def serve(
|
||||
|
|
@ -460,19 +468,38 @@ def serve(
|
|||
otel_enable: bool = typer.Option(
|
||||
False, "--otel-enable", help="Send logs to OpenTelemetry", show_default=True
|
||||
),
|
||||
mcp: bool = typer.Option(
|
||||
False, "--mcp", help="Run as a local MCP server over stdio", show_default=True
|
||||
),
|
||||
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
|
||||
) -> None:
|
||||
"""
|
||||
Start a local Arcade Worker server.
|
||||
"""
|
||||
workerup(host, port, disable_auth, otel_enable, debug)
|
||||
from arcade.cli.serve import serve_default_worker
|
||||
|
||||
try:
|
||||
serve_default_worker(
|
||||
host,
|
||||
port,
|
||||
disable_auth=disable_auth,
|
||||
enable_otel=otel_enable,
|
||||
debug=debug,
|
||||
mcp=mcp,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
typer.Exit()
|
||||
except Exception as e:
|
||||
error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}"
|
||||
console.print(error_message, style="bold red")
|
||||
typer.Exit(code=1)
|
||||
|
||||
|
||||
@cli.command(help="Launch Arcade Worker and Engine locally", rich_help_panel="Launch")
|
||||
@cli.command(help="Launch Arcade - requires 'arcade-engine'", rich_help_panel="Launch")
|
||||
def dev(
|
||||
host: str = typer.Option("127.0.0.1", help="Host for the worker server.", show_default=True),
|
||||
host: str = typer.Option("127.0.0.1", help="Host for the toolkit server.", show_default=True),
|
||||
port: int = typer.Option(
|
||||
8002, "-p", "--port", help="Port for the worker server.", show_default=True
|
||||
8002, "-p", "--port", help="Port for the toolkit server.", show_default=True
|
||||
),
|
||||
engine_config: str = typer.Option(
|
||||
None, "-c", "--config", help="Path to the engine configuration file."
|
||||
|
|
@ -483,7 +510,7 @@ def dev(
|
|||
debug: bool = typer.Option(False, "-d", "--debug", help="Show debug information"),
|
||||
) -> None:
|
||||
"""
|
||||
Start both the worker and engine servers.
|
||||
Start both the toolkit server and engine servers.
|
||||
"""
|
||||
try:
|
||||
start_servers(host, port, engine_config, engine_env=env_file, debug=debug)
|
||||
|
|
@ -493,8 +520,9 @@ def dev(
|
|||
typer.Exit(code=1)
|
||||
|
||||
|
||||
# TODO: deprecate this next major version
|
||||
@cli.command(help="Start a local Arcade Worker server", rich_help_panel="Launch", hidden=True)
|
||||
@cli.command(
|
||||
help="Start a server with locally installed Arcade tools", rich_help_panel="Launch", hidden=True
|
||||
)
|
||||
def workerup(
|
||||
host: str = typer.Option(
|
||||
"127.0.0.1",
|
||||
|
|
@ -523,17 +551,21 @@ def workerup(
|
|||
|
||||
try:
|
||||
serve_default_worker(
|
||||
host, port, disable_auth=disable_auth, enable_otel=otel_enable, debug=debug
|
||||
host,
|
||||
port,
|
||||
disable_auth=disable_auth,
|
||||
enable_otel=otel_enable,
|
||||
debug=debug,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
typer.Exit()
|
||||
except Exception as e:
|
||||
error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}"
|
||||
error_message = f"❌ Failed to start Arcade Toolkit Server: {escape(str(e))}"
|
||||
console.print(error_message, style="bold red")
|
||||
typer.Exit(code=1)
|
||||
|
||||
|
||||
@cli.command(help="Deploy worker to Arcade Cloud", rich_help_panel="Deployment")
|
||||
@cli.command(help="Deploy toolkits to Arcade Cloud", rich_help_panel="Deployment")
|
||||
def deploy(
|
||||
deployment_file: str = typer.Option(
|
||||
"worker.toml", "--deployment-file", "-d", help="The deployment file to deploy."
|
||||
|
|
|
|||
|
|
@ -1,70 +1,209 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as get_pkg_version
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
|
||||
from arcade.cli.constants import ARCADE_CONFIG_PATH
|
||||
from arcade.cli.utils import (
|
||||
build_tool_catalog,
|
||||
discover_toolkits,
|
||||
load_dotenv,
|
||||
validate_and_get_config,
|
||||
)
|
||||
from arcade.core.telemetry import OTELHandler
|
||||
from arcade.sdk import Toolkit
|
||||
from arcade.worker.fastapi.worker import FastAPIWorker
|
||||
|
||||
console = Console(width=70, color_system="auto")
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
|
||||
def _run_mcp_stdio(
|
||||
toolkits: list[Toolkit], *, logging_enabled: bool, env_file: str | None = None
|
||||
) -> None:
|
||||
"""Launch an MCP stdio server; blocks until it exits."""
|
||||
|
||||
from arcade.worker.mcp.stdio import StdioServer
|
||||
|
||||
# Load env vars before launching server (explicit path, config path, cwd)
|
||||
if env_file:
|
||||
load_dotenv(env_file, override=False)
|
||||
else:
|
||||
for candidate in [Path(ARCADE_CONFIG_PATH) / "arcade.env", Path.cwd() / "arcade.env"]:
|
||||
if candidate.is_file():
|
||||
load_dotenv(candidate, override=False)
|
||||
break
|
||||
|
||||
# Set up middleware configuration for stdio mode
|
||||
middleware_config = {
|
||||
"stdio_mode": True, # Ensure logs go to stderr
|
||||
}
|
||||
|
||||
catalog = build_tool_catalog(toolkits)
|
||||
server = StdioServer(
|
||||
catalog,
|
||||
enable_logging=logging_enabled,
|
||||
middleware_config=middleware_config,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(server.run())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("MCP server stopped by user.")
|
||||
except Exception as exc:
|
||||
logger.exception("Error while running MCP server: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def _run_fastapi_server(
|
||||
app: fastapi.FastAPI,
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
workers: int,
|
||||
timeout_keep_alive: int,
|
||||
enable_otel: bool,
|
||||
otel_handler: OTELHandler,
|
||||
**uvicorn_kwargs: Any,
|
||||
) -> None:
|
||||
"""Run a FastAPI application via Uvicorn with graceful shutdown."""
|
||||
|
||||
class CustomUvicornServer(uvicorn.Server):
|
||||
def install_signal_handlers(self) -> None:
|
||||
# Disable Uvicorn's default signal handling; we manage it manually
|
||||
pass
|
||||
|
||||
async def shutdown(self, sockets: Any = None) -> None:
|
||||
logger.info("Initiating graceful shutdown...")
|
||||
await super().shutdown(sockets=sockets)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
log_config=None,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
server = CustomUvicornServer(config=config)
|
||||
|
||||
async def _serve() -> None:
|
||||
await server.serve()
|
||||
|
||||
async def _graceful_shutdown() -> None:
|
||||
try:
|
||||
logger.info("Shutting down server ...")
|
||||
await server.shutdown()
|
||||
|
||||
# brief pause for connections to close gracefully
|
||||
await asyncio.sleep(0.5)
|
||||
finally:
|
||||
if enable_otel:
|
||||
otel_handler.shutdown()
|
||||
logger.debug("Server shutdown complete.")
|
||||
|
||||
# Map signals to our graceful shutdown
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig_name in (
|
||||
"SIGINT",
|
||||
"SIGTERM",
|
||||
"SIGHUP",
|
||||
"SIGUSR1",
|
||||
"SIGUSR2",
|
||||
"SIGWINCH",
|
||||
"SIGBREAK",
|
||||
):
|
||||
if hasattr(signal, sig_name):
|
||||
loop.add_signal_handler(
|
||||
getattr(signal, sig_name), lambda: asyncio.create_task(_graceful_shutdown())
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(_serve())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user.")
|
||||
finally:
|
||||
if enable_otel:
|
||||
otel_handler.shutdown()
|
||||
|
||||
|
||||
class RichInterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno # type: ignore[assignment]
|
||||
level = str(record.levelno)
|
||||
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back # type: ignore[assignment]
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
# Let Loguru handle caller info; don't do stack inspection here
|
||||
logger.opt(exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def setup_logging(log_level: int = logging.INFO) -> None:
|
||||
def setup_logging(log_level: int = logging.INFO, mcp_mode: bool = False) -> None:
|
||||
# Intercept everything at the root logger
|
||||
logging.root.handlers = [InterceptHandler()]
|
||||
logging.root.handlers = [RichInterceptHandler()]
|
||||
logging.root.setLevel(log_level)
|
||||
|
||||
# Remove every other logger's handlers
|
||||
# and propagate to root logger
|
||||
# Remove every other logger's handlers and propagate to root logger
|
||||
for name in logging.root.manager.loggerDict:
|
||||
# Keep handlers for MCP logger if middleware handles it separately
|
||||
if mcp_mode and name == "arcade.mcp":
|
||||
continue
|
||||
logging.getLogger(name).handlers = []
|
||||
logging.getLogger(name).propagate = True
|
||||
|
||||
# Configure loguru with custom format, no colors
|
||||
# Remove default handlers from Loguru
|
||||
logger.remove()
|
||||
|
||||
# Configure main Loguru sink
|
||||
# In MCP mode, all general console logs go to stderr to keep stdout clean
|
||||
sink_destination = sys.stderr if mcp_mode else sys.stdout
|
||||
|
||||
# Configure loguru with a cleaner format and colors
|
||||
if log_level == logging.DEBUG:
|
||||
format_string = "<level>{level}</level> | <green>{time:HH:mm:ss}</green> | <cyan>{name}:{file}:{line: <4}</cyan> | <level>{message}</level>"
|
||||
else:
|
||||
format_string = (
|
||||
"<level>{level}</level> | <green>{time:HH:mm:ss}</green> | <level>{message}</level>"
|
||||
)
|
||||
logger.configure(
|
||||
handlers=[
|
||||
{
|
||||
"sink": sys.stdout,
|
||||
"serialize": False,
|
||||
"sink": sink_destination, # Redirect sink based on mcp_mode
|
||||
"colorize": True,
|
||||
"level": log_level,
|
||||
"format": "{level} [{time:HH:mm:ss.SSS}] {message}"
|
||||
+ (" {name}:{function}:{line}" if log_level <= logging.DEBUG else "")
|
||||
+ ("\n{exception}" if "{exception}" in "{message}" else ""),
|
||||
# Format that ensures timestamp on every line and better alignment
|
||||
"format": format_string,
|
||||
# Make sure multiline messages are handled properly
|
||||
"enqueue": True,
|
||||
"diagnose": True, # Disable traceback framing which adds noise
|
||||
}
|
||||
]
|
||||
)
|
||||
if mcp_mode:
|
||||
logger.debug("Loguru sink configured for stderr in MCP mode.")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def]
|
||||
async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
|
||||
try:
|
||||
yield
|
||||
except asyncio.CancelledError:
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
# This is necessary to prevent an unhandled error
|
||||
# when the user presses Ctrl+C
|
||||
logger.debug("Lifespan cancelled.")
|
||||
raise
|
||||
|
||||
|
||||
def serve_default_worker(
|
||||
|
|
@ -75,76 +214,78 @@ def serve_default_worker(
|
|||
timeout_keep_alive: int = 5,
|
||||
enable_otel: bool = False,
|
||||
debug: bool = False,
|
||||
mcp: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Get an instance of a FastAPI server with the Arcade Worker.
|
||||
Get a default instance of a FastAPI server with the Arcade Worker
|
||||
serving tools installed in the current Python environment.
|
||||
|
||||
Args:
|
||||
host: The host to run the server on.
|
||||
port: The port to run the server on.
|
||||
disable_auth: Whether to disable authentication.
|
||||
workers: The number of workers to run.
|
||||
timeout_keep_alive: The timeout for keep-alive connections.
|
||||
enable_otel: Whether to enable OpenTelemetry.
|
||||
debug: Whether to enable debug logging.
|
||||
mcp: Whether to run worker as MCP server over stdio.
|
||||
"""
|
||||
# Setup unified logging
|
||||
setup_logging(log_level=logging.DEBUG if debug else logging.INFO)
|
||||
|
||||
toolkits = Toolkit.find_all_arcade_toolkits()
|
||||
if not toolkits:
|
||||
raise RuntimeError("No toolkits found in Python environment.")
|
||||
# Setup unified logging first
|
||||
version = get_pkg_version("arcade-ai")
|
||||
validate_and_get_config()
|
||||
setup_logging(log_level=logging.DEBUG if debug else logging.INFO, mcp_mode=mcp)
|
||||
|
||||
worker_secret = os.environ.get("ARCADE_WORKER_SECRET")
|
||||
if not disable_auth and not worker_secret:
|
||||
logger.warning(
|
||||
"Warning: ARCADE_WORKER_SECRET environment variable is not set. Using 'dev' as the worker secret.",
|
||||
)
|
||||
worker_secret = worker_secret or "dev"
|
||||
toolkits = discover_toolkits()
|
||||
logger.info("Serving the following toolkits:")
|
||||
for toolkit in toolkits:
|
||||
if debug:
|
||||
for name, tools in toolkit.tools.items():
|
||||
for tool in tools:
|
||||
logger.info(f" - {name}: {tool}")
|
||||
else:
|
||||
logger.info(f" - {toolkit.name}: {len(toolkit.tools)} tools")
|
||||
|
||||
# --- MCP stdio --------------------------------------------------
|
||||
if mcp:
|
||||
env_file = kwargs.pop("env_file", None)
|
||||
_run_mcp_stdio(toolkits, logging_enabled=not debug, env_file=env_file)
|
||||
return
|
||||
|
||||
# --- FastAPI HTTP --------------------------------------------------
|
||||
app = fastapi.FastAPI(
|
||||
title="Arcade Worker",
|
||||
description="Arcade default Worker implementation using FastAPI.",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan, # Use custom lifespan to catch errors, notably KeyboardInterrupt (Ctrl+C)
|
||||
description="A worker for the Arcade platform",
|
||||
version=version,
|
||||
docs_url="/docs" if debug else None,
|
||||
redoc_url="/redoc" if debug else None,
|
||||
openapi_url="/openapi.json" if debug else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
otel_handler = OTELHandler(app, enable=enable_otel)
|
||||
secret = os.getenv("ARCADE_WORKER_SECRET", None)
|
||||
if secret is None:
|
||||
logger.warning("No secret found for Arcade Worker")
|
||||
logger.info("Setting secret to 'dev'. Set this in production")
|
||||
secret = "dev" # noqa: S105
|
||||
|
||||
worker = FastAPIWorker(
|
||||
app, secret=worker_secret, disable_auth=disable_auth, otel_meter=otel_handler.get_meter()
|
||||
otel_handler = OTELHandler(
|
||||
app, enable=enable_otel, log_level=logging.DEBUG if debug else logging.INFO
|
||||
)
|
||||
|
||||
toolkit_tool_counts = {}
|
||||
for toolkit in toolkits:
|
||||
prev_tool_count = worker.catalog.get_tool_count()
|
||||
worker.register_toolkit(toolkit)
|
||||
new_tool_count = worker.catalog.get_tool_count()
|
||||
toolkit_tool_counts[f"{toolkit.name} ({toolkit.package_name})"] = (
|
||||
new_tool_count - prev_tool_count
|
||||
)
|
||||
|
||||
logger.info("Serving the following toolkits:")
|
||||
for name, tool_count in toolkit_tool_counts.items():
|
||||
logger.info(f" - {name}: {tool_count} tools")
|
||||
|
||||
logger.info("Starting FastAPI server...")
|
||||
|
||||
class CustomUvicornServer(uvicorn.Server):
|
||||
def install_signal_handlers(self) -> None:
|
||||
pass # Disable Uvicorn's default signal handlers
|
||||
|
||||
config = uvicorn.Config(
|
||||
_ = FastAPIWorker(
|
||||
app=app,
|
||||
secret=secret,
|
||||
disable_auth=disable_auth,
|
||||
otel_meter=otel_handler.get_meter(),
|
||||
)
|
||||
_run_fastapi_server(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
log_config=None,
|
||||
enable_otel=enable_otel,
|
||||
otel_handler=otel_handler,
|
||||
**kwargs,
|
||||
)
|
||||
server = CustomUvicornServer(config=config)
|
||||
|
||||
async def serve() -> None:
|
||||
await server.serve()
|
||||
|
||||
try:
|
||||
asyncio.run(serve())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user.")
|
||||
finally:
|
||||
if enable_otel:
|
||||
otel_handler.shutdown()
|
||||
logger.debug("Server shutdown complete.")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import importlib.util
|
||||
import ipaddress
|
||||
import os
|
||||
import shlex
|
||||
import webbrowser
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
|
@ -34,6 +36,11 @@ from arcade.sdk import ToolCatalog, Toolkit
|
|||
console = Console()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Shared helpers for the CLI
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OrderCommands(TyperGroup):
|
||||
def list_commands(self, ctx: Context) -> list[str]: # type: ignore[override]
|
||||
"""Return list of commands in the order appear."""
|
||||
|
|
@ -666,3 +673,81 @@ def get_today_context() -> str:
|
|||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
day_of_week = datetime.now().strftime("%A")
|
||||
return f"Today is {today}, {day_of_week}."
|
||||
|
||||
|
||||
def discover_toolkits() -> list[Toolkit]:
|
||||
"""Return all Arcade toolkits installed in the active Python environment.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no toolkits are found, mirroring the behaviour of Toolkit discovery elsewhere.
|
||||
"""
|
||||
toolkits = Toolkit.find_all_arcade_toolkits()
|
||||
if not toolkits:
|
||||
raise RuntimeError("No toolkits found in Python environment.")
|
||||
return toolkits
|
||||
|
||||
|
||||
def build_tool_catalog(toolkits: list[Toolkit]) -> ToolCatalog:
|
||||
"""Construct a ``ToolCatalog`` populated with *toolkits*.
|
||||
|
||||
|
||||
Args:
|
||||
toolkits: Toolkits to register in the catalog.
|
||||
|
||||
Returns:
|
||||
ToolCatalog
|
||||
"""
|
||||
catalog = ToolCatalog()
|
||||
for tk in toolkits:
|
||||
catalog.add_toolkit(tk)
|
||||
return catalog
|
||||
|
||||
|
||||
def _parse_line(line: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
Return (key, value) if the line looks like KEY=VALUE, else None.
|
||||
Handles quotes and escaped chars via shlex.
|
||||
"""
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
return None
|
||||
key, raw_val = line.split("=", 1)
|
||||
key = key.strip()
|
||||
raw_val = raw_val.strip()
|
||||
|
||||
# Use shlex to handle "quoted strings with # hash" etc.
|
||||
try:
|
||||
value = shlex.split(raw_val)[0] if raw_val else ""
|
||||
except ValueError:
|
||||
# Fallback: naked value without shlex parsing
|
||||
value = raw_val
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
def load_dotenv(path: str | Path, *, override: bool = False) -> dict[str, str]:
|
||||
"""
|
||||
Load variables from *path* into os.environ.
|
||||
|
||||
Args:
|
||||
path: .env file path
|
||||
override: replace existing env vars if True
|
||||
|
||||
Returns:
|
||||
The mapping of vars that were added/updated.
|
||||
"""
|
||||
path = Path(path).expanduser()
|
||||
if not path.is_file():
|
||||
return {}
|
||||
|
||||
loaded: dict[str, str] = {}
|
||||
|
||||
for raw in path.read_text().splitlines():
|
||||
parsed = _parse_line(raw.strip())
|
||||
if parsed is None:
|
||||
continue
|
||||
k, v = parsed
|
||||
if override or k not in os.environ:
|
||||
os.environ[k] = v
|
||||
loaded[k] = v
|
||||
|
||||
return loaded
|
||||
|
|
|
|||
|
|
@ -326,8 +326,16 @@ class ToolContext(BaseModel):
|
|||
for item in items:
|
||||
if item.key.lower() == normalized_key:
|
||||
return item.value
|
||||
|
||||
raise ValueError(f"{item_name.capitalize()} {key} not found in context.")
|
||||
|
||||
def set_secret(self, key: str, value: str) -> None:
|
||||
"""Set a secret for the tool invocation."""
|
||||
if not self.secrets:
|
||||
self.secrets = []
|
||||
secret = ToolSecretItem(key=str(key), value=str(value))
|
||||
self.secrets.append(secret)
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""The request to call (invoke) a tool."""
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from types import UnionType
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin
|
||||
from typing import Any, Callable, Literal, TypeVar, Union, get_args, get_origin
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def first_or_none(_type: type[T], iterable: Iterable[Any]) -> Optional[T]:
|
||||
def first_or_none(_type: type[T], iterable: Iterable[Any]) -> T | None:
|
||||
"""
|
||||
Returns the first item in the iterable that is an instance of the given type, or None if no such item is found.
|
||||
"""
|
||||
|
|
@ -65,7 +67,7 @@ def does_function_return_value(func: Callable) -> bool:
|
|||
Returns True if the given function returns a value, i.e. if it has a return statement with a value.
|
||||
"""
|
||||
try:
|
||||
source: Optional[str] = inspect.getsource(func)
|
||||
source: str | None = inspect.getsource(func)
|
||||
except OSError:
|
||||
# Workaround for parameterized unit tests that use a dynamically-generated function
|
||||
source = getattr(func, "__source__", None)
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class BaseWorker(Worker):
|
|||
"""
|
||||
Provide a health check that serves as a heartbeat of worker health.
|
||||
"""
|
||||
return {"status": "ok", "tool_count": len(self.catalog)}
|
||||
return {"status": "ok", "tool_count": str(len(self.catalog))}
|
||||
|
||||
def register_routes(self, router: Router) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,11 @@ from pydantic import BaseModel
|
|||
|
||||
from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
|
||||
|
||||
CatalogResponse = list[ToolDefinition]
|
||||
HealthCheckResponse = dict[str, str]
|
||||
JSONResponse = dict[str, Any]
|
||||
ResponseData = CatalogResponse | ToolCallResponse | HealthCheckResponse
|
||||
|
||||
|
||||
class RequestData(BaseModel):
|
||||
"""
|
||||
|
|
@ -17,7 +22,7 @@ class RequestData(BaseModel):
|
|||
"""The path of the request."""
|
||||
method: str
|
||||
"""The method of the request."""
|
||||
body_json: dict | None = None
|
||||
body_json: JSONResponse | None = None
|
||||
"""The deserialized body of the request (e.g. JSON)"""
|
||||
|
||||
|
||||
|
|
@ -28,7 +33,13 @@ class Router(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def add_route(
|
||||
self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True
|
||||
self,
|
||||
endpoint_path: str,
|
||||
handler: Callable,
|
||||
method: str,
|
||||
require_auth: bool = True,
|
||||
response_type: type[ResponseData] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Add a route to the router.
|
||||
|
|
@ -43,7 +54,7 @@ class Worker(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_catalog(self) -> list[ToolDefinition]:
|
||||
def get_catalog(self) -> CatalogResponse:
|
||||
"""
|
||||
Get the catalog of tools available in the worker.
|
||||
"""
|
||||
|
|
@ -57,7 +68,7 @@ class Worker(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def health_check(self) -> dict[str, Any]:
|
||||
def health_check(self) -> HealthCheckResponse:
|
||||
"""
|
||||
Perform a health check of the worker
|
||||
"""
|
||||
|
|
@ -76,7 +87,7 @@ class WorkerComponent(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, request: RequestData) -> Any:
|
||||
async def __call__(self, request: RequestData) -> ResponseData:
|
||||
"""
|
||||
Handle the request.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
|
||||
from arcade.worker.core.common import RequestData, Router, Worker, WorkerComponent
|
||||
from arcade.worker.core.common import (
|
||||
CatalogResponse,
|
||||
HealthCheckResponse,
|
||||
RequestData,
|
||||
Router,
|
||||
ToolCallRequest,
|
||||
ToolCallResponse,
|
||||
Worker,
|
||||
WorkerComponent,
|
||||
)
|
||||
|
||||
|
||||
class CatalogComponent(WorkerComponent):
|
||||
|
|
@ -14,9 +20,18 @@ class CatalogComponent(WorkerComponent):
|
|||
"""
|
||||
Register the catalog route with the router.
|
||||
"""
|
||||
router.add_route("tools", self, method="GET")
|
||||
router.add_route(
|
||||
"tools",
|
||||
self,
|
||||
method="GET",
|
||||
response_type=CatalogResponse,
|
||||
operation_id="get_catalog",
|
||||
description="Get the catalog of tools",
|
||||
summary="Get the catalog of tools",
|
||||
tags=["Arcade"],
|
||||
)
|
||||
|
||||
async def __call__(self, request: RequestData) -> list[ToolDefinition]:
|
||||
async def __call__(self, request: RequestData) -> CatalogResponse:
|
||||
"""
|
||||
Handle the request to get the catalog.
|
||||
"""
|
||||
|
|
@ -33,7 +48,16 @@ class CallToolComponent(WorkerComponent):
|
|||
"""
|
||||
Register the call tool route with the router.
|
||||
"""
|
||||
router.add_route("tools/invoke", self, method="POST")
|
||||
router.add_route(
|
||||
"tools/invoke",
|
||||
self,
|
||||
method="POST",
|
||||
response_type=ToolCallResponse,
|
||||
operation_id="call_tool",
|
||||
description="Call a tool",
|
||||
summary="Call a tool",
|
||||
tags=["Arcade"],
|
||||
)
|
||||
|
||||
async def __call__(self, request: RequestData) -> ToolCallResponse:
|
||||
"""
|
||||
|
|
@ -54,9 +78,19 @@ class HealthCheckComponent(WorkerComponent):
|
|||
"""
|
||||
Register the health check route with the router.
|
||||
"""
|
||||
router.add_route("health", self, method="GET", require_auth=False)
|
||||
router.add_route(
|
||||
"health",
|
||||
self,
|
||||
method="GET",
|
||||
require_auth=False,
|
||||
response_type=HealthCheckResponse,
|
||||
operation_id="health_check",
|
||||
description="Check the health of the worker",
|
||||
summary="Check the health of the worker",
|
||||
tags=["Arcade"],
|
||||
)
|
||||
|
||||
async def __call__(self, request: RequestData) -> dict[str, Any]:
|
||||
async def __call__(self, request: RequestData) -> HealthCheckResponse:
|
||||
"""
|
||||
Handle the request for a health check.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .worker import FastAPIWorker
|
||||
|
||||
__all__ = ["FastAPIWorker"]
|
||||
|
|
@ -9,7 +9,7 @@ from arcade.worker.core.base import (
|
|||
BaseWorker,
|
||||
Router,
|
||||
)
|
||||
from arcade.worker.core.common import RequestData
|
||||
from arcade.worker.core.common import RequestData, ResponseData, WorkerComponent
|
||||
from arcade.worker.fastapi.auth import validate_engine_request
|
||||
from arcade.worker.utils import is_async_callable
|
||||
|
||||
|
|
@ -30,12 +30,21 @@ class FastAPIWorker(BaseWorker):
|
|||
"""
|
||||
Initialize the FastAPIWorker with a FastAPI app instance.
|
||||
If no secret is provided, the worker will use the ARCADE_WORKER_SECRET environment variable.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app to host the worker in
|
||||
secret: Optional secret for authorization
|
||||
disable_auth: Whether to disable authorization
|
||||
otel_meter: Optional OpenTelemetry meter
|
||||
"""
|
||||
super().__init__(secret, disable_auth, otel_meter)
|
||||
self.app = app
|
||||
self.router = FastAPIRouter(app, self)
|
||||
self.register_routes(self.router)
|
||||
|
||||
# Initialize components
|
||||
self.components: list[WorkerComponent] = []
|
||||
|
||||
|
||||
security = HTTPBearer() # Authorization: Bearer <xxx>
|
||||
|
||||
|
|
@ -81,7 +90,13 @@ class FastAPIRouter(Router):
|
|||
return wrapped_handler
|
||||
|
||||
def add_route(
|
||||
self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True
|
||||
self,
|
||||
endpoint_path: str,
|
||||
handler: Callable,
|
||||
method: str,
|
||||
require_auth: bool = True,
|
||||
response_type: type[ResponseData] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Add a route to the FastAPI application.
|
||||
|
|
@ -90,4 +105,7 @@ class FastAPIRouter(Router):
|
|||
f"{self.worker.base_path}/{endpoint_path}",
|
||||
self._wrap_handler(handler, require_auth),
|
||||
methods=[method],
|
||||
response_model=response_type,
|
||||
# **kwargs to pass to FastAPI
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
7
arcade/arcade/worker/mcp/__init__.py
Normal file
7
arcade/arcade/worker/mcp/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""
|
||||
MCP (Model Context Protocol) support for Arcade workers.
|
||||
"""
|
||||
|
||||
from arcade.worker.mcp.stdio import StdioServer
|
||||
|
||||
__all__ = ["StdioServer"]
|
||||
188
arcade/arcade/worker/mcp/convert.py
Normal file
188
arcade/arcade/worker/mcp/convert.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from arcade.core.catalog import MaterializedTool
|
||||
|
||||
# Type aliases for MCP types
|
||||
MCPTool = dict[str, Any]
|
||||
MCPTextContent = dict[str, Any]
|
||||
MCPImageContent = dict[str, Any]
|
||||
MCPEmbeddedResource = dict[str, Any]
|
||||
MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource
|
||||
|
||||
logger = logging.getLogger("arcade.mcp")
|
||||
|
||||
|
||||
def create_mcp_tool(tool: MaterializedTool) -> dict[str, Any] | None: # noqa: C901
|
||||
"""
|
||||
Create an MCP-compatible tool definition from an Arcade tool.
|
||||
|
||||
Args:
|
||||
tool: An Arcade tool object
|
||||
|
||||
Returns:
|
||||
An MCP tool definition or None if the tool cannot be converted
|
||||
"""
|
||||
try:
|
||||
name = getattr(tool.definition, "fully_qualified_name", None) or getattr(
|
||||
tool.definition, "name", "unknown"
|
||||
)
|
||||
description = getattr(tool.definition, "description", "No description available")
|
||||
|
||||
# Extract parameters from the input model
|
||||
parameters = {}
|
||||
required = []
|
||||
|
||||
if (
|
||||
hasattr(tool, "input_model")
|
||||
and tool.input_model is not None
|
||||
and hasattr(tool.input_model, "model_fields")
|
||||
):
|
||||
for field_name, field in tool.input_model.model_fields.items():
|
||||
# Skip internal tool context parameters
|
||||
if field_name == getattr(
|
||||
tool.definition.input, "tool_context_parameter_name", None
|
||||
):
|
||||
continue
|
||||
|
||||
# Get field type information
|
||||
field_type = getattr(field, "annotation", None)
|
||||
field_type_name = "string" # default
|
||||
|
||||
# Safety check for field_type
|
||||
if field_type is int:
|
||||
field_type_name = "integer"
|
||||
elif field_type is float:
|
||||
field_type_name = "number"
|
||||
elif field_type is bool:
|
||||
field_type_name = "boolean"
|
||||
elif field_type is list or str(field_type).startswith("list["):
|
||||
field_type_name = "array"
|
||||
elif field_type is dict or str(field_type).startswith("dict["):
|
||||
field_type_name = "object"
|
||||
|
||||
# Get description with fallback
|
||||
field_description = getattr(field, "description", None)
|
||||
if not field_description:
|
||||
field_description = f"Parameter: {field_name}"
|
||||
|
||||
# Create parameter definition
|
||||
param_def = {
|
||||
"type": field_type_name,
|
||||
"description": field_description,
|
||||
}
|
||||
|
||||
# Enum support: if the field annotation is an Enum, add allowed values
|
||||
enum_type = None
|
||||
if hasattr(field, "annotation"):
|
||||
ann = field.annotation
|
||||
# Handle typing.Annotated[Enum, ...]
|
||||
if getattr(ann, "__origin__", None) is not None and hasattr(ann, "__args__"):
|
||||
for arg in ann.__args__: # type: ignore[union-attr]
|
||||
if isinstance(arg, type) and issubclass(arg, Enum):
|
||||
enum_type = arg
|
||||
break
|
||||
elif isinstance(ann, type) and issubclass(ann, Enum):
|
||||
enum_type = ann
|
||||
if enum_type is not None:
|
||||
param_def["enum"] = [e.value for e in enum_type]
|
||||
|
||||
parameters[field_name] = param_def
|
||||
|
||||
# In Pydantic v2, check if field is required based on default value
|
||||
try:
|
||||
if field.is_required():
|
||||
required.append(field_name)
|
||||
except (AttributeError, TypeError):
|
||||
# Fallback if is_required() doesn't exist or fails
|
||||
try:
|
||||
has_default = getattr(field, "default", None) is not None
|
||||
has_factory = getattr(field, "default_factory", None) is not None
|
||||
if not (has_default or has_factory):
|
||||
required.append(field_name)
|
||||
except Exception:
|
||||
# Ultimate fallback - assume required if we can't determine
|
||||
logger.debug(
|
||||
f"Could not determine if field {field_name} is required, assuming optional"
|
||||
)
|
||||
|
||||
# Create the input schema with explicit properties and required fields
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
}
|
||||
|
||||
# Only include required field if we have required parameters
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
|
||||
# Add annotations based on tool metadata
|
||||
annotations = {}
|
||||
|
||||
# Use tool name as title if available
|
||||
annotations["title"] = getattr(tool.definition, "title", str(name).replace(".", "_"))
|
||||
|
||||
# Determine hints based on tool properties
|
||||
if hasattr(tool.definition, "metadata"):
|
||||
metadata = tool.definition.metadata or {}
|
||||
annotations["readOnlyHint"] = metadata.get("read_only", False)
|
||||
annotations["destructiveHint"] = metadata.get("destructive", False)
|
||||
annotations["idempotentHint"] = metadata.get("idempotent", True)
|
||||
annotations["openWorldHint"] = metadata.get("open_world", False)
|
||||
|
||||
# Create the final tool definition
|
||||
tool_def: MCPTool = {
|
||||
"name": str(name).replace(".", "_"),
|
||||
"description": str(description),
|
||||
"inputSchema": input_schema,
|
||||
"annotations": annotations,
|
||||
}
|
||||
|
||||
logger.debug(f"Created tool definition for {name}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Error creating MCP tool definition for {getattr(tool, 'name', str(tool))}"
|
||||
)
|
||||
return None
|
||||
return tool_def
|
||||
|
||||
|
||||
def convert_to_mcp_content(value: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a Python value to MCP-compatible content.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
if isinstance(value, (str, bool, int, float)):
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
return [{"type": "text", "text": json.dumps(value)}]
|
||||
|
||||
# Default fallback
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
|
||||
def _map_type_to_json_schema_type(val_type: str) -> str:
|
||||
"""
|
||||
Map Arcade value types to JSON schema types.
|
||||
|
||||
Args:
|
||||
val_type: The Arcade value type as a string.
|
||||
|
||||
Returns:
|
||||
The corresponding JSON schema type as a string.
|
||||
"""
|
||||
mapping: dict[str, str] = {
|
||||
"string": "string",
|
||||
"integer": "integer",
|
||||
"number": "number",
|
||||
"boolean": "boolean",
|
||||
"json": "object",
|
||||
"array": "array",
|
||||
}
|
||||
return mapping.get(val_type, "string")
|
||||
215
arcade/arcade/worker/mcp/logging.py
Normal file
215
arcade/arcade/worker/mcp/logging.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from arcade.worker.mcp.types import (
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MCPMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("arcade.mcp")
|
||||
|
||||
|
||||
class MCPLoggingMiddleware:
|
||||
"""
|
||||
Middleware for logging MCP requests and responses.
|
||||
Logs request and response details, including timing and errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "INFO",
|
||||
log_request_body: bool = False,
|
||||
log_response_body: bool = False,
|
||||
log_errors: bool = True,
|
||||
min_duration_to_log_ms: int = 0,
|
||||
stdio_mode: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MCP logging middleware.
|
||||
|
||||
Args:
|
||||
log_level: Logging level (default: "INFO").
|
||||
log_request_body: Whether to log full request bodies (default: False).
|
||||
log_response_body: Whether to log full response bodies (default: False).
|
||||
log_errors: Whether to log errors at ERROR level (default: True).
|
||||
min_duration_to_log_ms: Minimum duration in ms to log (0 logs all).
|
||||
stdio_mode: Whether running in stdio mode (redirects logs to stderr).
|
||||
"""
|
||||
self.log_level = getattr(logging, log_level.upper())
|
||||
self.log_request_body = log_request_body
|
||||
self.log_response_body = log_response_body
|
||||
self.log_errors = log_errors
|
||||
self.min_duration_to_log_ms = min_duration_to_log_ms
|
||||
self.request_log_format = "[MCP>] {method}{params_str} (id: {id})"
|
||||
self.response_log_format = "[MCP<] {method} completed in {duration:.2f}ms (id: {id})"
|
||||
self.error_log_format = "[MCP!] {method} error: {error} (id: {id})"
|
||||
|
||||
# If in stdio mode, ensure MCP logs go to stderr
|
||||
if stdio_mode:
|
||||
self._redirect_logs_to_stderr()
|
||||
|
||||
# Log that middleware is initialized
|
||||
logger.debug(f"MCP logging middleware initialized (level: {log_level})")
|
||||
|
||||
def _redirect_logs_to_stderr(self) -> None:
|
||||
"""Redirect MCP logs to stderr to avoid interfering with stdio communication."""
|
||||
# Remove any existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# Add a stderr handler
|
||||
stderr_handler = logging.StreamHandler(sys.stderr)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stderr_handler.setFormatter(formatter)
|
||||
logger.addHandler(stderr_handler)
|
||||
|
||||
# Ensure we're not propagating to root logger which might log to stdout
|
||||
logger.propagate = False
|
||||
|
||||
logger.debug("MCP logs redirected to stderr for stdio mode")
|
||||
|
||||
def __call__(self, message: MCPMessage, direction: str) -> MCPMessage:
|
||||
"""
|
||||
Process and log an MCP message.
|
||||
|
||||
Args:
|
||||
message: The MCP message to process.
|
||||
direction: The message direction ("request" or "response").
|
||||
|
||||
Returns:
|
||||
The original message (unmodified).
|
||||
"""
|
||||
if direction == "request":
|
||||
self._log_request(message)
|
||||
else:
|
||||
self._log_response(message)
|
||||
return message
|
||||
|
||||
def _log_request(self, message: MCPMessage) -> None:
|
||||
"""
|
||||
Log an MCP request message.
|
||||
"""
|
||||
if not isinstance(message, JSONRPCRequest):
|
||||
logger.debug(f"Ignoring non-request message: {type(message).__name__}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Store request start time for duration calculation
|
||||
message._mcp_start_time = time.time() # type: ignore[attr-defined]
|
||||
|
||||
# Format parameters for logging
|
||||
params_str = ""
|
||||
if self.log_request_body and hasattr(message, "params") and message.params is not None:
|
||||
params_str = f": {self._format_params(message.params)}"
|
||||
|
||||
log_msg = self.request_log_format.format(
|
||||
method=message.method, params_str=params_str, id=getattr(message, "id", "none")
|
||||
)
|
||||
|
||||
logger.log(self.log_level, log_msg)
|
||||
except Exception:
|
||||
logger.exception("Error logging request")
|
||||
|
||||
def _log_response(self, message: MCPMessage) -> None:
|
||||
"""
|
||||
Log an MCP response message.
|
||||
"""
|
||||
if not isinstance(message, (JSONRPCResponse, JSONRPCError)):
|
||||
logger.debug(f"Ignoring non-response message: {type(message).__name__}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Calculate request duration if we have the start time
|
||||
duration_ms = 0
|
||||
request = getattr(message, "_request", None)
|
||||
if request:
|
||||
start_time = getattr(request, "_mcp_start_time", None)
|
||||
if start_time:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
else:
|
||||
start_time = getattr(message, "_mcp_start_time", None)
|
||||
if start_time:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Skip if below minimum duration threshold
|
||||
if self.min_duration_to_log_ms > 0 and duration_ms < self.min_duration_to_log_ms:
|
||||
return
|
||||
|
||||
# Handle error responses
|
||||
if hasattr(message, "error") and message.error is not None:
|
||||
if self.log_errors:
|
||||
error_msg = self.error_log_format.format(
|
||||
method=getattr(message, "method", "unknown"),
|
||||
error=getattr(message.error, "message", str(message.error)),
|
||||
id=getattr(message, "id", "none"),
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return
|
||||
|
||||
# Log successful response
|
||||
result_str = ""
|
||||
if self.log_response_body and hasattr(message, "result"):
|
||||
result_str = f": {self._format_result(message.result)}"
|
||||
|
||||
log_msg = self.response_log_format.format(
|
||||
method=getattr(message, "method", "unknown"),
|
||||
duration=duration_ms,
|
||||
id=getattr(message, "id", "none"),
|
||||
result_str=result_str,
|
||||
)
|
||||
|
||||
logger.log(self.log_level, log_msg)
|
||||
except Exception:
|
||||
logger.exception("Error logging response")
|
||||
|
||||
def _format_params(self, params: Any) -> str:
|
||||
"""
|
||||
Format parameters for logging.
|
||||
"""
|
||||
try:
|
||||
if isinstance(params, dict):
|
||||
# Handle common MCP params specially
|
||||
if "name" in params and "arguments" in params:
|
||||
return f"{params['name']}({json.dumps(params.get('arguments', {}))})"
|
||||
return json.dumps(params)
|
||||
return str(params)
|
||||
except Exception:
|
||||
logger.debug(f"Error formatting params {params!s}")
|
||||
return str(params)
|
||||
|
||||
def _format_result(self, result: Any) -> str:
|
||||
"""
|
||||
Format result for logging.
|
||||
"""
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
return json.dumps(result)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error formatting result {e!s}")
|
||||
return str(result)
|
||||
|
||||
|
||||
def create_mcp_logging_middleware(**config: Any) -> MCPLoggingMiddleware:
|
||||
"""
|
||||
Create an MCP logging middleware with the given configuration.
|
||||
|
||||
Args:
|
||||
**config: Configuration options.
|
||||
|
||||
Returns:
|
||||
An MCPLoggingMiddleware instance.
|
||||
"""
|
||||
return MCPLoggingMiddleware(
|
||||
log_level=config.get("log_level", "INFO"),
|
||||
log_request_body=config.get("log_request_body", False),
|
||||
log_response_body=config.get("log_response_body", False),
|
||||
log_errors=config.get("log_errors", True),
|
||||
min_duration_to_log_ms=config.get("min_duration_to_log_ms", 0),
|
||||
stdio_mode=config.get("stdio_mode", False),
|
||||
)
|
||||
83
arcade/arcade/worker/mcp/message_processor.py
Normal file
83
arcade/arcade/worker/mcp/message_processor.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from arcade.worker.mcp.types import InitializeRequest, JSONRPCRequest, MCPMessage
|
||||
|
||||
logger = logging.getLogger("arcade.mcp")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Type definition for middleware functions
|
||||
MessageProcessor = Callable[[Any, str], Any]
|
||||
|
||||
|
||||
class MCPMessageProcessor:
|
||||
"""
|
||||
Processes MCP messages through a chain of middleware.
|
||||
Supports both synchronous and asynchronous middleware.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.middleware: list[Callable[[MCPMessage, str], Any]] = []
|
||||
|
||||
def add_middleware(self, mw: Callable[[MCPMessage, str], Any]) -> None:
|
||||
self.middleware.append(mw)
|
||||
|
||||
async def process(self, message: Any, direction: str) -> Any: # noqa: C901
|
||||
# First, try to parse the message if it's a string
|
||||
if isinstance(message, str):
|
||||
# Strip any whitespace including newlines
|
||||
message = message.strip()
|
||||
if not message:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = json.loads(message)
|
||||
if isinstance(parsed, dict):
|
||||
method = parsed.get("method")
|
||||
# Convert to appropriate message type
|
||||
if method == "initialize" and "id" in parsed:
|
||||
logger.debug(f"Parsed initialize request: {parsed}")
|
||||
message = InitializeRequest(**parsed)
|
||||
elif method and method.startswith("notifications/"):
|
||||
# It's a notification, log it but pass through as dict
|
||||
logger.debug(f"Received notification: {method}")
|
||||
# Keep as parsed dict to avoid validation errors on unknown notifications
|
||||
message = parsed
|
||||
elif "method" in parsed and "id" in parsed:
|
||||
# Regular method request
|
||||
logger.debug(f"Parsed method request: {method}")
|
||||
message = JSONRPCRequest(**parsed)
|
||||
# Other message types can be handled similarly
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse message as JSON: {message[:100]}...")
|
||||
except Exception:
|
||||
logger.exception("Error processing message")
|
||||
|
||||
# Process through middleware chain
|
||||
result = message
|
||||
for mw in self.middleware:
|
||||
try:
|
||||
if inspect.iscoroutinefunction(mw):
|
||||
result = await mw(result, direction)
|
||||
else:
|
||||
result = mw(result, direction)
|
||||
except Exception:
|
||||
logger.exception(f"Error in middleware {mw}")
|
||||
return result
|
||||
|
||||
async def process_request(self, message: Any) -> Any:
|
||||
return await self.process(message, "request")
|
||||
|
||||
async def process_response(self, message: Any) -> Any:
|
||||
return await self.process(message, "response")
|
||||
|
||||
|
||||
def create_message_processor(*middleware: MessageProcessor) -> MCPMessageProcessor:
|
||||
processor = MCPMessageProcessor()
|
||||
for m in middleware:
|
||||
if m is not None:
|
||||
processor.add_middleware(m)
|
||||
return processor
|
||||
601
arcade/arcade/worker/mcp/server.py
Normal file
601
arcade/arcade/worker/mcp/server.py
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from arcadepy import ArcadeError, AsyncArcade
|
||||
from arcadepy.types.auth_authorize_params import AuthRequirement, AuthRequirementOauth2
|
||||
from arcadepy.types.shared import AuthorizationResponse
|
||||
|
||||
from arcade.core.catalog import MaterializedTool, ToolCatalog
|
||||
from arcade.core.executor import ToolExecutor
|
||||
from arcade.core.schema import ToolAuthorizationContext, ToolContext
|
||||
from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool
|
||||
from arcade.worker.mcp.logging import create_mcp_logging_middleware
|
||||
from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor
|
||||
from arcade.worker.mcp.types import (
|
||||
CallToolRequest,
|
||||
CallToolResponse,
|
||||
CallToolResult,
|
||||
CancelRequest,
|
||||
Implementation,
|
||||
InitializeRequest,
|
||||
InitializeResponse,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCResponse,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResponse,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResponse,
|
||||
ListToolsRequest,
|
||||
ListToolsResponse,
|
||||
ListToolsResult,
|
||||
PingRequest,
|
||||
PingResponse,
|
||||
ProgressNotification,
|
||||
ServerCapabilities,
|
||||
ShutdownRequest,
|
||||
ShutdownResponse,
|
||||
Tool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("arcade.mcp")
|
||||
|
||||
MCP_PROTOCOL_VERSION = "2024-11-05"
|
||||
|
||||
|
||||
class MessageMethod(str, Enum):
|
||||
"""Enumeration of supported MCP message methods"""
|
||||
|
||||
PING = "ping"
|
||||
INITIALIZE = "initialize"
|
||||
LIST_TOOLS = "tools/list"
|
||||
CALL_TOOL = "tools/call"
|
||||
PROGRESS = "progress"
|
||||
CANCEL = "$/cancelRequest"
|
||||
SHUTDOWN = "shutdown"
|
||||
LIST_RESOURCES = "resources/list"
|
||||
LIST_PROMPTS = "prompts/list"
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""
|
||||
Unified async MCP server that manages connections, middleware, and tool invocation.
|
||||
Handles protocol-level messages (ping, initialize, list_tools, call_tool, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_catalog: Any,
|
||||
enable_logging: bool = True,
|
||||
**client_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MCP server.
|
||||
|
||||
Args:
|
||||
tool_catalog: Catalog of available tools
|
||||
**client_kwargs: Additional arguments to pass to the AsyncArcade client
|
||||
"""
|
||||
self.tool_catalog: ToolCatalog = tool_catalog
|
||||
self.message_processor: MCPMessageProcessor = create_message_processor()
|
||||
|
||||
# Pop middleware_config from client_kwargs regardless of logging state,
|
||||
# as it's internal config not meant for AsyncArcade.
|
||||
middleware_config = client_kwargs.pop("middleware_config", {})
|
||||
|
||||
if enable_logging:
|
||||
# Create and add the logging middleware if logging is enabled.
|
||||
# Note: enable_logging must be True for this middleware (and its stdio_mode behavior)
|
||||
# to be activated.
|
||||
self.message_processor.add_middleware(
|
||||
create_mcp_logging_middleware(**middleware_config)
|
||||
)
|
||||
|
||||
self._shutdown: bool = False
|
||||
# Initialize AsyncArcade with the *remaining* client_kwargs
|
||||
self.arcade = AsyncArcade(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
# Initialize handler dispatch table
|
||||
self._method_handlers: dict[str, Callable] = {
|
||||
MessageMethod.PING: self._handle_ping,
|
||||
MessageMethod.INITIALIZE: self._handle_initialize,
|
||||
MessageMethod.LIST_TOOLS: self._handle_list_tools,
|
||||
MessageMethod.CALL_TOOL: self._handle_call_tool,
|
||||
MessageMethod.PROGRESS: self._handle_progress,
|
||||
MessageMethod.CANCEL: self._handle_cancel,
|
||||
MessageMethod.SHUTDOWN: self._handle_shutdown,
|
||||
MessageMethod.LIST_RESOURCES: self._handle_list_resources,
|
||||
MessageMethod.LIST_PROMPTS: self._handle_list_prompts,
|
||||
}
|
||||
|
||||
async def run_connection(
|
||||
self,
|
||||
read_stream: Any,
|
||||
write_stream: Any,
|
||||
init_options: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Handle a single MCP connection (SSE or stdio).
|
||||
|
||||
Args:
|
||||
read_stream: Async iterable yielding incoming messages.
|
||||
write_stream: Object with an async send(message) method.
|
||||
init_options: Initialization options for the connection.
|
||||
"""
|
||||
# Generate a user ID if possible
|
||||
user_id = self._get_user_id(init_options)
|
||||
|
||||
try:
|
||||
logger.info(f"Starting MCP connection for user {user_id}")
|
||||
|
||||
async for message in read_stream:
|
||||
# Process the message
|
||||
response = await self.handle_message(message, user_id=user_id)
|
||||
|
||||
# Skip sending responses for None (e.g., notifications)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
await self._send_response(write_stream, response)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Connection cancelled")
|
||||
except Exception:
|
||||
logger.exception("Error in connection")
|
||||
|
||||
def _get_user_id(self, init_options: Any) -> str:
|
||||
"""
|
||||
Get the user ID for a connection.
|
||||
|
||||
Args:
|
||||
init_options: Initialization options for the connection
|
||||
|
||||
Returns:
|
||||
A user ID string
|
||||
"""
|
||||
try:
|
||||
from arcade.core.config import config
|
||||
|
||||
# Prefer config.user.email if available
|
||||
if config.user and config.user.email:
|
||||
return config.user.email
|
||||
except ValueError:
|
||||
logger.debug("No logged in user for MCP Server")
|
||||
|
||||
fallback = str(uuid.uuid4())
|
||||
if os.environ.get("ARCADE_USER_ID", None):
|
||||
return os.environ.get("ARCADE_USER_ID", fallback)
|
||||
elif isinstance(init_options, dict):
|
||||
user_id = init_options.get("user_id")
|
||||
if user_id:
|
||||
return str(user_id)
|
||||
# Fallback to random UUID
|
||||
return str(fallback)
|
||||
|
||||
async def _send_response(self, write_stream: Any, response: Any) -> None:
|
||||
"""
|
||||
Send a response to the client.
|
||||
|
||||
Args:
|
||||
write_stream: Stream to write the response to
|
||||
response: Response object to send
|
||||
"""
|
||||
# Ensure the response is properly serialized to JSON
|
||||
if hasattr(response, "model_dump_json"):
|
||||
# It's a Pydantic model, serialize it
|
||||
json_response = response.model_dump_json()
|
||||
# Ensure it ends with a newline for JSON-RPC-over-stdio
|
||||
if not json_response.endswith("\n"):
|
||||
json_response += "\n"
|
||||
logger.debug(f"Sending response: {json_response[:200]}...")
|
||||
await write_stream.send(json_response)
|
||||
elif isinstance(response, dict):
|
||||
# It's a dict, convert to JSON
|
||||
import json
|
||||
|
||||
json_response = json.dumps(response)
|
||||
# Ensure it ends with a newline for JSON-RPC-over-stdio
|
||||
if not json_response.endswith("\n"):
|
||||
json_response += "\n"
|
||||
logger.debug(f"Sending response: {json_response[:200]}...")
|
||||
await write_stream.send(json_response)
|
||||
else:
|
||||
# It's already a string or something else
|
||||
response_str = str(response)
|
||||
# Ensure it ends with a newline for JSON-RPC-over-stdio
|
||||
if not response_str.endswith("\n"):
|
||||
response_str += "\n"
|
||||
logger.debug(f"Sending raw response type: {type(response)}")
|
||||
await write_stream.send(response_str)
|
||||
|
||||
async def handle_message(self, message: Any, user_id: str | None = None) -> Any:
|
||||
"""
|
||||
Handle an incoming MCP message. Processes it through middleware and dispatches
|
||||
to the appropriate handler based on the message method.
|
||||
|
||||
Args:
|
||||
message: The raw incoming message
|
||||
user_id: Optional user ID for authentication
|
||||
|
||||
Returns:
|
||||
A properly formatted response message
|
||||
"""
|
||||
# Pre-process message through middleware
|
||||
processed = await self.message_processor.process_request(message)
|
||||
|
||||
# Handle special case for JSON string initialize requests
|
||||
if isinstance(processed, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed = json.loads(processed)
|
||||
if (
|
||||
isinstance(parsed, dict)
|
||||
and parsed.get("method") == MessageMethod.INITIALIZE
|
||||
and "id" in parsed
|
||||
):
|
||||
# This is an initialize request
|
||||
init_response = await self._handle_initialize(InitializeRequest(**parsed))
|
||||
return init_response
|
||||
except Exception:
|
||||
logger.exception("Error processing JSON string")
|
||||
# Not parseable JSON, continue with normal processing
|
||||
pass
|
||||
|
||||
# Check if it's a notification
|
||||
if hasattr(processed, "method"):
|
||||
method = getattr(processed, "method", None)
|
||||
|
||||
# Handle notifications (methods starting with "notifications/")
|
||||
if method and method.startswith("notifications/"):
|
||||
await self._handle_notification(method, processed)
|
||||
return None
|
||||
|
||||
# Handle regular methods using the dispatch table
|
||||
if method in self._method_handlers:
|
||||
# If it's a call_tool request, we need to pass the user_id
|
||||
if method == MessageMethod.CALL_TOOL:
|
||||
return await self._method_handlers[method](processed, user_id=user_id)
|
||||
# For other methods, just pass the processed message
|
||||
return await self._method_handlers[method](processed)
|
||||
|
||||
# Unknown method
|
||||
return JSONRPCError(
|
||||
id=getattr(processed, "id", None),
|
||||
error={
|
||||
"code": -32601,
|
||||
"message": f"Method not found: {method}",
|
||||
},
|
||||
)
|
||||
|
||||
# If it's not a method request, just pass it through
|
||||
return processed
|
||||
|
||||
async def _handle_notification(self, method: str, message: Any) -> None:
|
||||
"""
|
||||
Handle notification messages.
|
||||
|
||||
Args:
|
||||
method: The notification method
|
||||
message: The notification message
|
||||
"""
|
||||
if method == "notifications/cancelled":
|
||||
logger.info(f"Request cancelled: {getattr(message, 'params', {})}")
|
||||
else:
|
||||
logger.debug(f"Received notification: {method}")
|
||||
|
||||
async def _handle_ping(self, message: PingRequest) -> PingResponse:
|
||||
"""
|
||||
Handle a ping request and return a pong response.
|
||||
|
||||
Args:
|
||||
message: The ping request
|
||||
|
||||
Returns:
|
||||
A properly formatted pong response
|
||||
"""
|
||||
return PingResponse(id=message.id)
|
||||
|
||||
async def _handle_initialize(self, message: InitializeRequest) -> InitializeResponse:
|
||||
"""
|
||||
Handle an initialize request and return a proper initialize response.
|
||||
|
||||
Args:
|
||||
message: The initialize request
|
||||
|
||||
Returns:
|
||||
A properly formatted initialize response
|
||||
"""
|
||||
# Create the result data
|
||||
result = InitializeResult(
|
||||
protocolVersion=MCP_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="Arcade MCP Worker", version="0.1.0"),
|
||||
instructions="Arcade MCP Worker initialized.",
|
||||
)
|
||||
|
||||
# Construct proper response with result field
|
||||
response = InitializeResponse(id=message.id, result=result)
|
||||
|
||||
logger.debug(f"Initialize response: {response.model_dump_json()}")
|
||||
return response
|
||||
|
||||
async def _handle_list_tools(
|
||||
self, message: ListToolsRequest
|
||||
) -> Union[ListToolsResponse, JSONRPCError]:
|
||||
"""
|
||||
Handle a tools/list request and return a list of available tools.
|
||||
|
||||
Args:
|
||||
message: The tools/list request
|
||||
|
||||
Returns:
|
||||
A properly formatted tools/list response or error
|
||||
"""
|
||||
try:
|
||||
# Get all tools from the catalog
|
||||
tools = []
|
||||
tool_conversion_errors = []
|
||||
|
||||
for tool in self.tool_catalog:
|
||||
try:
|
||||
mcp_tool = create_mcp_tool(tool)
|
||||
if mcp_tool:
|
||||
tools.append(mcp_tool)
|
||||
except Exception:
|
||||
tool_name = getattr(tool, "name", str(tool))
|
||||
logger.exception(f"Error converting tool: {tool_name}")
|
||||
tool_conversion_errors.append(tool_name)
|
||||
|
||||
# Log summary if we had errors
|
||||
if tool_conversion_errors:
|
||||
logger.warning(
|
||||
f"Failed to convert {len(tool_conversion_errors)} tools: {tool_conversion_errors}"
|
||||
)
|
||||
|
||||
# Create tool objects with exception handling for each one
|
||||
tool_objects = []
|
||||
for t in tools:
|
||||
try:
|
||||
# Make input schema optional if missing
|
||||
tool_dict = dict(t)
|
||||
if "inputSchema" not in tool_dict:
|
||||
tool_dict["inputSchema"] = {"type": "object", "properties": {}}
|
||||
|
||||
tool_objects.append(Tool(**tool_dict))
|
||||
except Exception:
|
||||
logger.exception(f"Error creating Tool object for {t.get('name', 'unknown')}")
|
||||
|
||||
# Return successful response with the tools we were able to convert
|
||||
result = ListToolsResult(tools=tool_objects)
|
||||
response = ListToolsResponse(id=message.id, result=result)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error listing tools")
|
||||
return JSONRPCError(
|
||||
id=message.id,
|
||||
error={
|
||||
"code": -32603,
|
||||
"message": "Internal error listing tools",
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _handle_call_tool(
|
||||
self, message: CallToolRequest, user_id: str | None = None
|
||||
) -> CallToolResponse:
|
||||
"""
|
||||
Handle a tools/call request to execute a tool.
|
||||
|
||||
Args:
|
||||
message: The tools/call request
|
||||
user_id: Optional user ID for authentication
|
||||
|
||||
Returns:
|
||||
A properly formatted tools/call response
|
||||
"""
|
||||
tool_name: str = message.params["name"]
|
||||
# Extract input from the correct field
|
||||
input_params: dict[str, Any] = message.params.get("input", {})
|
||||
if not input_params:
|
||||
input_params = message.params.get("arguments", {})
|
||||
|
||||
logger.info(f"Handling tool call for {tool_name}")
|
||||
|
||||
try:
|
||||
tool = self.tool_catalog.get_tool_by_name(tool_name, separator="_")
|
||||
tool_context = ToolContext()
|
||||
|
||||
# Set up context with secrets
|
||||
if tool.definition.requirements and tool.definition.requirements.secrets:
|
||||
self._setup_tool_secrets(tool, tool_context)
|
||||
|
||||
# Handle authorization if needed
|
||||
requirement = self._get_auth_requirement(tool)
|
||||
if requirement:
|
||||
auth_result = await self._check_authorization(requirement, user_id=user_id)
|
||||
if auth_result.status != "completed":
|
||||
return CallToolResponse(
|
||||
id=message.id,
|
||||
result=CallToolResult(content=[{"type": "text", "text": auth_result.url}]),
|
||||
)
|
||||
else:
|
||||
tool_context.authorization = ToolAuthorizationContext(
|
||||
token=auth_result.context.token if auth_result.context else None,
|
||||
user_info={"user_id": user_id} if user_id else {},
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
logger.debug(f"Executing tool {tool_name} with input: {input_params}")
|
||||
result = await ToolExecutor.run(
|
||||
func=tool.tool,
|
||||
definition=tool.definition,
|
||||
input_model=tool.input_model,
|
||||
output_model=tool.output_model,
|
||||
context=tool_context,
|
||||
**input_params,
|
||||
)
|
||||
logger.debug(f"Tool result: {result}")
|
||||
if result.value:
|
||||
return CallToolResponse(
|
||||
id=message.id,
|
||||
result=CallToolResult(content=convert_to_mcp_content(result.value)),
|
||||
)
|
||||
else:
|
||||
error = result.error or "Error calling tool"
|
||||
logger.error(f"Tool {tool_name} returned error: {error}")
|
||||
return CallToolResponse(
|
||||
id=message.id,
|
||||
result=CallToolResult(
|
||||
content=[{"type": "text", "text": convert_to_mcp_content(error)}]
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling tool {tool_name}")
|
||||
error = f"Error calling tool {tool_name}: {e!s}"
|
||||
return CallToolResponse(
|
||||
id=message.id,
|
||||
result=CallToolResult(
|
||||
content=[{"type": "text", "text": convert_to_mcp_content(error)}]
|
||||
),
|
||||
)
|
||||
|
||||
def _setup_tool_secrets(self, tool: Any, tool_context: ToolContext) -> None:
|
||||
"""
|
||||
Set up tool secrets in the tool context.
|
||||
|
||||
Args:
|
||||
tool: The tool to set up secrets for
|
||||
tool_context: The tool context to update
|
||||
"""
|
||||
for secret in tool.definition.requirements.secrets:
|
||||
value = os.environ.get(secret.key)
|
||||
if value is not None:
|
||||
tool_context.set_secret(secret.key, value)
|
||||
|
||||
async def _handle_progress(self, message: ProgressNotification) -> JSONRPCResponse:
|
||||
"""
|
||||
Handle a progress notification.
|
||||
|
||||
Args:
|
||||
message: The progress notification
|
||||
|
||||
Returns:
|
||||
A response acknowledging the notification
|
||||
"""
|
||||
return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True})
|
||||
|
||||
async def _handle_cancel(self, message: CancelRequest) -> JSONRPCResponse:
|
||||
"""
|
||||
Handle a cancel request.
|
||||
|
||||
Args:
|
||||
message: The cancel request
|
||||
|
||||
Returns:
|
||||
A response acknowledging the cancellation
|
||||
"""
|
||||
return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True})
|
||||
|
||||
async def _handle_shutdown(self, message: ShutdownRequest) -> ShutdownResponse:
|
||||
"""
|
||||
Handle a shutdown request.
|
||||
|
||||
Args:
|
||||
message: The shutdown request
|
||||
|
||||
Returns:
|
||||
A response acknowledging the shutdown request
|
||||
"""
|
||||
# Schedule a task to shutdown the server after sending the response
|
||||
proc = asyncio.create_task(self.shutdown())
|
||||
proc.add_done_callback(lambda _: logger.info("MCP server shutdown complete"))
|
||||
return ShutdownResponse(id=message.id, result={"ok": True})
|
||||
|
||||
async def _handle_list_resources(self, message: ListResourcesRequest) -> ListResourcesResponse:
|
||||
"""
|
||||
Handle a resources/list request.
|
||||
|
||||
Args:
|
||||
message: The resources/list request
|
||||
|
||||
Returns:
|
||||
A properly formatted resources/list response
|
||||
"""
|
||||
return ListResourcesResponse(id=message.id, result={"resources": []})
|
||||
|
||||
async def _handle_list_prompts(self, message: ListPromptsRequest) -> ListPromptsResponse:
|
||||
"""
|
||||
Handle a prompts/list request.
|
||||
|
||||
Args:
|
||||
message: The prompts/list request
|
||||
|
||||
Returns:
|
||||
A properly formatted prompts/list response
|
||||
"""
|
||||
return ListPromptsResponse(id=message.id, result={"prompts": []})
|
||||
|
||||
def _get_auth_requirement(self, tool: MaterializedTool) -> AuthRequirement | None:
|
||||
"""
|
||||
Get the authentication requirement for a tool.
|
||||
|
||||
Args:
|
||||
tool: The tool to get the requirement for
|
||||
|
||||
Returns:
|
||||
An authentication requirement or None if not required
|
||||
"""
|
||||
req = tool.definition.requirements.authorization
|
||||
if not req:
|
||||
return None
|
||||
if not req.provider_id and not req.provider_type:
|
||||
return None
|
||||
if hasattr(req, "oauth2") and req.oauth2:
|
||||
return AuthRequirement(
|
||||
provider_id=str(req.provider_id),
|
||||
provider_type=str(req.provider_type),
|
||||
oauth2=AuthRequirementOauth2(scopes=req.oauth2.scopes or []),
|
||||
)
|
||||
return AuthRequirement(
|
||||
provider_id=str(req.provider_id),
|
||||
provider_type=str(req.provider_type),
|
||||
)
|
||||
|
||||
async def _check_authorization(
|
||||
self, auth_requirement: AuthRequirement, user_id: str | None = None
|
||||
) -> AuthorizationResponse:
|
||||
"""
|
||||
Check if a tool is authorized for a user.
|
||||
|
||||
Args:
|
||||
tool: The tool to check authorization for
|
||||
user_id: The user ID to check authorization for
|
||||
|
||||
Returns:
|
||||
An authorization response
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the tool has no authorization requirement
|
||||
Exception: If authorization fails
|
||||
"""
|
||||
try:
|
||||
response = await self.arcade.auth.authorize(
|
||||
auth_requirement=auth_requirement,
|
||||
user_id=user_id or "anonymous",
|
||||
)
|
||||
logger.debug(f"Authorization response: {response}")
|
||||
|
||||
except ArcadeError:
|
||||
logger.exception("Error authorizing tool")
|
||||
raise
|
||||
return response
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the server."""
|
||||
self._shutdown = True
|
||||
logger.info("MCP server shutdown complete")
|
||||
185
arcade/arcade/worker/mcp/stdio.py
Normal file
185
arcade/arcade/worker/mcp/stdio.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from arcade.worker.mcp.server import MCPServer
|
||||
|
||||
logger = logging.getLogger("arcade.mcp")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def stdio_reader(stdin: object, q: queue.Queue[str | None]) -> None:
|
||||
"""Read lines from stdin and put them into a queue."""
|
||||
for line in stdin: # type: ignore[attr-defined]
|
||||
q.put(line)
|
||||
q.put(None)
|
||||
|
||||
|
||||
def stdio_writer(stdout: object, q: queue.Queue[str | None]) -> None:
|
||||
"""Write messages from a queue to stdout."""
|
||||
try:
|
||||
while True:
|
||||
msg = q.get()
|
||||
if msg is None:
|
||||
break
|
||||
|
||||
# Ensure message ends with a newline for proper JSON-RPC-over-stdio
|
||||
if not msg.endswith("\n"):
|
||||
msg += "\n"
|
||||
|
||||
stdout.write(msg) # type: ignore[attr-defined]
|
||||
stdout.flush() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.exception("Error in stdio writer")
|
||||
|
||||
|
||||
class StdioServer(MCPServer):
|
||||
"""
|
||||
Stdio server that handles signals and cleanup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_catalog: Any,
|
||||
enable_logging: bool = True,
|
||||
**client_kwargs: dict[str, Any],
|
||||
):
|
||||
# Set up stdio-specific middleware configuration
|
||||
middleware_config = client_kwargs.get("middleware_config", {})
|
||||
middleware_config["stdio_mode"] = True
|
||||
client_kwargs["middleware_config"] = middleware_config
|
||||
|
||||
super().__init__(tool_catalog, enable_logging, **client_kwargs)
|
||||
self.read_q: queue.Queue[str | None] = queue.Queue()
|
||||
self.write_q: queue.Queue[str | None] = queue.Queue()
|
||||
self.reader_thread: threading.Thread | None = None
|
||||
self.writer_thread: threading.Thread | None = None
|
||||
self.running = False
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def start_io_threads(self) -> None:
|
||||
"""Start stdio reader and writer threads."""
|
||||
self.reader_thread = threading.Thread(
|
||||
target=self._stdio_reader, args=(sys.stdin, self.read_q), daemon=True
|
||||
)
|
||||
self.writer_thread = threading.Thread(
|
||||
target=self._stdio_writer, args=(sys.stdout, self.write_q), daemon=True
|
||||
)
|
||||
self.reader_thread.start()
|
||||
self.writer_thread.start()
|
||||
|
||||
def _stdio_reader(self, stdin: object, q: queue.Queue[str | None]) -> None:
|
||||
"""Read lines from stdin and put them into a queue."""
|
||||
try:
|
||||
for line in stdin: # type: ignore[attr-defined]
|
||||
if not self.running:
|
||||
break
|
||||
q.put(line)
|
||||
except Exception:
|
||||
logger.exception("Error in stdio reader")
|
||||
finally:
|
||||
q.put(None) # Signal EOF
|
||||
|
||||
def _stdio_writer(self, stdout: object, q: queue.Queue[str | None]) -> None:
|
||||
"""Write messages from a queue to stdout."""
|
||||
try:
|
||||
while self.running:
|
||||
msg = q.get()
|
||||
if msg is None:
|
||||
break
|
||||
stdout.write(msg) # type: ignore[attr-defined]
|
||||
stdout.flush() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.exception("Error in stdio writer")
|
||||
|
||||
async def _read_stream(self) -> AsyncGenerator[str, None]:
|
||||
"""Async generator that yields lines from the read queue."""
|
||||
while self.running:
|
||||
try:
|
||||
line = await asyncio.to_thread(self.read_q.get)
|
||||
if line is None:
|
||||
break
|
||||
yield line
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Error reading from stdin")
|
||||
break
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Gracefully shut down the server."""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Shutting down stdio server...")
|
||||
self.running = False
|
||||
|
||||
# Signal shutdown to MCP server
|
||||
await self.shutdown()
|
||||
|
||||
# Clean up IO queues and threads
|
||||
try:
|
||||
if self.read_q:
|
||||
self.read_q.put(None)
|
||||
if self.write_q:
|
||||
self.write_q.put(None)
|
||||
except Exception:
|
||||
logger.exception("Error during shutdown")
|
||||
|
||||
# Signal completion
|
||||
self.shutdown_event.set()
|
||||
logger.info("Stdio server shutdown complete")
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the stdio server with signal handling."""
|
||||
self.running = True
|
||||
|
||||
# Set up signal handlers
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown()))
|
||||
except NotImplementedError:
|
||||
# Windows doesn't support POSIX signals
|
||||
if sys.platform == "win32":
|
||||
logger.warning("Signal handling not fully supported on Windows")
|
||||
else:
|
||||
logger.warning(f"Failed to set up signal handler for {sig}")
|
||||
|
||||
# Start IO threads
|
||||
self.start_io_threads()
|
||||
|
||||
logger.info("Starting MCP server with stdio transport")
|
||||
|
||||
# Create WriteStream class for MCP server
|
||||
class WriteStream:
|
||||
async def send(self_, message: str) -> None:
|
||||
if self.running:
|
||||
await asyncio.to_thread(self.write_q.put, message)
|
||||
|
||||
try:
|
||||
# Run MCP server connection
|
||||
await self.run_connection(self._read_stream(), WriteStream(), None)
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation
|
||||
logger.info("Server operation cancelled")
|
||||
except KeyboardInterrupt:
|
||||
# Handle keyboard interrupt
|
||||
logger.info("Keyboard interrupt received")
|
||||
except Exception:
|
||||
# Handle unexpected errors
|
||||
logger.exception("Unexpected error")
|
||||
finally:
|
||||
# Ensure we clean up
|
||||
await self.shutdown()
|
||||
# Wait for shutdown to complete
|
||||
await self.shutdown_event.wait()
|
||||
383
arcade/arcade/worker/mcp/types.py
Normal file
383
arcade/arcade/worker/mcp/types.py
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
RequestId = str | int
|
||||
AnyFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
progressToken: ProgressToken | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class NotificationParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
|
||||
NotificationParamsT = TypeVar(
|
||||
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
|
||||
)
|
||||
MethodT = TypeVar("MethodT", bound=str)
|
||||
|
||||
|
||||
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
method: MethodT
|
||||
params: RequestParamsT
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedRequest(Request[RequestParamsT, MethodT]):
|
||||
cursor: Cursor | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
method: MethodT
|
||||
params: NotificationParamsT
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedResult(Result):
|
||||
nextCursor: Cursor | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
"""Base class for all JSON-RPC messages."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
jsonrpc: str = Field(default="2.0", frozen=True)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
"""A JSON-RPC request message."""
|
||||
|
||||
id: str | int | None = None
|
||||
method: str
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
"""A JSON-RPC response message."""
|
||||
|
||||
id: str | int | None
|
||||
result: Any = None
|
||||
error: dict[str, Any] | None = None
|
||||
|
||||
def model_dump_json(self, **kwargs: Any) -> str:
|
||||
"""Convert to JSON string with proper formatting."""
|
||||
|
||||
# Convert to dict
|
||||
data = {
|
||||
"jsonrpc": self.jsonrpc,
|
||||
"id": self.id,
|
||||
}
|
||||
|
||||
# Add result if present
|
||||
if self.result is not None:
|
||||
# Check if result is a Pydantic model
|
||||
if hasattr(self.result, "model_dump"):
|
||||
data["result"] = self.result.model_dump(exclude_none=True)
|
||||
# Check if result is already a dict/list/primitive
|
||||
elif (
|
||||
isinstance(self.result, (dict, list, str, int, float, bool)) or self.result is None
|
||||
):
|
||||
data["result"] = self.result # type: ignore[assignment]
|
||||
else:
|
||||
# Try to convert using str() as a fallback
|
||||
data["result"] = str(self.result)
|
||||
|
||||
# Add error if present
|
||||
if self.error is not None:
|
||||
data["error"] = self.error # type: ignore[assignment]
|
||||
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
class JSONRPCError(JSONRPCMessage):
|
||||
"""A JSON-RPC error message."""
|
||||
|
||||
id: str | int | None
|
||||
error: dict[str, Any]
|
||||
|
||||
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
|
||||
|
||||
class ErrorData(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
JSONRPCMessageBaseModel = BaseModel | JSONRPCRequest | JSONRPCResponse | JSONRPCError
|
||||
|
||||
|
||||
class EmptyResult(Result):
|
||||
pass
|
||||
|
||||
|
||||
class Implementation(BaseModel):
|
||||
"""Describes the server or client implementation."""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class RootsCapability(BaseModel):
|
||||
listChanged: bool | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SamplingCapability(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ClientCapabilities(BaseModel):
|
||||
experimental: dict[str, dict[str, Any]] | None = None
|
||||
sampling: SamplingCapability | None = None
|
||||
roots: RootsCapability | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PromptsCapability(BaseModel):
|
||||
listChanged: bool | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourcesCapability(BaseModel):
|
||||
subscribe: bool | None = None
|
||||
listChanged: bool | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolsCapability(BaseModel):
|
||||
listChanged: bool | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class LoggingCapability(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ServerCapabilities(BaseModel):
|
||||
"""Describes the server's capabilities."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
tools: dict[str, Any] | None = None
|
||||
resources: dict[str, Any] | None = None
|
||||
prompts: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InitializeRequestParams(RequestParams):
|
||||
protocolVersion: str | int
|
||||
capabilities: ClientCapabilities
|
||||
clientInfo: Implementation
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class InitializeRequest(JSONRPCRequest):
|
||||
method: str = Field(default="initialize", frozen=True)
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InitializeResult(BaseModel):
|
||||
protocolVersion: str
|
||||
capabilities: ServerCapabilities
|
||||
serverInfo: Implementation
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class InitializedNotification(
|
||||
Notification[NotificationParams | None, Literal["notifications/initialized"]]
|
||||
):
|
||||
method: Literal["notifications/initialized"]
|
||||
params: NotificationParams | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PingRequest(JSONRPCRequest):
|
||||
method: str = Field(default="ping", frozen=True)
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ProgressNotificationParams(NotificationParams):
|
||||
progressToken: ProgressToken
|
||||
progress: float
|
||||
total: float | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ProgressNotification(JSONRPCMessage):
|
||||
method: str = Field(default="progress", frozen=True)
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class PingResponse(JSONRPCResponse):
|
||||
result: dict[str, Any] = Field(default_factory=lambda: {"pong": True})
|
||||
|
||||
|
||||
class ShutdownRequest(JSONRPCRequest):
|
||||
method: str = Field(default="shutdown", frozen=True)
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ShutdownResponse(JSONRPCResponse):
|
||||
result: dict[str, Any] = Field(default_factory=lambda: {"ok": True})
|
||||
|
||||
|
||||
class CancelRequest(JSONRPCRequest):
|
||||
method: str = Field(default="$/cancelRequest", frozen=True)
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class InitializeResponse(JSONRPCResponse):
|
||||
"""
|
||||
Response to an initialize request.
|
||||
|
||||
Note: This must be a properly formatted JSON-RPC response with a `result` field
|
||||
containing the initialization data, not another request.
|
||||
"""
|
||||
|
||||
result: InitializeResult
|
||||
|
||||
def model_dump_json(self, **kwargs: Any) -> str:
|
||||
"""Convert to JSON string with proper formatting."""
|
||||
# Convert to dict
|
||||
data = {
|
||||
"jsonrpc": self.jsonrpc,
|
||||
"id": self.id,
|
||||
"result": self.result.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# Return JSON string
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
class ListToolsRequest(JSONRPCRequest):
|
||||
method: str = Field(default="tools/list", frozen=True)
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolAnnotations(BaseModel):
|
||||
"""
|
||||
Represents tool annotations for hints about behavior.
|
||||
"""
|
||||
|
||||
title: str | None = None
|
||||
readOnlyHint: bool | None = None
|
||||
destructiveHint: bool | None = None
|
||||
idempotentHint: bool | None = None
|
||||
openWorldHint: bool | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""
|
||||
Represents an MCP tool definition.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
inputSchema: dict[str, Any] | None = None
|
||||
annotations: ToolAnnotations | None = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListToolsResult(BaseModel):
|
||||
tools: list[Tool]
|
||||
|
||||
|
||||
class ListToolsResponse(JSONRPCResponse):
|
||||
result: ListToolsResult
|
||||
|
||||
|
||||
class CallToolRequest(JSONRPCRequest):
|
||||
method: str = Field(default="tools/call", frozen=True)
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class CallToolResult(BaseModel):
|
||||
content: Any
|
||||
|
||||
|
||||
class CallToolResponse(JSONRPCResponse):
|
||||
result: CallToolResult
|
||||
|
||||
|
||||
# Resource and Prompt protocol stubs (expand as needed)
|
||||
class ListResourcesRequest(JSONRPCRequest):
|
||||
method: str = Field(default="resources/list", frozen=True)
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ListResourcesResponse(JSONRPCResponse):
|
||||
result: dict[str, Any]
|
||||
|
||||
|
||||
class ListPromptsRequest(JSONRPCRequest):
|
||||
method: str = Field(default="prompts/list", frozen=True)
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ListPromptsResponse(JSONRPCResponse):
|
||||
result: dict[str, Any]
|
||||
|
||||
|
||||
# Utility type alias for all MCP protocol messages
|
||||
MCPMessage = Union[
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
JSONRPCError,
|
||||
PingRequest,
|
||||
PingResponse,
|
||||
InitializeRequest,
|
||||
InitializeResponse,
|
||||
ListToolsRequest,
|
||||
ListToolsResponse,
|
||||
CallToolRequest,
|
||||
CallToolResponse,
|
||||
ProgressNotification,
|
||||
CancelRequest,
|
||||
ShutdownRequest,
|
||||
ShutdownResponse,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResponse,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResponse,
|
||||
]
|
||||
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from arcade.worker.config.deployment import (
|
||||
from arcade.cli.deployment import (
|
||||
Config,
|
||||
Deployment,
|
||||
LocalPackages,
|
||||
|
|
|
|||
44
arcade/tests/mcp/test_convert.py
Normal file
44
arcade/tests/mcp/test_convert.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk import tool
|
||||
from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool
|
||||
|
||||
|
||||
@tool
|
||||
def sample_tool(x: Annotated[int, "first"], y: Annotated[int, "second"]) -> int:
|
||||
"""Return x+y"""
|
||||
|
||||
return x + y
|
||||
|
||||
|
||||
def test_convert_to_mcp_content_primitives():
|
||||
assert convert_to_mcp_content(42) == [{"type": "text", "text": "42"}]
|
||||
assert convert_to_mcp_content("hello") == [{"type": "text", "text": "hello"}]
|
||||
assert convert_to_mcp_content(True) == [{"type": "text", "text": "True"}]
|
||||
|
||||
|
||||
def test_convert_to_mcp_content_complex():
|
||||
data = {"a": 1}
|
||||
expected_json = json.dumps(data)
|
||||
assert convert_to_mcp_content(data) == [{"type": "text", "text": expected_json}]
|
||||
|
||||
|
||||
def test_create_mcp_tool():
|
||||
# Materialize a tool via catalog then feed it to create_mcp_tool
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(sample_tool, "convert_toolkit")
|
||||
mat_tool = next(iter(catalog)) # only tool
|
||||
mcp_tool = create_mcp_tool(mat_tool)
|
||||
|
||||
assert mcp_tool is not None
|
||||
assert mcp_tool["name"] == "ConvertToolkit_SampleTool"
|
||||
assert mcp_tool["description"]
|
||||
# Ensure input schema contains both parameters and marks them required
|
||||
props = mcp_tool["inputSchema"]["properties"]
|
||||
assert set(props.keys()) == {"x", "y"}
|
||||
|
||||
required_fields = set(mcp_tool["inputSchema"].get("required", []))
|
||||
# Ensure no unexpected required fields and that declared ones are subset of expected
|
||||
assert required_fields.issubset({"x", "y"})
|
||||
57
arcade/tests/mcp/test_message_processor.py
Normal file
57
arcade/tests/mcp/test_message_processor.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor
|
||||
from arcade.worker.mcp.types import InitializeRequest, PingRequest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processor_parses_initialize_json():
|
||||
"""Ensure JSON initialize strings are converted into InitializeRequest objects."""
|
||||
json_init = '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}\n'
|
||||
processor = MCPMessageProcessor()
|
||||
|
||||
result = await processor.process_request(json_init)
|
||||
|
||||
assert isinstance(result, InitializeRequest)
|
||||
assert result.id == 1
|
||||
assert result.method == "initialize"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processor_passes_notifications_unchanged():
|
||||
"""Unknown notifications should be passed through as parsed dictionaries without errors."""
|
||||
json_notification = '{"jsonrpc":"2.0","id":null,"method":"notifications/custom","params":{}}\n'
|
||||
processor = MCPMessageProcessor()
|
||||
|
||||
result = await processor.process_request(json_notification)
|
||||
|
||||
# The MCPMessageProcessor keeps unknown notifications as simple dicts
|
||||
assert isinstance(result, dict)
|
||||
assert result["method"] == "notifications/custom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processor_middleware_execution_order(monkeypatch):
|
||||
"""Middleware (sync + async) should be executed in the order they were added."""
|
||||
|
||||
order: list[str] = []
|
||||
|
||||
def mw_sync(msg, direction): # type: ignore[return-value]
|
||||
order.append("sync")
|
||||
return msg
|
||||
|
||||
async def mw_async(msg, direction): # type: ignore[return-value]
|
||||
await asyncio.sleep(0) # ensure it is truly async
|
||||
order.append("async")
|
||||
return msg
|
||||
|
||||
processor = create_message_processor(mw_sync, mw_async)
|
||||
|
||||
# Use a pre-parsed PingRequest instance so we don't test parsing again here
|
||||
ping = PingRequest(id=42)
|
||||
|
||||
_ = await processor.process_request(ping)
|
||||
|
||||
assert order == ["sync", "async"]
|
||||
153
arcade/tests/mcp/test_server.py
Normal file
153
arcade/tests/mcp/test_server.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
import sys
|
||||
import types
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk import tool
|
||||
from arcade.worker.mcp import server as mcp_server
|
||||
from arcade.worker.mcp.types import (
|
||||
CallToolRequest,
|
||||
CancelRequest,
|
||||
InitializeRequest,
|
||||
ListToolsRequest,
|
||||
PingRequest,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test helpers / stubs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeAuth:
|
||||
async def authorize(self, auth_requirement: Any, user_id: str):
|
||||
"""Return an object that mimics AuthorizationResponse with completed status."""
|
||||
|
||||
class _Ctx: # minimal stub
|
||||
token = "dummy-token" # noqa: S105
|
||||
|
||||
class _Resp: # pylint: disable=too-few-public-methods
|
||||
status = "completed"
|
||||
url = ""
|
||||
context = _Ctx()
|
||||
|
||||
return _Resp()
|
||||
|
||||
|
||||
class _FakeArcade: # pylint: disable=too-few-public-methods
|
||||
def __init__(self, **_: Any):
|
||||
self.auth = _FakeAuth()
|
||||
|
||||
|
||||
# Ensure that the AsyncArcade & ArcadeError symbols inside server.py point to our stubs.
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_arcadepy(monkeypatch):
|
||||
"""Patch the external `arcadepy` dependency used by mcp.server."""
|
||||
|
||||
# Patch the imported symbols on the already-imported server module
|
||||
monkeypatch.setattr(mcp_server, "AsyncArcade", _FakeArcade, raising=True)
|
||||
monkeypatch.setattr(mcp_server, "ArcadeError", Exception, raising=True)
|
||||
|
||||
# Provide a dummy `arcadepy` module in sys.modules for any other importers
|
||||
fake_arcadepy = types.ModuleType("arcadepy")
|
||||
fake_arcadepy.AsyncArcade = _FakeArcade # type: ignore[attr-defined]
|
||||
fake_arcadepy.ArcadeError = Exception # type: ignore[attr-defined]
|
||||
sys.modules["arcadepy"] = fake_arcadepy
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
sys.modules.pop("arcadepy", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures for a sample tool / catalog / server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@tool
|
||||
def multiply(a: Annotated[int, "a"], b: Annotated[int, "b"]) -> Annotated[int, "result"]:
|
||||
"""Return the product of *a* and *b*."""
|
||||
|
||||
return a * b
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sample_catalog():
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(multiply, "test_toolkit")
|
||||
return catalog
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server(sample_catalog):
|
||||
# MCPServer constructor is synchronous, so fixture need not be async
|
||||
return mcp_server.MCPServer(sample_catalog, enable_logging=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_handle_ping(server):
|
||||
req = PingRequest(id=123)
|
||||
resp = await server._handle_ping(req) # pylint: disable=protected-access
|
||||
assert resp.id == 123
|
||||
assert resp.result == {"pong": True}
|
||||
|
||||
|
||||
async def test_handle_initialize(server):
|
||||
req = InitializeRequest(id=1)
|
||||
resp = await server._handle_initialize(req) # pylint: disable=protected-access
|
||||
assert resp.id == 1
|
||||
assert resp.result.protocolVersion == mcp_server.MCP_PROTOCOL_VERSION
|
||||
assert resp.result.serverInfo.name.startswith("Arcade")
|
||||
|
||||
|
||||
async def test_handle_list_tools(server):
|
||||
req = ListToolsRequest(id=99)
|
||||
resp = await server._handle_list_tools(req) # pylint: disable=protected-access
|
||||
assert resp.id == 99
|
||||
# Should list our sample tool only
|
||||
tool_names = [t.name for t in resp.result.tools]
|
||||
assert "TestToolkit_Multiply" in tool_names # toolkit + "_" + tool
|
||||
|
||||
|
||||
async def test_handle_call_tool_success(server):
|
||||
req = CallToolRequest(
|
||||
id="call-1",
|
||||
params={
|
||||
"name": "TestToolkit_Multiply",
|
||||
"input": {"a": 6, "b": 7},
|
||||
},
|
||||
)
|
||||
resp = await server._handle_call_tool(req, user_id="tester@example.com") # pylint: disable=protected-access
|
||||
|
||||
assert resp.id == "call-1"
|
||||
# convert_to_mcp_content wraps primitives in list-of-dicts
|
||||
assert resp.result.content == [{"type": "text", "text": "42"}]
|
||||
|
||||
|
||||
async def test_send_response_dict(server, monkeypatch):
|
||||
"""_send_response should JSON-serialize plain dictionaries."""
|
||||
|
||||
sent: list[str] = []
|
||||
|
||||
class _Write:
|
||||
async def send(self, msg):
|
||||
sent.append(msg)
|
||||
|
||||
await server._send_response(_Write(), {"foo": "bar"}) # pylint: disable=protected-access
|
||||
|
||||
assert sent and sent[0].strip() == '{"foo": "bar"}'
|
||||
|
||||
|
||||
async def test_handle_cancel(server):
|
||||
req = CancelRequest(id=77, params={"id": "abc"})
|
||||
resp = await server._handle_cancel(req) # pylint: disable=protected-access
|
||||
assert resp.result == {"ok": True}
|
||||
32
arcade/tests/mcp/test_stdio.py
Normal file
32
arcade/tests/mcp/test_stdio.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import io
|
||||
import queue
|
||||
|
||||
from arcade.worker.mcp.stdio import stdio_reader, stdio_writer
|
||||
|
||||
|
||||
def test_stdio_reader_puts_lines_and_none():
|
||||
q: queue.Queue[str | None] = queue.Queue()
|
||||
test_input = io.StringIO("line1\nline2\n")
|
||||
|
||||
stdio_reader(test_input, q)
|
||||
|
||||
# We should get the two lines followed by None sentinel
|
||||
assert q.get_nowait() == "line1\n"
|
||||
assert q.get_nowait() == "line2\n"
|
||||
assert q.get_nowait() is None
|
||||
|
||||
|
||||
def test_stdio_writer_reads_until_none():
|
||||
q: queue.Queue[str | None] = queue.Queue()
|
||||
output_stream = io.StringIO()
|
||||
|
||||
# preload queue with two messages and sentinel
|
||||
q.put("msg1")
|
||||
q.put("msg2\n")
|
||||
q.put(None)
|
||||
|
||||
stdio_writer(output_stream, q)
|
||||
|
||||
# Ensure writer appended newlines when missing
|
||||
output_stream.seek(0)
|
||||
assert output_stream.read() == "msg1\nmsg2\n"
|
||||
237
arcade/tests/worker/test_worker_base.py
Normal file
237
arcade/tests/worker/test_worker_base.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import os
|
||||
from typing import Annotated
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from arcade.core.errors import ToolDefinitionError
|
||||
from arcade.core.schema import (
|
||||
ToolCallRequest,
|
||||
ToolCallResponse,
|
||||
ToolContext,
|
||||
ToolReference,
|
||||
)
|
||||
from arcade.sdk import tool
|
||||
from arcade.worker.core.base import BaseWorker
|
||||
from arcade.worker.core.common import RequestData, Router
|
||||
from arcade.worker.core.components import (
|
||||
CallToolComponent,
|
||||
CatalogComponent,
|
||||
HealthCheckComponent,
|
||||
)
|
||||
|
||||
|
||||
@tool()
|
||||
def sample_tool(
|
||||
context: ToolContext, a: Annotated[int, "a"], b: Annotated[int, "b"]
|
||||
) -> Annotated[int, "output"]:
|
||||
"""Sample tool for testing."""
|
||||
return a + b
|
||||
|
||||
|
||||
# Define error tool at module level to avoid indentation issues with getsource
|
||||
@tool()
|
||||
def error_tool(context: ToolContext) -> int:
|
||||
"""This tool always raises an error."""
|
||||
raise ValueError("Something went wrong")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router():
|
||||
router = MagicMock(spec=Router)
|
||||
router.add_route = MagicMock()
|
||||
return router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_worker(mock_router):
|
||||
# Set env var temporarily for testing secret loading
|
||||
os.environ["ARCADE_WORKER_SECRET"] = "test_secret_env" # noqa: S105
|
||||
worker = BaseWorker()
|
||||
worker.register_routes(mock_router) # Register routes using the mock router
|
||||
# Clean up env var
|
||||
del os.environ["ARCADE_WORKER_SECRET"]
|
||||
return worker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_worker_no_auth():
|
||||
return BaseWorker(disable_auth=True)
|
||||
|
||||
|
||||
# --- BaseWorker Tests ---
|
||||
|
||||
|
||||
def test_base_worker_init_with_secret():
|
||||
worker = BaseWorker(secret="explicit_secret") # noqa: S106
|
||||
assert worker.secret == "explicit_secret" # noqa: S105
|
||||
assert not worker.disable_auth
|
||||
|
||||
|
||||
def test_base_worker_init_with_env_secret():
|
||||
os.environ["ARCADE_WORKER_SECRET"] = "env_secret_value" # noqa: S105
|
||||
worker = BaseWorker()
|
||||
assert worker.secret == "env_secret_value" # noqa: S105
|
||||
assert not worker.disable_auth
|
||||
del os.environ["ARCADE_WORKER_SECRET"]
|
||||
|
||||
|
||||
def test_base_worker_init_no_secret_raises_error():
|
||||
# Ensure env var is not set
|
||||
if "ARCADE_WORKER_SECRET" in os.environ:
|
||||
del os.environ["ARCADE_WORKER_SECRET"]
|
||||
with pytest.raises(ValueError, match="No secret provided for worker"):
|
||||
BaseWorker()
|
||||
|
||||
|
||||
def test_base_worker_init_disable_auth():
|
||||
worker = BaseWorker(disable_auth=True)
|
||||
assert worker.secret == ""
|
||||
assert worker.disable_auth
|
||||
|
||||
|
||||
def test_register_tool(base_worker_no_auth):
|
||||
assert len(base_worker_no_auth.catalog) == 0
|
||||
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
|
||||
assert len(base_worker_no_auth.catalog) == 1
|
||||
tool_def = base_worker_no_auth.get_catalog()[0]
|
||||
assert tool_def.name == "SampleTool"
|
||||
assert tool_def.toolkit.name == "TestKit"
|
||||
|
||||
|
||||
def test_get_catalog(base_worker_no_auth):
|
||||
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
|
||||
catalog = base_worker_no_auth.get_catalog()
|
||||
assert isinstance(catalog, list)
|
||||
assert len(catalog) == 1
|
||||
assert catalog[0].name == "SampleTool"
|
||||
|
||||
|
||||
def test_health_check(base_worker_no_auth):
|
||||
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
|
||||
health = base_worker_no_auth.health_check()
|
||||
assert health == {"status": "ok", "tool_count": "1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_success(base_worker_no_auth):
|
||||
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
|
||||
# Create ToolReference WITHOUT version, as register_tool doesn't seem to set it
|
||||
tool_ref = ToolReference(toolkit="TestKit", name="SampleTool")
|
||||
tool_request = ToolCallRequest(
|
||||
execution_id="test_exec_id",
|
||||
tool=tool_ref,
|
||||
inputs={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
response = await base_worker_no_auth.call_tool(tool_request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.output.value == 8
|
||||
assert response.output.error is None
|
||||
assert response.execution_id == "test_exec_id"
|
||||
assert response.duration > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_execution_error(base_worker_no_auth):
|
||||
# Tool is now defined at module level
|
||||
try:
|
||||
base_worker_no_auth.register_tool(error_tool, toolkit_name="error_kit")
|
||||
except ToolDefinitionError as e:
|
||||
pytest.fail(f"Failed to register error_tool: {e}")
|
||||
|
||||
# Create ToolReference WITHOUT version
|
||||
tool_ref = ToolReference(toolkit="ErrorKit", name="ErrorTool")
|
||||
tool_request = ToolCallRequest(
|
||||
execution_id="test_exec_error",
|
||||
tool=tool_ref,
|
||||
inputs={},
|
||||
)
|
||||
|
||||
response = await base_worker_no_auth.call_tool(tool_request)
|
||||
|
||||
assert response.success is False
|
||||
assert response.output.value is None
|
||||
assert response.output.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_not_found(base_worker_no_auth):
|
||||
# Use ToolReference without version for lookup consistency
|
||||
tool_ref = ToolReference(toolkit="nonexistent", name="nosuchtool")
|
||||
tool_request = ToolCallRequest(
|
||||
execution_id="test_exec_notfound",
|
||||
tool=tool_ref,
|
||||
inputs={},
|
||||
)
|
||||
|
||||
# Update regex to match actual error format
|
||||
with pytest.raises(ValueError):
|
||||
await base_worker_no_auth.call_tool(tool_request)
|
||||
|
||||
|
||||
# --- Component Tests (tested via BaseWorker registration) ---
|
||||
|
||||
|
||||
def test_register_routes_registers_default_components(base_worker, mock_router):
|
||||
# BaseWorker calls register_routes in its init via the fixture
|
||||
assert mock_router.add_route.call_count == len(BaseWorker.default_components)
|
||||
|
||||
calls = mock_router.add_route.call_args_list
|
||||
expected_paths = ["tools", "tools/invoke", "health"]
|
||||
registered_paths = [
|
||||
call[0][0] for call in calls
|
||||
] # call[0] are positional args, call[0][0] is endpoint_path
|
||||
|
||||
assert sorted(registered_paths) == sorted(expected_paths)
|
||||
|
||||
# Check if components were instantiated and passed to add_route
|
||||
assert any(isinstance(call[0][1], CatalogComponent) for call in calls)
|
||||
assert any(isinstance(call[0][1], CallToolComponent) for call in calls)
|
||||
assert any(isinstance(call[0][1], HealthCheckComponent) for call in calls)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_catalog_component_call(base_worker_no_auth):
|
||||
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
|
||||
component = CatalogComponent(base_worker_no_auth)
|
||||
# Mock request data - not actually used by this component's __call__
|
||||
mock_request = MagicMock(spec=RequestData)
|
||||
catalog_response = await component(mock_request)
|
||||
|
||||
assert isinstance(catalog_response, list)
|
||||
assert len(catalog_response) == 1
|
||||
assert catalog_response[0].name == "SampleTool"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_component_call(base_worker_no_auth):
|
||||
base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit")
|
||||
component = CallToolComponent(base_worker_no_auth)
|
||||
|
||||
# Create ToolReference WITHOUT version
|
||||
tool_ref = ToolReference(toolkit="TestKit", name="SampleTool")
|
||||
request_body = {
|
||||
"execution_id": "comp_test_exec",
|
||||
"tool": tool_ref.model_dump(),
|
||||
"inputs": {"a": 10, "b": 5},
|
||||
}
|
||||
mock_request = MagicMock(spec=RequestData)
|
||||
mock_request.body_json = request_body
|
||||
|
||||
response = await component(mock_request)
|
||||
|
||||
assert isinstance(response, ToolCallResponse)
|
||||
assert response.success is True
|
||||
assert response.output.value == 15
|
||||
assert response.execution_id == "comp_test_exec"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_component_call(base_worker_no_auth):
|
||||
component = HealthCheckComponent(base_worker_no_auth)
|
||||
mock_request = MagicMock(spec=RequestData)
|
||||
health_response = await component(mock_request)
|
||||
|
||||
assert health_response == {"status": "ok", "tool_count": "0"}
|
||||
167
arcade/tests/worker/test_worker_fastapi.py
Normal file
167
arcade/tests/worker/test_worker_fastapi.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from arcade.core.schema import ToolCallRequest, ToolContext, ToolReference
|
||||
from arcade.sdk import tool
|
||||
from arcade.worker.fastapi.worker import FastAPIWorker
|
||||
|
||||
|
||||
@tool()
|
||||
def sample_tool_fastapi(
|
||||
context: ToolContext, x: Annotated[int, "x"], y: Annotated[str, "y"]
|
||||
) -> Annotated[str, "output"]:
|
||||
"""A sample tool for FastAPI tests."""
|
||||
return f"{y}-{x}"
|
||||
|
||||
|
||||
# Define tool at module level to avoid indentation issues with getsource
|
||||
@tool()
|
||||
def error_throwing_tool(
|
||||
context: ToolContext,
|
||||
a: Annotated[int, "a", "Input integer a"], # Added description for parameter
|
||||
) -> int:
|
||||
"""This tool throws a ValueError.""" # Added description for tool
|
||||
raise ValueError("Test execution error")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app():
|
||||
return FastAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker_secret():
|
||||
return "test-secret-fastapi"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fastapi_worker(test_app, worker_secret):
|
||||
worker = FastAPIWorker(app=test_app, secret=worker_secret)
|
||||
worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit")
|
||||
return worker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fastapi_worker_no_auth(test_app):
|
||||
worker = FastAPIWorker(app=test_app, disable_auth=True)
|
||||
worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit")
|
||||
return worker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_app, fastapi_worker): # Use the worker fixture to ensure routes are registered
|
||||
return TestClient(test_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_no_auth(test_app, fastapi_worker_no_auth):
|
||||
return TestClient(test_app)
|
||||
|
||||
|
||||
# --- FastAPIWorker Tests ---
|
||||
|
||||
|
||||
def test_fastapi_worker_registers_routes(client, fastapi_worker):
|
||||
# Check if routes exist by trying to access them (even if auth fails)
|
||||
response = client.get(f"{fastapi_worker.base_path}/health")
|
||||
assert response.status_code != 404 # Should be 200
|
||||
|
||||
response = client.get(f"{fastapi_worker.base_path}/tools")
|
||||
assert response.status_code != 404 # Should be 403 without auth
|
||||
|
||||
# Prepare a dummy request body for invoke
|
||||
tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi")
|
||||
request_body = ToolCallRequest(
|
||||
execution_id="test", tool=tool_ref, inputs={"x": 1, "y": "test"}
|
||||
).model_dump()
|
||||
|
||||
response = client.post(f"{fastapi_worker.base_path}/tools/invoke", json=request_body)
|
||||
assert response.status_code != 404 # Should be 403 without auth
|
||||
|
||||
|
||||
# --- Route Tests (using TestClient) ---
|
||||
|
||||
|
||||
# Health Check
|
||||
def test_health_check_route(client, worker_secret):
|
||||
response = client.get("/worker/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok", "tool_count": "1"}
|
||||
|
||||
|
||||
def test_health_check_route_no_auth(client_no_auth):
|
||||
response = client_no_auth.get("/worker/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok", "tool_count": "1"}
|
||||
|
||||
|
||||
# Catalog
|
||||
def test_get_catalog_route_no_auth_header(client):
|
||||
response = client.get("/worker/tools")
|
||||
assert response.status_code == 403
|
||||
assert "Not authenticated" in response.text
|
||||
|
||||
|
||||
def test_get_catalog_route_invalid_auth_header(client, worker_secret):
|
||||
response = client.get("/worker/tools", headers={"Authorization": "Bearer invalid-token"})
|
||||
assert response.status_code == 401 # Unauthorized
|
||||
# Updated expected error message based on last run
|
||||
assert "Invalid token. Error: Not enough segments" in response.text
|
||||
|
||||
|
||||
def test_get_catalog_route_no_auth_worker(client_no_auth):
|
||||
response = client_no_auth.get("/worker/tools")
|
||||
assert response.status_code == 200
|
||||
catalog = response.json()
|
||||
assert isinstance(catalog, list)
|
||||
assert len(catalog) == 1
|
||||
assert catalog[0]["name"] == "SampleToolFastapi"
|
||||
|
||||
|
||||
# Call Tool
|
||||
@pytest.fixture
|
||||
def call_tool_payload():
|
||||
tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi")
|
||||
return ToolCallRequest(
|
||||
execution_id="fastapi-test-exec", tool=tool_ref, inputs={"x": 123, "y": "hello"}
|
||||
).model_dump()
|
||||
|
||||
|
||||
def test_call_tool_route_no_auth_header(client, call_tool_payload):
|
||||
response = client.post("/worker/tools/invoke", json=call_tool_payload)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_call_tool_route_invalid_auth_header(client, worker_secret, call_tool_payload):
|
||||
response = client.post(
|
||||
"/worker/tools/invoke",
|
||||
json=call_tool_payload,
|
||||
headers={"Authorization": "Bearer invalid-token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_call_tool_route_no_auth_worker(client_no_auth, call_tool_payload):
|
||||
response = client_no_auth.post("/worker/tools/invoke", json=call_tool_payload)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["success"] is True
|
||||
assert result["output"]["value"] == "hello-123"
|
||||
|
||||
|
||||
def test_call_tool_route_tool_not_found(client_no_auth, call_tool_payload):
|
||||
call_tool_payload["tool"]["name"] = "NonExistentTool"
|
||||
call_tool_payload["tool"]["toolkit"] = "FastapiKit"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = client_no_auth.post(
|
||||
"/worker/tools/invoke",
|
||||
json=call_tool_payload,
|
||||
)
|
||||
# The handler catches the ValueError and returns a 500 internal server error
|
||||
# Ideally, this might be a 404 or 400, but BaseWorker.call_tool raises ValueError
|
||||
# which isn't automatically mapped to a 4xx by FastAPI unless handled explicitly.
|
||||
# TODO fix this.
|
||||
8
examples/mcp/claude.json
Normal file
8
examples/mcp/claude.json
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"arcade": {
|
||||
"command": "bash",
|
||||
"args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade mcp"]
|
||||
}
|
||||
}
|
||||
}
|
||||
25
examples/mcp/run_stdio.py
Normal file
25
examples/mcp/run_stdio.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import arcade_google # pip install arcade_google
|
||||
import arcade_search # pip install arcade_search
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.worker.mcp.stdio import StdioServer
|
||||
|
||||
# 2. Create and populate the tool catalog
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_google) # Registers all tools in the package
|
||||
catalog.add_module(arcade_search)
|
||||
|
||||
|
||||
# 3. Main entrypoint
|
||||
async def main():
|
||||
# Create the worker with the tool catalog
|
||||
worker = StdioServer(catalog)
|
||||
|
||||
# Run the worker
|
||||
await worker.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
"""
|
||||
Example script demonstrating how to build a simple chatbot with Arcade.
|
||||
|
||||
For this example, we are using the prebuilt Google Docs toolkit to create and edit documents.
|
||||
|
||||
Try asking questions like:
|
||||
- "Create a document with the title 'My New Document' and content 'Hello, World!'"
|
||||
- "List my 2 most recently modified documents and tell me the title, document id, and document URL of each one and summarize them."
|
||||
- "Edit the second document from the list you just returned and add the text 'Hello, World!' to the end of it."
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
def chat(openai_client: OpenAI, tool_names: list[str], user_id: str) -> None:
|
||||
history = []
|
||||
|
||||
print("Hello! How can I help you today?")
|
||||
while True:
|
||||
message = {"role": "user", "content": input(">")}
|
||||
history.append(message)
|
||||
chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history)
|
||||
|
||||
# If the tool call requires authorization, then wait for the user to authorize and then call the tool again
|
||||
if (
|
||||
chat_result.choices[0].tool_authorizations
|
||||
and chat_result.choices[0].tool_authorizations[0].get("status") == "pending"
|
||||
):
|
||||
print("\n" + chat_result.choices[0].message.content)
|
||||
input("\nAfter you have authorized, press Enter to continue...")
|
||||
chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history)
|
||||
|
||||
history.append({"role": "assistant", "content": chat_result.choices[0].message.content})
|
||||
|
||||
print(chat_result.choices[0].message.content)
|
||||
|
||||
|
||||
def call_tool_with_openai(
|
||||
client: OpenAI, tool_names: list[str], user_id: str, messages: list[dict]
|
||||
) -> dict:
|
||||
response = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model="gpt-4o-mini",
|
||||
user=user_id,
|
||||
tools=tool_names,
|
||||
tool_choice="generate",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arcade_api_key = os.environ.get(
|
||||
"ARCADE_API_KEY"
|
||||
) # If you forget your Arcade API key, it is stored at ~/.arcade/credentials.yaml on `arcade login`
|
||||
cloud_host = "https://api.arcade.dev/v1"
|
||||
user_id = "user@example.com"
|
||||
|
||||
openai_client = OpenAI(
|
||||
api_key=arcade_api_key,
|
||||
base_url=cloud_host,
|
||||
)
|
||||
|
||||
tool_names = [
|
||||
"Google.SendEmail",
|
||||
"Google.SendDraftEmail",
|
||||
"Google.WriteDraftEmail",
|
||||
"Google.UpdateDraftEmail",
|
||||
"Google.ListDraftEmails",
|
||||
"Google.ListEmailsByHeader",
|
||||
"Google.ListEmails",
|
||||
]
|
||||
|
||||
chat(openai_client, tool_names, user_id)
|
||||
Loading…
Reference in a new issue