diff --git a/.vscode/settings.json b/.vscode/settings.json index 791aba76..4e27f70a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,7 +10,6 @@ "[python]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.fixAll": "explicit", "source.organizeImports": "explicit" }, "editor.defaultFormatter": "charliermarsh.ruff" diff --git a/arcade/arcade/worker/config/deployment.py b/arcade/arcade/cli/deployment.py similarity index 100% rename from arcade/arcade/worker/config/deployment.py rename to arcade/arcade/cli/deployment.py diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index 3375385b..65a8906a 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -24,6 +24,7 @@ from arcade.cli.constants import ( PROD_CLOUD_HOST, PROD_ENGINE_HOST, ) +from arcade.cli.deployment import Deployment from arcade.cli.display import ( display_arcade_chat_header, display_eval_results, @@ -48,7 +49,6 @@ from arcade.cli.utils import ( version_callback, ) from arcade.cli.worker import parse_deployment_response -from arcade.worker.config.deployment import Deployment cli = typer.Typer( cls=OrderCommands, @@ -61,7 +61,12 @@ cli = typer.Typer( ) -cli.add_typer(worker.app, name="worker", help="Manage workers") +cli.add_typer( + worker.app, + name="worker", + help="Manage deployments of tool servers (logs, list, etc)", + rich_help_panel="Deployment", +) console = Console() @@ -192,7 +197,10 @@ def show( show_logic(toolkit, tool, host, local, port, force_tls, force_no_tls, debug) -@cli.command(help="Start Arcade Chat in the terminal", rich_help_panel="Launch") +@cli.command( + help="Start a chat with a model in the terminal to test tools", + rich_help_panel="Tool Development", +) def chat( model: str = typer.Option("gpt-4o", "-m", "--model", help="The model to use for prediction."), stream: bool = typer.Option( @@ -439,7 +447,7 @@ def evals( @cli.command( - help="Start Arcade Worker serving tools installed in the current python environment", + help="Start tool server worker with locally installed tools", rich_help_panel="Launch", ) def serve( @@ -460,19 +468,38 @@ def serve( otel_enable: bool = typer.Option( False, "--otel-enable", help="Send logs to OpenTelemetry", show_default=True ), + mcp: bool = typer.Option( + False, "--mcp", help="Run as a local MCP server over stdio", show_default=True + ), debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"), ) -> None: """ Start a local Arcade Worker server. """ - workerup(host, port, disable_auth, otel_enable, debug) + from arcade.cli.serve import serve_default_worker + + try: + serve_default_worker( + host, + port, + disable_auth=disable_auth, + enable_otel=otel_enable, + debug=debug, + mcp=mcp, + ) + except KeyboardInterrupt: + typer.Exit() + except Exception as e: + error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}" + console.print(error_message, style="bold red") + typer.Exit(code=1) -@cli.command(help="Launch Arcade Worker and Engine locally", rich_help_panel="Launch") +@cli.command(help="Launch Arcade - requires 'arcade-engine'", rich_help_panel="Launch") def dev( - host: str = typer.Option("127.0.0.1", help="Host for the worker server.", show_default=True), + host: str = typer.Option("127.0.0.1", help="Host for the toolkit server.", show_default=True), port: int = typer.Option( - 8002, "-p", "--port", help="Port for the worker server.", show_default=True + 8002, "-p", "--port", help="Port for the toolkit server.", show_default=True ), engine_config: str = typer.Option( None, "-c", "--config", help="Path to the engine configuration file." @@ -483,7 +510,7 @@ def dev( debug: bool = typer.Option(False, "-d", "--debug", help="Show debug information"), ) -> None: """ - Start both the worker and engine servers. + Start both the toolkit server and engine servers. """ try: start_servers(host, port, engine_config, engine_env=env_file, debug=debug) @@ -493,8 +520,9 @@ def dev( typer.Exit(code=1) -# TODO: deprecate this next major version -@cli.command(help="Start a local Arcade Worker server", rich_help_panel="Launch", hidden=True) +@cli.command( + help="Start a server with locally installed Arcade tools", rich_help_panel="Launch", hidden=True +) def workerup( host: str = typer.Option( "127.0.0.1", @@ -523,17 +551,21 @@ def workerup( try: serve_default_worker( - host, port, disable_auth=disable_auth, enable_otel=otel_enable, debug=debug + host, + port, + disable_auth=disable_auth, + enable_otel=otel_enable, + debug=debug, ) except KeyboardInterrupt: typer.Exit() except Exception as e: - error_message = f"❌ Failed to start Arcade Worker: {escape(str(e))}" + error_message = f"❌ Failed to start Arcade Toolkit Server: {escape(str(e))}" console.print(error_message, style="bold red") typer.Exit(code=1) -@cli.command(help="Deploy worker to Arcade Cloud", rich_help_panel="Deployment") +@cli.command(help="Deploy toolkits to Arcade Cloud", rich_help_panel="Deployment") def deploy( deployment_file: str = typer.Option( "worker.toml", "--deployment-file", "-d", help="The deployment file to deploy." diff --git a/arcade/arcade/cli/serve.py b/arcade/arcade/cli/serve.py index 8100a4a9..4c378b2b 100644 --- a/arcade/arcade/cli/serve.py +++ b/arcade/arcade/cli/serve.py @@ -1,70 +1,209 @@ import asyncio import logging import os +import signal import sys +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from importlib.metadata import version as get_pkg_version +from pathlib import Path from typing import Any import fastapi import uvicorn from loguru import logger +from rich.console import Console +from arcade.cli.constants import ARCADE_CONFIG_PATH +from arcade.cli.utils import ( + build_tool_catalog, + discover_toolkits, + load_dotenv, + validate_and_get_config, +) from arcade.core.telemetry import OTELHandler from arcade.sdk import Toolkit from arcade.worker.fastapi.worker import FastAPIWorker +console = Console(width=70, color_system="auto") -class InterceptHandler(logging.Handler): + +def _run_mcp_stdio( + toolkits: list[Toolkit], *, logging_enabled: bool, env_file: str | None = None +) -> None: + """Launch an MCP stdio server; blocks until it exits.""" + + from arcade.worker.mcp.stdio import StdioServer + + # Load env vars before launching server (explicit path, config path, cwd) + if env_file: + load_dotenv(env_file, override=False) + else: + for candidate in [Path(ARCADE_CONFIG_PATH) / "arcade.env", Path.cwd() / "arcade.env"]: + if candidate.is_file(): + load_dotenv(candidate, override=False) + break + + # Set up middleware configuration for stdio mode + middleware_config = { + "stdio_mode": True, # Ensure logs go to stderr + } + + catalog = build_tool_catalog(toolkits) + server = StdioServer( + catalog, + enable_logging=logging_enabled, + middleware_config=middleware_config, + ) + + try: + asyncio.run(server.run()) + except KeyboardInterrupt: + logger.info("MCP server stopped by user.") + except Exception as exc: + logger.exception("Error while running MCP server: %s", exc) + raise + + +def _run_fastapi_server( + app: fastapi.FastAPI, + *, + host: str, + port: int, + workers: int, + timeout_keep_alive: int, + enable_otel: bool, + otel_handler: OTELHandler, + **uvicorn_kwargs: Any, +) -> None: + """Run a FastAPI application via Uvicorn with graceful shutdown.""" + + class CustomUvicornServer(uvicorn.Server): + def install_signal_handlers(self) -> None: + # Disable Uvicorn's default signal handling; we manage it manually + pass + + async def shutdown(self, sockets: Any = None) -> None: + logger.info("Initiating graceful shutdown...") + await super().shutdown(sockets=sockets) + + config = uvicorn.Config( + app=app, + host=host, + port=port, + workers=workers, + timeout_keep_alive=timeout_keep_alive, + log_config=None, + **uvicorn_kwargs, + ) + + server = CustomUvicornServer(config=config) + + async def _serve() -> None: + await server.serve() + + async def _graceful_shutdown() -> None: + try: + logger.info("Shutting down server ...") + await server.shutdown() + + # brief pause for connections to close gracefully + await asyncio.sleep(0.5) + finally: + if enable_otel: + otel_handler.shutdown() + logger.debug("Server shutdown complete.") + + # Map signals to our graceful shutdown + loop = asyncio.get_event_loop() + for sig_name in ( + "SIGINT", + "SIGTERM", + "SIGHUP", + "SIGUSR1", + "SIGUSR2", + "SIGWINCH", + "SIGBREAK", + ): + if hasattr(signal, sig_name): + loop.add_signal_handler( + getattr(signal, sig_name), lambda: asyncio.create_task(_graceful_shutdown()) + ) + + try: + asyncio.run(_serve()) + except KeyboardInterrupt: + logger.info("Server stopped by user.") + finally: + if enable_otel: + otel_handler.shutdown() + + +class RichInterceptHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: - # Get corresponding Loguru level if it exists try: level = logger.level(record.levelname).name except ValueError: - level = record.levelno # type: ignore[assignment] + level = str(record.levelno) - # Find caller from where originated the logged message - frame, depth = sys._getframe(6), 6 - while frame and frame.f_code.co_filename == logging.__file__: - frame = frame.f_back # type: ignore[assignment] - depth += 1 - - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + # Let Loguru handle caller info; don't do stack inspection here + logger.opt(exception=record.exc_info).log(level, record.getMessage()) -def setup_logging(log_level: int = logging.INFO) -> None: +def setup_logging(log_level: int = logging.INFO, mcp_mode: bool = False) -> None: # Intercept everything at the root logger - logging.root.handlers = [InterceptHandler()] + logging.root.handlers = [RichInterceptHandler()] logging.root.setLevel(log_level) - # Remove every other logger's handlers - # and propagate to root logger + # Remove every other logger's handlers and propagate to root logger for name in logging.root.manager.loggerDict: + # Keep handlers for MCP logger if middleware handles it separately + if mcp_mode and name == "arcade.mcp": + continue logging.getLogger(name).handlers = [] logging.getLogger(name).propagate = True - # Configure loguru with custom format, no colors + # Remove default handlers from Loguru + logger.remove() + + # Configure main Loguru sink + # In MCP mode, all general console logs go to stderr to keep stdout clean + sink_destination = sys.stderr if mcp_mode else sys.stdout + + # Configure loguru with a cleaner format and colors + if log_level == logging.DEBUG: + format_string = "{level} | {time:HH:mm:ss} | {name}:{file}:{line: <4} | {message}" + else: + format_string = ( + "{level} | {time:HH:mm:ss} | {message}" + ) logger.configure( handlers=[ { - "sink": sys.stdout, - "serialize": False, + "sink": sink_destination, # Redirect sink based on mcp_mode + "colorize": True, "level": log_level, - "format": "{level} [{time:HH:mm:ss.SSS}] {message}" - + (" {name}:{function}:{line}" if log_level <= logging.DEBUG else "") - + ("\n{exception}" if "{exception}" in "{message}" else ""), + # Format that ensures timestamp on every line and better alignment + "format": format_string, + # Make sure multiline messages are handled properly + "enqueue": True, + "diagnose": True, # Disable traceback framing which adds noise } ] ) + if mcp_mode: + logger.debug("Loguru sink configured for stderr in MCP mode.") @asynccontextmanager -async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def] +async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]: try: yield - except asyncio.CancelledError: + except (asyncio.CancelledError, KeyboardInterrupt): # This is necessary to prevent an unhandled error # when the user presses Ctrl+C logger.debug("Lifespan cancelled.") + raise def serve_default_worker( @@ -75,76 +214,78 @@ def serve_default_worker( timeout_keep_alive: int = 5, enable_otel: bool = False, debug: bool = False, + mcp: bool = False, **kwargs: Any, ) -> None: """ - Get an instance of a FastAPI server with the Arcade Worker. + Get a default instance of a FastAPI server with the Arcade Worker + serving tools installed in the current Python environment. + + Args: + host: The host to run the server on. + port: The port to run the server on. + disable_auth: Whether to disable authentication. + workers: The number of workers to run. + timeout_keep_alive: The timeout for keep-alive connections. + enable_otel: Whether to enable OpenTelemetry. + debug: Whether to enable debug logging. + mcp: Whether to run worker as MCP server over stdio. """ - # Setup unified logging - setup_logging(log_level=logging.DEBUG if debug else logging.INFO) - toolkits = Toolkit.find_all_arcade_toolkits() - if not toolkits: - raise RuntimeError("No toolkits found in Python environment.") + # Setup unified logging first + version = get_pkg_version("arcade-ai") + validate_and_get_config() + setup_logging(log_level=logging.DEBUG if debug else logging.INFO, mcp_mode=mcp) - worker_secret = os.environ.get("ARCADE_WORKER_SECRET") - if not disable_auth and not worker_secret: - logger.warning( - "Warning: ARCADE_WORKER_SECRET environment variable is not set. Using 'dev' as the worker secret.", - ) - worker_secret = worker_secret or "dev" + toolkits = discover_toolkits() + logger.info("Serving the following toolkits:") + for toolkit in toolkits: + if debug: + for name, tools in toolkit.tools.items(): + for tool in tools: + logger.info(f" - {name}: {tool}") + else: + logger.info(f" - {toolkit.name}: {len(toolkit.tools)} tools") + # --- MCP stdio -------------------------------------------------- + if mcp: + env_file = kwargs.pop("env_file", None) + _run_mcp_stdio(toolkits, logging_enabled=not debug, env_file=env_file) + return + + # --- FastAPI HTTP -------------------------------------------------- app = fastapi.FastAPI( title="Arcade Worker", - description="Arcade default Worker implementation using FastAPI.", - version="0.1.0", - lifespan=lifespan, # Use custom lifespan to catch errors, notably KeyboardInterrupt (Ctrl+C) + description="A worker for the Arcade platform", + version=version, + docs_url="/docs" if debug else None, + redoc_url="/redoc" if debug else None, + openapi_url="/openapi.json" if debug else None, + lifespan=lifespan, ) - otel_handler = OTELHandler(app, enable=enable_otel) + secret = os.getenv("ARCADE_WORKER_SECRET", None) + if secret is None: + logger.warning("No secret found for Arcade Worker") + logger.info("Setting secret to 'dev'. Set this in production") + secret = "dev" # noqa: S105 - worker = FastAPIWorker( - app, secret=worker_secret, disable_auth=disable_auth, otel_meter=otel_handler.get_meter() + otel_handler = OTELHandler( + app, enable=enable_otel, log_level=logging.DEBUG if debug else logging.INFO ) - - toolkit_tool_counts = {} - for toolkit in toolkits: - prev_tool_count = worker.catalog.get_tool_count() - worker.register_toolkit(toolkit) - new_tool_count = worker.catalog.get_tool_count() - toolkit_tool_counts[f"{toolkit.name} ({toolkit.package_name})"] = ( - new_tool_count - prev_tool_count - ) - - logger.info("Serving the following toolkits:") - for name, tool_count in toolkit_tool_counts.items(): - logger.info(f" - {name}: {tool_count} tools") - - logger.info("Starting FastAPI server...") - - class CustomUvicornServer(uvicorn.Server): - def install_signal_handlers(self) -> None: - pass # Disable Uvicorn's default signal handlers - - config = uvicorn.Config( + _ = FastAPIWorker( app=app, + secret=secret, + disable_auth=disable_auth, + otel_meter=otel_handler.get_meter(), + ) + _run_fastapi_server( + app, host=host, port=port, workers=workers, timeout_keep_alive=timeout_keep_alive, - log_config=None, + enable_otel=enable_otel, + otel_handler=otel_handler, **kwargs, ) - server = CustomUvicornServer(config=config) - - async def serve() -> None: - await server.serve() - - try: - asyncio.run(serve()) - except KeyboardInterrupt: - logger.info("Server stopped by user.") - finally: - if enable_otel: - otel_handler.shutdown() - logger.debug("Server shutdown complete.") diff --git a/arcade/arcade/cli/utils.py b/arcade/arcade/cli/utils.py index 524d220a..44dbdb41 100644 --- a/arcade/arcade/cli/utils.py +++ b/arcade/arcade/cli/utils.py @@ -1,5 +1,7 @@ import importlib.util import ipaddress +import os +import shlex import webbrowser from dataclasses import dataclass from datetime import datetime @@ -34,6 +36,11 @@ from arcade.sdk import ToolCatalog, Toolkit console = Console() +# ----------------------------------------------------------------------------- +# Shared helpers for the CLI +# ----------------------------------------------------------------------------- + + class OrderCommands(TyperGroup): def list_commands(self, ctx: Context) -> list[str]: # type: ignore[override] """Return list of commands in the order appear.""" @@ -666,3 +673,81 @@ def get_today_context() -> str: today = datetime.now().strftime("%Y-%m-%d") day_of_week = datetime.now().strftime("%A") return f"Today is {today}, {day_of_week}." + + +def discover_toolkits() -> list[Toolkit]: + """Return all Arcade toolkits installed in the active Python environment. + + Raises: + RuntimeError: If no toolkits are found, mirroring the behaviour of Toolkit discovery elsewhere. + """ + toolkits = Toolkit.find_all_arcade_toolkits() + if not toolkits: + raise RuntimeError("No toolkits found in Python environment.") + return toolkits + + +def build_tool_catalog(toolkits: list[Toolkit]) -> ToolCatalog: + """Construct a ``ToolCatalog`` populated with *toolkits*. + + + Args: + toolkits: Toolkits to register in the catalog. + + Returns: + ToolCatalog + """ + catalog = ToolCatalog() + for tk in toolkits: + catalog.add_toolkit(tk) + return catalog + + +def _parse_line(line: str) -> tuple[str, str] | None: + """ + Return (key, value) if the line looks like KEY=VALUE, else None. + Handles quotes and escaped chars via shlex. + """ + if not line or line.startswith("#") or "=" not in line: + return None + key, raw_val = line.split("=", 1) + key = key.strip() + raw_val = raw_val.strip() + + # Use shlex to handle "quoted strings with # hash" etc. + try: + value = shlex.split(raw_val)[0] if raw_val else "" + except ValueError: + # Fallback: naked value without shlex parsing + value = raw_val + + return key, value + + +def load_dotenv(path: str | Path, *, override: bool = False) -> dict[str, str]: + """ + Load variables from *path* into os.environ. + + Args: + path: .env file path + override: replace existing env vars if True + + Returns: + The mapping of vars that were added/updated. + """ + path = Path(path).expanduser() + if not path.is_file(): + return {} + + loaded: dict[str, str] = {} + + for raw in path.read_text().splitlines(): + parsed = _parse_line(raw.strip()) + if parsed is None: + continue + k, v = parsed + if override or k not in os.environ: + os.environ[k] = v + loaded[k] = v + + return loaded diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index af21b3ba..5a27ee0e 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -326,8 +326,16 @@ class ToolContext(BaseModel): for item in items: if item.key.lower() == normalized_key: return item.value + raise ValueError(f"{item_name.capitalize()} {key} not found in context.") + def set_secret(self, key: str, value: str) -> None: + """Set a secret for the tool invocation.""" + if not self.secrets: + self.secrets = [] + secret = ToolSecretItem(key=str(key), value=str(value)) + self.secrets.append(secret) + class ToolCallRequest(BaseModel): """The request to call (invoke) a tool.""" diff --git a/arcade/arcade/core/utils.py b/arcade/arcade/core/utils.py index a199f93d..d9c58696 100644 --- a/arcade/arcade/core/utils.py +++ b/arcade/arcade/core/utils.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import ast import inspect import re from collections.abc import Iterable from types import UnionType -from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, Literal, TypeVar, Union, get_args, get_origin T = TypeVar("T") -def first_or_none(_type: type[T], iterable: Iterable[Any]) -> Optional[T]: +def first_or_none(_type: type[T], iterable: Iterable[Any]) -> T | None: """ Returns the first item in the iterable that is an instance of the given type, or None if no such item is found. """ @@ -65,7 +67,7 @@ def does_function_return_value(func: Callable) -> bool: Returns True if the given function returns a value, i.e. if it has a return statement with a value. """ try: - source: Optional[str] = inspect.getsource(func) + source: str | None = inspect.getsource(func) except OSError: # Workaround for parameterized unit tests that use a dynamically-generated function source = getattr(func, "__source__", None) diff --git a/arcade/arcade/worker/core/base.py b/arcade/arcade/worker/core/base.py index 5c307350..4416c24b 100644 --- a/arcade/arcade/worker/core/base.py +++ b/arcade/arcade/worker/core/base.py @@ -175,7 +175,7 @@ class BaseWorker(Worker): """ Provide a health check that serves as a heartbeat of worker health. """ - return {"status": "ok", "tool_count": len(self.catalog)} + return {"status": "ok", "tool_count": str(len(self.catalog))} def register_routes(self, router: Router) -> None: """ diff --git a/arcade/arcade/worker/core/common.py b/arcade/arcade/worker/core/common.py index 328b1d70..85a73b48 100644 --- a/arcade/arcade/worker/core/common.py +++ b/arcade/arcade/worker/core/common.py @@ -5,6 +5,11 @@ from pydantic import BaseModel from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition +CatalogResponse = list[ToolDefinition] +HealthCheckResponse = dict[str, str] +JSONResponse = dict[str, Any] +ResponseData = CatalogResponse | ToolCallResponse | HealthCheckResponse + class RequestData(BaseModel): """ @@ -17,7 +22,7 @@ class RequestData(BaseModel): """The path of the request.""" method: str """The method of the request.""" - body_json: dict | None = None + body_json: JSONResponse | None = None """The deserialized body of the request (e.g. JSON)""" @@ -28,7 +33,13 @@ class Router(ABC): @abstractmethod def add_route( - self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True + self, + endpoint_path: str, + handler: Callable, + method: str, + require_auth: bool = True, + response_type: type[ResponseData] | None = None, + **kwargs: Any, ) -> None: """ Add a route to the router. @@ -43,7 +54,7 @@ class Worker(ABC): """ @abstractmethod - def get_catalog(self) -> list[ToolDefinition]: + def get_catalog(self) -> CatalogResponse: """ Get the catalog of tools available in the worker. """ @@ -57,7 +68,7 @@ class Worker(ABC): pass @abstractmethod - def health_check(self) -> dict[str, Any]: + def health_check(self) -> HealthCheckResponse: """ Perform a health check of the worker """ @@ -76,7 +87,7 @@ class WorkerComponent(ABC): pass @abstractmethod - async def __call__(self, request: RequestData) -> Any: + async def __call__(self, request: RequestData) -> ResponseData: """ Handle the request. """ diff --git a/arcade/arcade/worker/core/components.py b/arcade/arcade/worker/core/components.py index a9590f9e..490954a6 100644 --- a/arcade/arcade/worker/core/components.py +++ b/arcade/arcade/worker/core/components.py @@ -1,9 +1,15 @@ -from typing import Any - from opentelemetry import trace -from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition -from arcade.worker.core.common import RequestData, Router, Worker, WorkerComponent +from arcade.worker.core.common import ( + CatalogResponse, + HealthCheckResponse, + RequestData, + Router, + ToolCallRequest, + ToolCallResponse, + Worker, + WorkerComponent, +) class CatalogComponent(WorkerComponent): @@ -14,9 +20,18 @@ class CatalogComponent(WorkerComponent): """ Register the catalog route with the router. """ - router.add_route("tools", self, method="GET") + router.add_route( + "tools", + self, + method="GET", + response_type=CatalogResponse, + operation_id="get_catalog", + description="Get the catalog of tools", + summary="Get the catalog of tools", + tags=["Arcade"], + ) - async def __call__(self, request: RequestData) -> list[ToolDefinition]: + async def __call__(self, request: RequestData) -> CatalogResponse: """ Handle the request to get the catalog. """ @@ -33,7 +48,16 @@ class CallToolComponent(WorkerComponent): """ Register the call tool route with the router. """ - router.add_route("tools/invoke", self, method="POST") + router.add_route( + "tools/invoke", + self, + method="POST", + response_type=ToolCallResponse, + operation_id="call_tool", + description="Call a tool", + summary="Call a tool", + tags=["Arcade"], + ) async def __call__(self, request: RequestData) -> ToolCallResponse: """ @@ -54,9 +78,19 @@ class HealthCheckComponent(WorkerComponent): """ Register the health check route with the router. """ - router.add_route("health", self, method="GET", require_auth=False) + router.add_route( + "health", + self, + method="GET", + require_auth=False, + response_type=HealthCheckResponse, + operation_id="health_check", + description="Check the health of the worker", + summary="Check the health of the worker", + tags=["Arcade"], + ) - async def __call__(self, request: RequestData) -> dict[str, Any]: + async def __call__(self, request: RequestData) -> HealthCheckResponse: """ Handle the request for a health check. """ diff --git a/arcade/arcade/worker/fastapi/__init__.py b/arcade/arcade/worker/fastapi/__init__.py index e69de29b..f3c85996 100644 --- a/arcade/arcade/worker/fastapi/__init__.py +++ b/arcade/arcade/worker/fastapi/__init__.py @@ -0,0 +1,3 @@ +from .worker import FastAPIWorker + +__all__ = ["FastAPIWorker"] diff --git a/arcade/arcade/worker/fastapi/worker.py b/arcade/arcade/worker/fastapi/worker.py index 17e92cf3..c455941e 100644 --- a/arcade/arcade/worker/fastapi/worker.py +++ b/arcade/arcade/worker/fastapi/worker.py @@ -9,7 +9,7 @@ from arcade.worker.core.base import ( BaseWorker, Router, ) -from arcade.worker.core.common import RequestData +from arcade.worker.core.common import RequestData, ResponseData, WorkerComponent from arcade.worker.fastapi.auth import validate_engine_request from arcade.worker.utils import is_async_callable @@ -30,12 +30,21 @@ class FastAPIWorker(BaseWorker): """ Initialize the FastAPIWorker with a FastAPI app instance. If no secret is provided, the worker will use the ARCADE_WORKER_SECRET environment variable. + + Args: + app: The FastAPI app to host the worker in + secret: Optional secret for authorization + disable_auth: Whether to disable authorization + otel_meter: Optional OpenTelemetry meter """ super().__init__(secret, disable_auth, otel_meter) self.app = app self.router = FastAPIRouter(app, self) self.register_routes(self.router) + # Initialize components + self.components: list[WorkerComponent] = [] + security = HTTPBearer() # Authorization: Bearer @@ -81,7 +90,13 @@ class FastAPIRouter(Router): return wrapped_handler def add_route( - self, endpoint_path: str, handler: Callable, method: str, require_auth: bool = True + self, + endpoint_path: str, + handler: Callable, + method: str, + require_auth: bool = True, + response_type: type[ResponseData] | None = None, + **kwargs: Any, ) -> None: """ Add a route to the FastAPI application. @@ -90,4 +105,7 @@ class FastAPIRouter(Router): f"{self.worker.base_path}/{endpoint_path}", self._wrap_handler(handler, require_auth), methods=[method], + response_model=response_type, + # **kwargs to pass to FastAPI + **kwargs, ) diff --git a/arcade/arcade/worker/mcp/__init__.py b/arcade/arcade/worker/mcp/__init__.py new file mode 100644 index 00000000..9761f4f2 --- /dev/null +++ b/arcade/arcade/worker/mcp/__init__.py @@ -0,0 +1,7 @@ +""" +MCP (Model Context Protocol) support for Arcade workers. +""" + +from arcade.worker.mcp.stdio import StdioServer + +__all__ = ["StdioServer"] diff --git a/arcade/arcade/worker/mcp/convert.py b/arcade/arcade/worker/mcp/convert.py new file mode 100644 index 00000000..59e10642 --- /dev/null +++ b/arcade/arcade/worker/mcp/convert.py @@ -0,0 +1,188 @@ +import json +import logging +from enum import Enum +from typing import Any + +from arcade.core.catalog import MaterializedTool + +# Type aliases for MCP types +MCPTool = dict[str, Any] +MCPTextContent = dict[str, Any] +MCPImageContent = dict[str, Any] +MCPEmbeddedResource = dict[str, Any] +MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource + +logger = logging.getLogger("arcade.mcp") + + +def create_mcp_tool(tool: MaterializedTool) -> dict[str, Any] | None: # noqa: C901 + """ + Create an MCP-compatible tool definition from an Arcade tool. + + Args: + tool: An Arcade tool object + + Returns: + An MCP tool definition or None if the tool cannot be converted + """ + try: + name = getattr(tool.definition, "fully_qualified_name", None) or getattr( + tool.definition, "name", "unknown" + ) + description = getattr(tool.definition, "description", "No description available") + + # Extract parameters from the input model + parameters = {} + required = [] + + if ( + hasattr(tool, "input_model") + and tool.input_model is not None + and hasattr(tool.input_model, "model_fields") + ): + for field_name, field in tool.input_model.model_fields.items(): + # Skip internal tool context parameters + if field_name == getattr( + tool.definition.input, "tool_context_parameter_name", None + ): + continue + + # Get field type information + field_type = getattr(field, "annotation", None) + field_type_name = "string" # default + + # Safety check for field_type + if field_type is int: + field_type_name = "integer" + elif field_type is float: + field_type_name = "number" + elif field_type is bool: + field_type_name = "boolean" + elif field_type is list or str(field_type).startswith("list["): + field_type_name = "array" + elif field_type is dict or str(field_type).startswith("dict["): + field_type_name = "object" + + # Get description with fallback + field_description = getattr(field, "description", None) + if not field_description: + field_description = f"Parameter: {field_name}" + + # Create parameter definition + param_def = { + "type": field_type_name, + "description": field_description, + } + + # Enum support: if the field annotation is an Enum, add allowed values + enum_type = None + if hasattr(field, "annotation"): + ann = field.annotation + # Handle typing.Annotated[Enum, ...] + if getattr(ann, "__origin__", None) is not None and hasattr(ann, "__args__"): + for arg in ann.__args__: # type: ignore[union-attr] + if isinstance(arg, type) and issubclass(arg, Enum): + enum_type = arg + break + elif isinstance(ann, type) and issubclass(ann, Enum): + enum_type = ann + if enum_type is not None: + param_def["enum"] = [e.value for e in enum_type] + + parameters[field_name] = param_def + + # In Pydantic v2, check if field is required based on default value + try: + if field.is_required(): + required.append(field_name) + except (AttributeError, TypeError): + # Fallback if is_required() doesn't exist or fails + try: + has_default = getattr(field, "default", None) is not None + has_factory = getattr(field, "default_factory", None) is not None + if not (has_default or has_factory): + required.append(field_name) + except Exception: + # Ultimate fallback - assume required if we can't determine + logger.debug( + f"Could not determine if field {field_name} is required, assuming optional" + ) + + # Create the input schema with explicit properties and required fields + input_schema = { + "type": "object", + "properties": parameters, + } + + # Only include required field if we have required parameters + if required: + input_schema["required"] = required + + # Add annotations based on tool metadata + annotations = {} + + # Use tool name as title if available + annotations["title"] = getattr(tool.definition, "title", str(name).replace(".", "_")) + + # Determine hints based on tool properties + if hasattr(tool.definition, "metadata"): + metadata = tool.definition.metadata or {} + annotations["readOnlyHint"] = metadata.get("read_only", False) + annotations["destructiveHint"] = metadata.get("destructive", False) + annotations["idempotentHint"] = metadata.get("idempotent", True) + annotations["openWorldHint"] = metadata.get("open_world", False) + + # Create the final tool definition + tool_def: MCPTool = { + "name": str(name).replace(".", "_"), + "description": str(description), + "inputSchema": input_schema, + "annotations": annotations, + } + + logger.debug(f"Created tool definition for {name}") + + except Exception: + logger.exception( + f"Error creating MCP tool definition for {getattr(tool, 'name', str(tool))}" + ) + return None + return tool_def + + +def convert_to_mcp_content(value: Any) -> list[dict[str, Any]]: + """ + Convert a Python value to MCP-compatible content. + """ + if value is None: + return [] + + if isinstance(value, (str, bool, int, float)): + return [{"type": "text", "text": str(value)}] + + if isinstance(value, (dict, list)): + return [{"type": "text", "text": json.dumps(value)}] + + # Default fallback + return [{"type": "text", "text": str(value)}] + + +def _map_type_to_json_schema_type(val_type: str) -> str: + """ + Map Arcade value types to JSON schema types. + + Args: + val_type: The Arcade value type as a string. + + Returns: + The corresponding JSON schema type as a string. + """ + mapping: dict[str, str] = { + "string": "string", + "integer": "integer", + "number": "number", + "boolean": "boolean", + "json": "object", + "array": "array", + } + return mapping.get(val_type, "string") diff --git a/arcade/arcade/worker/mcp/logging.py b/arcade/arcade/worker/mcp/logging.py new file mode 100644 index 00000000..0299d5e1 --- /dev/null +++ b/arcade/arcade/worker/mcp/logging.py @@ -0,0 +1,215 @@ +import json +import logging +import sys +import time +from typing import Any + +from arcade.worker.mcp.types import ( + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + MCPMessage, +) + +logger = logging.getLogger("arcade.mcp") + + +class MCPLoggingMiddleware: + """ + Middleware for logging MCP requests and responses. + Logs request and response details, including timing and errors. + """ + + def __init__( + self, + log_level: str = "INFO", + log_request_body: bool = False, + log_response_body: bool = False, + log_errors: bool = True, + min_duration_to_log_ms: int = 0, + stdio_mode: bool = False, + ) -> None: + """ + Initialize the MCP logging middleware. + + Args: + log_level: Logging level (default: "INFO"). + log_request_body: Whether to log full request bodies (default: False). + log_response_body: Whether to log full response bodies (default: False). + log_errors: Whether to log errors at ERROR level (default: True). + min_duration_to_log_ms: Minimum duration in ms to log (0 logs all). + stdio_mode: Whether running in stdio mode (redirects logs to stderr). + """ + self.log_level = getattr(logging, log_level.upper()) + self.log_request_body = log_request_body + self.log_response_body = log_response_body + self.log_errors = log_errors + self.min_duration_to_log_ms = min_duration_to_log_ms + self.request_log_format = "[MCP>] {method}{params_str} (id: {id})" + self.response_log_format = "[MCP<] {method} completed in {duration:.2f}ms (id: {id})" + self.error_log_format = "[MCP!] {method} error: {error} (id: {id})" + + # If in stdio mode, ensure MCP logs go to stderr + if stdio_mode: + self._redirect_logs_to_stderr() + + # Log that middleware is initialized + logger.debug(f"MCP logging middleware initialized (level: {log_level})") + + def _redirect_logs_to_stderr(self) -> None: + """Redirect MCP logs to stderr to avoid interfering with stdio communication.""" + # Remove any existing handlers + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Add a stderr handler + stderr_handler = logging.StreamHandler(sys.stderr) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + stderr_handler.setFormatter(formatter) + logger.addHandler(stderr_handler) + + # Ensure we're not propagating to root logger which might log to stdout + logger.propagate = False + + logger.debug("MCP logs redirected to stderr for stdio mode") + + def __call__(self, message: MCPMessage, direction: str) -> MCPMessage: + """ + Process and log an MCP message. + + Args: + message: The MCP message to process. + direction: The message direction ("request" or "response"). + + Returns: + The original message (unmodified). + """ + if direction == "request": + self._log_request(message) + else: + self._log_response(message) + return message + + def _log_request(self, message: MCPMessage) -> None: + """ + Log an MCP request message. + """ + if not isinstance(message, JSONRPCRequest): + logger.debug(f"Ignoring non-request message: {type(message).__name__}") + return + + try: + # Store request start time for duration calculation + message._mcp_start_time = time.time() # type: ignore[attr-defined] + + # Format parameters for logging + params_str = "" + if self.log_request_body and hasattr(message, "params") and message.params is not None: + params_str = f": {self._format_params(message.params)}" + + log_msg = self.request_log_format.format( + method=message.method, params_str=params_str, id=getattr(message, "id", "none") + ) + + logger.log(self.log_level, log_msg) + except Exception: + logger.exception("Error logging request") + + def _log_response(self, message: MCPMessage) -> None: + """ + Log an MCP response message. + """ + if not isinstance(message, (JSONRPCResponse, JSONRPCError)): + logger.debug(f"Ignoring non-response message: {type(message).__name__}") + return + + try: + # Calculate request duration if we have the start time + duration_ms = 0 + request = getattr(message, "_request", None) + if request: + start_time = getattr(request, "_mcp_start_time", None) + if start_time: + duration_ms = (time.time() - start_time) * 1000 + else: + start_time = getattr(message, "_mcp_start_time", None) + if start_time: + duration_ms = (time.time() - start_time) * 1000 + + # Skip if below minimum duration threshold + if self.min_duration_to_log_ms > 0 and duration_ms < self.min_duration_to_log_ms: + return + + # Handle error responses + if hasattr(message, "error") and message.error is not None: + if self.log_errors: + error_msg = self.error_log_format.format( + method=getattr(message, "method", "unknown"), + error=getattr(message.error, "message", str(message.error)), + id=getattr(message, "id", "none"), + ) + logger.error(error_msg) + return + + # Log successful response + result_str = "" + if self.log_response_body and hasattr(message, "result"): + result_str = f": {self._format_result(message.result)}" + + log_msg = self.response_log_format.format( + method=getattr(message, "method", "unknown"), + duration=duration_ms, + id=getattr(message, "id", "none"), + result_str=result_str, + ) + + logger.log(self.log_level, log_msg) + except Exception: + logger.exception("Error logging response") + + def _format_params(self, params: Any) -> str: + """ + Format parameters for logging. + """ + try: + if isinstance(params, dict): + # Handle common MCP params specially + if "name" in params and "arguments" in params: + return f"{params['name']}({json.dumps(params.get('arguments', {}))})" + return json.dumps(params) + return str(params) + except Exception: + logger.debug(f"Error formatting params {params!s}") + return str(params) + + def _format_result(self, result: Any) -> str: + """ + Format result for logging. + """ + try: + if isinstance(result, dict): + return json.dumps(result) + return str(result) + except Exception as e: + logger.debug(f"Error formatting result {e!s}") + return str(result) + + +def create_mcp_logging_middleware(**config: Any) -> MCPLoggingMiddleware: + """ + Create an MCP logging middleware with the given configuration. + + Args: + **config: Configuration options. + + Returns: + An MCPLoggingMiddleware instance. + """ + return MCPLoggingMiddleware( + log_level=config.get("log_level", "INFO"), + log_request_body=config.get("log_request_body", False), + log_response_body=config.get("log_response_body", False), + log_errors=config.get("log_errors", True), + min_duration_to_log_ms=config.get("min_duration_to_log_ms", 0), + stdio_mode=config.get("stdio_mode", False), + ) diff --git a/arcade/arcade/worker/mcp/message_processor.py b/arcade/arcade/worker/mcp/message_processor.py new file mode 100644 index 00000000..a2199bee --- /dev/null +++ b/arcade/arcade/worker/mcp/message_processor.py @@ -0,0 +1,83 @@ +import inspect +import json +import logging +from typing import Any, Callable, TypeVar + +from arcade.worker.mcp.types import InitializeRequest, JSONRPCRequest, MCPMessage + +logger = logging.getLogger("arcade.mcp") + +T = TypeVar("T") + +# Type definition for middleware functions +MessageProcessor = Callable[[Any, str], Any] + + +class MCPMessageProcessor: + """ + Processes MCP messages through a chain of middleware. + Supports both synchronous and asynchronous middleware. + """ + + def __init__(self) -> None: + self.middleware: list[Callable[[MCPMessage, str], Any]] = [] + + def add_middleware(self, mw: Callable[[MCPMessage, str], Any]) -> None: + self.middleware.append(mw) + + async def process(self, message: Any, direction: str) -> Any: # noqa: C901 + # First, try to parse the message if it's a string + if isinstance(message, str): + # Strip any whitespace including newlines + message = message.strip() + if not message: + return None + + try: + parsed = json.loads(message) + if isinstance(parsed, dict): + method = parsed.get("method") + # Convert to appropriate message type + if method == "initialize" and "id" in parsed: + logger.debug(f"Parsed initialize request: {parsed}") + message = InitializeRequest(**parsed) + elif method and method.startswith("notifications/"): + # It's a notification, log it but pass through as dict + logger.debug(f"Received notification: {method}") + # Keep as parsed dict to avoid validation errors on unknown notifications + message = parsed + elif "method" in parsed and "id" in parsed: + # Regular method request + logger.debug(f"Parsed method request: {method}") + message = JSONRPCRequest(**parsed) + # Other message types can be handled similarly + except json.JSONDecodeError: + logger.warning(f"Failed to parse message as JSON: {message[:100]}...") + except Exception: + logger.exception("Error processing message") + + # Process through middleware chain + result = message + for mw in self.middleware: + try: + if inspect.iscoroutinefunction(mw): + result = await mw(result, direction) + else: + result = mw(result, direction) + except Exception: + logger.exception(f"Error in middleware {mw}") + return result + + async def process_request(self, message: Any) -> Any: + return await self.process(message, "request") + + async def process_response(self, message: Any) -> Any: + return await self.process(message, "response") + + +def create_message_processor(*middleware: MessageProcessor) -> MCPMessageProcessor: + processor = MCPMessageProcessor() + for m in middleware: + if m is not None: + processor.add_middleware(m) + return processor diff --git a/arcade/arcade/worker/mcp/server.py b/arcade/arcade/worker/mcp/server.py new file mode 100644 index 00000000..ae255b14 --- /dev/null +++ b/arcade/arcade/worker/mcp/server.py @@ -0,0 +1,601 @@ +import asyncio +import logging +import os +import uuid +from enum import Enum +from typing import Any, Callable, Union + +from arcadepy import ArcadeError, AsyncArcade +from arcadepy.types.auth_authorize_params import AuthRequirement, AuthRequirementOauth2 +from arcadepy.types.shared import AuthorizationResponse + +from arcade.core.catalog import MaterializedTool, ToolCatalog +from arcade.core.executor import ToolExecutor +from arcade.core.schema import ToolAuthorizationContext, ToolContext +from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool +from arcade.worker.mcp.logging import create_mcp_logging_middleware +from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor +from arcade.worker.mcp.types import ( + CallToolRequest, + CallToolResponse, + CallToolResult, + CancelRequest, + Implementation, + InitializeRequest, + InitializeResponse, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ListPromptsRequest, + ListPromptsResponse, + ListResourcesRequest, + ListResourcesResponse, + ListToolsRequest, + ListToolsResponse, + ListToolsResult, + PingRequest, + PingResponse, + ProgressNotification, + ServerCapabilities, + ShutdownRequest, + ShutdownResponse, + Tool, +) + +logger = logging.getLogger("arcade.mcp") + +MCP_PROTOCOL_VERSION = "2024-11-05" + + +class MessageMethod(str, Enum): + """Enumeration of supported MCP message methods""" + + PING = "ping" + INITIALIZE = "initialize" + LIST_TOOLS = "tools/list" + CALL_TOOL = "tools/call" + PROGRESS = "progress" + CANCEL = "$/cancelRequest" + SHUTDOWN = "shutdown" + LIST_RESOURCES = "resources/list" + LIST_PROMPTS = "prompts/list" + + +class MCPServer: + """ + Unified async MCP server that manages connections, middleware, and tool invocation. + Handles protocol-level messages (ping, initialize, list_tools, call_tool, etc.). + """ + + def __init__( + self, + tool_catalog: Any, + enable_logging: bool = True, + **client_kwargs: dict[str, Any], + ) -> None: + """ + Initialize the MCP server. + + Args: + tool_catalog: Catalog of available tools + **client_kwargs: Additional arguments to pass to the AsyncArcade client + """ + self.tool_catalog: ToolCatalog = tool_catalog + self.message_processor: MCPMessageProcessor = create_message_processor() + + # Pop middleware_config from client_kwargs regardless of logging state, + # as it's internal config not meant for AsyncArcade. + middleware_config = client_kwargs.pop("middleware_config", {}) + + if enable_logging: + # Create and add the logging middleware if logging is enabled. + # Note: enable_logging must be True for this middleware (and its stdio_mode behavior) + # to be activated. + self.message_processor.add_middleware( + create_mcp_logging_middleware(**middleware_config) + ) + + self._shutdown: bool = False + # Initialize AsyncArcade with the *remaining* client_kwargs + self.arcade = AsyncArcade(**client_kwargs) # type: ignore[arg-type] + + # Initialize handler dispatch table + self._method_handlers: dict[str, Callable] = { + MessageMethod.PING: self._handle_ping, + MessageMethod.INITIALIZE: self._handle_initialize, + MessageMethod.LIST_TOOLS: self._handle_list_tools, + MessageMethod.CALL_TOOL: self._handle_call_tool, + MessageMethod.PROGRESS: self._handle_progress, + MessageMethod.CANCEL: self._handle_cancel, + MessageMethod.SHUTDOWN: self._handle_shutdown, + MessageMethod.LIST_RESOURCES: self._handle_list_resources, + MessageMethod.LIST_PROMPTS: self._handle_list_prompts, + } + + async def run_connection( + self, + read_stream: Any, + write_stream: Any, + init_options: Any, + ) -> None: + """ + Handle a single MCP connection (SSE or stdio). + + Args: + read_stream: Async iterable yielding incoming messages. + write_stream: Object with an async send(message) method. + init_options: Initialization options for the connection. + """ + # Generate a user ID if possible + user_id = self._get_user_id(init_options) + + try: + logger.info(f"Starting MCP connection for user {user_id}") + + async for message in read_stream: + # Process the message + response = await self.handle_message(message, user_id=user_id) + + # Skip sending responses for None (e.g., notifications) + if response is None: + continue + + await self._send_response(write_stream, response) + + except asyncio.CancelledError: + logger.info("Connection cancelled") + except Exception: + logger.exception("Error in connection") + + def _get_user_id(self, init_options: Any) -> str: + """ + Get the user ID for a connection. + + Args: + init_options: Initialization options for the connection + + Returns: + A user ID string + """ + try: + from arcade.core.config import config + + # Prefer config.user.email if available + if config.user and config.user.email: + return config.user.email + except ValueError: + logger.debug("No logged in user for MCP Server") + + fallback = str(uuid.uuid4()) + if os.environ.get("ARCADE_USER_ID", None): + return os.environ.get("ARCADE_USER_ID", fallback) + elif isinstance(init_options, dict): + user_id = init_options.get("user_id") + if user_id: + return str(user_id) + # Fallback to random UUID + return str(fallback) + + async def _send_response(self, write_stream: Any, response: Any) -> None: + """ + Send a response to the client. + + Args: + write_stream: Stream to write the response to + response: Response object to send + """ + # Ensure the response is properly serialized to JSON + if hasattr(response, "model_dump_json"): + # It's a Pydantic model, serialize it + json_response = response.model_dump_json() + # Ensure it ends with a newline for JSON-RPC-over-stdio + if not json_response.endswith("\n"): + json_response += "\n" + logger.debug(f"Sending response: {json_response[:200]}...") + await write_stream.send(json_response) + elif isinstance(response, dict): + # It's a dict, convert to JSON + import json + + json_response = json.dumps(response) + # Ensure it ends with a newline for JSON-RPC-over-stdio + if not json_response.endswith("\n"): + json_response += "\n" + logger.debug(f"Sending response: {json_response[:200]}...") + await write_stream.send(json_response) + else: + # It's already a string or something else + response_str = str(response) + # Ensure it ends with a newline for JSON-RPC-over-stdio + if not response_str.endswith("\n"): + response_str += "\n" + logger.debug(f"Sending raw response type: {type(response)}") + await write_stream.send(response_str) + + async def handle_message(self, message: Any, user_id: str | None = None) -> Any: + """ + Handle an incoming MCP message. Processes it through middleware and dispatches + to the appropriate handler based on the message method. + + Args: + message: The raw incoming message + user_id: Optional user ID for authentication + + Returns: + A properly formatted response message + """ + # Pre-process message through middleware + processed = await self.message_processor.process_request(message) + + # Handle special case for JSON string initialize requests + if isinstance(processed, str): + try: + import json + + parsed = json.loads(processed) + if ( + isinstance(parsed, dict) + and parsed.get("method") == MessageMethod.INITIALIZE + and "id" in parsed + ): + # This is an initialize request + init_response = await self._handle_initialize(InitializeRequest(**parsed)) + return init_response + except Exception: + logger.exception("Error processing JSON string") + # Not parseable JSON, continue with normal processing + pass + + # Check if it's a notification + if hasattr(processed, "method"): + method = getattr(processed, "method", None) + + # Handle notifications (methods starting with "notifications/") + if method and method.startswith("notifications/"): + await self._handle_notification(method, processed) + return None + + # Handle regular methods using the dispatch table + if method in self._method_handlers: + # If it's a call_tool request, we need to pass the user_id + if method == MessageMethod.CALL_TOOL: + return await self._method_handlers[method](processed, user_id=user_id) + # For other methods, just pass the processed message + return await self._method_handlers[method](processed) + + # Unknown method + return JSONRPCError( + id=getattr(processed, "id", None), + error={ + "code": -32601, + "message": f"Method not found: {method}", + }, + ) + + # If it's not a method request, just pass it through + return processed + + async def _handle_notification(self, method: str, message: Any) -> None: + """ + Handle notification messages. + + Args: + method: The notification method + message: The notification message + """ + if method == "notifications/cancelled": + logger.info(f"Request cancelled: {getattr(message, 'params', {})}") + else: + logger.debug(f"Received notification: {method}") + + async def _handle_ping(self, message: PingRequest) -> PingResponse: + """ + Handle a ping request and return a pong response. + + Args: + message: The ping request + + Returns: + A properly formatted pong response + """ + return PingResponse(id=message.id) + + async def _handle_initialize(self, message: InitializeRequest) -> InitializeResponse: + """ + Handle an initialize request and return a proper initialize response. + + Args: + message: The initialize request + + Returns: + A properly formatted initialize response + """ + # Create the result data + result = InitializeResult( + protocolVersion=MCP_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="Arcade MCP Worker", version="0.1.0"), + instructions="Arcade MCP Worker initialized.", + ) + + # Construct proper response with result field + response = InitializeResponse(id=message.id, result=result) + + logger.debug(f"Initialize response: {response.model_dump_json()}") + return response + + async def _handle_list_tools( + self, message: ListToolsRequest + ) -> Union[ListToolsResponse, JSONRPCError]: + """ + Handle a tools/list request and return a list of available tools. + + Args: + message: The tools/list request + + Returns: + A properly formatted tools/list response or error + """ + try: + # Get all tools from the catalog + tools = [] + tool_conversion_errors = [] + + for tool in self.tool_catalog: + try: + mcp_tool = create_mcp_tool(tool) + if mcp_tool: + tools.append(mcp_tool) + except Exception: + tool_name = getattr(tool, "name", str(tool)) + logger.exception(f"Error converting tool: {tool_name}") + tool_conversion_errors.append(tool_name) + + # Log summary if we had errors + if tool_conversion_errors: + logger.warning( + f"Failed to convert {len(tool_conversion_errors)} tools: {tool_conversion_errors}" + ) + + # Create tool objects with exception handling for each one + tool_objects = [] + for t in tools: + try: + # Make input schema optional if missing + tool_dict = dict(t) + if "inputSchema" not in tool_dict: + tool_dict["inputSchema"] = {"type": "object", "properties": {}} + + tool_objects.append(Tool(**tool_dict)) + except Exception: + logger.exception(f"Error creating Tool object for {t.get('name', 'unknown')}") + + # Return successful response with the tools we were able to convert + result = ListToolsResult(tools=tool_objects) + response = ListToolsResponse(id=message.id, result=result) + + except Exception: + logger.exception("Error listing tools") + return JSONRPCError( + id=message.id, + error={ + "code": -32603, + "message": "Internal error listing tools", + }, + ) + return response + + async def _handle_call_tool( + self, message: CallToolRequest, user_id: str | None = None + ) -> CallToolResponse: + """ + Handle a tools/call request to execute a tool. + + Args: + message: The tools/call request + user_id: Optional user ID for authentication + + Returns: + A properly formatted tools/call response + """ + tool_name: str = message.params["name"] + # Extract input from the correct field + input_params: dict[str, Any] = message.params.get("input", {}) + if not input_params: + input_params = message.params.get("arguments", {}) + + logger.info(f"Handling tool call for {tool_name}") + + try: + tool = self.tool_catalog.get_tool_by_name(tool_name, separator="_") + tool_context = ToolContext() + + # Set up context with secrets + if tool.definition.requirements and tool.definition.requirements.secrets: + self._setup_tool_secrets(tool, tool_context) + + # Handle authorization if needed + requirement = self._get_auth_requirement(tool) + if requirement: + auth_result = await self._check_authorization(requirement, user_id=user_id) + if auth_result.status != "completed": + return CallToolResponse( + id=message.id, + result=CallToolResult(content=[{"type": "text", "text": auth_result.url}]), + ) + else: + tool_context.authorization = ToolAuthorizationContext( + token=auth_result.context.token if auth_result.context else None, + user_info={"user_id": user_id} if user_id else {}, + ) + + # Execute the tool + logger.debug(f"Executing tool {tool_name} with input: {input_params}") + result = await ToolExecutor.run( + func=tool.tool, + definition=tool.definition, + input_model=tool.input_model, + output_model=tool.output_model, + context=tool_context, + **input_params, + ) + logger.debug(f"Tool result: {result}") + if result.value: + return CallToolResponse( + id=message.id, + result=CallToolResult(content=convert_to_mcp_content(result.value)), + ) + else: + error = result.error or "Error calling tool" + logger.error(f"Tool {tool_name} returned error: {error}") + return CallToolResponse( + id=message.id, + result=CallToolResult( + content=[{"type": "text", "text": convert_to_mcp_content(error)}] + ), + ) + except Exception as e: + logger.exception(f"Error calling tool {tool_name}") + error = f"Error calling tool {tool_name}: {e!s}" + return CallToolResponse( + id=message.id, + result=CallToolResult( + content=[{"type": "text", "text": convert_to_mcp_content(error)}] + ), + ) + + def _setup_tool_secrets(self, tool: Any, tool_context: ToolContext) -> None: + """ + Set up tool secrets in the tool context. + + Args: + tool: The tool to set up secrets for + tool_context: The tool context to update + """ + for secret in tool.definition.requirements.secrets: + value = os.environ.get(secret.key) + if value is not None: + tool_context.set_secret(secret.key, value) + + async def _handle_progress(self, message: ProgressNotification) -> JSONRPCResponse: + """ + Handle a progress notification. + + Args: + message: The progress notification + + Returns: + A response acknowledging the notification + """ + return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True}) + + async def _handle_cancel(self, message: CancelRequest) -> JSONRPCResponse: + """ + Handle a cancel request. + + Args: + message: The cancel request + + Returns: + A response acknowledging the cancellation + """ + return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True}) + + async def _handle_shutdown(self, message: ShutdownRequest) -> ShutdownResponse: + """ + Handle a shutdown request. + + Args: + message: The shutdown request + + Returns: + A response acknowledging the shutdown request + """ + # Schedule a task to shutdown the server after sending the response + proc = asyncio.create_task(self.shutdown()) + proc.add_done_callback(lambda _: logger.info("MCP server shutdown complete")) + return ShutdownResponse(id=message.id, result={"ok": True}) + + async def _handle_list_resources(self, message: ListResourcesRequest) -> ListResourcesResponse: + """ + Handle a resources/list request. + + Args: + message: The resources/list request + + Returns: + A properly formatted resources/list response + """ + return ListResourcesResponse(id=message.id, result={"resources": []}) + + async def _handle_list_prompts(self, message: ListPromptsRequest) -> ListPromptsResponse: + """ + Handle a prompts/list request. + + Args: + message: The prompts/list request + + Returns: + A properly formatted prompts/list response + """ + return ListPromptsResponse(id=message.id, result={"prompts": []}) + + def _get_auth_requirement(self, tool: MaterializedTool) -> AuthRequirement | None: + """ + Get the authentication requirement for a tool. + + Args: + tool: The tool to get the requirement for + + Returns: + An authentication requirement or None if not required + """ + req = tool.definition.requirements.authorization + if not req: + return None + if not req.provider_id and not req.provider_type: + return None + if hasattr(req, "oauth2") and req.oauth2: + return AuthRequirement( + provider_id=str(req.provider_id), + provider_type=str(req.provider_type), + oauth2=AuthRequirementOauth2(scopes=req.oauth2.scopes or []), + ) + return AuthRequirement( + provider_id=str(req.provider_id), + provider_type=str(req.provider_type), + ) + + async def _check_authorization( + self, auth_requirement: AuthRequirement, user_id: str | None = None + ) -> AuthorizationResponse: + """ + Check if a tool is authorized for a user. + + Args: + tool: The tool to check authorization for + user_id: The user ID to check authorization for + + Returns: + An authorization response + + Raises: + RuntimeError: If the tool has no authorization requirement + Exception: If authorization fails + """ + try: + response = await self.arcade.auth.authorize( + auth_requirement=auth_requirement, + user_id=user_id or "anonymous", + ) + logger.debug(f"Authorization response: {response}") + + except ArcadeError: + logger.exception("Error authorizing tool") + raise + return response + + async def shutdown(self) -> None: + """Shutdown the server.""" + self._shutdown = True + logger.info("MCP server shutdown complete") diff --git a/arcade/arcade/worker/mcp/stdio.py b/arcade/arcade/worker/mcp/stdio.py new file mode 100644 index 00000000..4c637005 --- /dev/null +++ b/arcade/arcade/worker/mcp/stdio.py @@ -0,0 +1,185 @@ +import asyncio +import logging +import queue +import signal +import sys +import threading +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + pass + +from arcade.worker.mcp.server import MCPServer + +logger = logging.getLogger("arcade.mcp") + +T = TypeVar("T") + + +def stdio_reader(stdin: object, q: queue.Queue[str | None]) -> None: + """Read lines from stdin and put them into a queue.""" + for line in stdin: # type: ignore[attr-defined] + q.put(line) + q.put(None) + + +def stdio_writer(stdout: object, q: queue.Queue[str | None]) -> None: + """Write messages from a queue to stdout.""" + try: + while True: + msg = q.get() + if msg is None: + break + + # Ensure message ends with a newline for proper JSON-RPC-over-stdio + if not msg.endswith("\n"): + msg += "\n" + + stdout.write(msg) # type: ignore[attr-defined] + stdout.flush() # type: ignore[attr-defined] + except Exception: + logger.exception("Error in stdio writer") + + +class StdioServer(MCPServer): + """ + Stdio server that handles signals and cleanup. + """ + + def __init__( + self, + tool_catalog: Any, + enable_logging: bool = True, + **client_kwargs: dict[str, Any], + ): + # Set up stdio-specific middleware configuration + middleware_config = client_kwargs.get("middleware_config", {}) + middleware_config["stdio_mode"] = True + client_kwargs["middleware_config"] = middleware_config + + super().__init__(tool_catalog, enable_logging, **client_kwargs) + self.read_q: queue.Queue[str | None] = queue.Queue() + self.write_q: queue.Queue[str | None] = queue.Queue() + self.reader_thread: threading.Thread | None = None + self.writer_thread: threading.Thread | None = None + self.running = False + self.shutdown_event = asyncio.Event() + + def start_io_threads(self) -> None: + """Start stdio reader and writer threads.""" + self.reader_thread = threading.Thread( + target=self._stdio_reader, args=(sys.stdin, self.read_q), daemon=True + ) + self.writer_thread = threading.Thread( + target=self._stdio_writer, args=(sys.stdout, self.write_q), daemon=True + ) + self.reader_thread.start() + self.writer_thread.start() + + def _stdio_reader(self, stdin: object, q: queue.Queue[str | None]) -> None: + """Read lines from stdin and put them into a queue.""" + try: + for line in stdin: # type: ignore[attr-defined] + if not self.running: + break + q.put(line) + except Exception: + logger.exception("Error in stdio reader") + finally: + q.put(None) # Signal EOF + + def _stdio_writer(self, stdout: object, q: queue.Queue[str | None]) -> None: + """Write messages from a queue to stdout.""" + try: + while self.running: + msg = q.get() + if msg is None: + break + stdout.write(msg) # type: ignore[attr-defined] + stdout.flush() # type: ignore[attr-defined] + except Exception: + logger.exception("Error in stdio writer") + + async def _read_stream(self) -> AsyncGenerator[str, None]: + """Async generator that yields lines from the read queue.""" + while self.running: + try: + line = await asyncio.to_thread(self.read_q.get) + if line is None: + break + yield line + except asyncio.CancelledError: + break + except Exception: + logger.exception("Error reading from stdin") + break + + async def shutdown(self) -> None: + """Gracefully shut down the server.""" + if not self.running: + return + + logger.info("Shutting down stdio server...") + self.running = False + + # Signal shutdown to MCP server + await self.shutdown() + + # Clean up IO queues and threads + try: + if self.read_q: + self.read_q.put(None) + if self.write_q: + self.write_q.put(None) + except Exception: + logger.exception("Error during shutdown") + + # Signal completion + self.shutdown_event.set() + logger.info("Stdio server shutdown complete") + + async def run(self) -> None: + """Run the stdio server with signal handling.""" + self.running = True + + # Set up signal handlers + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown())) + except NotImplementedError: + # Windows doesn't support POSIX signals + if sys.platform == "win32": + logger.warning("Signal handling not fully supported on Windows") + else: + logger.warning(f"Failed to set up signal handler for {sig}") + + # Start IO threads + self.start_io_threads() + + logger.info("Starting MCP server with stdio transport") + + # Create WriteStream class for MCP server + class WriteStream: + async def send(self_, message: str) -> None: + if self.running: + await asyncio.to_thread(self.write_q.put, message) + + try: + # Run MCP server connection + await self.run_connection(self._read_stream(), WriteStream(), None) + except asyncio.CancelledError: + # Handle cancellation + logger.info("Server operation cancelled") + except KeyboardInterrupt: + # Handle keyboard interrupt + logger.info("Keyboard interrupt received") + except Exception: + # Handle unexpected errors + logger.exception("Unexpected error") + finally: + # Ensure we clean up + await self.shutdown() + # Wait for shutdown to complete + await self.shutdown_event.wait() diff --git a/arcade/arcade/worker/mcp/types.py b/arcade/arcade/worker/mcp/types.py new file mode 100644 index 00000000..4354241b --- /dev/null +++ b/arcade/arcade/worker/mcp/types.py @@ -0,0 +1,383 @@ +import json +from collections.abc import Callable +from typing import ( + Any, + Generic, + Literal, + TypeAlias, + TypeVar, + Union, +) + +from pydantic import BaseModel, ConfigDict, Field + +ProgressToken = str | int +Cursor = str +Role = Literal["user", "assistant"] +RequestId = str | int +AnyFunction: TypeAlias = Callable[..., Any] + + +class RequestParams(BaseModel): + class Meta(BaseModel): + progressToken: ProgressToken | None = None + model_config = ConfigDict(extra="allow") + + meta: Meta | None = Field(alias="_meta", default=None) + + model_config = ConfigDict(extra="allow") + + +class NotificationParams(BaseModel): + class Meta(BaseModel): + model_config = ConfigDict(extra="allow") + + meta: Meta | None = Field(alias="_meta", default=None) + model_config = ConfigDict(extra="allow") + + +RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) +NotificationParamsT = TypeVar( + "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None +) +MethodT = TypeVar("MethodT", bound=str) + + +class Request(BaseModel, Generic[RequestParamsT, MethodT]): + method: MethodT + params: RequestParamsT + model_config = ConfigDict(extra="allow") + + +class PaginatedRequest(Request[RequestParamsT, MethodT]): + cursor: Cursor | None = None + model_config = ConfigDict(extra="allow") + + +class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): + method: MethodT + params: NotificationParamsT + model_config = ConfigDict(extra="allow") + + +class Result(BaseModel): + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + model_config = ConfigDict(extra="allow") + + +class PaginatedResult(Result): + nextCursor: Cursor | None = None + model_config = ConfigDict(extra="allow") + + +class JSONRPCMessage(BaseModel): + """Base class for all JSON-RPC messages.""" + + model_config = ConfigDict(extra="allow") + jsonrpc: str = Field(default="2.0", frozen=True) + + +class JSONRPCRequest(JSONRPCMessage): + """A JSON-RPC request message.""" + + id: str | int | None = None + method: str + params: dict[str, Any] | None = None + + +class JSONRPCResponse(JSONRPCMessage): + """A JSON-RPC response message.""" + + id: str | int | None + result: Any = None + error: dict[str, Any] | None = None + + def model_dump_json(self, **kwargs: Any) -> str: + """Convert to JSON string with proper formatting.""" + + # Convert to dict + data = { + "jsonrpc": self.jsonrpc, + "id": self.id, + } + + # Add result if present + if self.result is not None: + # Check if result is a Pydantic model + if hasattr(self.result, "model_dump"): + data["result"] = self.result.model_dump(exclude_none=True) + # Check if result is already a dict/list/primitive + elif ( + isinstance(self.result, (dict, list, str, int, float, bool)) or self.result is None + ): + data["result"] = self.result # type: ignore[assignment] + else: + # Try to convert using str() as a fallback + data["result"] = str(self.result) + + # Add error if present + if self.error is not None: + data["error"] = self.error # type: ignore[assignment] + + return json.dumps(data, ensure_ascii=False) + + +class JSONRPCError(JSONRPCMessage): + """A JSON-RPC error message.""" + + id: str | int | None + error: dict[str, Any] + + +PARSE_ERROR = -32700 +INVALID_REQUEST = -32600 +METHOD_NOT_FOUND = -32601 +INVALID_PARAMS = -32602 +INTERNAL_ERROR = -32603 + + +class ErrorData(BaseModel): + code: int + message: str + data: Any | None = None + model_config = ConfigDict(extra="allow") + + +JSONRPCMessageBaseModel = BaseModel | JSONRPCRequest | JSONRPCResponse | JSONRPCError + + +class EmptyResult(Result): + pass + + +class Implementation(BaseModel): + """Describes the server or client implementation.""" + + name: str + version: str + model_config = ConfigDict(extra="allow") + + +class RootsCapability(BaseModel): + listChanged: bool | None = None + model_config = ConfigDict(extra="allow") + + +class SamplingCapability(BaseModel): + model_config = ConfigDict(extra="allow") + + +class ClientCapabilities(BaseModel): + experimental: dict[str, dict[str, Any]] | None = None + sampling: SamplingCapability | None = None + roots: RootsCapability | None = None + model_config = ConfigDict(extra="allow") + + +class PromptsCapability(BaseModel): + listChanged: bool | None = None + model_config = ConfigDict(extra="allow") + + +class ResourcesCapability(BaseModel): + subscribe: bool | None = None + listChanged: bool | None = None + model_config = ConfigDict(extra="allow") + + +class ToolsCapability(BaseModel): + listChanged: bool | None = None + model_config = ConfigDict(extra="allow") + + +class LoggingCapability(BaseModel): + model_config = ConfigDict(extra="allow") + + +class ServerCapabilities(BaseModel): + """Describes the server's capabilities.""" + + model_config = ConfigDict(extra="allow") + tools: dict[str, Any] | None = None + resources: dict[str, Any] | None = None + prompts: dict[str, Any] | None = None + + +class InitializeRequestParams(RequestParams): + protocolVersion: str | int + capabilities: ClientCapabilities + clientInfo: Implementation + model_config = ConfigDict(extra="allow") + + +class InitializeRequest(JSONRPCRequest): + method: str = Field(default="initialize", frozen=True) + params: dict[str, Any] | None = None + + +class InitializeResult(BaseModel): + protocolVersion: str + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + + +class InitializedNotification( + Notification[NotificationParams | None, Literal["notifications/initialized"]] +): + method: Literal["notifications/initialized"] + params: NotificationParams | None = None + model_config = ConfigDict(extra="allow") + + +class PingRequest(JSONRPCRequest): + method: str = Field(default="ping", frozen=True) + params: dict[str, Any] | None = None + + +class ProgressNotificationParams(NotificationParams): + progressToken: ProgressToken + progress: float + total: float | None = None + model_config = ConfigDict(extra="allow") + + +class ProgressNotification(JSONRPCMessage): + method: str = Field(default="progress", frozen=True) + params: dict[str, Any] + + +class PingResponse(JSONRPCResponse): + result: dict[str, Any] = Field(default_factory=lambda: {"pong": True}) + + +class ShutdownRequest(JSONRPCRequest): + method: str = Field(default="shutdown", frozen=True) + params: dict[str, Any] | None = None + + +class ShutdownResponse(JSONRPCResponse): + result: dict[str, Any] = Field(default_factory=lambda: {"ok": True}) + + +class CancelRequest(JSONRPCRequest): + method: str = Field(default="$/cancelRequest", frozen=True) + params: dict[str, Any] + + +class InitializeResponse(JSONRPCResponse): + """ + Response to an initialize request. + + Note: This must be a properly formatted JSON-RPC response with a `result` field + containing the initialization data, not another request. + """ + + result: InitializeResult + + def model_dump_json(self, **kwargs: Any) -> str: + """Convert to JSON string with proper formatting.""" + # Convert to dict + data = { + "jsonrpc": self.jsonrpc, + "id": self.id, + "result": self.result.model_dump(exclude_none=True), + } + + # Return JSON string + return json.dumps(data, ensure_ascii=False) + + +class ListToolsRequest(JSONRPCRequest): + method: str = Field(default="tools/list", frozen=True) + params: dict[str, Any] | None = None + + +class ToolAnnotations(BaseModel): + """ + Represents tool annotations for hints about behavior. + """ + + title: str | None = None + readOnlyHint: bool | None = None + destructiveHint: bool | None = None + idempotentHint: bool | None = None + openWorldHint: bool | None = None + model_config = ConfigDict(extra="allow") + + +class Tool(BaseModel): + """ + Represents an MCP tool definition. + """ + + name: str + description: str + inputSchema: dict[str, Any] | None = None + annotations: ToolAnnotations | None = None + + model_config = ConfigDict(extra="allow") + + +class ListToolsResult(BaseModel): + tools: list[Tool] + + +class ListToolsResponse(JSONRPCResponse): + result: ListToolsResult + + +class CallToolRequest(JSONRPCRequest): + method: str = Field(default="tools/call", frozen=True) + params: dict[str, Any] + + +class CallToolResult(BaseModel): + content: Any + + +class CallToolResponse(JSONRPCResponse): + result: CallToolResult + + +# Resource and Prompt protocol stubs (expand as needed) +class ListResourcesRequest(JSONRPCRequest): + method: str = Field(default="resources/list", frozen=True) + params: dict[str, Any] | None = None + + +class ListResourcesResponse(JSONRPCResponse): + result: dict[str, Any] + + +class ListPromptsRequest(JSONRPCRequest): + method: str = Field(default="prompts/list", frozen=True) + params: dict[str, Any] | None = None + + +class ListPromptsResponse(JSONRPCResponse): + result: dict[str, Any] + + +# Utility type alias for all MCP protocol messages +MCPMessage = Union[ + JSONRPCRequest, + JSONRPCResponse, + JSONRPCError, + PingRequest, + PingResponse, + InitializeRequest, + InitializeResponse, + ListToolsRequest, + ListToolsResponse, + CallToolRequest, + CallToolResponse, + ProgressNotification, + CancelRequest, + ShutdownRequest, + ShutdownResponse, + ListResourcesRequest, + ListResourcesResponse, + ListPromptsRequest, + ListPromptsResponse, +] diff --git a/arcade/tests/deployment/test_config.py b/arcade/tests/deployment/test_config.py index 531bfaac..7d62c461 100644 --- a/arcade/tests/deployment/test_config.py +++ b/arcade/tests/deployment/test_config.py @@ -7,7 +7,7 @@ from pathlib import Path import pytest -from arcade.worker.config.deployment import ( +from arcade.cli.deployment import ( Config, Deployment, LocalPackages, diff --git a/arcade/tests/mcp/test_convert.py b/arcade/tests/mcp/test_convert.py new file mode 100644 index 00000000..a4bcf0c2 --- /dev/null +++ b/arcade/tests/mcp/test_convert.py @@ -0,0 +1,44 @@ +import json +from typing import Annotated + +from arcade.core.catalog import ToolCatalog +from arcade.sdk import tool +from arcade.worker.mcp.convert import convert_to_mcp_content, create_mcp_tool + + +@tool +def sample_tool(x: Annotated[int, "first"], y: Annotated[int, "second"]) -> int: + """Return x+y""" + + return x + y + + +def test_convert_to_mcp_content_primitives(): + assert convert_to_mcp_content(42) == [{"type": "text", "text": "42"}] + assert convert_to_mcp_content("hello") == [{"type": "text", "text": "hello"}] + assert convert_to_mcp_content(True) == [{"type": "text", "text": "True"}] + + +def test_convert_to_mcp_content_complex(): + data = {"a": 1} + expected_json = json.dumps(data) + assert convert_to_mcp_content(data) == [{"type": "text", "text": expected_json}] + + +def test_create_mcp_tool(): + # Materialize a tool via catalog then feed it to create_mcp_tool + catalog = ToolCatalog() + catalog.add_tool(sample_tool, "convert_toolkit") + mat_tool = next(iter(catalog)) # only tool + mcp_tool = create_mcp_tool(mat_tool) + + assert mcp_tool is not None + assert mcp_tool["name"] == "ConvertToolkit_SampleTool" + assert mcp_tool["description"] + # Ensure input schema contains both parameters and marks them required + props = mcp_tool["inputSchema"]["properties"] + assert set(props.keys()) == {"x", "y"} + + required_fields = set(mcp_tool["inputSchema"].get("required", [])) + # Ensure no unexpected required fields and that declared ones are subset of expected + assert required_fields.issubset({"x", "y"}) diff --git a/arcade/tests/mcp/test_message_processor.py b/arcade/tests/mcp/test_message_processor.py new file mode 100644 index 00000000..fac0d91a --- /dev/null +++ b/arcade/tests/mcp/test_message_processor.py @@ -0,0 +1,57 @@ +import asyncio + +import pytest + +from arcade.worker.mcp.message_processor import MCPMessageProcessor, create_message_processor +from arcade.worker.mcp.types import InitializeRequest, PingRequest + + +@pytest.mark.asyncio +async def test_message_processor_parses_initialize_json(): + """Ensure JSON initialize strings are converted into InitializeRequest objects.""" + json_init = '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}\n' + processor = MCPMessageProcessor() + + result = await processor.process_request(json_init) + + assert isinstance(result, InitializeRequest) + assert result.id == 1 + assert result.method == "initialize" + + +@pytest.mark.asyncio +async def test_message_processor_passes_notifications_unchanged(): + """Unknown notifications should be passed through as parsed dictionaries without errors.""" + json_notification = '{"jsonrpc":"2.0","id":null,"method":"notifications/custom","params":{}}\n' + processor = MCPMessageProcessor() + + result = await processor.process_request(json_notification) + + # The MCPMessageProcessor keeps unknown notifications as simple dicts + assert isinstance(result, dict) + assert result["method"] == "notifications/custom" + + +@pytest.mark.asyncio +async def test_message_processor_middleware_execution_order(monkeypatch): + """Middleware (sync + async) should be executed in the order they were added.""" + + order: list[str] = [] + + def mw_sync(msg, direction): # type: ignore[return-value] + order.append("sync") + return msg + + async def mw_async(msg, direction): # type: ignore[return-value] + await asyncio.sleep(0) # ensure it is truly async + order.append("async") + return msg + + processor = create_message_processor(mw_sync, mw_async) + + # Use a pre-parsed PingRequest instance so we don't test parsing again here + ping = PingRequest(id=42) + + _ = await processor.process_request(ping) + + assert order == ["sync", "async"] diff --git a/arcade/tests/mcp/test_server.py b/arcade/tests/mcp/test_server.py new file mode 100644 index 00000000..bfb75eeb --- /dev/null +++ b/arcade/tests/mcp/test_server.py @@ -0,0 +1,153 @@ +import sys +import types +from typing import Annotated, Any + +import pytest + +from arcade.core.catalog import ToolCatalog +from arcade.sdk import tool +from arcade.worker.mcp import server as mcp_server +from arcade.worker.mcp.types import ( + CallToolRequest, + CancelRequest, + InitializeRequest, + ListToolsRequest, + PingRequest, +) + +# --------------------------------------------------------------------------- +# Test helpers / stubs +# --------------------------------------------------------------------------- + + +class _FakeAuth: + async def authorize(self, auth_requirement: Any, user_id: str): + """Return an object that mimics AuthorizationResponse with completed status.""" + + class _Ctx: # minimal stub + token = "dummy-token" # noqa: S105 + + class _Resp: # pylint: disable=too-few-public-methods + status = "completed" + url = "" + context = _Ctx() + + return _Resp() + + +class _FakeArcade: # pylint: disable=too-few-public-methods + def __init__(self, **_: Any): + self.auth = _FakeAuth() + + +# Ensure that the AsyncArcade & ArcadeError symbols inside server.py point to our stubs. +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(autouse=True) +def _patch_arcadepy(monkeypatch): + """Patch the external `arcadepy` dependency used by mcp.server.""" + + # Patch the imported symbols on the already-imported server module + monkeypatch.setattr(mcp_server, "AsyncArcade", _FakeArcade, raising=True) + monkeypatch.setattr(mcp_server, "ArcadeError", Exception, raising=True) + + # Provide a dummy `arcadepy` module in sys.modules for any other importers + fake_arcadepy = types.ModuleType("arcadepy") + fake_arcadepy.AsyncArcade = _FakeArcade # type: ignore[attr-defined] + fake_arcadepy.ArcadeError = Exception # type: ignore[attr-defined] + sys.modules["arcadepy"] = fake_arcadepy + + yield + + # Cleanup + sys.modules.pop("arcadepy", None) + + +# --------------------------------------------------------------------------- +# Fixtures for a sample tool / catalog / server +# --------------------------------------------------------------------------- + + +@tool +def multiply(a: Annotated[int, "a"], b: Annotated[int, "b"]) -> Annotated[int, "result"]: + """Return the product of *a* and *b*.""" + + return a * b + + +@pytest.fixture(scope="module") +def sample_catalog(): + catalog = ToolCatalog() + catalog.add_tool(multiply, "test_toolkit") + return catalog + + +@pytest.fixture() +def server(sample_catalog): + # MCPServer constructor is synchronous, so fixture need not be async + return mcp_server.MCPServer(sample_catalog, enable_logging=False) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +async def test_handle_ping(server): + req = PingRequest(id=123) + resp = await server._handle_ping(req) # pylint: disable=protected-access + assert resp.id == 123 + assert resp.result == {"pong": True} + + +async def test_handle_initialize(server): + req = InitializeRequest(id=1) + resp = await server._handle_initialize(req) # pylint: disable=protected-access + assert resp.id == 1 + assert resp.result.protocolVersion == mcp_server.MCP_PROTOCOL_VERSION + assert resp.result.serverInfo.name.startswith("Arcade") + + +async def test_handle_list_tools(server): + req = ListToolsRequest(id=99) + resp = await server._handle_list_tools(req) # pylint: disable=protected-access + assert resp.id == 99 + # Should list our sample tool only + tool_names = [t.name for t in resp.result.tools] + assert "TestToolkit_Multiply" in tool_names # toolkit + "_" + tool + + +async def test_handle_call_tool_success(server): + req = CallToolRequest( + id="call-1", + params={ + "name": "TestToolkit_Multiply", + "input": {"a": 6, "b": 7}, + }, + ) + resp = await server._handle_call_tool(req, user_id="tester@example.com") # pylint: disable=protected-access + + assert resp.id == "call-1" + # convert_to_mcp_content wraps primitives in list-of-dicts + assert resp.result.content == [{"type": "text", "text": "42"}] + + +async def test_send_response_dict(server, monkeypatch): + """_send_response should JSON-serialize plain dictionaries.""" + + sent: list[str] = [] + + class _Write: + async def send(self, msg): + sent.append(msg) + + await server._send_response(_Write(), {"foo": "bar"}) # pylint: disable=protected-access + + assert sent and sent[0].strip() == '{"foo": "bar"}' + + +async def test_handle_cancel(server): + req = CancelRequest(id=77, params={"id": "abc"}) + resp = await server._handle_cancel(req) # pylint: disable=protected-access + assert resp.result == {"ok": True} diff --git a/arcade/tests/mcp/test_stdio.py b/arcade/tests/mcp/test_stdio.py new file mode 100644 index 00000000..6d512157 --- /dev/null +++ b/arcade/tests/mcp/test_stdio.py @@ -0,0 +1,32 @@ +import io +import queue + +from arcade.worker.mcp.stdio import stdio_reader, stdio_writer + + +def test_stdio_reader_puts_lines_and_none(): + q: queue.Queue[str | None] = queue.Queue() + test_input = io.StringIO("line1\nline2\n") + + stdio_reader(test_input, q) + + # We should get the two lines followed by None sentinel + assert q.get_nowait() == "line1\n" + assert q.get_nowait() == "line2\n" + assert q.get_nowait() is None + + +def test_stdio_writer_reads_until_none(): + q: queue.Queue[str | None] = queue.Queue() + output_stream = io.StringIO() + + # preload queue with two messages and sentinel + q.put("msg1") + q.put("msg2\n") + q.put(None) + + stdio_writer(output_stream, q) + + # Ensure writer appended newlines when missing + output_stream.seek(0) + assert output_stream.read() == "msg1\nmsg2\n" diff --git a/arcade/tests/worker/test_worker_base.py b/arcade/tests/worker/test_worker_base.py new file mode 100644 index 00000000..f50eb871 --- /dev/null +++ b/arcade/tests/worker/test_worker_base.py @@ -0,0 +1,237 @@ +import os +from typing import Annotated +from unittest.mock import MagicMock + +import pytest + +from arcade.core.errors import ToolDefinitionError +from arcade.core.schema import ( + ToolCallRequest, + ToolCallResponse, + ToolContext, + ToolReference, +) +from arcade.sdk import tool +from arcade.worker.core.base import BaseWorker +from arcade.worker.core.common import RequestData, Router +from arcade.worker.core.components import ( + CallToolComponent, + CatalogComponent, + HealthCheckComponent, +) + + +@tool() +def sample_tool( + context: ToolContext, a: Annotated[int, "a"], b: Annotated[int, "b"] +) -> Annotated[int, "output"]: + """Sample tool for testing.""" + return a + b + + +# Define error tool at module level to avoid indentation issues with getsource +@tool() +def error_tool(context: ToolContext) -> int: + """This tool always raises an error.""" + raise ValueError("Something went wrong") + + +@pytest.fixture +def mock_router(): + router = MagicMock(spec=Router) + router.add_route = MagicMock() + return router + + +@pytest.fixture +def base_worker(mock_router): + # Set env var temporarily for testing secret loading + os.environ["ARCADE_WORKER_SECRET"] = "test_secret_env" # noqa: S105 + worker = BaseWorker() + worker.register_routes(mock_router) # Register routes using the mock router + # Clean up env var + del os.environ["ARCADE_WORKER_SECRET"] + return worker + + +@pytest.fixture +def base_worker_no_auth(): + return BaseWorker(disable_auth=True) + + +# --- BaseWorker Tests --- + + +def test_base_worker_init_with_secret(): + worker = BaseWorker(secret="explicit_secret") # noqa: S106 + assert worker.secret == "explicit_secret" # noqa: S105 + assert not worker.disable_auth + + +def test_base_worker_init_with_env_secret(): + os.environ["ARCADE_WORKER_SECRET"] = "env_secret_value" # noqa: S105 + worker = BaseWorker() + assert worker.secret == "env_secret_value" # noqa: S105 + assert not worker.disable_auth + del os.environ["ARCADE_WORKER_SECRET"] + + +def test_base_worker_init_no_secret_raises_error(): + # Ensure env var is not set + if "ARCADE_WORKER_SECRET" in os.environ: + del os.environ["ARCADE_WORKER_SECRET"] + with pytest.raises(ValueError, match="No secret provided for worker"): + BaseWorker() + + +def test_base_worker_init_disable_auth(): + worker = BaseWorker(disable_auth=True) + assert worker.secret == "" + assert worker.disable_auth + + +def test_register_tool(base_worker_no_auth): + assert len(base_worker_no_auth.catalog) == 0 + base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit") + assert len(base_worker_no_auth.catalog) == 1 + tool_def = base_worker_no_auth.get_catalog()[0] + assert tool_def.name == "SampleTool" + assert tool_def.toolkit.name == "TestKit" + + +def test_get_catalog(base_worker_no_auth): + base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit") + catalog = base_worker_no_auth.get_catalog() + assert isinstance(catalog, list) + assert len(catalog) == 1 + assert catalog[0].name == "SampleTool" + + +def test_health_check(base_worker_no_auth): + base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit") + health = base_worker_no_auth.health_check() + assert health == {"status": "ok", "tool_count": "1"} + + +@pytest.mark.asyncio +async def test_call_tool_success(base_worker_no_auth): + base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit") + # Create ToolReference WITHOUT version, as register_tool doesn't seem to set it + tool_ref = ToolReference(toolkit="TestKit", name="SampleTool") + tool_request = ToolCallRequest( + execution_id="test_exec_id", + tool=tool_ref, + inputs={"a": 5, "b": 3}, + ) + + response = await base_worker_no_auth.call_tool(tool_request) + + assert response.success is True + assert response.output.value == 8 + assert response.output.error is None + assert response.execution_id == "test_exec_id" + assert response.duration > 0 + + +@pytest.mark.asyncio +async def test_call_tool_execution_error(base_worker_no_auth): + # Tool is now defined at module level + try: + base_worker_no_auth.register_tool(error_tool, toolkit_name="error_kit") + except ToolDefinitionError as e: + pytest.fail(f"Failed to register error_tool: {e}") + + # Create ToolReference WITHOUT version + tool_ref = ToolReference(toolkit="ErrorKit", name="ErrorTool") + tool_request = ToolCallRequest( + execution_id="test_exec_error", + tool=tool_ref, + inputs={}, + ) + + response = await base_worker_no_auth.call_tool(tool_request) + + assert response.success is False + assert response.output.value is None + assert response.output.error is not None + + +@pytest.mark.asyncio +async def test_call_tool_not_found(base_worker_no_auth): + # Use ToolReference without version for lookup consistency + tool_ref = ToolReference(toolkit="nonexistent", name="nosuchtool") + tool_request = ToolCallRequest( + execution_id="test_exec_notfound", + tool=tool_ref, + inputs={}, + ) + + # Update regex to match actual error format + with pytest.raises(ValueError): + await base_worker_no_auth.call_tool(tool_request) + + +# --- Component Tests (tested via BaseWorker registration) --- + + +def test_register_routes_registers_default_components(base_worker, mock_router): + # BaseWorker calls register_routes in its init via the fixture + assert mock_router.add_route.call_count == len(BaseWorker.default_components) + + calls = mock_router.add_route.call_args_list + expected_paths = ["tools", "tools/invoke", "health"] + registered_paths = [ + call[0][0] for call in calls + ] # call[0] are positional args, call[0][0] is endpoint_path + + assert sorted(registered_paths) == sorted(expected_paths) + + # Check if components were instantiated and passed to add_route + assert any(isinstance(call[0][1], CatalogComponent) for call in calls) + assert any(isinstance(call[0][1], CallToolComponent) for call in calls) + assert any(isinstance(call[0][1], HealthCheckComponent) for call in calls) + + +@pytest.mark.asyncio +async def test_catalog_component_call(base_worker_no_auth): + base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit") + component = CatalogComponent(base_worker_no_auth) + # Mock request data - not actually used by this component's __call__ + mock_request = MagicMock(spec=RequestData) + catalog_response = await component(mock_request) + + assert isinstance(catalog_response, list) + assert len(catalog_response) == 1 + assert catalog_response[0].name == "SampleTool" + + +@pytest.mark.asyncio +async def test_call_tool_component_call(base_worker_no_auth): + base_worker_no_auth.register_tool(sample_tool, toolkit_name="test_kit") + component = CallToolComponent(base_worker_no_auth) + + # Create ToolReference WITHOUT version + tool_ref = ToolReference(toolkit="TestKit", name="SampleTool") + request_body = { + "execution_id": "comp_test_exec", + "tool": tool_ref.model_dump(), + "inputs": {"a": 10, "b": 5}, + } + mock_request = MagicMock(spec=RequestData) + mock_request.body_json = request_body + + response = await component(mock_request) + + assert isinstance(response, ToolCallResponse) + assert response.success is True + assert response.output.value == 15 + assert response.execution_id == "comp_test_exec" + + +@pytest.mark.asyncio +async def test_health_check_component_call(base_worker_no_auth): + component = HealthCheckComponent(base_worker_no_auth) + mock_request = MagicMock(spec=RequestData) + health_response = await component(mock_request) + + assert health_response == {"status": "ok", "tool_count": "0"} diff --git a/arcade/tests/worker/test_worker_fastapi.py b/arcade/tests/worker/test_worker_fastapi.py new file mode 100644 index 00000000..83c6f105 --- /dev/null +++ b/arcade/tests/worker/test_worker_fastapi.py @@ -0,0 +1,167 @@ +from typing import Annotated + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from arcade.core.schema import ToolCallRequest, ToolContext, ToolReference +from arcade.sdk import tool +from arcade.worker.fastapi.worker import FastAPIWorker + + +@tool() +def sample_tool_fastapi( + context: ToolContext, x: Annotated[int, "x"], y: Annotated[str, "y"] +) -> Annotated[str, "output"]: + """A sample tool for FastAPI tests.""" + return f"{y}-{x}" + + +# Define tool at module level to avoid indentation issues with getsource +@tool() +def error_throwing_tool( + context: ToolContext, + a: Annotated[int, "a", "Input integer a"], # Added description for parameter +) -> int: + """This tool throws a ValueError.""" # Added description for tool + raise ValueError("Test execution error") + + +@pytest.fixture +def test_app(): + return FastAPI() + + +@pytest.fixture +def worker_secret(): + return "test-secret-fastapi" + + +@pytest.fixture +def fastapi_worker(test_app, worker_secret): + worker = FastAPIWorker(app=test_app, secret=worker_secret) + worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit") + return worker + + +@pytest.fixture +def fastapi_worker_no_auth(test_app): + worker = FastAPIWorker(app=test_app, disable_auth=True) + worker.register_tool(sample_tool_fastapi, toolkit_name="fastapi_kit") + return worker + + +@pytest.fixture +def client(test_app, fastapi_worker): # Use the worker fixture to ensure routes are registered + return TestClient(test_app) + + +@pytest.fixture +def client_no_auth(test_app, fastapi_worker_no_auth): + return TestClient(test_app) + + +# --- FastAPIWorker Tests --- + + +def test_fastapi_worker_registers_routes(client, fastapi_worker): + # Check if routes exist by trying to access them (even if auth fails) + response = client.get(f"{fastapi_worker.base_path}/health") + assert response.status_code != 404 # Should be 200 + + response = client.get(f"{fastapi_worker.base_path}/tools") + assert response.status_code != 404 # Should be 403 without auth + + # Prepare a dummy request body for invoke + tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi") + request_body = ToolCallRequest( + execution_id="test", tool=tool_ref, inputs={"x": 1, "y": "test"} + ).model_dump() + + response = client.post(f"{fastapi_worker.base_path}/tools/invoke", json=request_body) + assert response.status_code != 404 # Should be 403 without auth + + +# --- Route Tests (using TestClient) --- + + +# Health Check +def test_health_check_route(client, worker_secret): + response = client.get("/worker/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok", "tool_count": "1"} + + +def test_health_check_route_no_auth(client_no_auth): + response = client_no_auth.get("/worker/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok", "tool_count": "1"} + + +# Catalog +def test_get_catalog_route_no_auth_header(client): + response = client.get("/worker/tools") + assert response.status_code == 403 + assert "Not authenticated" in response.text + + +def test_get_catalog_route_invalid_auth_header(client, worker_secret): + response = client.get("/worker/tools", headers={"Authorization": "Bearer invalid-token"}) + assert response.status_code == 401 # Unauthorized + # Updated expected error message based on last run + assert "Invalid token. Error: Not enough segments" in response.text + + +def test_get_catalog_route_no_auth_worker(client_no_auth): + response = client_no_auth.get("/worker/tools") + assert response.status_code == 200 + catalog = response.json() + assert isinstance(catalog, list) + assert len(catalog) == 1 + assert catalog[0]["name"] == "SampleToolFastapi" + + +# Call Tool +@pytest.fixture +def call_tool_payload(): + tool_ref = ToolReference(toolkit="FastapiKit", name="SampleToolFastapi") + return ToolCallRequest( + execution_id="fastapi-test-exec", tool=tool_ref, inputs={"x": 123, "y": "hello"} + ).model_dump() + + +def test_call_tool_route_no_auth_header(client, call_tool_payload): + response = client.post("/worker/tools/invoke", json=call_tool_payload) + assert response.status_code == 403 + + +def test_call_tool_route_invalid_auth_header(client, worker_secret, call_tool_payload): + response = client.post( + "/worker/tools/invoke", + json=call_tool_payload, + headers={"Authorization": "Bearer invalid-token"}, + ) + assert response.status_code == 401 + + +def test_call_tool_route_no_auth_worker(client_no_auth, call_tool_payload): + response = client_no_auth.post("/worker/tools/invoke", json=call_tool_payload) + assert response.status_code == 200 + result = response.json() + assert result["success"] is True + assert result["output"]["value"] == "hello-123" + + +def test_call_tool_route_tool_not_found(client_no_auth, call_tool_payload): + call_tool_payload["tool"]["name"] = "NonExistentTool" + call_tool_payload["tool"]["toolkit"] = "FastapiKit" + + with pytest.raises(ValueError): + _ = client_no_auth.post( + "/worker/tools/invoke", + json=call_tool_payload, + ) + # The handler catches the ValueError and returns a 500 internal server error + # Ideally, this might be a 404 or 400, but BaseWorker.call_tool raises ValueError + # which isn't automatically mapped to a 4xx by FastAPI unless handled explicitly. + # TODO fix this. diff --git a/examples/mcp/claude.json b/examples/mcp/claude.json new file mode 100644 index 00000000..d3f9832f --- /dev/null +++ b/examples/mcp/claude.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "arcade": { + "command": "bash", + "args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade mcp"] + } + } +} diff --git a/examples/mcp/run_stdio.py b/examples/mcp/run_stdio.py new file mode 100644 index 00000000..3011640e --- /dev/null +++ b/examples/mcp/run_stdio.py @@ -0,0 +1,25 @@ +import arcade_google # pip install arcade_google +import arcade_search # pip install arcade_search + +from arcade.core.catalog import ToolCatalog +from arcade.worker.mcp.stdio import StdioServer + +# 2. Create and populate the tool catalog +catalog = ToolCatalog() +catalog.add_module(arcade_google) # Registers all tools in the package +catalog.add_module(arcade_search) + + +# 3. Main entrypoint +async def main(): + # Create the worker with the tool catalog + worker = StdioServer(catalog) + + # Run the worker + await worker.run() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/examples/simple_chatbot.py b/examples/simple_chatbot.py deleted file mode 100644 index 859d914e..00000000 --- a/examples/simple_chatbot.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Example script demonstrating how to build a simple chatbot with Arcade. - -For this example, we are using the prebuilt Google Docs toolkit to create and edit documents. - -Try asking questions like: -- "Create a document with the title 'My New Document' and content 'Hello, World!'" -- "List my 2 most recently modified documents and tell me the title, document id, and document URL of each one and summarize them." -- "Edit the second document from the list you just returned and add the text 'Hello, World!' to the end of it." -""" - -import os - -from openai import OpenAI - - -def chat(openai_client: OpenAI, tool_names: list[str], user_id: str) -> None: - history = [] - - print("Hello! How can I help you today?") - while True: - message = {"role": "user", "content": input(">")} - history.append(message) - chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history) - - # If the tool call requires authorization, then wait for the user to authorize and then call the tool again - if ( - chat_result.choices[0].tool_authorizations - and chat_result.choices[0].tool_authorizations[0].get("status") == "pending" - ): - print("\n" + chat_result.choices[0].message.content) - input("\nAfter you have authorized, press Enter to continue...") - chat_result = call_tool_with_openai(openai_client, tool_names, user_id, history) - - history.append({"role": "assistant", "content": chat_result.choices[0].message.content}) - - print(chat_result.choices[0].message.content) - - -def call_tool_with_openai( - client: OpenAI, tool_names: list[str], user_id: str, messages: list[dict] -) -> dict: - response = client.chat.completions.create( - messages=messages, - model="gpt-4o-mini", - user=user_id, - tools=tool_names, - tool_choice="generate", - ) - - return response - - -if __name__ == "__main__": - arcade_api_key = os.environ.get( - "ARCADE_API_KEY" - ) # If you forget your Arcade API key, it is stored at ~/.arcade/credentials.yaml on `arcade login` - cloud_host = "https://api.arcade.dev/v1" - user_id = "user@example.com" - - openai_client = OpenAI( - api_key=arcade_api_key, - base_url=cloud_host, - ) - - tool_names = [ - "Google.SendEmail", - "Google.SendDraftEmail", - "Google.WriteDraftEmail", - "Google.UpdateDraftEmail", - "Google.ListDraftEmails", - "Google.ListEmailsByHeader", - "Google.ListEmails", - ] - - chat(openai_client, tool_names, user_id)