diff --git a/examples/mcp_servers/authorization/.dockerignore b/examples/mcp_servers/authorization/.dockerignore new file mode 100644 index 00000000..ea8cf11c --- /dev/null +++ b/examples/mcp_servers/authorization/.dockerignore @@ -0,0 +1,33 @@ +# Virtual environment +.venv/ +venv/ +env/ + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Distribution +dist/ +build/ +*.egg-info/ + +# Docker +docker/ +.dockerignore +Dockerfile +docker-compose.yml diff --git a/examples/mcp_servers/authorization/docker/Dockerfile b/examples/mcp_servers/authorization/docker/Dockerfile new file mode 100644 index 00000000..ffeeae71 --- /dev/null +++ b/examples/mcp_servers/authorization/docker/Dockerfile @@ -0,0 +1,45 @@ +FROM ghcr.io/astral-sh/uv:python3.11-bookworm-slim + +# Create non-root user +RUN useradd -m -u 1000 appuser + +WORKDIR /app + +# Copy project files +COPY pyproject.toml uv.lock ./ +COPY src/ ./src/ + +# Auto-detect package name from pyproject.toml +# First try using Python's tomllib +# Fallback to grep/sed for compatibility +RUN PACKAGE_NAME=$(python3 -c "import tomllib; f=open('pyproject.toml','rb'); data=tomllib.load(f); print(data['project']['name'])" 2>/dev/null || \ + grep -E '^name\s*=' pyproject.toml | head -1 | sed -E "s/.*name\s*=\s*[\"']([^\"']+)[\"'].*/\1/" || \ + grep -E '^name\s*=' pyproject.toml | head -1 | sed -E 's/.*name\s*=\s*([^ ]+).*/\1/') && \ + if [ -z "$PACKAGE_NAME" ]; then \ + echo "ERROR: Could not detect package name from pyproject.toml" && exit 1; \ + fi && \ + echo "Detected package: $PACKAGE_NAME" && \ + echo "$PACKAGE_NAME" > /tmp/package_name.txt + +# Install dependencies +RUN uv sync --frozen --no-dev + +# Change ownership to non-root user +RUN chown -R appuser:appuser /app + +USER appuser + +# Expose the port +EXPOSE 8001 + +# Run the server from src//server.py +CMD PACKAGE_NAME=$(cat /tmp/package_name.txt) && \ + if [ -f "src/${PACKAGE_NAME}/server.py" ]; then \ + uv run src/${PACKAGE_NAME}/server.py; \ + else \ + echo "ERROR: Could not find server.py at src/${PACKAGE_NAME}/server.py" && \ + echo " Package detected: ${PACKAGE_NAME}" && \ + echo " Available directories in src/:" && \ + ls -la src/ 2>/dev/null || echo " src/ directory not found" && \ + exit 1; \ + fi diff --git a/examples/mcp_servers/authorization/docker/README.md b/examples/mcp_servers/authorization/docker/README.md new file mode 100644 index 00000000..ec27839e --- /dev/null +++ b/examples/mcp_servers/authorization/docker/README.md @@ -0,0 +1,93 @@ +# Docker Setup for MCP Servers + +This directory contains a generalized Docker configuration template that can be used with any MCP server in this repository. + +## Quick Start + +1. **Copy the Docker files to your MCP server directory:** + + ```bash + cp -r examples/docker-template/docker your-mcp-server/ + cp examples/docker-template/.dockerignore your-mcp-server/ + ``` + +2. **Build and run:** + + ```bash + cd your-mcp-server + docker-compose -f docker/docker-compose.yml up --build + ``` + +## Configuration + +### Package Detection + +The Dockerfile uses the package name from `pyproject.toml` by reading the `[project] name` field. It expects your server file at `src//server.py` (where `` is from `pyproject.toml`). + +If the server file is not found at this location, then the build will fail with an error message showing the detected package name and available directories in `src/`. + +### Environment Variables + +- `ARCADE_SERVER_TRANSPORT`: The transport protocol to use + - Default: `http` + - Options: `http`, `stdio` +- `ARCADE_SERVER_PORT`: The port to run the server on (internal) + - Default: `8001` +- `ARCADE_SERVER_HOST`: The host to bind to + - Default: `0.0.0.0` + +### Example: Simple MCP Server + +```bash +# From examples/mcp_servers/simple/ +docker-compose -f docker/docker-compose.yml up --build +``` + +The server will run internally on port 8001 but be accessible externally on port 8080 (http://localhost:8080). This demonstrates front-door auth working when the canonical URL differs from the internal bind address. + +You can customize the ports by editing `docker/docker-compose.yml` and changing: +- The port mapping (e.g., "8080:8001") +- The `ARCADE_SERVER_PORT` environment variable (internal port) +- The `MCP_RESOURCE_SERVER_CANONICAL_URL` (external URL) +## Building the Image + +```bash +docker build \ + -f docker/Dockerfile \ + -t your-mcp-server \ + . +``` + +## Running with Docker + +```bash +docker run -p 8080:8001 \ + -e ARCADE_SERVER_TRANSPORT=http \ + -e ARCADE_SERVER_HOST=0.0.0.0 \ + -e ARCADE_SERVER_PORT=8001 \ + your-mcp-server +``` + +## Features + +- **Automatic package detection**: Reads package name from `pyproject.toml` +- **Standard server location**: Expects server file at `src//server.py` +- **Secure by default**: Runs as non-root user +- **Arcade environment variable support**: Uses `ARCADE_SERVER_*` environment variables +- **Environment-based config**: Easy customization via environment variables +- **uv integration**: Uses uv for fast dependency management +- **Lightweight**: Based on Python 3.11 Bookworm slim image with uv + +## Connecting from Cursor + +Add to your `~/.cursor/mcp.json`: + +```json +"your-server-name": { + "name": "your-server-name", + "type": "stream", + "url": "http://localhost:8080/mcp" +} +``` + +Then restart Cursor to connect to the server. diff --git a/examples/mcp_servers/authorization/docker/docker-compose.yml b/examples/mcp_servers/authorization/docker/docker-compose.yml new file mode 100644 index 00000000..86091fb5 --- /dev/null +++ b/examples/mcp_servers/authorization/docker/docker-compose.yml @@ -0,0 +1,12 @@ +services: + mcp-server: + build: + context: .. + dockerfile: docker/Dockerfile + ports: + - "8080:8001" # External port 8080 maps to internal port 8001 + environment: + - ARCADE_SERVER_TRANSPORT=http + - ARCADE_SERVER_HOST=0.0.0.0 + - ARCADE_SERVER_PORT=8001 + - MCP_RESOURCE_SERVER_CANONICAL_URL=http://127.0.0.1:8080/mcp diff --git a/examples/mcp_servers/authorization/pyproject.toml b/examples/mcp_servers/authorization/pyproject.toml new file mode 100644 index 00000000..c71ea451 --- /dev/null +++ b/examples/mcp_servers/authorization/pyproject.toml @@ -0,0 +1,45 @@ +[project] +name = "authorization" +version = "0.1.0" +description = "MCP Server created with Arcade.dev" +requires-python = ">=3.10" +dependencies = [ + "arcade-mcp-server>=1.8.0,<2.0.0", + "httpx>=0.28.0,<1.0.0", +] + +[project.optional-dependencies] +dev = [ + "arcade-mcp[all]>=1.5.2,<2.0.0", + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "mypy>=1.0.0", + "ruff>=0.1.0", +] + +# Tell Arcade.dev that this package has Arcade tools +[project.entry-points.arcade_toolkits] +toolkit_name = "authorization" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/authorization"] + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.mypy] +python_version = "3.12" +warn_unused_configs = true +disallow_untyped_defs = false + +# # Uncomment the following if you are developing inside of the arcade-mcp repo & want to use editable mode +# # Otherwise, you will install the following packages from PyPI +# [tool.uv.sources] +# arcade-mcp = { path = "../../../", editable = true } +# arcade-serve = { path = "../../../libs/arcade-serve/", editable = true } +# arcade-mcp-server = { path = "../../../libs/arcade-mcp-server/", editable = true } diff --git a/examples/mcp_servers/authorization/src/authorization/.env.example b/examples/mcp_servers/authorization/src/authorization/.env.example new file mode 100644 index 00000000..cdbf0cc9 --- /dev/null +++ b/examples/mcp_servers/authorization/src/authorization/.env.example @@ -0,0 +1,16 @@ +# Server Auth environment variables +MCP_RESOURCE_SERVER_CANONICAL_URL="http://127.0.0.1:8000/mcp" +MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS='[ + { + "authorization_server_url": "https://your-workos.authkit.app", + "issuer": "https://your-workos.authkit.app", + "jwks_uri": "https://your-workos.authkit.app/oauth2/jwks", + "algorithm": "RS256", + "verify_options": { + "verify_aud": false + } + } +]' + +# Tool Secrets +MY_SECRET_KEY="Your tools can have secrets injected at runtime!" diff --git a/examples/mcp_servers/authorization/src/authorization/__init__.py b/examples/mcp_servers/authorization/src/authorization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/mcp_servers/authorization/src/authorization/server.py b/examples/mcp_servers/authorization/src/authorization/server.py new file mode 100644 index 00000000..e6635dc7 --- /dev/null +++ b/examples/mcp_servers/authorization/src/authorization/server.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +"""authorization MCP server""" + +from typing import Annotated + +import httpx +from arcade_mcp_server import Context, MCPApp +from arcade_mcp_server.auth import Reddit +from arcade_mcp_server.resource_server import ( + AuthorizationServerEntry, + ResourceServerAuth, +) + +# Option 1: Single authorization server with custom audience +# Use expected_audiences when your auth server returns a non-standard audience (aud) claim +# (e.g., client_id instead of canonical_url) +resource_server_auth = ResourceServerAuth( + canonical_url="http://127.0.0.1:8000/mcp", + authorization_servers=[ + AuthorizationServerEntry( # WorkOS Authkit example configuration + authorization_server_url="https://your-workos.authkit.app", + issuer="https://your-workos.authkit.app", + jwks_uri="https://your-workos.authkit.app/oauth2/jwks", + expected_audiences=["your-authkit-client-id"], # Override expected aud claim + ), + ], +) + +# Option 2: Multiple authorization servers with different keys (e.g., multi-IdP) +# resource_server_auth = ResourceServerAuth( +# canonical_url="http://127.0.0.1:8000/mcp", +# authorization_servers=[ +# AuthorizationServerEntry( # WorkOS Authkit example configuration +# authorization_server_url="https://your-workos.authkit.app", +# issuer="https://your-workos.authkit.app", +# jwks_uri="https://your-workos.authkit.app/oauth2/jwks", +# expected_audiences=["your-authkit-client-id"], +# ), +# AuthorizationServerEntry( # Keycloak example configuration +# authorization_server_url="http://localhost:8080/realms/mcp-test", +# issuer="http://localhost:8080/realms/mcp-test", +# jwks_uri="http://localhost:8080/realms/mcp-test/protocol/openid-connect/certs", +# algorithm="RS256", +# expected_audiences=["your-keycloak-client-id"], +# ) +# ], +# ) + +# Option 3: Authorization via env vars (place in your .env file) +# ```bash +# MCP_RESOURCE_SERVER_CANONICAL_URL=http://127.0.0.1:8000/mcp +# MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS='[ +# { +# "authorization_server_url": "https://your-workos.authkit.app", +# "issuer": "https://your-workos.authkit.app", +# "jwks_uri": "https://your-workos.authkit.app/oauth2/jwks", +# "algorithm": "RS256", +# "expected_audiences": ["your-authkit-client-id"] +# } +# ]' +# ``` +# resource_server_auth = ResourceServerAuth() + +app = MCPApp(name="authorization", version="1.0.0", log_level="DEBUG", auth=resource_server_auth) + + +@app.tool +def greet(name: Annotated[str, "The name of the person to greet"]) -> str: + """Greet a person by name.""" + return f"Hello, {name}!" + + +@app.tool(requires_secrets=["MY_SECRET_KEY"]) +def whisper_secret(context: Context) -> Annotated[str, "The last 4 characters of the secret"]: + """Reveal the last 4 characters of a secret""" + try: + secret = context.get_secret("MY_SECRET_KEY") + except Exception as e: + return str(e) + + return "The last 4 characters of the secret are: " + secret[-4:] + + +# To use this tool locally, you need to install the Arcade CLI (uv tool install arcade-mcp) +# and then run 'arcade login' to authenticate. +@app.tool(requires_auth=Reddit(scopes=["read"])) +async def get_posts_in_subreddit( + context: Context, subreddit: Annotated[str, "The name of the subreddit"] +) -> dict: + """Get posts from a specific subreddit""" + subreddit = subreddit.lower().replace("r/", "").replace(" ", "") + oauth_token = context.get_auth_token_or_empty() + headers = { + "Authorization": f"Bearer {oauth_token}", + "User-Agent": "authorization-mcp-server", + } + params = {"limit": 5} + url = f"https://oauth.reddit.com/r/{subreddit}/hot" + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, params=params) + response.raise_for_status() + + return response.json() + + +if __name__ == "__main__": + app.run(transport="http", host="127.0.0.1", port=8000) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/__init__.py index 81214296..a33d801c 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/__init__.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/__init__.py @@ -36,7 +36,7 @@ __all__ = [ # Integrated Factory and Runner "create_arcade_mcp", "run_arcade_mcp", - # Re-exported TDK functionality + # Re-exported from TDK functionality "tool", ] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/context.py b/libs/arcade-mcp-server/arcade_mcp_server/context.py index 9fc7a5e1..0bb6fd87 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/context.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/context.py @@ -37,6 +37,7 @@ from arcade_core.schema import ( ToolContext, ) +from arcade_mcp_server.resource_server.base import ResourceOwner from arcade_mcp_server.types import ( AudioContent, CallToolParams, @@ -124,6 +125,7 @@ class Context(ToolContext): server: Any, session: Any | None = None, request_id: str | None = None, + resource_owner: ResourceOwner | None = None, ): """Initialize context with server reference.""" super().__init__() @@ -133,6 +135,9 @@ class Context(ToolContext): self._notification_queue: set[str] = set() self._request_id: str | None = request_id + # Resource owner from front-door auth (if the server is protected) + self._resource_owner: ResourceOwner | None = resource_owner + # Namespaced adapters self._log = Logs(self) self._progress = Progress(self) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/fastapi/auth_routes.py b/libs/arcade-mcp-server/arcade_mcp_server/fastapi/auth_routes.py new file mode 100644 index 00000000..df9c9146 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/fastapi/auth_routes.py @@ -0,0 +1,98 @@ +"""FastAPI routes for MCP Resource Server authorization endpoints. + +The routes defined here enable MCP clients to discover authorization servers +associated with this MCP server. +""" + +import logging +from urllib.parse import urlparse + +from fastapi import APIRouter +from fastapi.responses import JSONResponse + +from arcade_mcp_server.resource_server.base import ResourceServerValidator + +logger = logging.getLogger(__name__) + + +def create_auth_router( + resource_server_validator: ResourceServerValidator, + canonical_url: str | None, +) -> APIRouter: + """Create FastAPI router with OAuth discovery endpoints. + + The well-known URI is constructed by inserting the well-known path after the host. + If the canonical URL has a path component, then it becomes a suffix on the well-known path. + + For example: + - canonical_url "https://example.com" -> "/.well-known/oauth-protected-resource" + - canonical_url "https://example.com/mcp" -> "/.well-known/oauth-protected-resource/mcp" + + Args: + resource_server_validator: The resource server validator instance + canonical_url: Canonical URL of the MCP server + + Returns: + APIRouter configured with OAuth discovery endpoints + """ + router = APIRouter(tags=["MCP Protocol"]) + + path_suffix = "" + if canonical_url: + parsed = urlparse(canonical_url) + path_suffix = parsed.path + + well_known_base = "/.well-known/oauth-protected-resource" + well_known_path = f"{well_known_base}{path_suffix}" + + async def oauth_protected_resource() -> JSONResponse: + """OAuth 2.0 Protected Resource Metadata (RFC 9728)""" + if not canonical_url: + return JSONResponse( + {"error": "Server canonical URL not configured"}, + status_code=500, + ) + + metadata = resource_server_validator.get_resource_metadata() + if metadata is None: + logger.error( + "Resource metadata unavailable for OAuth discovery endpoint. " + "This is unexpected - the validator should provide metadata if OAuth discovery is enabled." + ) + return JSONResponse( + {"error": "Resource metadata not available"}, + status_code=500, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + + return JSONResponse( + metadata, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + + # Register the well-known endpoint at the RFC 9728 compliant path + router.add_api_route( + well_known_path, + oauth_protected_resource, + methods=["GET"], + name="oauth_protected_resource", + ) + + # Also register at base path if there's a suffix for extra compatibility + if path_suffix: + router.add_api_route( + well_known_base, + oauth_protected_resource, + methods=["GET"], + include_in_schema=False, + ) + + return router diff --git a/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py b/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py index 3c4fb555..91c59471 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py @@ -19,12 +19,12 @@ from arcade_core.catalog import MaterializedTool, ToolCatalog, ToolDefinitionErr from arcade_tdk.auth import ToolAuthorization from arcade_tdk.error_adapters import ErrorAdapter from arcade_tdk.tool import tool as tool_decorator -from dotenv import load_dotenv from loguru import logger from watchfiles import watch from arcade_mcp_server.exceptions import ServerError from arcade_mcp_server.logging_utils import intercept_standard_logging +from arcade_mcp_server.resource_server.base import ResourceServerValidator from arcade_mcp_server.server import MCPServer from arcade_mcp_server.settings import MCPSettings, ServerSettings from arcade_mcp_server.types import Prompt, PromptMessage, Resource @@ -75,6 +75,7 @@ class MCPApp: host: str = "127.0.0.1", port: int = 8000, reload: bool = False, + auth: ResourceServerValidator | None = None, **kwargs: Any, ): """ @@ -90,6 +91,7 @@ class MCPApp: host: Host for transport port: Port for transport reload: Enable auto-reload for development + auth: Resource Server validator for front-door authentication **kwargs: Additional server configuration """ self._name = self._validate_name(name) @@ -97,6 +99,7 @@ class MCPApp: self.title = title or name self.instructions = instructions self.log_level = log_level + self.resource_server_validator = auth self.server_kwargs = kwargs self.transport = transport self.host = host @@ -123,7 +126,6 @@ class MCPApp: # Store the actual instructions that ended up in ServerSettings self.instructions = self._mcp_settings.server.instructions - self._load_env() if not logger._core.handlers: # type: ignore[attr-defined] self._setup_logging(transport == "stdio") @@ -193,13 +195,6 @@ class MCPApp: """Runtime resources API: add/remove/list.""" return _ResourcesAPI(self) - def _load_env(self) -> None: - """Load .env file from the current directory.""" - env_path = Path.cwd() / ".env" - if env_path.exists(): - load_dotenv(env_path, override=False) - logger.info(f"Loaded environment from {env_path}") - def _setup_logging(self, stdio_mode: bool = False) -> None: logger.remove() @@ -313,6 +308,24 @@ class MCPApp: logger.info(f"Starting {self._name} v{self.version} with {len(self._catalog)} tools") if transport in ["http", "streamable-http", "streamable"]: + resource_server_auth_enabled = isinstance( + self.resource_server_validator, ResourceServerValidator + ) + if resource_server_auth_enabled: + logger.info("Resource Server authentication is enabled. MCP routes are protected.") + else: + logger.warning( + "Resource Server authentication is disabled. MCP routes are not protected, so tools requiring auth or secrets will fail." + ) + if ( + isinstance(self.resource_server_validator, ResourceServerValidator) + and self.resource_server_validator.supports_oauth_discovery() + ): + metadata = self.resource_server_validator.get_resource_metadata() + if metadata: + auth_servers = metadata.get("authorization_servers", []) + logger.info(f"Accepted authorization server(s): {', '.join(auth_servers)}") + if reload: self._run_with_reload(host, port) else: @@ -326,6 +339,7 @@ class MCPApp: host=None, port=None, tool_count=len(self._catalog), + resource_server_validator=self.resource_server_validator, ) asyncio.run( run_stdio_server( @@ -403,6 +417,7 @@ class MCPApp: catalog=self._catalog, mcp_settings=self._mcp_settings, debug=debug, + resource_server_validator=self.resource_server_validator, **self.server_kwargs, ) @@ -412,6 +427,7 @@ class MCPApp: host=host, port=port, tool_count=len(self._catalog), + resource_server_validator=self.resource_server_validator, ) asyncio.run(serve_with_force_quit(app=app, host=host, port=port, log_level=log_level)) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/README.md b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/README.md new file mode 100644 index 00000000..6213140b --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/README.md @@ -0,0 +1,161 @@ +# MCP Resource Server Authentication + +OAuth 2.1-compliant Resource Server authentication for securing HTTP-based MCP servers. + +## Overview + +The MCP server acts as an OAuth 2.1 **Resource Server**, validating Bearer tokens on **every HTTP request** before processing MCP protocol messages. This enables: + +1. **Secure HTTP Transport** - Protect your MCP server with OAuth 2.1 +2. **Tool-Level Authorization** - Enable tools requiring end-user OAuth on HTTP transport +3. **OAuth Discovery** - MCP clients automatically discover authentication requirements via OAuth Protected Resource Metadata (RFC 9728) +4. **User Context** - Tools receive authenticated resource owner identity from the Authorization Server + +MCP servers can accept tokens from one or more authorization servers. Accepting tokens from multiple authorization servers supports scenarios like regional endpoints, multiple identity providers, or migrating between auth systems. + +**Note:** The MCP server (Resource Server) doesn't need to know how MCP clients are registered with the Authorization Server (for example, Dynamic Client Registration, static client secrets, etc.) - that's the authorization server's concern. The MCP server simply validates tokens and advertises the AS URLs. + +## Environment Variable Configuration + +`ResourceServerAuth` supports environment variable configuration for production deployments. This is the **recommended approach for production**. + +**Note:** `JWKSTokenValidator` does not support environment variables and requires explicit programmatic parameters to its initializer + +### Supported Environment Variables + +| Environment Variable | Type | Description | Required | +|---------------------|------|-------------|----------| +| `MCP_RESOURCE_SERVER_CANONICAL_URL` | string | MCP server canonical URL | Yes | +| `MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS` | JSON array | Authorization server entries | Yes | + +The `MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS` must be a JSON array of entry objects. Each object should include: +- `authorization_server_url`: Authorization server URL +- `issuer`: Expected token issuer +- `jwks_uri`: JWKS endpoint URL +- `algorithm`: (Optional) JWT algorithm, defaults to RS256 +- `expected_audiences`: (Optional) list of expected audience claim values. If not provided, defaults to the canonical_url. Use this when your auth server returns a different aud claim (e.g., client_id). +- `validation_options`: (Optional) dict with optional `verify_exp`, `verify_iat`, `verify_iss`, `verify_nbf`, and `leeway` (int, seconds). All verify flags default to True. + +### Precedence Rules + +**Explicit parameters take precedence over environment variables:** + +```python +from arcade_mcp_server import MCPApp +from arcade_mcp_server.resource_server import ( + AuthorizationServerEntry, + ResourceServerAuth, +) + +# Explicit parameters override env vars (if both are provided) +resource_server_auth = ResourceServerAuth( + canonical_url="http://127.0.0.1:8000/mcp", # used even if env var is set + authorization_servers=[ # used even if env var is set + AuthorizationServerEntry( + authorization_server_url="https://your-workos.authkit.app", + issuer="https://your-workos.authkit.app", + jwks_uri="https://your-workos.authkit.app/oauth2/jwks", + algorithm="RS256", + # Override expected aud if auth server returns different audience (e.g., client_id) + expected_audiences=["my-authkit-client-id"], + ) + ], +) +app = MCPApp(name="Protected", auth=resource_server_auth) + +# If no parameters provided, env vars are used as fallback +resource_server_auth = ResourceServerAuth() # Uses MCP_RESOURCE_SERVER_* env vars +``` + +### Example .env File + +#### Single Authorization Server + +```bash +# Resource Server Configuration +MCP_RESOURCE_SERVER_CANONICAL_URL=https://mcp.example.com/mcp +MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS='[ + { + "authorization_server_url": "https://auth.example.com", + "issuer": "https://auth.example.com", + "jwks_uri": "https://auth.example.com/.well-known/jwks.json", + "algorithm": "RS256" + } +]' +``` + +#### Single Authorization Server (Custom Audience) + +When your auth server returns a different `aud` claim (e.g., client_id instead of canonical URL): + +```bash +MCP_RESOURCE_SERVER_CANONICAL_URL=https://mcp.example.com/mcp +MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS='[ + { + "authorization_server_url": "https://auth.example.com", + "issuer": "https://auth.example.com", + "jwks_uri": "https://auth.example.com/.well-known/jwks.json", + "algorithm": "RS256", + "expected_audiences": ["my-client-id"] + } +]' +``` + +#### Multiple Authorization Servers (Shared Keys) + +```bash +# Regional endpoints with shared keys +MCP_RESOURCE_SERVER_CANONICAL_URL=https://mcp.example.com/mcp +MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS='[ + { + "authorization_server_url": "https://auth-us.example.com", + "issuer": "https://auth.example.com", + "jwks_uri": "https://auth.example.com/.well-known/jwks.json" + }, + { + "authorization_server_url": "https://auth-eu.example.com", + "issuer": "https://auth.example.com", + "jwks_uri": "https://auth.example.com/.well-known/jwks.json" + } +]' +``` + +#### Multiple Authorization Servers (Different Keys) + +```bash +# Multi-IdP configuration with custom audiences +MCP_RESOURCE_SERVER_CANONICAL_URL=https://mcp.example.com/mcp +MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS='[ + { + "authorization_server_url": "https://workos.authkit.app", + "issuer": "https://workos.authkit.app", + "jwks_uri": "https://workos.authkit.app/oauth2/jwks", + "expected_audiences": ["my-workos-client-id"] + }, + { + "authorization_server_url": "http://localhost:8080/realms/mcp-test", + "issuer": "http://localhost:8080/realms/mcp-test", + "jwks_uri": "http://localhost:8080/realms/mcp-test/protocol/openid-connect/certs", + "expected_audiences": ["my-keycloak-client-id"] + } +]' +``` + +### How It Works + +1. **Resource Server validates tokens** - Extracts user identity from validated token's `sub` claim +2. **User ID flows to ToolContext** - Used for tool-level OAuth via Arcade platform +3. **Transport restriction lifted** - HTTP is now safe for tools requiring auth/secrets +4. **Separate authorization layers** - Resource Server auth != tool OAuth (but building a protected server enables tool authorization) + +## Vendor-Specific Implementations + +The `ResourceServerAuth` class is designed to be subclassed for vendor-specific implementations: + +```python +# Your vendor-specific implementations +class ArcadeResourceServerAuth(ResourceServerAuth): ... +class WorkOSResourceServerAuth(ResourceServerAuth): ... +class Auth0ResourceServerAuth(ResourceServerAuth): ... +class DescopeResourceServerAuth(ResourceServerAuth): ... +``` diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/__init__.py new file mode 100644 index 00000000..cccbe6b4 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/__init__.py @@ -0,0 +1,23 @@ +""" +MCP Resource Server authentication. + +This module provides OAuth 2.1 Resource Server capabilities for MCP servers. +It enables MCP servers to validate Bearer tokens on every HTTP request +before processing MCP messages. +""" + +from arcade_mcp_server.resource_server.base import ( + AccessTokenValidationOptions, + AuthorizationServerEntry, +) +from arcade_mcp_server.resource_server.validators import ( + JWKSTokenValidator, + ResourceServerAuth, +) + +__all__ = [ + "AccessTokenValidationOptions", + "AuthorizationServerEntry", + "JWKSTokenValidator", + "ResourceServerAuth", +] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/base.py b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/base.py new file mode 100644 index 00000000..f7f05df1 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/base.py @@ -0,0 +1,168 @@ +"""Base classes for MCP Resource Server authentication.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from pydantic import BaseModel, Field + + +class AccessTokenValidationOptions(BaseModel): + """Options for access token validation. + + All validations are enabled by default for security. + Set to False to disable specific validations for authorization servers + that are not compliant with MCP. + + Note: Token signature verification and audience validation are always enabled + and cannot be disabled. Additionally, the subject (sub claim) must always be + present in the token. + """ + + verify_exp: bool = Field( + default=True, + description="Verify token expiration (exp claim)", + ) + verify_iat: bool = Field( + default=True, + description="Verify issued-at time (iat claim)", + ) + verify_iss: bool = Field( + default=True, + description="Verify issuer claim (iss claim)", + ) + verify_nbf: bool = Field( + default=True, + description="Verify not-before time (nbf claim). Rejects tokens used before their activation time.", + ) + leeway: int = Field( + default=0, + description="Clock skew tolerance in seconds for exp/nbf validation. Recommended: 30-60 seconds.", + ) + + +@dataclass +class ResourceOwner: + """User information extracted from validated access token. + + This represents the authenticated resource owner (end-user) making requests + to the MCP server. The user_id typically comes from the 'sub' (subject) claim + in JWT tokens. + """ + + user_id: str + """User identifier from token (typically 'sub' claim)""" + + client_id: str | None = None + """OAuth client identifier from 'client_id' or 'azp' claim""" + + email: str | None = None + """User email if available in token claims""" + + claims: dict[str, Any] = field(default_factory=dict) + """All claims from the validated token for advanced use cases""" + + +@dataclass +class AuthorizationServerEntry: + """Configuration entry for a single authorization server. + + Each authorization server that can issue valid tokens for this + MCP server (Resource Server) needs its own entry specifying how to + verify tokens from that server. + """ + + authorization_server_url: str + """Authorization server URL for client discovery (RFC 9728)""" + + issuer: str + """Expected issuer claim in JWT tokens from this server""" + + jwks_uri: str + """JWKS endpoint to fetch public keys for token verification""" + + algorithm: str = "RS256" + """JWT signature algorithm (RS256, ES256, PS256, etc.)""" + + expected_audiences: list[str] | None = None + """Optional list of expected audience claims. If not provided, + defaults to the MCP server's canonical_url. Use this when your + authorization server returns a different aud claim (e.g., client_id).""" + + validation_options: AccessTokenValidationOptions = field( + default_factory=AccessTokenValidationOptions + ) + """Token validation options for this authorization server""" + + +class AuthenticationError(Exception): + """Base authentication error.""" + + pass + + +class TokenExpiredError(AuthenticationError): + """Token has expired.""" + + pass + + +class InvalidTokenError(AuthenticationError): + """Token is invalid (signature, audience, issuer, etc.).""" + + pass + + +class ResourceServerValidator(ABC): + """Base class for MCP Resource Server token validation. + + An MCP server acts as an OAuth 2.1 Resource Server, validating Bearer tokens + on every HTTP request. Implementations must validate tokens according to + OAuth 2.1 Resource Server requirements, including: + - Token signature verification + - Expiration checking + - Issuer validation + - Audience validation + + Tokens are validated on every request - no caching is permitted per MCP spec. + """ + + @abstractmethod + async def validate_token(self, token: str) -> ResourceOwner: + """Validate bearer token and return authenticated resource owner info. + + Must validate: + - Token signature + - Expiration + - Issuer (matches expected authorization server) + - Audience (matches this MCP server's canonical URL) + + Args: + token: Bearer token from Authorization header + + Returns: + ResourceOwner with user_id and claims + + Raises: + TokenExpiredError: Token has expired + InvalidTokenError: Token is invalid (signature, audience, issuer mismatch) + AuthenticationError: Other validation errors + """ + pass + + def supports_oauth_discovery(self) -> bool: + """Whether this validator supports OAuth discovery endpoints. + + Returns True if the validator can serve OAuth 2.0 Protected Resource Metadata + (RFC 9728) to enable MCP clients to discover authorization servers. + """ + return False + + def get_resource_metadata(self) -> dict[str, Any] | None: + """Return OAuth Protected Resource Metadata (RFC 9728) if supported. + + Returns: + Metadata dictionary with 'resource' and 'authorization_servers' fields, + or None if discovery is not supported. + """ + return None diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/middleware.py b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/middleware.py new file mode 100644 index 00000000..c5add389 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/middleware.py @@ -0,0 +1,201 @@ +"""ASGI middleware for MCP Resource Server authentication.""" + +from urllib.parse import urlparse, urlunparse + +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + +from arcade_mcp_server.resource_server.base import ( + AuthenticationError, + InvalidTokenError, + ResourceOwner, + ResourceServerValidator, + TokenExpiredError, +) + + +class ResourceServerMiddleware: + """ASGI middleware that validates Bearer tokens on every HTTP request. + + Validates tokens per MCP specification: + - Checks Authorization header for Bearer token + - Validates token on every request + - Returns 401 with WWW-Authenticate header if authentication fails + - Stores authenticated resource owner in scope for downstream use to lift + tool-auth and tool-secrets restrictions + + The WWW-Authenticate header includes: + - resource_metadata URL for OAuth discovery (if validator supports it) + - error and error_description for token validation failures (RFC 6750) + """ + + def __init__( + self, + app: ASGIApp, + validator: ResourceServerValidator, + canonical_url: str | None, + ): + """Initialize the Resource Server middleware. + + Args: + app: ASGI application to wrap + validator: Token validator for access token validation + canonical_url: Canonical URL of this MCP server (for OAuth metadata). + Required only for validators that support OAuth discovery. + """ + self.app = app + self.validator = validator + self.canonical_url = canonical_url + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process ASGI request with authentication. + + For HTTP requests: + 1. Allow CORS preflight OPTIONS requests to pass through + 2. Extract Bearer token from Authorization header + 3. Validate token (on EVERY request - no caching) + 4. Store authenticated resource owner in scope + 5. Pass to wrapped app + + For non-HTTP requests, pass through without auth. + """ + # Only process HTTP requests + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive) + + # Allow CORS preflight requests to pass through without authentication. + # Browsers send OPTIONS requests without Authorization headers to check + # if the cross-origin request is allowed before sending the actual request. + if request.method == "OPTIONS": + response = self._create_cors_preflight_response() + await response(scope, receive, send) + return + + try: + resource_owner = await self._authenticate_request(request) + + # Store in scope for downstream usage & continue to app execution + scope["resource_owner"] = resource_owner + await self.app(scope, receive, send) + + except (TokenExpiredError, InvalidTokenError) as e: + response = self._create_401_response( + error="invalid_token", + error_description=str(e), + ) + await response(scope, receive, send) + + except AuthenticationError: + response = self._create_401_response() + await response(scope, receive, send) + + async def _authenticate_request(self, request: Request) -> ResourceOwner: + """Extract and validate Bearer token from Authorization header. + + Args: + request: Starlette request object + + Returns: + ResourceOwner from validated token + + Raises: + AuthenticationError: No token or invalid format + TokenExpiredError: Token has expired + InvalidTokenError: Token signature/audience/issuer invalid + """ + auth_header = request.headers.get("Authorization") + + if not auth_header: + raise AuthenticationError("No Authorization header") + + if not auth_header.startswith("Bearer "): + raise AuthenticationError("Invalid Authorization header format.") + + # Remove "Bearer " prefix + token = auth_header[7:] + + return await self.validator.validate_token(token) + + def _build_metadata_url(self) -> str: + """Build the OAuth Protected Resource Metadata URL per RFC 9728. + + For example, for a canonical_url of "https://example.com/mcp" the metadata URL is: + "https://example.com/.well-known/oauth-protected-resource/mcp" + + Returns: + Metadata URL + """ + if not self.canonical_url: + return "" + + parsed = urlparse(self.canonical_url) + # Insert well-known path after host, with resource path as suffix + well_known_path = f"/.well-known/oauth-protected-resource{parsed.path}" + return urlunparse((parsed.scheme, parsed.netloc, well_known_path, "", "", "")) + + def _create_cors_preflight_response(self) -> Response: + """Create a CORS preflight response for OPTIONS requests. + + Returns: + Response with 204 status and CORS headers + """ + return Response( + content=None, + status_code=204, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization, Mcp-Session-Id, Accept", + "Access-Control-Expose-Headers": "WWW-Authenticate, Mcp-Session-Id", + "Access-Control-Max-Age": "86400", # 24 hr + }, + ) + + def _create_401_response( + self, + error: str | None = None, + error_description: str | None = None, + ) -> Response: + """Create RFC 6750 + RFC 9728 compliant 401 response. + + The WWW-Authenticate header format follows: + - RFC 6750 (OAuth 2.0 Bearer Token Usage) + - RFC 9728 (OAuth 2.0 Protected Resource Metadata) + + Args: + error: Error code (e.g., "invalid_token") + error_description: Human-readable error description + + Returns: + Response with 401 status with WWW-Authenticate header + """ + www_auth_parts = [] + + # Add resource metadata URL if validator supports discovery (RFC 9728) + if self.validator.supports_oauth_discovery() and self.canonical_url: + metadata_url = self._build_metadata_url() + www_auth_parts.append(f'resource_metadata="{metadata_url}"') + + # Add error details if token validation failed (RFC 6750) + if error: + www_auth_parts.append(f'error="{error}"') + if error_description: + www_auth_parts.append(f'error_description="{error_description}"') + + www_auth_value = "Bearer " + ", ".join(www_auth_parts) + + return Response( + content="Unauthorized", + status_code=401, + headers={ + "WWW-Authenticate": www_auth_value, + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization, Mcp-Session-Id, Accept", + "Access-Control-Expose-Headers": "WWW-Authenticate, Mcp-Session-Id", + }, + ) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/__init__.py new file mode 100644 index 00000000..880196bc --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/__init__.py @@ -0,0 +1,13 @@ +""" +Token validator implementations for MCP Resource Servers. + +Provides concrete implementations of ResourceServerValidator for different auth scenarios. +""" + +from arcade_mcp_server.resource_server.validators.auth import ResourceServerAuth +from arcade_mcp_server.resource_server.validators.jwks import JWKSTokenValidator + +__all__ = [ + "JWKSTokenValidator", + "ResourceServerAuth", +] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/auth.py b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/auth.py new file mode 100644 index 00000000..9d2c4adc --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/auth.py @@ -0,0 +1,209 @@ +"""ResourceServerAuth implementation with OAuth discovery metadata support. + +This module provides the base ResourceServerAuth class that validates JWT tokens +from one or more authorization servers and provides OAuth 2.0 Protected Resource +Metadata (RFC 9728) for discovery. + +Vendor specific implementations (WorkOS, Auth0, Descope, etc.) should inherit +from ResourceServerAuth. +""" + +from typing import Any + +from arcade_mcp_server.resource_server.base import ( + AuthenticationError, + AuthorizationServerEntry, + InvalidTokenError, + ResourceOwner, + ResourceServerValidator, + TokenExpiredError, +) +from arcade_mcp_server.resource_server.validators.jwks import JWKSTokenValidator +from arcade_mcp_server.settings import MCPSettings + + +class ResourceServerAuth(ResourceServerValidator): + """OAuth 2.1 Resource Server with discovery metadata support. + + This class implements the MCP server's role as an OAuth 2.1 Resource Server, + validating JWT tokens from one or more authorization servers and providing + OAuth 2.0 Protected Resource Metadata (RFC 9728) for discovery. + + """ + + def __init__( + self, + authorization_servers: list[AuthorizationServerEntry] | None = None, + canonical_url: str | None = None, + cache_ttl: int = 3600, + ): + """Initialize Resource Server. + + Supports environment variable configuration via MCP_RESOURCE_SERVER_* variables. + Explicit parameters take precedence over environment variables. + + Args: + authorization_servers: List of authorization server entries + canonical_url: MCP server canonical URL (or MCP_RESOURCE_SERVER_CANONICAL_URL) + cache_ttl: JWKS cache TTL in seconds + + Raises: + ValueError: If required fields not provided via params or env vars + + Example: + ```python + # Option 1: Use environment variables + # Set MCP_RESOURCE_SERVER_CANONICAL_URL and MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS env vars + resource_server_auth = ResourceServerAuth() + + # Option 2: Single Authorization Server (aud claim matches canonical_url) + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/jwks", + ) + ], + ) + + # Option 3: Custom audience (when auth server returns different aud claim) + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://workos.authkit.app", + issuer="https://workos.authkit.app", + jwks_uri="https://workos.authkit.app/oauth2/jwks", + expected_audiences=["my-authkit-client-id"], # Override expected aud + ), + AuthorizationServerEntry( # Keycloak example configuration + authorization_server_url="http://localhost:8080/realms/mcp-test", + issuer="http://localhost:8080/realms/mcp-test", + jwks_uri="http://localhost:8080/realms/mcp-test/protocol/openid-connect/certs", + algorithm="RS256", + expected_audiences=["my-keycloak-client-id"], + ), + ], + ) + ``` + """ + settings = MCPSettings.from_env() + + self.cache_ttl = cache_ttl + + # Explicit parameters take precedence over environment variables + if canonical_url is not None: + self.canonical_url = canonical_url + elif settings.resource_server.canonical_url is not None: + self.canonical_url = settings.resource_server.canonical_url + else: + raise ValueError( + "'canonical_url' required (parameter or MCP_RESOURCE_SERVER_CANONICAL_URL environment variable)" + ) + + if authorization_servers is not None: + configs = authorization_servers + elif settings.resource_server.authorization_servers: + configs = settings.resource_server.to_authorization_server_entries() + else: + raise ValueError( + "'authorization_servers' required (parameter or MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS environment variable)" + ) + + self._validators = self._create_validators(configs) + + self._resource_metadata = self._build_resource_metadata() + + def _build_resource_metadata(self) -> dict[str, Any]: + """Build RFC 9728 Protected Resource Metadata + + Returns: + Dictionary containing resource metadata per RFC 9728 + """ + return { + "resource": self.canonical_url, + "authorization_servers": list(self._validators.keys()), + "bearer_methods_supported": ["header"], + } + + def _create_validators( + self, entries: list[AuthorizationServerEntry] + ) -> dict[str, JWKSTokenValidator]: + """Create a mapping of authorization server URLs to their JWKSTokenValidator instances. + + Args: + entries: List of authorization server entries + + Returns: + Dictionary that maps authorization_server_url to its JWKSTokenValidator instance + """ + validators = {} + + for entry in entries: + # Use expected_audiences if provided, otherwise default to canonical_url + audience = ( + entry.expected_audiences if entry.expected_audiences else [self.canonical_url] + ) + validators[entry.authorization_server_url] = JWKSTokenValidator( + jwks_uri=entry.jwks_uri, + issuer=entry.issuer, + audience=audience, + algorithm=entry.algorithm, + cache_ttl=self.cache_ttl, + validation_options=entry.validation_options, + ) + + return validators + + async def validate_token(self, token: str) -> ResourceOwner: + """Validate the given token against each configured authorization server. + + Tries each validator until one succeeds. If all fail, raises InvalidTokenError. + + Error handling strategy: + - TokenExpiredError: Raise immediately. If any validator raises this, the token + is expired for all authorization servers (expiration is universal). No point + trying other validators. + - InvalidTokenError/AuthenticationError: Continue to next validator because another + authorization server might accept the token. These errors indicate wrong issuer, + audience, or signature mismatch. + + Args: + token: JWT Bearer token + + Returns: + ResourceOwner with user_id, client_id, and claims + + Raises: + TokenExpiredError: Token has expired + InvalidTokenError: Token signature, algorithm, audience, or issuer is invalid + AuthenticationError: Other validation errors + """ + for validator in self._validators.values(): + try: + return await validator.validate_token(token) + except TokenExpiredError: + raise + except (InvalidTokenError, AuthenticationError): + continue + + raise InvalidTokenError("Token validation failed for all configured authorization servers") + + def supports_oauth_discovery(self) -> bool: + """This Resource Server supports OAuth discovery.""" + return True + + def get_resource_metadata(self) -> dict[str, Any]: + """Return RFC 9728 Protected Resource Metadata. + + This metadata tells MCP clients: + 1. What resource this server protects (canonical URL) + 2. Which authorization server(s) can issue tokens for this resource + 3. Supported bearer token methods + + Returns: + Dictionary containing resource metadata per RFC 9728 + """ + return self._resource_metadata diff --git a/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/jwks.py b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/jwks.py new file mode 100644 index 00000000..de018dd4 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/resource_server/validators/jwks.py @@ -0,0 +1,329 @@ +""" +JWKS-based token validator for MCP Resource Servers. + +Implements OAuth 2.1 Resource Server token validation using JWT with JWKS. +""" + +import time +from typing import Any, cast + +import httpx +from jose import jwk, jwt + +from arcade_mcp_server.resource_server.base import ( + AccessTokenValidationOptions, + AuthenticationError, + InvalidTokenError, + ResourceOwner, + ResourceServerValidator, + TokenExpiredError, +) + +# Note: Only asymmetric algorithms supported +SUPPORTED_ALGORITHMS = { + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", +} + + +class JWKSTokenValidator(ResourceServerValidator): + """JWKS-based JWT token validator for simple, explicit token validation. + + This validator fetches public keys from a JWKS endpoint and validates + JWT access tokens against them. Use this when you need direct control + over token validation without OAuth discovery support. + """ + + def __init__( + self, + jwks_uri: str, + issuer: str | list[str], + audience: str | list[str], + algorithm: str = "RS256", + cache_ttl: int = 3600, + validation_options: AccessTokenValidationOptions | None = None, + ): + """Initialize JWKS token validator. + + Args: + jwks_uri: URL to fetch JWKS + issuer: Token issuer or list of allowed issuers + audience: Token audience or list of allowed audiences (typically your MCP server's canonical URL) + algorithm: Signature algorithm. Default RS256. + cache_ttl: JWKS cache TTL in seconds + validation_options: Access token validation options + + Raises: + ValueError: If required fields not provided or algorithm unsupported + + Example: + ```python + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/jwks", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + # Multiple issuers + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/jwks", + issuer=["https://auth1.example.com", "https://auth2.example.com"], + audience="https://mcp.example.com/mcp", + ) + + # Multiple audiences (e.g., URL migration) + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/jwks", + issuer="https://auth.example.com", + audience=["https://old-mcp.example.com/mcp", "https://new-mcp.example.com/mcp"], + ) + + # Different algorithm + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/jwks", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + algorithm="ES256", + ) + ``` + """ + if algorithm not in SUPPORTED_ALGORITHMS: + raise ValueError( + f"Unsupported algorithm '{algorithm}'. " + f"Supported asymmetric algorithms: {', '.join(sorted(SUPPORTED_ALGORITHMS))}" + ) + + if validation_options is None: + validation_options = AccessTokenValidationOptions() + + self.jwks_uri = jwks_uri + self.issuer = issuer + self.audience = audience + self.algorithm = algorithm + self.validation_options = validation_options + + self._cache_ttl = cache_ttl + self._http_client = httpx.AsyncClient(timeout=10.0) + self._jwks_cache: dict[str, Any] | None = None + self._cache_timestamp: float = 0 + + async def _fetch_jwks(self) -> dict[str, Any]: + """Fetch JWKS with caching. + + Returns: + JWKS dictionary containing public keys + + Raises: + AuthenticationError: If JWKS cannot be fetched + """ + current_time = time.time() + + # Use cached JWKS if it's still valid + if self._jwks_cache and (current_time - self._cache_timestamp) < self._cache_ttl: + return self._jwks_cache + + try: + response = await self._http_client.get(self.jwks_uri) + response.raise_for_status() + self._jwks_cache = response.json() + self._cache_timestamp = current_time + except httpx.HTTPError as e: + raise AuthenticationError(f"Failed to fetch JWKS: {e}") from e + else: + return self._jwks_cache + + def _find_signing_key(self, jwks: dict[str, Any], token: str) -> Any: + """Find the signing key from JWKS that matches the token's kid. + + Args: + jwks: JSON Web Key Set + token: JWT token + + Returns: + Signing key in PEM format + + Raises: + InvalidTokenError: If no matching key found or algorithm mismatch + """ + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + token_alg = unverified_header.get("alg") + + # Validate token algorithm matches configuration (prevent algorithm confusion) + if token_alg and token_alg != self.algorithm: + raise InvalidTokenError( + f"Token algorithm '{token_alg}' doesn't match " + f"configured algorithm '{self.algorithm}'" + ) + + for key_data in jwks.get("keys", []): + if key_data.get("kid") == kid: + key_alg = key_data.get("alg") + + if key_alg and key_alg != self.algorithm: + raise InvalidTokenError( + f"Key algorithm '{key_alg}' doesn't match " + f"configured algorithm '{self.algorithm}'" + ) + + key_obj = jwk.construct(key_data, algorithm=self.algorithm) + return key_obj.to_pem().decode("utf-8") + + raise InvalidTokenError("No matching key found in JWKS") + + def _decode_token(self, token: str, signing_key: str) -> dict[str, Any]: + """Decode and verify the provided JWT token. + + Args: + token: JWT token + signing_key: Public key in PEM format + + Returns: + Decoded token claims + + Raises: + jwt.ExpiredSignatureError: Token has expired + jwt.JWTClaimsError: Token claims validation failed (audience/issuer mismatch) + jwt.JWTError: Token is invalid + """ + decode_options = { + "verify_signature": True, # Always verify signature. Cannot be disabled. + "verify_exp": self.validation_options.verify_exp, + "verify_iat": self.validation_options.verify_iat, + "verify_nbf": self.validation_options.verify_nbf, + "verify_aud": False, # Manual validation for multi-audience support + "verify_iss": False, # Manual validation for multi-issuer support + "leeway": self.validation_options.leeway, + } + + # Decode token once without aud/iss validation + decoded = cast( + dict[str, Any], + jwt.decode( + token, + signing_key, + algorithms=[self.algorithm], + options=decode_options, + ), + ) + + # Manually validate issuer (if flag is enabled) + if self.validation_options.verify_iss: + token_iss = decoded.get("iss") + if isinstance(self.issuer, list): + if token_iss not in self.issuer: + raise InvalidTokenError( + f"Token issuer '{token_iss}' not in allowed issuers: {self.issuer}" + ) + else: + if token_iss != self.issuer: + raise InvalidTokenError( + f"Token issuer '{token_iss}' doesn't match expected '{self.issuer}'" + ) + + # Always validate audience + token_aud = decoded.get("aud") + token_audiences = [token_aud] if isinstance(token_aud, str) else (token_aud or []) + expected_audiences = [self.audience] if isinstance(self.audience, str) else self.audience + + # Token is valid if any of its aud values match any of our expected values + if not (set(token_audiences) & set(expected_audiences)): + raise InvalidTokenError( + f"Token audience {token_aud} doesn't match expected {self.audience}" + ) + + return decoded + + def _extract_user_id(self, decoded: dict[str, Any]) -> str: + """Extract and validate user_id from decoded token. + + Args: + decoded: Decoded token claims + + Returns: + User ID from 'sub' claim + + Raises: + InvalidTokenError: If 'sub' claim is missing + """ + user_id = decoded.get("sub") + if not user_id: + raise InvalidTokenError("Token missing 'sub' claim") + return cast(str, user_id) + + def _extract_client_id(self, decoded: dict[str, Any]) -> str | None: + """Extract client ID from decoded token. + + Args: + decoded: Decoded token claims + + Returns: + Client identifier or "unknown" if no client claim found + """ + client_id = decoded.get("client_id") or decoded.get("azp") or "unknown" + + return client_id + + async def validate_token(self, token: str) -> ResourceOwner: + """Validate JWT and return authenticated resource owner. + + Always validates (cannot be disabled): + - Token signature using JWKS public key + - Subject (sub claim) exists + - Audience (aud claim) matches configured audience(s) + + Optionally validates (controlled by validation_options, all default to True): + - Expiration (exp claim) - verify_exp + - Issued-at time (iat claim) - verify_iat + - Not-before time (nbf claim) - verify_nbf + - Issuer (iss claim) matches configured issuer(s) - verify_iss + + Clock skew tolerance can be configured via validation_options.leeway (in seconds). + + Args: + token: JWT Bearer token + + Returns: + ResourceOwner with user_id, client_id, and claims + + Raises: + TokenExpiredError: Token has expired + InvalidTokenError: Token signature, algorithm, audience, or issuer is invalid + AuthenticationError: Other validation errors + """ + try: + jwks = await self._fetch_jwks() + signing_key = self._find_signing_key(jwks, token) + decoded = self._decode_token(token, signing_key) + user_id = self._extract_user_id(decoded) + client_id = self._extract_client_id(decoded) + email = decoded.get("email") + + return ResourceOwner( + user_id=user_id, + client_id=client_id, + email=email, + claims=decoded, + ) + + except jwt.ExpiredSignatureError as e: + raise TokenExpiredError("Token has expired") from e + except jwt.JWTClaimsError as e: + raise InvalidTokenError(f"Token claims validation failed: {e}") from e + except jwt.JWTError as e: + raise InvalidTokenError(f"Invalid token: {e}") from e + except (InvalidTokenError, TokenExpiredError): + raise + except Exception as e: + raise AuthenticationError(f"Token validation failed: {e}") from e + + async def close(self) -> None: + """Close the HTTP client.""" + await self._http_client.aclose() diff --git a/libs/arcade-mcp-server/arcade_mcp_server/server.py b/libs/arcade-mcp-server/arcade_mcp_server/server.py index 6bc92021..e7342133 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/server.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/server.py @@ -38,6 +38,7 @@ from arcade_mcp_server.middleware import ( Middleware, MiddlewareContext, ) +from arcade_mcp_server.resource_server.base import ResourceOwner from arcade_mcp_server.session import InitializationState, NotificationManager, ServerSession from arcade_mcp_server.settings import MCPSettings, ServerSettings from arcade_mcp_server.types import ( @@ -402,6 +403,7 @@ class MCPServer: self, message: Any, session: ServerSession | None = None, + resource_owner: ResourceOwner | None = None, ) -> MCPMessage | None: """ Handle an incoming message. @@ -409,6 +411,7 @@ class MCPServer: Args: message: Message to handle session: Server session + resource_owner: Authenticated resource owner from front-door auth Returns: Response message or None @@ -502,9 +505,13 @@ class MCPServer: # Create request context context = ( - await session.create_request_context() + await session.create_request_context(resource_owner=resource_owner) if session - else Context(self, request_id=str(msg_id) if msg_id else None) + else Context( + self, + request_id=str(msg_id) if msg_id else None, + resource_owner=resource_owner, + ) ) # Set as current model context @@ -678,7 +685,7 @@ class MCPServer: }, ) - async def _create_tool_context( + def _create_tool_context( self, tool: MaterializedTool, session: ServerSession | None = None ) -> ToolContext: """Create a tool context from a tool definition and session""" @@ -692,27 +699,53 @@ class MCPServer: elif secret.key in os.environ: tool_context.set_secret(secret.key, os.environ[secret.key]) - # user_id selection - env = (self.settings.arcade.environment or "").lower() - user_id = self.settings.arcade.user_id - - # If no user_id from env, try credentials file - if not user_id: - _, config_user_id = self._load_config_values() - user_id = config_user_id - - if user_id: - tool_context.user_id = user_id - logger.debug(f"Context user_id set: {user_id}") - elif env in ("development", "dev", "local"): - tool_context.user_id = session.session_id if session else None - logger.debug(f"Context user_id set from session (dev env={env})") - else: - tool_context.user_id = session.session_id if session else None - logger.debug("Context user_id set from session (non-dev env)") + tool_context.user_id = self._select_user_id(session) return tool_context + def _select_user_id(self, session: ServerSession | None = None) -> str | None: + """Select the user_id for the tool's context. + + User ID selection priority: + - Authenticated user from front-door auth + - Configured user_id from settings + - Configured user_id from credentials file + - Use session ID if no other user_id is available + + Args: + session: Server session + + Returns: + User ID for the context + """ + env = (self.settings.arcade.environment or "").lower() + + # First priority: resource owner from front-door auth (from current model context) + mctx = get_current_model_context() + if mctx is not None and hasattr(mctx, "_resource_owner") and mctx._resource_owner: + user_id = mctx._resource_owner.user_id + logger.debug(f"Context user_id set from Authorization Server 'sub' claim: {user_id}") + return cast(str, user_id) + + # Second priority: configured user_id from settings + if (settings_user_id := self.settings.arcade.user_id) is not None: + logger.debug(f"Context user_id set from settings: {settings_user_id}") + return settings_user_id + + # Third priority: configured user_id from credentials file + _, config_user_id = self._load_config_values() + if config_user_id: + logger.debug(f"Context user_id set from credentials file: {config_user_id}") + return config_user_id + + # Fourth priority: use session ID if no other user_id is available + if env in ("development", "dev", "local"): + logger.debug(f"Context user_id set from session (dev env={env})") + else: + logger.debug("Context user_id set from session (non-dev env)") + + return session.session_id if session else None + async def _check_and_warn_missing_secrets(self) -> None: """ Check for missing tool secrets and log warnings. @@ -761,7 +794,7 @@ class MCPServer: tool = await self._tool_manager.get_tool(tool_name) # Create tool context - tool_context = await self._create_tool_context(tool, session) + tool_context = self._create_tool_context(tool, session) # Check restrictions for unauthenticated HTTP transport if transport_restriction_response := self._check_transport_restrictions( @@ -906,13 +939,31 @@ class MCPServer: tool_name: str, session: ServerSession | None = None, ) -> JSONRPCResponse[CallToolResult] | None: - """Check transport restrictions for tools requiring auth or secrets""" + """Check transport restrictions for tools requiring auth or secrets. + + Tools requiring authorization or secrets are blocked on unauthenticated HTTP + transport for security reasons. However, if the HTTP transport has front-door + authentication enabled (resource_owner is present), these tools are allowed + since we can safely identify the end-user and handle their authorization. + """ # Check transport restrictions for tools requiring auth or secrets if session and session.init_options: transport_type = session.init_options.get("transport_type") if transport_type != "stdio": + # Get resource_owner from current model context (set during handle_message) + mctx = get_current_model_context() + is_authenticated = ( + mctx is not None + and hasattr(mctx, "_resource_owner") + and mctx._resource_owner is not None + ) + requirements = tool.definition.requirements - if requirements and (requirements.authorization or requirements.secrets): + if ( + requirements + and (requirements.authorization or requirements.secrets) + and not is_authenticated + ): documentation_url = "https://docs.arcade.dev/en/home/compare-server-types" user_message = "✗ Unsupported transport\n\n" user_message += f" Tool '{tool_name}' cannot run over HTTP transport for security reasons.\n" diff --git a/libs/arcade-mcp-server/arcade_mcp_server/session.py b/libs/arcade-mcp-server/arcade_mcp_server/session.py index 4a7c6b5b..8ba3ae4d 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/session.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/session.py @@ -18,6 +18,7 @@ import anyio from arcade_mcp_server.context import Context from arcade_mcp_server.exceptions import RequestError, SessionError +from arcade_mcp_server.resource_server.base import ResourceOwner from arcade_mcp_server.types import ( CancelledNotification, CancelledParams, @@ -37,6 +38,7 @@ from arcade_mcp_server.types import ( ProgressNotificationParams, PromptListChangedNotification, ResourceListChangedNotification, + SessionMessage, ToolListChangedNotification, ) @@ -361,11 +363,24 @@ class ServerSession: # Cancel any pending requests await self._cleanup_pending_requests() - async def _process_message(self, message: str) -> None: - """Process a single message.""" + async def _process_message(self, message: str | Any) -> None: + """Process a single message. + + Args: + message: Either a JSON string (stdio) or SessionMessage object (http) + """ try: - # Parse message - data = json.loads(message) + if isinstance(message, str): + data = json.loads(message) + resource_owner = None + elif isinstance(message, SessionMessage): + # We must keep exclude_none=True to avoid Pydantic union type coersion + # when reconstructing models from dict (e.g., RequestId = str | int) + data = message.message.model_dump(exclude_none=True) + resource_owner = message.resource_owner + else: + logger.error(f"Unexpected message type: {type(message)}") + return # Check if it's a response to our request if "id" in data and "method" not in data: @@ -377,7 +392,7 @@ class ServerSession: return # Otherwise, process as incoming request - response = await self.server.handle_message(data, self) + response = await self.server.handle_message(data, self, resource_owner=resource_owner) # Send response if any if response and self.write_stream: @@ -646,9 +661,16 @@ class ServerSession: self._request_meta = None # Context management - async def create_request_context(self) -> Context: - """Create a context for the current request.""" - context = Context(self.server) + async def create_request_context(self, resource_owner: ResourceOwner | None = None) -> Context: + """Create a context for the current request. + + Args: + resource_owner: The authenticated resource owner from front-door auth. + """ + context = Context( + server=self.server, + resource_owner=resource_owner, + ) context.set_session(self) self._current_context = context return context diff --git a/libs/arcade-mcp-server/arcade_mcp_server/settings.py b/libs/arcade-mcp-server/arcade_mcp_server/settings.py index 942dcefb..e8e4b8c1 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/settings.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/settings.py @@ -5,6 +5,7 @@ Provides Pydantic-based settings with validation and environment variable suppor """ import os +from pathlib import Path from typing import Any from pydantic import Field, field_validator @@ -93,6 +94,71 @@ class ServerSettings(BaseSettings): model_config = {"env_prefix": "MCP_SERVER_"} +class ResourceServerSettings(BaseSettings): + """Settings for ResourceServer configuration via environment variables.""" + + canonical_url: str | None = Field( + default=None, + description="Canonical URL of this MCP server (e.g., https://mcp.example.com/mcp)", + ) + authorization_servers: list[dict[str, Any]] | None = Field( + default=None, + description="JSON array of authorization server entries." + 'Example: \'[{"authorization_server_url":"https://auth.example.com","issuer":"https://auth.example.com","jwks_uri":"https://auth.example.com/oauth2/jwks","algorithm":"RS256"}]\'', + ) + + @field_validator("authorization_servers", mode="before") + @classmethod + def parse_authorization_servers(cls, v: Any) -> list[dict[str, Any]] | None: + """Parse JSON array from environment variable.""" + if v is None: + return None + if isinstance(v, str): + import json + + try: + parsed = json.loads(v) + if not isinstance(parsed, list): + raise TypeError("authorization_servers must be a JSON array") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in authorization_servers: {e}") from e + else: + return parsed + if isinstance(v, list): + return v + return None + + def to_authorization_server_entries(self) -> list[Any]: + """Convert settings to list of AuthorizationServerEntry objects.""" + if not self.authorization_servers: + return [] + + from arcade_mcp_server.resource_server import ( + AccessTokenValidationOptions, + AuthorizationServerEntry, + ) + + return [ + AuthorizationServerEntry( + authorization_server_url=config["authorization_server_url"], + issuer=config["issuer"], + jwks_uri=config["jwks_uri"], + algorithm=config.get("algorithm", "RS256"), + expected_audiences=config.get("expected_audiences"), + validation_options=AccessTokenValidationOptions( + verify_exp=config.get("validation_options", {}).get("verify_exp", True), + verify_iat=config.get("validation_options", {}).get("verify_iat", True), + verify_iss=config.get("validation_options", {}).get("verify_iss", True), + verify_nbf=config.get("validation_options", {}).get("verify_nbf", True), + leeway=config.get("validation_options", {}).get("leeway", 0), + ), + ) + for config in self.authorization_servers + ] + + model_config = {"env_prefix": "MCP_RESOURCE_SERVER_"} + + class MiddlewareSettings(BaseSettings): """Middleware-related settings.""" @@ -207,6 +273,10 @@ class MCPSettings(BaseSettings): default_factory=ServerSettings, description="Server settings", ) + resource_server: ResourceServerSettings = Field( + default_factory=ResourceServerSettings, + description="Server authentication settings", + ) middleware: MiddlewareSettings = Field( default_factory=MiddlewareSettings, description="Middleware settings", @@ -236,7 +306,20 @@ class MCPSettings(BaseSettings): @classmethod def from_env(cls) -> "MCPSettings": - """Create settings from environment variables.""" + """Create settings from environment variables. + + Automatically loads .env file from current directory if it exists, + then creates settings from the combined environment. + + The .env file is loaded with override=False, meaning existing + environment variables take precedence. Multiple calls are safe + """ + from dotenv import load_dotenv + + env_path = Path.cwd() / ".env" + if env_path.exists(): + load_dotenv(env_path, override=False) + return cls() def tool_secrets(self) -> dict[str, Any]: diff --git a/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py b/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py index fd4ad486..21a2d596 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py @@ -168,8 +168,8 @@ class HTTPStreamableTransport: self._terminated = False # Streams for connection - self._read_stream_writer: MemoryObjectSendStream[str | Exception] | None = None - self._read_stream: MemoryObjectReceiveStream[str | Exception] | None = None + self._read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None + self._read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None self._write_stream: MemoryObjectSendStream[str | SessionMessage] | None = None self._write_stream_reader: MemoryObjectReceiveStream[str | SessionMessage] | None = None @@ -218,7 +218,13 @@ class HTTPStreamableTransport: headers: dict[str, str] | None = None, ) -> Response: """Create an error response.""" - response_headers = {"Content-Type": CONTENT_TYPE_JSON} + response_headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, DELETE", + "Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, Mcp-Session-Id", + "Access-Control-Expose-Headers": "Mcp-Session-Id", + } if headers: response_headers.update(headers) @@ -406,13 +412,20 @@ class HTTPStreamableTransport: elif not await self._validate_request_headers(request, send): return + # Extract resource owner from scope (set by ASGI Resource Server middleware) + resource_owner = request.scope.get("resource_owner") + # For notifications and responses, return 202 Accepted if not isinstance(message, JSONRPCRequest): response = self._create_json_response(None, HTTPStatus.ACCEPTED) await response(scope, receive, send) # Process the message - await writer.send(body_str if body_str.endswith("\n") else body_str + "\n") + session_message = SessionMessage( + message=message, + resource_owner=resource_owner, + ) + await writer.send(session_message) return # Handle requests @@ -421,8 +434,11 @@ class HTTPStreamableTransport: request_stream_reader = self._request_streams[request_id][1] if self.is_json_response_enabled: - # JSON response mode - await writer.send(body_str if body_str.endswith("\n") else body_str + "\n") + session_message = SessionMessage( + message=message, + resource_owner=resource_owner, + ) + await writer.send(session_message) try: response_message = None @@ -490,7 +506,12 @@ class HTTPStreamableTransport: try: async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) - await writer.send(body_str if body_str.endswith("\n") else body_str + "\n") + # Send SessionMessage object + session_message = SessionMessage( + message=message, + resource_owner=resource_owner, + ) + await writer.send(session_message) except Exception: logger.exception("SSE response error") await sse_stream_writer.aclose() @@ -742,7 +763,7 @@ class HTTPStreamableTransport: self, ) -> AsyncIterator[ tuple[ - MemoryObjectReceiveStream[str | Exception], + MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[str | SessionMessage], ] ]: @@ -754,7 +775,9 @@ class HTTPStreamableTransport: stream identified by `GET_STREAM_KEY`). """ # Create memory streams with buffer - read_stream_writer, read_stream = anyio.create_memory_object_stream[str | Exception](100) + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](100) write_stream, write_stream_reader = anyio.create_memory_object_stream[str | SessionMessage]( 100 ) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/types.py b/libs/arcade-mcp-server/arcade_mcp_server/types.py index ddb38e82..8bb4760a 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/types.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/types.py @@ -5,6 +5,8 @@ from typing import Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field +from arcade_mcp_server.resource_server.base import ResourceOwner + # ----------------------------------------------------------------------------- # JSON-RPC constants # ----------------------------------------------------------------------------- @@ -91,9 +93,14 @@ class JSONRPCError(JSONRPCMessage): @dataclass class SessionMessage: - """Wrapper for messages in transport sessions.""" + """Wrapper for messages in transport sessions. + + Carries both the MCP protocol message and optional authenticated user + information from front-door authentication. + """ message: JSONRPCMessage + resource_owner: ResourceOwner | None = None # ----------------------------------------------------------------------------- @@ -660,7 +667,13 @@ MCPMessage = ( JSONRPCRequest | JSONRPCResponse[Any] | JSONRPCError + | InitializedNotification | CancelledNotification | ProgressNotification | LoggingMessageNotification + | ResourceListChangedNotification + | ResourceUpdatedNotification + | PromptListChangedNotification + | ToolListChangedNotification + | RootsListChangedNotification ) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/usage/constants.py b/libs/arcade-mcp-server/arcade_mcp_server/usage/constants.py index c582b5b9..03047adb 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/usage/constants.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/usage/constants.py @@ -10,3 +10,4 @@ PROP_TOOL_COUNT = "tool_count" PROP_MCP_SERVER_VERSION = "arcade_mcp_server_version" PROP_IS_EXECUTION_SUCCESS = "is_execution_success" PROP_FAILURE_REASON = "failure_reason" +PROP_RESOURCE_SERVER_TYPE = "resource_server_type" diff --git a/libs/arcade-mcp-server/arcade_mcp_server/usage/server_tracker.py b/libs/arcade-mcp-server/arcade_mcp_server/usage/server_tracker.py index 99a896a0..7c72c302 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/usage/server_tracker.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/usage/server_tracker.py @@ -2,6 +2,7 @@ import platform import sys import time from importlib import metadata +from typing import Any from arcade_core.usage import UsageIdentity, UsageService, is_tracking_enabled from arcade_core.usage.constants import ( @@ -20,6 +21,7 @@ from arcade_mcp_server.usage.constants import ( PROP_IS_EXECUTION_SUCCESS, PROP_MCP_SERVER_VERSION, PROP_PORT, + PROP_RESOURCE_SERVER_TYPE, PROP_TOOL_COUNT, PROP_TRANSPORT, ) @@ -62,12 +64,27 @@ class ServerTracker: """Get the distinct_id based on developer's authentication state""" return self.identity.get_distinct_id() + def _get_resource_server_type(self, resource_server_validator: Any) -> str: + """Get the class name of the resource server validator. + + Args: + resource_server_validator: The resource server validator instance or None + + Returns: + The class name of the validator, or "none" if no validator + """ + if resource_server_validator is None: + return "none" + + return str(resource_server_validator.__class__.__name__) + def track_server_start( self, transport: str, host: str | None, port: int | None, tool_count: int, + resource_server_validator: Any = None, ) -> None: """Track MCP server start event. @@ -76,6 +93,7 @@ class ServerTracker: host: The host address (None for stdio) port: The port number (None for stdio) tool_count: The number of tools available at server start + resource_server_validator: The resource server validator instance (None if no auth) """ if not is_tracking_enabled(): return @@ -92,6 +110,7 @@ class ServerTracker: properties: dict[str, str | int | float] = { PROP_TRANSPORT: transport, PROP_TOOL_COUNT: tool_count, + PROP_RESOURCE_SERVER_TYPE: self._get_resource_server_type(resource_server_validator), PROP_MCP_SERVER_VERSION: self.mcp_server_version, PROP_RUNTIME_LANGUAGE: "python", PROP_RUNTIME_VERSION: self.runtime_version, diff --git a/libs/arcade-mcp-server/arcade_mcp_server/worker.py b/libs/arcade-mcp-server/arcade_mcp_server/worker.py index b5f949ad..bd7fa2bc 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/worker.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/worker.py @@ -23,7 +23,10 @@ from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send +from arcade_mcp_server.fastapi.auth_routes import create_auth_router from arcade_mcp_server.fastapi.middleware import AddTrailingSlashToPathMiddleware +from arcade_mcp_server.resource_server.base import ResourceServerValidator +from arcade_mcp_server.resource_server.middleware import ResourceServerMiddleware from arcade_mcp_server.server import MCPServer from arcade_mcp_server.settings import MCPSettings from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager @@ -120,6 +123,7 @@ def create_arcade_mcp( mcp_settings: MCPSettings | None = None, debug: bool = False, otel_enable: bool = False, + resource_server_validator: ResourceServerValidator | None = None, **kwargs: Any, ) -> FastAPI: """ @@ -127,6 +131,14 @@ def create_arcade_mcp( and Arcade Worker endpoints if a secret is provided. MCP is always enabled in this integrated application. + + Args: + catalog: Tool catalog for available tools + mcp_settings: MCP configuration settings + debug: Enable debug mode + otel_enable: Enable OpenTelemetry + resource_server_validator: Resource Server validator for front-door authentication + **kwargs: Additional configuration options """ if mcp_settings is None: mcp_settings = MCPSettings.from_env() @@ -178,6 +190,18 @@ def create_arcade_mcp( app.add_middleware(AddTrailingSlashToPathMiddleware) + # Add OAuth discovery endpoint if auth is enabled + if resource_server_validator and resource_server_validator.supports_oauth_discovery(): + canonical_url = getattr(resource_server_validator, "canonical_url", None) + if not canonical_url: + raise ValueError( + "canonical_url must be set via parameter or " + "MCP_RESOURCE_SERVER_CANONICAL_URL environment variable" + ) + + auth_router = create_auth_router(resource_server_validator, canonical_url) + app.include_router(auth_router) + # Worker endpoints if secret is not None: worker = FastAPIWorker( @@ -201,8 +225,23 @@ def create_arcade_mcp( return await session_manager.handle_request(scope, receive, send) - # Mount the actual ASGI proxy to handle all /mcp requests - app.mount("/mcp", _MCPASGIProxy(app), name="mcp-proxy") + # Create MCP proxy and wrap with auth middleware if enabled + mcp_proxy: Any = _MCPASGIProxy(app) + if resource_server_validator: + # Get canonical_url from validator if it supports OAuth discovery + canonical_url = None + if resource_server_validator.supports_oauth_discovery(): + canonical_url = getattr(resource_server_validator, "canonical_url", None) + if not canonical_url: + raise ValueError( + "canonical_url must be set via parameter or " + "MCP_RESOURCE_SERVER_CANONICAL_URL environment variable" + ) + + mcp_proxy = ResourceServerMiddleware(mcp_proxy, resource_server_validator, canonical_url) + + # Mount the ASGI proxy to handle all /mcp requests + app.mount("/mcp", mcp_proxy, name="mcp-proxy") # Customize OpenAPI to include MCP documentation def custom_openapi() -> dict[str, Any]: diff --git a/libs/arcade-mcp-server/pyproject.toml b/libs/arcade-mcp-server/pyproject.toml index 028af861..531eed34 100644 --- a/libs/arcade-mcp-server/pyproject.toml +++ b/libs/arcade-mcp-server/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "arcade-mcp-server" -version = "1.11.2" +version = "1.12.0" description = "Model Context Protocol (MCP) server framework for Arcade.dev" readme = "README.md" authors = [{ name = "Arcade.dev" }] @@ -34,6 +34,8 @@ dependencies = [ "anyio>=4.0.0", "python-dotenv>=1.0.0", "pydantic-settings>=2.10.1", + "python-jose[cryptography]>=3.3.0,<4.0.0", + "httpx>=0.27.0,<1.0.0", ] [project.optional-dependencies] diff --git a/libs/tests/arcade_mcp_server/test_mcp_app.py b/libs/tests/arcade_mcp_server/test_mcp_app.py index a81865d0..655ed78d 100644 --- a/libs/tests/arcade_mcp_server/test_mcp_app.py +++ b/libs/tests/arcade_mcp_server/test_mcp_app.py @@ -342,6 +342,7 @@ class TestMCPApp: catalog=mcp_app._catalog, mcp_settings=mcp_app._mcp_settings, debug=False, + resource_server_validator=mcp_app.resource_server_validator, ) mock_serve.assert_called_once_with( app=mock_fastapi_app, @@ -365,6 +366,7 @@ class TestMCPApp: catalog=mcp_app._catalog, mcp_settings=mcp_app._mcp_settings, debug=True, + resource_server_validator=mcp_app.resource_server_validator, ) mock_serve.assert_called_once_with( app=mock_fastapi_app, diff --git a/libs/tests/arcade_mcp_server/test_resource_server_auth.py b/libs/tests/arcade_mcp_server/test_resource_server_auth.py new file mode 100644 index 00000000..14445e1b --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_resource_server_auth.py @@ -0,0 +1,1246 @@ +import base64 +import time +from unittest.mock import Mock, patch + +import pytest +from arcade_core.catalog import ToolCatalog +from arcade_mcp_server.resource_server import ( + AccessTokenValidationOptions, + AuthorizationServerEntry, + JWKSTokenValidator, + ResourceServerAuth, +) +from arcade_mcp_server.resource_server.base import ( + InvalidTokenError, + ResourceOwner, + TokenExpiredError, +) +from arcade_mcp_server.resource_server.middleware import ResourceServerMiddleware +from arcade_mcp_server.worker import create_arcade_mcp +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from jose import jwt + + +# Test fixtures +@pytest.fixture(autouse=True) +def clean_auth_env(monkeypatch): + """Clean server auth environment variables before each test.""" + env_vars = [ + "MCP_RESOURCE_SERVER_CANONICAL_URL", + "MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS", + ] + + for var in env_vars: + monkeypatch.delenv(var, raising=False) + + yield + + +@pytest.fixture +def rsa_keypair(): + """Generate RSA key pair for testing.""" + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + public_key = private_key.public_key() + + return private_key, public_key + + +@pytest.fixture +def serialized_private_key(rsa_keypair): + """Generate private key as PEM format for testing.""" + private_key, _ = rsa_keypair + # Serialize private key to PEM format for python-jose + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pem + + +@pytest.fixture +def jwks_data(rsa_keypair): + """Generate JWKS data for testing.""" + _, public_key = rsa_keypair + + # Export public key in JWK format + public_numbers = public_key.public_numbers() + n = public_numbers.n + e = public_numbers.e + + # Convert to base64url + n_bytes = n.to_bytes((n.bit_length() + 7) // 8, byteorder="big") + e_bytes = e.to_bytes((e.bit_length() + 7) // 8, byteorder="big") + n_b64 = base64.urlsafe_b64encode(n_bytes).decode("utf-8").rstrip("=") + e_b64 = base64.urlsafe_b64encode(e_bytes).decode("utf-8").rstrip("=") + + return { + "keys": [ + { + "kty": "RSA", + "kid": "test-key-1", + "use": "sig", + "alg": "RS256", + "n": n_b64, + "e": e_b64, + } + ] + } + + +@pytest.fixture +def valid_jwt_token(rsa_keypair): + """Generate valid JWT token for testing.""" + private_key, _ = rsa_keypair + + payload = { + "sub": "user123", + "email": "user@example.com", + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + return token + + +@pytest.fixture +def expired_jwt_token(rsa_keypair): + """Generate expired JWT token for testing.""" + private_key, _ = rsa_keypair + + payload = { + "sub": "user123", + "email": "user@example.com", + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) - 3600, + "iat": int(time.time()) - 7200, + } + + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + return token + + +class TestJWKSTokenValidator: + """Tests for JWKSTokenValidator class.""" + + @pytest.mark.asyncio + async def test_validate_valid_token(self, valid_jwt_token, jwks_data): + """Test validating a valid JWT token.""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + user = await validator.validate_token(valid_jwt_token) + + assert isinstance(user, ResourceOwner) + assert user.user_id == "user123" + assert user.email == "user@example.com" + assert user.claims["iss"] == "https://auth.example.com" + + @pytest.mark.asyncio + async def test_validate_expired_token(self, expired_jwt_token, jwks_data): + """Test validating an expired JWT token.""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + with pytest.raises(TokenExpiredError): + await validator.validate_token(expired_jwt_token) + + @pytest.mark.asyncio + async def test_validate_wrong_audience(self, serialized_private_key, jwks_data): + """Test validating token with wrong audience.""" + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://wrong-server.com", # Wrong audience + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + with pytest.raises(InvalidTokenError, match="audience"): + await validator.validate_token(token) + + @pytest.mark.asyncio + async def test_validate_wrong_issuer(self, serialized_private_key, jwks_data): + """Test validating token with wrong issuer.""" + payload = { + "sub": "user123", + "iss": "https://wrong-issuer.com", # Wrong issuer + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + with pytest.raises(InvalidTokenError, match="issuer"): + await validator.validate_token(token) + + @pytest.mark.asyncio + async def test_validate_missing_sub_claim(self, serialized_private_key, jwks_data): + """Test validating token without sub claim.""" + payload = { + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + with pytest.raises(InvalidTokenError, match="sub"): + await validator.validate_token(token) + + @pytest.mark.asyncio + async def test_jwks_caching(self, valid_jwt_token, jwks_data): + """Test that JWKS is cached to avoid repeated fetches.""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + cache_ttl=3600, + ) + + # First validation should fetch JWKS + await validator.validate_token(valid_jwt_token) + assert mock_get.call_count == 1 + + # Second validation should use cached JWKS + await validator.validate_token(valid_jwt_token) + assert mock_get.call_count == 1 + + @pytest.mark.asyncio + async def test_validate_multiple_audiences_single_token_aud( + self, serialized_private_key, jwks_data + ): + """Test validator with multiple audiences accepts token with matching single aud.""" + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://old-mcp.example.com", # Matches first audience + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience=["https://old-mcp.example.com", "https://new-mcp.example.com"], + ) + + user = await validator.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_validate_multiple_audiences_list_token_aud( + self, serialized_private_key, jwks_data + ): + """Test validator with multiple audiences accepts token with list aud.""" + # Token with list of audiences where one matches the validator's accepted audiences + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": ["https://api1.com", "https://new-mcp.example.com"], # Second matches + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience=["https://old-mcp.example.com", "https://new-mcp.example.com"], + ) + + user = await validator.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_validate_multiple_audiences_no_match(self, serialized_private_key, jwks_data): + """Test validator with multiple audiences rejects token with non-matching aud.""" + # Token with audience that doesn't match any of validator's accepted audiences + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://different-server.com", # Doesn't match + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience=["https://old-mcp.example.com", "https://new-mcp.example.com"], + ) + + with pytest.raises(InvalidTokenError, match="audience"): + await validator.validate_token(token) + + @pytest.mark.asyncio + async def test_validate_single_audience_with_list_token_aud( + self, serialized_private_key, jwks_data + ): + """Test validator with single audience accepts token with list aud containing match.""" + # Token with list of audiences where one matches validator's single audience + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": ["https://api1.com", "https://mcp.example.com/mcp", "https://api2.com"], + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", # Single audience + ) + + user = await validator.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_validate_multiple_issuers_efficient(self, serialized_private_key, jwks_data): + """Test that multi-issuer validation is efficient (single decode).""" + # Token from second issuer in list + payload = { + "sub": "user123", + "iss": "https://auth2.example.com", # Second in list + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + # Patch jwt.decode to count calls + with patch( + "arcade_mcp_server.resource_server.validators.jwks.jwt.decode", wraps=jwt.decode + ) as mock_decode: + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer=[ + "https://auth1.example.com", + "https://auth2.example.com", + "https://auth3.example.com", + ], + audience="https://mcp.example.com/mcp", + ) + + user = await validator.validate_token(token) + assert user.user_id == "user123" + + # Should only need to decode once, not 3 times + assert mock_decode.call_count == 1 + + @pytest.mark.asyncio + async def test_validate_nbf_claim_before_time(self, serialized_private_key, jwks_data): + """Test that token with nbf claim in the future is rejected.""" + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 7200, # expires in 2 hours + "iat": int(time.time()), + "nbf": int(time.time()) + 3600, # Not valid for 1 hour + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + validation_options=AccessTokenValidationOptions(verify_nbf=True), + ) + + with pytest.raises(InvalidTokenError): + await validator.validate_token(token) + + @pytest.mark.asyncio + async def test_validate_nbf_claim_disabled(self, serialized_private_key, jwks_data): + """Test that token with nbf in future is accepted when verify_nbf=False.""" + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 7200, # expires in 2 hours + "iat": int(time.time()), + "nbf": int(time.time()) + 3600, # Not valid for 1 hour + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + validation_options=AccessTokenValidationOptions(verify_nbf=False), + ) + + # Should accept the token when nbf verification is disabled + user = await validator.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_validate_with_leeway(self, serialized_private_key, jwks_data): + """Test that leeway allows slightly expired tokens.""" + # Token expired 30 seconds ago + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) - 30, + "iat": int(time.time()) - 3600, + } + + token = jwt.encode( + payload, + serialized_private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + # Validator with 60 second leeway should accept this token + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + validation_options=AccessTokenValidationOptions(leeway=60), + ) + + user = await validator.validate_token(token) + assert user.user_id == "user123" + + +# ResourceServerAuth Tests +class TestResourceServerAuth: + """Tests for ResourceServerAuth class.""" + + def test_supports_oauth_discovery(self): + """Test that ResourceServerAuth supports OAuth discovery.""" + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + ) + ], + ) + + assert resource_server_auth.supports_oauth_discovery() is True + + def test_get_resource_metadata(self): + """Test getting OAuth Protected Resource Metadata.""" + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + ) + ], + ) + + metadata = resource_server_auth.get_resource_metadata() + + assert metadata["resource"] == "https://mcp.example.com/mcp" + assert metadata["authorization_servers"] == ["https://auth.example.com"] + assert metadata["bearer_methods_supported"] == ["header"] + + @pytest.mark.asyncio + async def test_expected_audiences_override(self, rsa_keypair, jwks_data): + """Test that expected_audiences overrides canonical_url for audience validation.""" + private_key, _ = rsa_keypair + + # Token with custom audience + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "my-authkit-client-id", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + expected_audiences=["my-authkit-client-id"], + ) + ], + ) + + user = await resource_server_auth.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_expected_audiences_multiple_values(self, rsa_keypair, jwks_data): + """Test that multiple expected_audiences work correctly.""" + private_key, _ = rsa_keypair + + # Token with one of the expected audiences + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "secondary-client-id", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + expected_audiences=[ + "primary-client-id", + "secondary-client-id", + "tertiary-client-id", + ], + ) + ], + ) + + user = await resource_server_auth.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_expected_audiences_defaults_to_canonical_url(self, rsa_keypair, jwks_data): + """Test that without expected_audiences, canonical_url is used for audience validation.""" + private_key, _ = rsa_keypair + + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + ) + ], + ) + + user = await resource_server_auth.validate_token(token) + assert user.user_id == "user123" + + @pytest.mark.asyncio + async def test_expected_audiences_wrong_audience_rejected(self, rsa_keypair, jwks_data): + """Test that tokens with wrong audience are rejected even with expected_audiences.""" + private_key, _ = rsa_keypair + + payload = { + "sub": "user123", + "iss": "https://auth.example.com", + "aud": "wrong-client-id", # Not in expected_audiences list + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + expected_audiences=["correct-client-id"], + ) + ], + ) + + with pytest.raises(InvalidTokenError): + await resource_server_auth.validate_token(token) + + +# ResourceServerMiddleware Tests +class TestResourceServerMiddleware: + """Tests for ResourceServerMiddleware class.""" + + @pytest.mark.asyncio + async def test_authenticated_request(self, valid_jwt_token, jwks_data): + """Test authenticated request passes through.""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + # Mock app + app_called = False + + async def mock_app(scope, receive, send): + nonlocal app_called + app_called = True + assert "resource_owner" in scope + assert scope["resource_owner"].user_id == "user123" + + middleware = ResourceServerMiddleware( + mock_app, + validator, + "https://mcp.example.com/mcp", + ) + + # Create mock request + scope = { + "type": "http", + "method": "POST", + "headers": [(b"authorization", f"Bearer {valid_jwt_token}".encode())], + } + + async def receive(): + return {"type": "http.request", "body": b""} + + async def send(message): + pass + + await middleware(scope, receive, send) + assert app_called is True + + @pytest.mark.asyncio + async def test_missing_authorization_header(self, jwks_data): + """Test request without Authorization header returns 401.""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/.well-known/jwks.json", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + async def mock_app(scope, receive, send): + pytest.fail("App should not be called") + + middleware = ResourceServerMiddleware( + mock_app, + validator, + "https://mcp.example.com/mcp", + ) + + # mock request w/o auth header + scope = { + "type": "http", + "method": "POST", + "headers": [], + } + + async def receive(): + return {"type": "http.request", "body": b""} + + response_sent = {} + + async def send(message): + if message["type"] == "http.response.start": + response_sent["status"] = message["status"] + response_sent["headers"] = dict(message.get("headers", [])) + + await middleware(scope, receive, send) + + assert response_sent["status"] == 401 + assert any(k.lower() == b"www-authenticate" for k in response_sent["headers"]) + + @pytest.mark.asyncio + async def test_www_authenticate_header_format(self, jwks_data): + """Test WWW-Authenticate header format compliance.""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + ) + ], + ) + + async def mock_app(scope, receive, send): + pytest.fail("App should not be called") + + middleware = ResourceServerMiddleware( + mock_app, + resource_server, + "https://mcp.example.com/mcp", + ) + + scope = { + "type": "http", + "method": "POST", + "headers": [], + } + + async def receive(): + return {"type": "http.request", "body": b""} + + response_headers = {} + + async def send(message): + if message["type"] == "http.response.start": + response_headers.update(dict(message.get("headers", []))) + + await middleware(scope, receive, send) + + www_auth = response_headers.get(b"www-authenticate", b"").decode() + + assert "Bearer" in www_auth + assert "resource_metadata=" in www_auth + assert "/.well-known/oauth-protected-resource" in www_auth + + +class TestEnvVarConfiguration: + """Tests for front-door auth env var configuration support.""" + + @pytest.mark.asyncio + async def test_resource_server_param_precedence(self, monkeypatch): + """Test that explicit parameters take precedence over environment variables.""" + monkeypatch.setenv("MCP_RESOURCE_SERVER_CANONICAL_URL", "https://env-mcp.example.com") + monkeypatch.setenv( + "MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS", + '[{"authorization_server_url":"https://env.example.com","issuer":"https://env.example.com","jwks_uri":"https://env.example.com/jwks"}]', + ) + + resource_server = ResourceServerAuth( + canonical_url="https://param-mcp.example.com", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://param.example.com", + issuer="https://param.example.com", + jwks_uri="https://param.example.com/jwks", + ) + ], + ) + + # Explicit parameters should take precedence over env vars + assert resource_server.canonical_url == "https://param-mcp.example.com" + metadata = resource_server.get_resource_metadata() + assert metadata["authorization_servers"] == ["https://param.example.com"] + + @pytest.mark.asyncio + async def test_resource_server_all_env_vars(self, monkeypatch): + """Test ResourceServerAuth with all env vars, no parameters.""" + monkeypatch.setenv("MCP_RESOURCE_SERVER_CANONICAL_URL", "https://mcp.example.com/mcp") + monkeypatch.setenv( + "MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS", + '[{"authorization_server_url":"https://auth.example.com","issuer":"https://auth.example.com","jwks_uri":"https://auth.example.com/jwks","algorithm":"RS256","expected_audiences":["custom-client-id"]}]', + ) + + resource_server_auth = ResourceServerAuth() + + assert resource_server_auth.canonical_url == "https://mcp.example.com/mcp" + metadata = resource_server_auth.get_resource_metadata() + assert metadata["authorization_servers"] == ["https://auth.example.com"] + + def test_resource_server_missing_required(self): + """Test that missing required fields raise ValueError.""" + with pytest.raises(ValueError, match="'canonical_url' required"): + ResourceServerAuth( + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/jwks", + ) + ], + # Missing canonical_url + ) + + @pytest.mark.asyncio + async def test_worker_no_canonical_url_for_jwks_validator(self): + """Test that worker doesn't require canonical_url for JWKSTokenValidator.""" + jwt_validator = JWKSTokenValidator( + jwks_uri="https://auth.example.com/jwks", + issuer="https://auth.example.com", + audience="https://mcp.example.com/mcp", + ) + + catalog = ToolCatalog() + # Shouldn't raise b/c JWKSTokenValidator doesn't support OAuth discovery + app = create_arcade_mcp(catalog, resource_server_validator=jwt_validator) + assert app is not None + + def test_worker_requires_canonical_url_for_resource_server(self): + """Test that ResourceServerAuth validation happens during init.""" + with pytest.raises(ValueError, match="'canonical_url' required"): + ResourceServerAuth( + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/jwks", + ) + ], + # Missing canonical_url + ) + + +class TestMultipleAuthorizationServers: + """Tests for multiple authorization server support.""" + + @pytest.mark.asyncio + async def test_resource_server_multiple_as_shared_jwks(self, jwks_data, valid_jwt_token): + """Test multiple AS URLs with same JWKS""" + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth-us.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/jwks", + ), + AuthorizationServerEntry( + authorization_server_url="https://auth-eu.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/jwks", + ), + ], + ) + + # Verify that metadata returns all Auth Server URLs + metadata = resource_server_auth.get_resource_metadata() + assert metadata["resource"] == "https://mcp.example.com/mcp" + assert metadata["authorization_servers"] == [ + "https://auth-us.example.com", + "https://auth-eu.example.com", + ] + + # Verify that token validation works + user = await resource_server_auth.validate_token(valid_jwt_token) + assert user.user_id == "user123" + assert user.email == "user@example.com" + + @pytest.mark.asyncio + async def test_resource_server_multiple_as_different_jwks(self, rsa_keypair, jwks_data): + """Test multiple AS with different JWKS (multi-IdP).""" + private_key, _ = rsa_keypair + + payload1 = { + "sub": "user123", + "email": "user@workos.com", + "iss": "https://workos.authkit.app", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + token1 = jwt.encode( + payload1, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + payload2 = { + "sub": "user456", + "email": "user@keycloak.com", + "iss": "http://localhost:8080/realms/mcp-test", + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + token2 = jwt.encode( + payload2, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://workos.authkit.app", + issuer="https://workos.authkit.app", + jwks_uri="https://workos.authkit.app/oauth2/jwks", + ), + AuthorizationServerEntry( + authorization_server_url="http://localhost:8080/realms/mcp-test", + issuer="http://localhost:8080/realms/mcp-test", + jwks_uri="http://localhost:8080/realms/mcp-test/protocol/openid-connect/certs", + algorithm="RS256", + ), + ], + ) + + # Verify metadata returns all Auth Server URLs + metadata = resource_server_auth.get_resource_metadata() + assert metadata["authorization_servers"] == [ + "https://workos.authkit.app", + "http://localhost:8080/realms/mcp-test", + ] + + # Verify tokens from both Auth Servers work + user1 = await resource_server_auth.validate_token(token1) + assert user1.user_id == "user123" + assert user1.email == "user@workos.com" + + user2 = await resource_server_auth.validate_token(token2) + assert user2.user_id == "user456" + assert user2.email == "user@keycloak.com" + + @pytest.mark.asyncio + async def test_resource_server_rejects_unconfigured_as(self, rsa_keypair, jwks_data): + """Test that tokens from unlisted AS are rejected.""" + private_key, _ = rsa_keypair + + payload = { + "sub": "user123", + "email": "user@evil.com", + "iss": "https://evil.com", # Not in configured list (unauthorized issuer) + "aud": "https://mcp.example.com/mcp", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + token = jwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "test-key-1"}, + ) + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.json.return_value = jwks_data + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth.example.com", + issuer="https://auth.example.com", + jwks_uri="https://auth.example.com/jwks", + ) + ], + ) + + # Should reject token from unauthorized Auth Server (issuer) + with pytest.raises( + InvalidTokenError, + match="Token validation failed for all configured authorization servers", + ): + await resource_server_auth.validate_token(token) + + def test_authorization_servers_env_var_parsing_json(self, monkeypatch): + """Test parsing JSON array of AS configs from env var.""" + monkeypatch.setenv("MCP_RESOURCE_SERVER_CANONICAL_URL", "https://mcp.example.com/mcp") + monkeypatch.setenv( + "MCP_RESOURCE_SERVER_AUTHORIZATION_SERVERS", + '[{"authorization_server_url": "https://auth1.com", "issuer": "https://auth1.com", "jwks_uri": "https://auth1.com/jwks"}]', + ) + + resource_server_auth = ResourceServerAuth() + + metadata = resource_server_auth.get_resource_metadata() + assert metadata["authorization_servers"] == ["https://auth1.com"] + + def test_resource_metadata_multiple_as(self): + """Test that resource metadata returns all AS URLs.""" + resource_server_auth = ResourceServerAuth( + canonical_url="https://mcp.example.com/mcp", + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://auth1.example.com", + issuer="https://auth1.example.com", + jwks_uri="https://auth1.example.com/jwks", + ), + AuthorizationServerEntry( + authorization_server_url="https://auth2.example.com", + issuer="https://auth2.example.com", + jwks_uri="https://auth2.example.com/jwks", + ), + AuthorizationServerEntry( + authorization_server_url="https://auth3.example.com", + issuer="https://auth3.example.com", + jwks_uri="https://auth3.example.com/jwks", + ), + ], + ) + + metadata = resource_server_auth.get_resource_metadata() + assert metadata["resource"] == "https://mcp.example.com/mcp" + assert len(metadata["authorization_servers"]) == 3 + assert "https://auth1.example.com" in metadata["authorization_servers"] + assert "https://auth2.example.com" in metadata["authorization_servers"] + assert "https://auth3.example.com" in metadata["authorization_servers"] diff --git a/libs/tests/arcade_mcp_server/test_server.py b/libs/tests/arcade_mcp_server/test_server.py index c26cd9bf..cdfae286 100644 --- a/libs/tests/arcade_mcp_server/test_server.py +++ b/libs/tests/arcade_mcp_server/test_server.py @@ -962,7 +962,6 @@ class TestMCPServer: "arguments": {"text": "test"}, }, ) - response = await mcp_server._handle_call_tool(message, session=session) assert isinstance(response, JSONRPCResponse)