diff --git a/.github/actions/setup-uv-env/action.yml b/.github/actions/setup-uv-env/action.yml index 9e80f0bd..f3530806 100644 --- a/.github/actions/setup-uv-env/action.yml +++ b/.github/actions/setup-uv-env/action.yml @@ -32,8 +32,16 @@ runs: working-directory: ${{ inputs.working-directory }} python-version: ${{ inputs.python-version }} - - name: Install package dependencies - if: inputs.is-toolkit == 'true' || inputs.is-contrib == 'true' + - name: Install toolkit dependencies + if: inputs.is-toolkit == 'true' + working-directory: ${{ inputs.working-directory }} + run: | + echo "Installing dependencies for ${{ inputs.working-directory }}" + make install-local + shell: bash + + - name: Install contrib dependencies + if: inputs.is-contrib == 'true' working-directory: ${{ inputs.working-directory }} run: | echo "Installing dependencies for ${{ inputs.working-directory }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d51e41f..e9a74942 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: - id: check-toml exclude: ".*/templates/.*" - id: check-yaml - exclude: ".*/templates/.*" + exclude: ".*/templates/.*|libs/arcade-mcp-server/mkdocs.yml" - id: end-of-file-fixer exclude: ".*/templates/.*" - id: trailing-whitespace @@ -17,6 +17,6 @@ repos: hooks: - id: ruff args: [--fix] - exclude: ".*/templates/.*" + exclude: "(.*/templates/.*|libs/tests/.*)" - id: ruff-format - exclude: ".*/templates/.*" + exclude: "(.*/templates/.*|libs/tests/.*)" diff --git a/.ruff.toml b/.ruff.toml index 477d663c..cd21d1db 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -60,6 +60,8 @@ ignore = [ [lint.per-file-ignores] "**/tests/*" = ["S101"] +"libs/**/*.py" = ["C901"] +"libs/arcade-mcp-server/docs/**" = ["TRY400"] [format] preview = true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a106969b..a9f5279d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# Contributing to `arcade-ai` +# Contributing to `arcade-mcp` Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. @@ -9,7 +9,7 @@ You can contribute in many ways: ## Report Bugs -Report bugs at https://github.com/ArcadeAI/arcade-ai/issues +Report bugs at https://github.com/ArcadeAI/arcade-mcp/issues If you are reporting a bug, please include: @@ -33,7 +33,7 @@ Arcade could always use more documentation, whether as part of the official docs ## Submit Feedback -The best way to send feedback is to file an issue at https://github.com/ArcadeAI/arcade-ai/issues. +The best way to send feedback is to file an issue at https://github.com/ArcadeAI/arcade-mcp/issues. If you are proposing a new feature: @@ -44,22 +44,22 @@ If you are proposing a new feature: # Get Started! -Ready to contribute? Here's how to set up `arcade-ai` for local development. +Ready to contribute? Here's how to set up `arcade-mcp` for local development. Please note this documentation assumes you already have `uv` and `Git` installed and ready to go. -1. Fork the `arcade-ai` repo on GitHub. +1. Fork the `arcade-mcp` repo on GitHub. 2. Clone your fork locally: ```bash cd -git clone git@github.com:YOUR_GITHUB_USERNAME/arcade-ai.git +git clone git@github.com:YOUR_GITHUB_USERNAME/arcade-mcp.git ``` 3. Now we need to install the environment. Navigate into the directory ```bash -cd arcade-ai +cd arcade-mcp ``` Create your virtual environment diff --git a/Makefile b/Makefile index dea27130..f10e7fcb 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ .PHONY: install install: ## Install the uv environment and all packages with dependencies @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv workspace" - @uv sync --active --dev --extra all + @uv sync --dev --extra all @uv run pre-commit install @echo "โœ… All packages and dependencies installed via uv workspace" @@ -41,11 +41,11 @@ install-toolkits: ## Install dependencies for all toolkits check: ## Run code quality tools. @echo "๐Ÿš€ Linting code: Running pre-commit" @uv run pre-commit run -a - @echo "๐Ÿš€ Static type checking: Running mypy on libs" + @echo "๐Ÿš€ Static type checking: Running mypy on libs" @for lib in libs/arcade*/ ; do \ - echo "๐Ÿ” Type checking $$lib"; \ - (cd $$lib && uv run mypy . || true); \ - done + echo "๐Ÿ” Type checking $$lib"; \ + (cd $$lib && uv run mypy . --exclude tests || true); \ + done .PHONY: check-libs check-libs: ## Run code quality tools for each lib package @@ -62,16 +62,16 @@ check-toolkits: ## Run code quality tools for each toolkit that has a Makefile @for dir in toolkits/*/ ; do \ if [ -f "$$dir/Makefile" ]; then \ echo "๐Ÿ› ๏ธ Checking toolkit $$dir"; \ - (cd "$$dir" && uv run --active pre-commit run -a && uv run --active mypy --config-file=pyproject.toml); \ + (cd "$$dir" && uv run pre-commit run -a && uv run mypy --config-file=pyproject.toml); \ else \ echo "๐Ÿ› ๏ธ Skipping toolkit $$dir (no Makefile found)"; \ fi; \ - done + done .PHONY: test test: ## Test the code with pytest @echo "๐Ÿš€ Testing libs: Running pytest" - @uv run pytest -W ignore -v --cov=libs/tests --cov-config=pyproject.toml --cov-report=xml + @uv run pytest -W ignore -v libs/tests --cov=libs --cov-config=pyproject.toml --cov-report=xml .PHONY: test-libs test-libs: ## Test each lib package individually @@ -87,7 +87,7 @@ test-toolkits: ## Iterate over all toolkits and run pytest on each one @for dir in toolkits/*/ ; do \ toolkit_name=$$(basename "$$dir"); \ echo "๐Ÿงช Testing $$toolkit_name toolkit"; \ - (cd $$dir && uv run --active pytest -W ignore -v --cov=arcade_$$toolkit_name --cov-report=xml || exit 1); \ + (cd $$dir && uv run pytest -W ignore -v --cov=arcade_$$toolkit_name --cov-report=xml || exit 1); \ done .PHONY: coverage @@ -194,7 +194,7 @@ full-dist: clean-dist ## Build all projects and copy wheels to ./dist (cd libs/$$lib && uv build); \ done - @echo "๐Ÿ› ๏ธ Building arcade-ai package and copying wheel to ./dist" + @echo "๐Ÿ› ๏ธ Building arcade-mcp package and copying wheel to ./dist" @uv build @rm -f dist/*.tar.gz @@ -224,7 +224,9 @@ clean-dist: ## Clean all built distributions done .PHONY: setup -setup: install ## Complete development setup (same as install) +setup: ## Run uv environment setup script + @chmod +x ./uv_setup.sh + @./uv_setup.sh .PHONY: lint lint: check ## Alias for check command @@ -238,3 +240,12 @@ help: @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' .DEFAULT_GOAL := help + +.PHONY: shell +shell: ## Open an interactive shell with the virtual environment activated + @if [ -f ".venv/bin/activate" ]; then \ + . .venv/bin/activate && exec $$SHELL -l; \ + else \ + echo "โš ๏ธ Virtual environment not found. Run 'make setup' first."; \ + exit 1; \ + fi diff --git a/README.md b/README.md index 18dd6925..a4f4b060 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,15 @@ >
- + License - GitHub last commit - -GitHub Actions Status + GitHub last commit + +GitHub Actions Status - - Python Version + + Python Version
@@ -22,7 +22,7 @@ Follow on X - + Follow on LinkedIn @@ -43,11 +43,12 @@ Arcade is a developer platform that lets you build, deploy, and manage tools for This repository contains the core Arcade libraries, organized as separate packages for maximum flexibility and modularity: -- **arcade-core** - Core platform functionality and schemas | [Source code](https://github.com/ArcadeAI/arcade-ai/tree/main/libs/arcade-core) | `pip install arcade-core` | -- **arcade-tdk** - Tool Development Kit with the `@tool` decorator | [Source code](https://github.com/ArcadeAI/arcade-ai/tree/main/libs/arcade-tdk) | `pip install arcade-tdk` | -- **arcade-serve** - Serving infrastructure for workers and MCP servers | [Source code](https://github.com/ArcadeAI/arcade-ai/tree/main/libs/arcade-serve) | `pip install arcade-serve` | -- **arcade-evals** - Evaluation framework for testing tool performance | [Source code](https://github.com/ArcadeAI/arcade-ai/tree/main/libs/arcade-evals) | `pip install 'arcade-ai[evals]` | -- **arcade-cli** - Command-line interface for the Arcade platform | [Source code](https://github.com/ArcadeAI/arcade-ai/tree/main/libs/arcade-cli) | `pip install arcade-ai` | +- **arcade-core** - Core platform functionality and schemas | [Source code](https://github.com/ArcadeAI/arcade-mcp/tree/main/libs/arcade-core) | `pip install arcade-core` | +- **arcade-mcp-server** - MCP Server Development Framework | [Source code](https://github.com/ArcadeAI/arcade-mcp/tree/main/libs/arcade-mcp-server) | `pip install arcade-mcp-server` | +- **arcade-tdk** - Tool Development Kit with the `@tool` decorator | [Source code](https://github.com/ArcadeAI/arcade-mcp/tree/main/libs/arcade-tdk) | `pip install arcade-tdk` | +- **arcade-serve** - Serving infrastructure for workers and MCP servers | [Source code](https://github.com/ArcadeAI/arcade-mcp/tree/main/libs/arcade-serve) | `pip install arcade-serve` | +- **arcade-evals** - Evaluation framework for testing tool performance | [Source code](https://github.com/ArcadeAI/arcade-mcp/tree/main/libs/arcade-evals) | `pip install 'arcade-mcp[evals]` | +- **arcade-cli** - Command-line interface for the Arcade platform | [Source code](https://github.com/ArcadeAI/arcade-mcp/tree/main/libs/arcade-cli) | `pip install arcade-mcp` | ![diagram](https://github.com/user-attachments/assets/1a567e5f-d6b4-4b1e-9918-c401ad232ebb) @@ -55,8 +56,8 @@ This repository contains the core Arcade libraries, organized as separate packag _Pst. hey, you, give us a star if you like it!_ - - GitHub stars + + GitHub stars ## Quick Start @@ -76,9 +77,9 @@ make install For production use, install individual packages as needed: ```bash -pip install arcade-ai # CLI -pip install 'arcade-ai[evals]' # CLI + Evaluation framework -pip install 'arcade-ai[all]' # CLI + Serving infra + eval framework + TDK +pip install arcade-mcp # CLI +pip install 'arcade-mcp[evals]' # CLI + Evaluation framework +pip install 'arcade-mcp[all]' # CLI + Serving infra + eval framework + TDK pip install arcade_serve # Serving infrastructure pip install arcade-tdk # Tool Development Kit ``` @@ -115,5 +116,5 @@ make help ## Support and Community - **Discord:** Join our [Discord community](https://discord.com/invite/GUZEMpEZ9p) for real-time support and discussions. -- **GitHub:** Contribute or report issues on the [Arcade GitHub repository](https://github.com/ArcadeAI/arcade-ai). +- **GitHub:** Contribute or report issues on the [Arcade GitHub repository](https://github.com/ArcadeAI/arcade-mcp). - **Documentation:** Find in-depth guides and API references at [Arcade Documentation](https://docs.arcade.dev). diff --git a/contrib/crewai/README.md b/contrib/crewai/README.md index 4c41b01d..9920dc42 100644 --- a/contrib/crewai/README.md +++ b/contrib/crewai/README.md @@ -6,7 +6,7 @@

CrewAI Integration

- + License @@ -34,4 +34,4 @@ pip install crewai-arcade ## Usage -See the [examples](https://github.com/ArcadeAI/arcade-ai/tree/main/examples/crewai) for usage examples +See the [examples](https://github.com/ArcadeAI/arcade-mcp/tree/main/examples/crewai) for usage examples diff --git a/contrib/crewai/pyproject.toml b/contrib/crewai/pyproject.toml index 221015ec..bfe8367b 100644 --- a/contrib/crewai/pyproject.toml +++ b/contrib/crewai/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.1" description = "An integration package connecting Arcade and CrewAI" authors = ["Arcade "] readme = "README.md" -repository = "https://github.com/arcadeai/arcade-ai/tree/main/contrib/crewai" +repository = "https://github.com/arcadeai/arcade-mcp/tree/main/contrib/crewai" license = "MIT" [tool.poetry.dependencies] diff --git a/contrib/langchain/README.md b/contrib/langchain/README.md index 386d687e..a9752277 100644 --- a/contrib/langchain/README.md +++ b/contrib/langchain/README.md @@ -108,7 +108,7 @@ graph = create_react_agent(model, tools) # Run the agent with the "user_id" field in the config # IMPORTANT the "user_id" field is required for tools that require user authorization config = {"configurable": {"user_id": "user@lgexample.com"}} -user_input = {"messages": [("user", "Star the arcadeai/arcade-ai repository on GitHub")]} +user_input = {"messages": [("user", "Star the arcadeai/arcade-mcp repository on GitHub")]} for chunk in graph.stream(user_input, config, debug=True): if chunk.get("__interrupt__"): @@ -124,7 +124,7 @@ for chunk in graph.stream(user_input, config, debug=True): ``` -See the Functional examples in the [examples directory](https://github.com/ArcadeAI/arcade-ai/tree/main/examples/langchain) that continue the agent after authorization and handle authorization errors gracefully. +See the Functional examples in the [examples directory](https://github.com/ArcadeAI/arcade-mcp/tree/main/examples/langchain) that continue the agent after authorization and handle authorization errors gracefully. ### Async Support @@ -172,4 +172,4 @@ For a complete list, see the [Arcade Toolkits documentation](https://docs.arcade ## More Examples -For more examples, see the [examples directory](https://github.com/ArcadeAI/arcade-ai/tree/main/examples/langchain). +For more examples, see the [examples directory](https://github.com/ArcadeAI/arcade-mcp/tree/main/examples/langchain). diff --git a/contrib/langchain/langchain_arcade/__init__.py b/contrib/langchain/langchain_arcade/__init__.py index f01cd510..90af2fdf 100644 --- a/contrib/langchain/langchain_arcade/__init__.py +++ b/contrib/langchain/langchain_arcade/__init__.py @@ -1,7 +1,7 @@ from .manager import ArcadeToolManager, AsyncToolManager, ToolManager __all__ = [ - "ToolManager", - "AsyncToolManager", "ArcadeToolManager", # Deprecated + "AsyncToolManager", + "ToolManager", ] diff --git a/contrib/langchain/pyproject.toml b/contrib/langchain/pyproject.toml index ac26a6c1..f5bb8f1f 100644 --- a/contrib/langchain/pyproject.toml +++ b/contrib/langchain/pyproject.toml @@ -7,7 +7,7 @@ name = "langchain-arcade" version = "1.4.4" description = "An integration package connecting Arcade and Langchain/LangGraph" readme = "README.md" -repository = "https://github.com/arcadeai/arcade-ai/tree/main/contrib/langchain" +repository = "https://github.com/arcadeai/arcade-mcp/tree/main/contrib/langchain" license = "MIT" requires-python = ">=3.10" dependencies = [ diff --git a/docker/Dockerfile b/docker/Dockerfile index 58d6d6b3..c4e66c03 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -38,7 +38,8 @@ RUN ls -la /app/dist/ # Install the worker and CLI package RUN python -m pip install \ /app/dist/arcade_serve-*.whl \ - /app/dist/arcade_ai-*.whl + /app/dist/arcade_mcp-*.whl + /app/dist/arcade_mcp_server-*.whl # Conditionally install toolkit wheels from dist directory if INSTALL_TOOLKITS is true and the toolkit is in toolkits.txt RUN if [ "$INSTALL_TOOLKITS" = "true" ] ; then \ @@ -51,7 +52,8 @@ RUN if [ "$INSTALL_TOOLKITS" = "true" ] ; then \ # Check if this is not a core package and if the wheel file exists if [ "$wheel_name" != "arcade_core" ] && \ [ "$wheel_name" != "arcade_serve" ] && \ - [ "$wheel_name" != "arcade_ai" ] && \ + [ "$wheel_name" != "arcade_mcp" ] && \ + [ "$wheel_name" != "arcade_mcp_server" ] && \ [ "$wheel_name" != "arcade_tdk" ]; then \ if ls $wheel_file 1> /dev/null 2>&1; then \ echo "Installing $toolkit from $wheel_file"; \ diff --git a/docker/Makefile b/docker/Makefile index 5d71c173..126700ec 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -1,6 +1,6 @@ VENDOR ?= ArcadeAI PROJECT ?= ArcadeAI -SOURCE ?= https://github.com/ArcadeAI/arcade-ai +SOURCE ?= https://github.com/ArcadeAI/arcade-mcp LICENSE ?= MIT DESCRIPTION ?= "Arcade Worker for LLM Tool Serving" REPOSITORY ?= arcadeai/worker diff --git a/docker/README.md b/docker/README.md index ebba30f8..65f99b77 100644 --- a/docker/README.md +++ b/docker/README.md @@ -14,12 +14,12 @@ This guide provides detailed instructions on how to set up and run Arcade using Begin by cloning the Arcade repository: ```bash -git clone https://github.com/ArcadeAI/arcade-ai.git +git clone https://github.com/ArcadeAI/arcade-mcp.git ``` ### 2. Build package wheels -From the root of the arcade-ai repository: +From the root of the arcade-mcp repository: ```bash make full-dist @@ -30,7 +30,7 @@ make full-dist Change to the `docker` directory: ```bash -cd arcade-ai/docker +cd arcade-mcp/docker ``` Copy the example environment file to `.env`: diff --git a/examples/ai-sdk/README.md b/examples/ai-sdk/README.md index 6ea89405..13d87cd7 100644 --- a/examples/ai-sdk/README.md +++ b/examples/ai-sdk/README.md @@ -22,7 +22,7 @@ # Arcade - AI SDK -This example demonstrates how to integrate [Arcade](https://docs.arcade.dev) with the [Vercel AI SDK](https://sdk.vercel.ai/) to create powerful AI agents. Arcade provides access to a wide range of tools including Gmail, Slack, LinkedIn, and more. You can also develop custom tools using the [Tool SDK](https://github.com/ArcadeAI/arcade-ai). +This example demonstrates how to integrate [Arcade](https://docs.arcade.dev) with the [Vercel AI SDK](https://sdk.vercel.ai/) to create powerful AI agents. Arcade provides access to a wide range of tools including Gmail, Slack, LinkedIn, and more. You can also develop custom tools using the [Tool SDK](https://github.com/ArcadeAI/arcade-mcp). For a list of all hosted tools and auth providers, see the [Arcade Integrations](https://docs.arcade.dev/toolkits) documentation. diff --git a/examples/ai-sdk/package.json b/examples/ai-sdk/package.json index 5c8ce180..e9018f9a 100644 --- a/examples/ai-sdk/package.json +++ b/examples/ai-sdk/package.json @@ -1,25 +1,25 @@ { - "name": "arcade-ai-sdk", - "version": "1.0.0", - "description": "", - "main": "index.js", - "type": "module", - "scripts": { - "dev": "node --env-file=.env index.js", - "generateText": "node --env-file=.env generateText.js" - }, - "keywords": [], - "author": "", - "license": "ISC", - "packageManager": "pnpm@10.6.5", - "dependencies": { - "@ai-sdk/openai": "^1.3.22", - "@arcadeai/arcadejs": "latest", - "ai": "^4.3.15" - }, - "pnpm": { - "overrides": { - "form-data@>=4.0.0 <4.0.4": ">=4.0.4" - } - } + "name": "arcade-mcp-sdk", + "version": "1.0.0", + "description": "", + "main": "index.js", + "type": "module", + "scripts": { + "dev": "node --env-file=.env index.js", + "generateText": "node --env-file=.env generateText.js" + }, + "keywords": [], + "author": "", + "license": "ISC", + "packageManager": "pnpm@10.6.5", + "dependencies": { + "@ai-sdk/openai": "^1.3.22", + "@arcadeai/arcadejs": "latest", + "ai": "^4.3.15" + }, + "pnpm": { + "overrides": { + "form-data@>=4.0.0 <4.0.4": ">=4.0.4" + } + } } diff --git a/examples/call_a_tool_with_llm.py b/examples/call_a_tool_with_llm.py index ad7c98e0..1d8e18e6 100644 --- a/examples/call_a_tool_with_llm.py +++ b/examples/call_a_tool_with_llm.py @@ -10,7 +10,7 @@ from openai import OpenAI def call_tool_with_openai(client: OpenAI) -> dict: response = client.chat.completions.create( messages=[ - {"role": "user", "content": "Star the ArcadeAI/arcade-ai repository."}, + {"role": "user", "content": "Star the ArcadeAI/arcade-mcp repository."}, ], model="gpt-4o-mini", # TODO: Try "claude-3-5-sonnet-20240620" or other models from our supported model providers. Checkout out our docs for a full list https://docs.arcade.dev user="you@example.com", diff --git a/examples/langchain-ts/langgraph-with-user-auth.ts b/examples/langchain-ts/langgraph-with-user-auth.ts index 9b63d10f..c4896a62 100644 --- a/examples/langchain-ts/langgraph-with-user-auth.ts +++ b/examples/langchain-ts/langgraph-with-user-auth.ts @@ -129,7 +129,7 @@ const main = async () => { messages: [ { role: "user", - content: "Star arcadeai/arcade-ai on github", + content: "Star arcadeai/arcade-mcp on github", }, ], }; diff --git a/examples/langchain/langgraph_arcade_minimal.py b/examples/langchain/langgraph_arcade_minimal.py index a937fe9f..cc6a5976 100644 --- a/examples/langchain/langgraph_arcade_minimal.py +++ b/examples/langchain/langgraph_arcade_minimal.py @@ -44,7 +44,7 @@ graph = create_react_agent(model=bound_model, tools=lc_tools, checkpointer=memor # 6) Provide basic config and a user query. # Note: user_id is required for the tool to be authorized config = {"configurable": {"thread_id": "1", "user_id": "user@example.com"}} -user_input = {"messages": [("user", "star the arcadeai/arcade-ai repo on github")]} +user_input = {"messages": [("user", "star the arcadeai/arcade-mcp repo on github")]} # 7) Stream the agent's output. If the tool is unauthorized, it may trigger interrupts for chunk in graph.stream(user_input, config, stream_mode="values"): diff --git a/examples/langchain/langgraph_with_user_auth.py b/examples/langchain/langgraph_with_user_auth.py index 594bf5bb..481ea54f 100644 --- a/examples/langchain/langgraph_with_user_auth.py +++ b/examples/langchain/langgraph_with_user_auth.py @@ -93,7 +93,7 @@ if __name__ == "__main__": "messages": [ { "role": "user", - "content": "Star arcadeai/arcade-ai on github", + "content": "Star arcadeai/arcade-mcp on github", } ], } diff --git a/examples/mastra/README.md b/examples/mastra/README.md index 84154e20..65e265c4 100644 --- a/examples/mastra/README.md +++ b/examples/mastra/README.md @@ -21,7 +21,7 @@ This example demonstrates how to integrate [Arcade](https://docs.arcade.dev) with [Mastra](https://mastra.ai/en/docs) to create powerful AI agents. Arcade provides access to a wide range of tools including Gmail, Slack, LinkedIn, and more, while Mastra provides a robust framework for building AI agents with TypeScript. -For a list of all available tools and authentication options, see the [Arcade Integrations](https://docs.arcade.dev/toolkits) documentation. You can also build custom tools with the [Tool SDK](https://github.com/ArcadeAI/arcade-ai) as described in our [documentation](https://docs.arcade.dev/home/build-tools/create-a-toolkit). +For a list of all available tools and authentication options, see the [Arcade Integrations](https://docs.arcade.dev/toolkits) documentation. You can also build custom tools with the [Tool SDK](https://github.com/ArcadeAI/arcade-mcp) as described in our [documentation](https://docs.arcade.dev/home/build-tools/create-a-toolkit). ## Prerequisites diff --git a/examples/mcp/claude.json b/examples/mcp/claude.json deleted file mode 100644 index d3f9832f..00000000 --- a/examples/mcp/claude.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "mcpServers": { - "arcade": { - "command": "bash", - "args": ["-c", "export ARCADE_API_KEY=arc_xxxx && /path/to/python /path/to/arcade mcp"] - } - } -} diff --git a/examples/mcp/run_stdio.py b/examples/mcp/run_stdio.py deleted file mode 100644 index 489a80e2..00000000 --- a/examples/mcp/run_stdio.py +++ /dev/null @@ -1,24 +0,0 @@ -import arcade_gmail # pip install arcade_gmail -import arcade_search # pip install arcade_search -from arcade_core.catalog import ToolCatalog -from arcade_serve.mcp.stdio import StdioServer - -# 2. Create and populate the tool catalog -catalog = ToolCatalog() -catalog.add_module(arcade_gmail) # Registers all tools in the package -catalog.add_module(arcade_search) - - -# 3. Main entrypoint -async def main(): - # Create the worker with the tool catalog - worker = StdioServer(catalog) - - # Run the worker - await worker.run() - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/libs/arcade-cli/README.md b/libs/arcade-cli/README.md index ae0cc6fc..308fbbed 100644 --- a/libs/arcade-cli/README.md +++ b/libs/arcade-cli/README.md @@ -14,7 +14,7 @@ Arcade CLI provides a comprehensive command-line interface for the Arcade platfo ## Installation ```bash -pip install arcade-ai +pip install arcade-mcp ``` ## Usage diff --git a/libs/arcade-cli/arcade_cli/configure.py b/libs/arcade-cli/arcade_cli/configure.py new file mode 100644 index 00000000..a69ba2e4 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/configure.py @@ -0,0 +1,236 @@ +"""Connect command for configuring MCP clients.""" + +import json +import os +import platform +from pathlib import Path + +import typer +from rich.console import Console + +console = Console() + + +def get_claude_config_path() -> Path: + """Get the Claude Desktop configuration file path.""" + system = platform.system() + if system == "Darwin": # macOS + return ( + Path.home() + / "Library" + / "Application Support" + / "Claude" + / "claude_desktop_config.json" + ) + elif system == "Windows": + return Path(os.environ["APPDATA"]) / "Claude" / "claude_desktop_config.json" + else: # Linux + return Path.home() / ".config" / "Claude" / "claude_desktop_config.json" + + +def get_cursor_config_path() -> Path: + """Get the Cursor configuration file path.""" + system = platform.system() + if system == "Darwin": # macOS + return Path.home() / ".cursor" / "mcp.json" + elif system == "Windows": + return Path(os.environ["APPDATA"]) / "Cursor" / "mcp.json" + else: # Linux + return Path.home() / ".config" / "Cursor" / "mcp.json" + + +def get_vscode_config_path() -> Path: + """Get the VS Code configuration file path.""" + # Paths to global 'Default User' MCP configuration file + system = platform.system() + if system == "Darwin": # macOS + return Path.home() / "Library" / "Application Support" / "Code" / "User" / "mcp.json" + elif system == "Windows": + return Path(os.environ["APPDATA"]) / "Code" / "User" / "mcp.json" + else: # Linux + return Path.home() / ".config" / "Code" / "User" / "mcp.json" + + +def configure_claude_local(server_name: str, port: int = 8000, path: Path | None = None) -> None: + """Configure Claude Desktop to add a local MCP server to the configuration.""" + config_path = path or get_claude_config_path() + config_path.parent.mkdir(parents=True, exist_ok=True) + + # Load existing config or create new one + config = {} + if config_path.exists(): + with open(config_path) as f: + config = json.load(f) + + # Add or update MCP servers configuration + if "mcpServers" not in config: + config["mcpServers"] = {} + + config["mcpServers"][server_name] = { + "command": "python", + "args": ["-m", "arcade_mcp_server", "stream"], + "url": f"http://localhost:{port}/mcp", + } + + # Write updated config + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + console.print( + f"โœ… Configured Claude Desktop by adding local MCP server '{server_name}' to the configuration", + style="green", + ) + console.print( + f" MCP client config file: {config_path.as_posix().replace(' ', '\\ ')}", style="dim" + ) + console.print(f" MCP Server URL: http://localhost:{port}/mcp", style="dim") + console.print(" Restart Claude Desktop for changes to take effect.", style="yellow") + + +def configure_claude_arcade(server_name: str, path: Path | None = None) -> None: + """Configure Claude Desktop to add an Arcade Cloud MCP server to the configuration.""" + # This would connect to the Arcade Cloud to get the server URL + # For now, this is a placeholder + console.print("[red]Connecting to Arcade Cloud servers not yet implemented[/red]") + + +def configure_cursor_local(server_name: str, port: int = 8000, path: Path | None = None) -> None: + """Configure Cursor to add a local MCP server to the configuration.""" + config_path = path or get_cursor_config_path() + config_path.parent.mkdir(parents=True, exist_ok=True) + + # Load existing config or create new one + config = {} + if config_path.exists(): + with open(config_path) as f: + config = json.load(f) + + # Add or update MCP servers configuration + if "mcpServers" not in config: + config["mcpServers"] = {} + + config["mcpServers"][server_name] = { + "name": server_name, + "type": "stream", # Cursor prefers stream + "url": f"http://localhost:{port}/mcp", + } + + # Write updated config + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + console.print( + f"โœ… Configured Cursor by adding local MCP server '{server_name}' to the configuration", + style="green", + ) + console.print( + f" MCP client config file: {config_path.as_posix().replace(' ', '\\ ')}", style="dim" + ) + console.print(f" MCP Server URL: http://localhost:{port}/mcp", style="dim") + console.print(" Restart Cursor for changes to take effect.", style="yellow") + + +def configure_cursor_arcade(server_name: str, path: Path | None = None) -> None: + """Configure Cursor to add an Arcade Cloud MCP server to the configuration.""" + console.print("[red]Connecting to Arcade Cloud servers not yet implemented[/red]") + + +def configure_vscode_local(server_name: str, port: int = 8000, path: Path | None = None) -> None: + """Configure VS Code to add a local MCP server to the configuration.""" + config_path = path or get_vscode_config_path() + config_path.parent.mkdir(parents=True, exist_ok=True) + # Load existing config or create new one + config = {} + if config_path.exists(): + with open(config_path) as f: + try: + config = json.load(f) + except json.JSONDecodeError as e: + raise ValueError( + f"\n\tFailed to load MCP configuration file at {config_path.as_posix()} " + f"\n\tThe file contains invalid JSON: {e}. " + "\n\tPlease check the file format or delete it to create a new configuration." + ) + + # Add or update MCP servers configuration + if "servers" not in config: + config["servers"] = {} + + config["servers"][server_name] = { + "type": "http", + "url": f"http://localhost:{port}/mcp", + } + + # Write updated config + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + console.print( + f"โœ… Configured VS Code by adding local MCP server '{server_name}' to the configuration", + style="green", + ) + console.print( + f" MCP client config file: {config_path.as_posix().replace(' ', '\\ ')}", style="dim" + ) + console.print(f" MCP Server URL: http://localhost:{port}/mcp", style="dim") + console.print(" Restart VS Code for changes to take effect.", style="yellow") + + +def configure_vscode_arcade(server_name: str, path: Path | None = None) -> None: + """Configure VS Code to add an Arcade Cloud MCP server to the configuration.""" + console.print("[red]Connecting to Arcade Cloud servers not yet implemented[/red]") + + +def configure_client( + client: str, + server_name: str | None = None, + from_local: bool = False, + from_arcade: bool = False, + port: int = 8000, + path: Path | None = None, +) -> None: + """ + Configure an MCP client to connect to a server. + + Args: + client: The MCP client to configure (claude, cursor, vscode) + server_name: Name of the server to add to the configuration + from_local: Add a local server to the configuration + from_arcade: Add an Arcade Cloud server to the configuration + port: Port for local servers (default: 8000) + path: Custom path to the MCP client configuration file + """ + if not from_local and not from_arcade: + console.print("[red]Must specify either --from-local or --from-arcade[/red]") + raise typer.Exit(1) + + if from_local and from_arcade: + console.print("[red]Cannot specify both --from-local and --from-arcade[/red]") + raise typer.Exit(1) + + # Default server name if not provided + if not server_name: + # Try to detect from current directory + server_name = Path.cwd().name if Path("server.py").exists() else "arcade-mcp-server" + + client_lower = client.lower() + + if client_lower == "claude": + if from_local: + configure_claude_local(server_name, port, path) + else: + configure_claude_arcade(server_name, path) + elif client_lower == "cursor": + if from_local: + configure_cursor_local(server_name, port, path) + else: + configure_cursor_arcade(server_name, path) + elif client_lower == "vscode": + if from_local: + configure_vscode_local(server_name, port, path) + else: + configure_vscode_arcade(server_name, path) + else: + console.print(f"[red]Unknown client: {client}[/red]") + console.print("Supported clients: claude, cursor, vscode") + raise typer.Exit(1) diff --git a/libs/arcade-cli/arcade_cli/deployment.py b/libs/arcade-cli/arcade_cli/deployment.py index 47ff2e2c..2787f242 100644 --- a/libs/arcade-cli/arcade_cli/deployment.py +++ b/libs/arcade-cli/arcade_cli/deployment.py @@ -10,6 +10,7 @@ from typing import Any import toml from arcade_core import Toolkit +from arcade_core.catalog import ToolCatalog from arcade_core.toolkit import Validate from arcadepy import Arcade, NotFoundError from httpx import Client, ConnectError, HTTPStatusError, TimeoutException @@ -75,12 +76,78 @@ class Secret(BaseModel): pattern: str | None = None +class AuthProvider(BaseModel): + """Configuration for a local auth provider.""" + + provider_id: str + """The provider ID (e.g., 'google', 'github', 'custom-oauth')""" + + provider_type: str = "oauth2" + """The type of provider, usually 'oauth2'""" + + client_id: str + """OAuth client ID for this provider""" + + client_secret: str + """OAuth client secret for this provider""" + + # Mock tokens for local development + mock_tokens: dict[str, str] | None = None + """ + Mock access tokens by user ID for local development. + Example: {"user-123": "mock-google-token-abc", "user-456": "mock-google-token-def"} + """ + + scopes: list[str] | None = None + """Default scopes for this provider""" + + class Config(BaseModel): + """The configuration for an Arcade worker deployment.""" + id: str + """The unique id for the worker deployment.""" + enabled: bool = True - timeout: int = 30 - retries: int = 3 + """Whether the worker is enabled. Defaults to True.""" + secret: Secret | None = None + """The shared secret between the worker and Arcade Engine server.""" + + timeout: int = 120 + """The maximum execution time in seconds for a tool in this worker.""" + + retries: int = 1 + """The number of times to retry a failed tool invocation. Defaults to 1.""" + + # Local development context - only used when running locally + local_context: dict[str, Any] | None = None + """ + Local context configuration for development. This section is only used when running + 'arcade serve' locally and is ignored during deployment. It can include: + - user_id: Default user ID for local testing + - user_info: Dictionary of user metadata + - metadata: Additional metadata fields + Example: + [worker.config.local_context] + user_id = "test-user-123" + user_info = { email = "test@example.com", name = "Test User" } + """ + + # Local auth providers - only used when running locally + local_auth_providers: list[AuthProvider] | None = None + """ + Local auth provider configurations for development. These are only used when running + 'arcade serve' locally and are ignored during deployment. They define mock OAuth + providers and tokens for testing tools that require authentication. + Example: + [[worker.config.local_auth_providers]] + provider_id = "google" + client_id = "mock-google-client" + client_secret = "mock-google-secret" + [worker.config.local_auth_providers.mock_tokens] + "test-user-123" = "mock-google-access-token" + """ # Validate and parse the secret if required @field_validator("secret", mode="before") @@ -89,22 +156,19 @@ class Config(BaseModel): # If the secret is a string, attempt to parse it as an environment variable or return the secret if isinstance(v, str): secret = get_env_secret(v) - # If the secret has been manually set, return it elif isinstance(v, Secret): secret = v else: raise TypeError("Secret must be a string or a Secret object") - # Check that the secret is not the default dev secret or empty - if secret.value.strip() == "" or secret.value == "dev": - raise ValueError("Secret must be a non-empty string and not 'dev'") + if secret.value.strip() == "": + raise ValueError("Secret must be a non-empty string") return secret @field_serializer("secret") def serialize_secret(self, secret: Secret) -> str: if secret.pattern: return f"$env:{secret.pattern}" - else: - return secret.value + return secret.value # Cloud request for deploying a worker @@ -254,7 +318,8 @@ class Worker(BaseModel): ) # Validate that we are able to load the package - Toolkit.tools_from_directory(package_dir=package_path, package_name=package_path.name) + # Use from_directory to properly resolve src/ layouts and avoid double prefixes + Toolkit.from_directory(package_path) # Compress the package into a byte stream and tar byte_stream = io.BytesIO() @@ -287,6 +352,22 @@ class Worker(BaseModel): if dupes: raise ValueError(f"Duplicate packages: {dupes}") + def get_required_secrets(self) -> set[str]: + """Inspect local toolkits and return a set of required secret keys.""" + all_secrets = set() + if self.local_source: + catalog = ToolCatalog() + for package_path_str in self.local_source.packages: + package_path = self.toml_path.parent / package_path_str + toolkit = Toolkit.from_directory(package_path) + catalog.add_toolkit(toolkit) + + for tool in catalog: + if tool.definition.requirements and tool.definition.requirements.secrets: + for secret in tool.definition.requirements.secrets: + all_secrets.add(secret.key) + return all_secrets + class Deployment(BaseModel): toml_path: Path diff --git a/libs/arcade-cli/arcade_cli/display.py b/libs/arcade-cli/arcade_cli/display.py index b661a57b..fae74b5e 100644 --- a/libs/arcade-cli/arcade_cli/display.py +++ b/libs/arcade-cli/arcade_cli/display.py @@ -36,7 +36,7 @@ def display_tools_table(tools: list[ToolDefinition]) -> None: console.print(table) -def display_tool_details(tool: ToolDefinition, worker: bool = False) -> None: # noqa: C901 +def display_tool_details(tool: ToolDefinition, worker: bool = False) -> None: """ Display detailed information about a specific tool using multiple panels. @@ -59,36 +59,19 @@ def display_tool_details(tool: ToolDefinition, worker: bool = False) -> None: # inputs_table.add_column("Type", style="magenta") inputs_table.add_column("Required", style="yellow") inputs_table.add_column("Description", style="white") - + inputs_table.add_column("Default", style="blue") for param in inputs: - # Format the type string properly - type_str = _format_type_string(param.value_schema) - - # Add the main parameter row + # Since InputParameter does not have a default field, we use "N/A" + default_value = "N/A" + if param.value_schema.enum: + default_value = f"One of {param.value_schema.enum}" inputs_table.add_row( param.name, - type_str, + param.value_schema.val_type, str(param.required), param.description or "", + default_value, ) - - # If this is a json type with properties, show them - if ( - param.value_schema.val_type == "json" - and hasattr(param.value_schema, "properties") - and param.value_schema.properties - ): - _add_nested_properties(inputs_table, param.value_schema.properties, indent=1) - # Handle arrays with inner properties - elif ( - param.value_schema.val_type == "array" - and hasattr(param.value_schema, "inner_properties") - and param.value_schema.inner_properties - ): - _add_nested_properties( - inputs_table, param.value_schema.inner_properties, indent=1, is_array_item=True - ) - inputs_panel = Panel( inputs_table, title="Input Parameters", @@ -258,7 +241,7 @@ def _add_nested_properties( is_array_item: bool = False, ) -> None: """ - Recursively add nested properties to the table. + Recursively add nested properties to the output table. Args: table: The Rich table to add rows to @@ -270,14 +253,11 @@ def _add_nested_properties( # Show array item indicator if needed if is_array_item and indent > 0: - # Get column count from the table - num_columns = len(table.columns) - - # Create a row with the array indicator in the first column and empty strings for the rest - row_data = [f"{indent_prefix[:-2]}[item]"] + [""] * (num_columns - 1) - if num_columns >= 3: - row_data[2] = "[dim]Each item in array:[/dim]" - table.add_row(*row_data) + table.add_row( + f"{indent_prefix[:-2]}[item]", + "", + "[dim]Each item in array:[/dim]", + ) for prop_name, prop_schema in properties.items(): # Format the type string @@ -289,19 +269,11 @@ def _add_nested_properties( if hasattr(prop_schema, "description") and prop_schema.description: description = prop_schema.description - # Create row data based on number of columns - num_columns = len(table.columns) - row_data = [f"{indent_prefix}{prop_name}", type_str] - - # For input parameter tables (4 columns), add empty required column - if num_columns == 4: - row_data.append("") # Empty "Required" column for nested properties - row_data.append(f"[dim]{description}[/dim]" if description else "") - # For output tables (3 columns), just add description - elif num_columns == 3: - row_data.append(f"[dim]{description}[/dim]" if description else "") - - table.add_row(*row_data) + table.add_row( + f"{indent_prefix}{prop_name}", + type_str, + f"[dim]{description}[/dim]" if description else "", + ) # Recursively add nested properties if this is a json type with properties if ( diff --git a/libs/arcade-cli/arcade_cli/main.py b/libs/arcade-cli/arcade_cli/main.py index 4978af8c..9f14edd3 100644 --- a/libs/arcade-cli/arcade_cli/main.py +++ b/libs/arcade-cli/arcade_cli/main.py @@ -1,52 +1,46 @@ import asyncio import os +import subprocess +import sys import threading import traceback import uuid import webbrowser from pathlib import Path -from typing import Any, Optional +from typing import Optional import httpx import typer from arcadepy import Arcade -from arcadepy.types import AuthorizationResponse -from openai import OpenAI, OpenAIError from rich.console import Console from rich.markup import escape from rich.text import Text from tqdm import tqdm +import arcade_cli.secret as secret import arcade_cli.worker as worker from arcade_cli.authn import LocalAuthCallbackServer, check_existing_login from arcade_cli.constants import ( CREDENTIALS_FILE_PATH, - LOCALHOST, PROD_CLOUD_HOST, PROD_ENGINE_HOST, ) from arcade_cli.deployment import Deployment from arcade_cli.display import ( - display_arcade_chat_header, display_eval_results, - display_tool_messages, ) from arcade_cli.show import show_logic from arcade_cli.toolkit_docs import generate_toolkit_docs from arcade_cli.utils import ( OrderCommands, + Provider, compute_base_url, compute_login_url, get_eval_files, - get_today_context, - get_user_input, - handle_chat_interaction, - handle_tool_authorization, - handle_user_command, - is_authorization_pending, load_eval_suites, log_engine_health, require_dependency, + resolve_provider_api_key, validate_and_get_config, version_callback, ) @@ -69,6 +63,13 @@ cli.add_typer( rich_help_panel="Deployment", ) +cli.add_typer( + secret.app, + name="secret", + help="Manage tool secrets in the cloud (set, unset, list)", + rich_help_panel="Admin", +) + console = Console() @@ -179,18 +180,119 @@ def new( ), directory: str = typer.Option(os.getcwd(), "--dir", help="tools directory path"), debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"), + full: bool = typer.Option( + False, + "--full", + "-f", + help="Create a toolkit package with a full scaffolding (includes evals, tests, license, etc)", + ), ) -> None: """ Creates a new toolkit with the given name, description, and result type. """ - from arcade_cli.new import create_new_toolkit + from arcade_cli.new import create_new_toolkit, create_new_toolkit_minimal try: - create_new_toolkit(directory, toolkit_name) + if not full: + create_new_toolkit_minimal(directory, toolkit_name) + else: + create_new_toolkit(directory, toolkit_name) except Exception as e: handle_cli_error("Failed to create new Toolkit", e, debug) +@cli.command( + name="mcp", + help="Run MCP servers with different transports", + rich_help_panel="Launch", +) +def mcp( + transport: str = typer.Argument("http", help="Transport type: stdio, http"), + host: str = typer.Option("127.0.0.1", "--host", help="Host to bind to (HTTP mode only)"), + port: int = typer.Option(8000, "--port", help="Port to bind to (HTTP mode only)"), + tool_package: Optional[str] = typer.Option( + None, + "--tool-package", + "--package", + "-p", + help="Specific tool package to load (e.g., 'github' for arcade-github)", + ), + discover_installed: bool = typer.Option( + False, "--discover-installed", "--all", help="Discover all installed arcade tool packages" + ), + show_packages: bool = typer.Option( + False, "--show-packages", help="Show loaded packages during discovery" + ), + reload: bool = typer.Option( + False, "--reload", help="Enable auto-reload on code changes (HTTP mode only)" + ), + debug: bool = typer.Option(False, "--debug", help="Enable debug mode with verbose logging"), + env_file: Optional[str] = typer.Option(None, "--env-file", help="Path to environment file"), + name: Optional[str] = typer.Option(None, "--name", help="Server name"), + version: Optional[str] = typer.Option(None, "--version", help="Server version"), + cwd: Optional[str] = typer.Option(None, "--cwd", help="Working directory to run from"), +) -> None: + """ + Run Arcade MCP Server (passthrough to arcade_mcp_server). + + This command provides a unified CLI experience by passing through + all arguments to the arcade_mcp_server module. + + Examples: + arcade mcp stdio + arcade mcp http --port 8080 + arcade mcp --tool-package github + arcade mcp --discover-installed --show-packages + """ + # Build the command to pass through to arcade_mcp_server + cmd = [sys.executable, "-m", "arcade_mcp_server", transport] + + # Add optional arguments + cmd.extend(["--host", host]) + cmd.extend(["--port", str(port)]) + cmd.append("--debug") + if tool_package: + cmd.extend(["--tool-package", tool_package]) + if discover_installed: + cmd.append("--discover-installed") + if show_packages: + cmd.append("--show-packages") + if reload: + cmd.append("--reload") + if env_file: + cmd.extend(["--env-file", env_file]) + if name: + cmd.extend(["--name", name]) + if version: + cmd.extend(["--version", version]) + if cwd: + cmd.extend(["--cwd", cwd]) + + try: + # Show what command we're running in debug mode + if debug: + console.print(f"[dim]Running: {' '.join(cmd)}[/dim]") + + # Execute the command and pass through all output + result = subprocess.run(cmd, check=False) + + # Exit with the same code as the subprocess + if result.returncode != 0: + handle_cli_error("Failed to run MCP server") + + except KeyboardInterrupt: + console.print("\n[yellow]MCP server stopped[/yellow]") + raise typer.Exit(0) + except FileNotFoundError: + console.print( + "[red]arcade_mcp_server module not found. Make sure arcade-mcp-server is installed.[/red]" + ) + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Error running MCP server: {e}[/red]") + raise typer.Exit(1) + + @cli.command( help="Show the installed toolkits or details of a specific tool", rich_help_panel="Tool Development", @@ -260,134 +362,6 @@ def show( ) -@cli.command( - help="Start a chat with a model in the terminal to test tools", - rich_help_panel="Tool Development", -) -def chat( - model: str = typer.Option("gpt-4o", "-m", "--model", help="The model to use for prediction."), - stream: bool = typer.Option( - False, "-s", "--stream", is_flag=True, help="Stream the tool output." - ), - prompt: str = typer.Option(None, "--prompt", help="The system prompt to use for the chat."), - debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"), - host: str = typer.Option( - PROD_ENGINE_HOST, - "-h", - "--host", - help="The Arcade Engine address to send chat requests to.", - ), - port: Optional[int] = typer.Option( - None, - "-p", - "--port", - help="The port of the Arcade Engine.", - ), - force_tls: bool = typer.Option( - False, - "--tls", - help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.", - ), - force_no_tls: bool = typer.Option( - False, - "--no-tls", - help="Whether to disable TLS for the connection to the Arcade Engine.", - ), -) -> None: - """ - Chat with a language model. - """ - try: - import readline - except ImportError: - console.print( - "Readline is not available on this platform. Command history will be limited.", - style="dim", - ) - - config = validate_and_get_config() - base_url = compute_base_url(force_tls, force_no_tls, host, port) - - client = Arcade(api_key=config.api.key, base_url=base_url) - user_email = config.user.email if config.user else None - - try: - # start messages conversation - history: list[dict[str, Any]] = [] - - # Ground the LLM with today's date and day of the week to help when calling date-related tools - # in case the user refers to relative dates (e.g. next Monday, last month, etc) - today_context = get_today_context() - - if prompt: - prompt = f"{today_context} {prompt}" - else: - prompt = today_context - - history.append({"role": "system", "content": prompt}) - - display_arcade_chat_header(base_url, stream) - - # Try to hit /health endpoint on engine and warn if it is down - log_engine_health(client) - - while True: - console.print( - f"\n[magenta][bold]User[/bold] {user_email}: [/magenta]" - + "([bold][default]/?[/default][/bold] for help)" - ) - - user_input = get_user_input() - - # Add the input to history - readline.add_history(user_input) - - if handle_user_command( - user_input, history, host, port, force_tls, force_no_tls, show_logic - ): - continue - - history.append({"role": "user", "content": user_input}) - - try: - # TODO fixup configuration to remove this + "/v1" workaround - openai_client = OpenAI(api_key=config.api.key, base_url=base_url + "/v1") - chat_result = handle_chat_interaction( - openai_client, model, history, user_email, stream - ) - - history = chat_result.history - tool_messages = chat_result.tool_messages - tool_authorization = chat_result.tool_authorization - - # wait for tool authorizations to complete, if any - if tool_authorization and is_authorization_pending(tool_authorization): - chat_result = handle_tool_authorization( - client, - AuthorizationResponse.model_validate(tool_authorization), - history, - openai_client, - model, - user_email, - stream, - ) - history = chat_result.history - tool_messages = chat_result.tool_messages - - except OpenAIError as e: - handle_cli_error("Arcade Chat failed", e, debug, should_exit=False) - continue - if debug: - display_tool_messages(tool_messages) - - except KeyboardInterrupt: - console.print("Chat stopped by user.", style="bold blue") - typer.Exit() - - except RuntimeError as e: - handle_cli_error("Failed to run tool", e, debug) - - @cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development") def evals( directory: str = typer.Argument(".", help="Directory containing evaluation files"), @@ -402,34 +376,19 @@ def evals( "gpt-4o", "--models", "-m", - help="The models to use for evaluation (default: gpt-4o). Use commas to separate multiple models.", + help="The models to use for evaluation (default: gpt-4o). Use commas to separate multiple models. All models must belong to the same provider.", ), - host: str = typer.Option( - LOCALHOST, - "-h", - "--host", - help="The Arcade Engine address to send chat requests to.", - ), - cloud: bool = typer.Option( - False, - "--cloud", - help="Whether to run evaluations against the Arcade Cloud Engine. Overrides the 'host' option.", - ), - port: Optional[int] = typer.Option( - None, + provider: Provider = typer.Option( + Provider.OPENAI, + "--provider", "-p", - "--port", - help="The port of the Arcade Engine.", + help="The provider of the models to use for evaluation.", ), - force_tls: bool = typer.Option( - False, - "--tls", - help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.", - ), - force_no_tls: bool = typer.Option( - False, - "--no-tls", - help="Whether to disable TLS for the connection to the Arcade Engine.", + provider_api_key: str = typer.Option( + None, + "--provider-api-key", + "-k", + help="The model provider API key. If not provided, will look for the appropriate environment variable based on the provider (e.g., OPENAI_API_KEY for openai provider), first in the current environment, then in the current working directory's .env file.", ), debug: bool = typer.Option(False, "--debug", help="Show debug information"), ) -> None: @@ -440,7 +399,7 @@ def evals( require_dependency( package_name="arcade_evals", command_name="evals", - install_command=r"pip install 'arcade-ai\[evals]'", + install_command=r"pip install 'arcade-mcp\[evals]'", ) # Although Evals does not depend on the TDK, some evaluations import the # ToolCatalog class from the TDK instead of from arcade_core, so we require @@ -451,27 +410,27 @@ def evals( install_command=r"pip install arcade-tdk", ) - config = validate_and_get_config() - - host = PROD_ENGINE_HOST if cloud else host - base_url = compute_base_url(force_tls, force_no_tls, host, port) - models_list = models.split(",") # Use 'models_list' to avoid shadowing + # Resolve the API key for the provider + resolved_api_key = resolve_provider_api_key(provider, provider_api_key) + if not resolved_api_key: + provider_env_vars = { + Provider.OPENAI: "OPENAI_API_KEY", + } + env_var_name = provider_env_vars.get(provider, f"{provider.upper()}_API_KEY") + handle_cli_error( + f"API key not found for provider '{provider.value}'. " + f"Please provide it via --provider-api-key,-k argument, set the {env_var_name} environment variable, " + f"or add it to a .env file in the current directory.", + should_exit=True, + ) + eval_files = get_eval_files(directory) if not eval_files: return - console.print( - Text.assemble( - ("\nRunning evaluations against Arcade Engine at ", "bold"), - (base_url, "bold blue"), - ) - ) - - # Try to hit /health endpoint on engine and warn if it is down - with Arcade(api_key=config.api.key, base_url=base_url) as client: - log_engine_health(client) + console.print("\nRunning evaluations", style="bold") # Use the new function to load eval suites eval_suites = load_eval_suites(eval_files) @@ -500,8 +459,7 @@ def evals( for model in models_list: task = asyncio.create_task( suite_func( - config=config, - base_url=base_url, + provider_api_key=resolved_api_key, model=model, max_concurrency=max_concurrent, ) @@ -528,6 +486,7 @@ def evals( @cli.command( help="Start tool server worker with locally installed tools", rich_help_panel="Launch", + hidden=True, ) def serve( host: str = typer.Option( @@ -565,6 +524,9 @@ def serve( """ Start a local Arcade Worker server. """ + console.log( + "โš ๏ธ This command is deprecated and will be removed in a future version.", style="yellow" + ) require_dependency( package_name="arcade_serve", command_name="serve", @@ -590,58 +552,68 @@ def serve( @cli.command( - help="Start a server with locally installed Arcade tools", - rich_help_panel="Launch", - hidden=True, + help="Configure MCP clients to connect to your server", rich_help_panel="Tool Development" ) -def workerup( - host: str = typer.Option( - "127.0.0.1", - help="Host for the app, from settings by default.", - show_default=True, +def configure( + client: str = typer.Argument( + ..., + help="The MCP client to configure (claude, cursor, vscode)", + ), + server_name: Optional[str] = typer.Option( + None, + "--server", + "-s", + help="Name of the server to connect to (defaults to current directory name)", + ), + from_local: bool = typer.Option( + False, + "--from-local", + help="Connect to a local MCP server", + is_flag=True, + ), + from_arcade: bool = typer.Option( + False, + "--from-arcade", + help="Connect to an Arcade Cloud MCP server", + is_flag=True, ), port: int = typer.Option( - "8002", - "-p", + 8000, "--port", - help="Port for the app, defaults to ", - show_default=True, + "-p", + help="Port for local servers", ), - disable_auth: bool = typer.Option( - False, - "--no-auth", - help="Disable authentication for the worker. Not recommended for production.", - show_default=True, - ), - otel_enable: bool = typer.Option( - False, "--otel-enable", help="Send logs to OpenTelemetry", show_default=True + path: Optional[Path] = typer.Option( + None, + "--path", + "-f", + exists=False, + help="Optional path to a specific MCP client config file (overrides default path)", ), debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"), ) -> None: """ - Starts the worker with host, port, and reload options. Uses - Uvicorn as ASGI worker. Parameters allow runtime configuration. - """ - require_dependency( - package_name="arcade_serve", - command_name="worker", - install_command=r"pip install 'arcade-serve'", - ) + Configure MCP clients to connect to your server. - from arcade_cli.serve import serve_default_worker + Examples: + arcade configure claude --from-local + arcade configure cursor --from-local --port 8080 + arcade configure vscode --from-local --path .vscode/mcp.json + arcade configure claude --from-arcade --server my-toolkit + """ + from arcade_cli.configure import configure_client try: - serve_default_worker( - host, - port, - disable_auth=disable_auth, - enable_otel=otel_enable, - debug=debug, + configure_client( + client=client, + server_name=server_name, + from_local=from_local, + from_arcade=from_arcade, + port=port, + path=path, ) - except KeyboardInterrupt: - typer.Exit() except Exception as e: - handle_cli_error("Failed to start Arcade Toolkit Server", e, debug) + handle_cli_error(f"Failed to configure {client}", e, debug) @cli.command(help="Deploy toolkits to Arcade Cloud", rich_help_panel="Deployment") @@ -712,6 +684,30 @@ def deploy( for worker in deployment.worker: console.log(f"Deploying '{worker.config.id}...'", style="dim") try: + # Discover and upload secrets + required_secret_keys = worker.get_required_secrets() + for secret_key in required_secret_keys: + secret_value = os.getenv(secret_key) + if not secret_value: + console.log( + f"โš ๏ธ Secret '{secret_key}' not found in environment, skipping.", + style="yellow", + ) + continue + try: + secret._upsert_secret_to_engine( + engine_url, config.api.key, secret_key, secret_value + ) + except Exception as e: + handle_cli_error( + f"Failed to upload secret '{secret_key}'", e, debug, should_exit=False + ) + else: + console.log( + f"โœ… Secret '{secret_key}' uploaded successfully", + style="dim green", + ) + # Attempt to deploy worker worker.request().execute(cloud_client, engine_client) console.log( @@ -792,7 +788,10 @@ def dashboard( ) def docs( toolkit_name: str = typer.Option( - ..., "--toolkit-name", "-n", help="The name of the toolkit to generate documentation for." + ..., + "--toolkit-name", + "-n", + help="The name of the toolkit to generate documentation for.", ), toolkit_dir: str = typer.Option( ..., @@ -897,7 +896,10 @@ def docs( ) def generate_toolkit_docs_command( toolkit_name: str = typer.Option( - ..., "--toolkit-name", "-n", help="The name of the toolkit to generate documentation for." + ..., + "--toolkit-name", + "-n", + help="The name of the toolkit to generate documentation for.", ), toolkit_dir: str = typer.Option( ..., @@ -975,14 +977,14 @@ def main_callback( help="Print version and exit.", ), ) -> None: - excluded_commands = { + # Commands that do not require a logged in user + public_commands = { login.__name__, logout.__name__, - serve.__name__, - workerup.__name__, dashboard.__name__, + evals.__name__, } - if ctx.invoked_subcommand in excluded_commands: + if ctx.invoked_subcommand in public_commands: return if not check_existing_login(suppress_message=True): diff --git a/libs/arcade-cli/arcade_cli/new.py b/libs/arcade-cli/arcade_cli/new.py index e6562ee2..f1fc578a 100644 --- a/libs/arcade-cli/arcade_cli/new.py +++ b/libs/arcade-cli/arcade_cli/new.py @@ -9,25 +9,25 @@ import typer from jinja2 import Environment, FileSystemLoader, select_autoescape from rich.console import Console -from arcade_cli.deployment import ( - create_demo_deployment, -) +from arcade_cli.templates import get_full_template_directory, get_minimal_template_directory console = Console() -# Retrieve the installed version of arcade-ai +# Retrieve the installed version of arcade-mcp try: - ARCADE_AI_MIN_VERSION = get_version("arcade-ai") - ARCADE_AI_MAX_VERSION = str(int(ARCADE_AI_MIN_VERSION.split(".")[0]) + 1) + ".0.0" + ARCADE_MCP_MIN_VERSION = get_version("arcade-mcp") + ARCADE_MCP_MAX_VERSION = str(int(ARCADE_MCP_MIN_VERSION.split(".")[0]) + 1) + ".0.0" except Exception as e: - console.print(f"[red]Failed to get arcade-ai version: {e}[/red]") - ARCADE_AI_MIN_VERSION = "2.0.0" # Default version if unable to fetch - ARCADE_AI_MAX_VERSION = "3.0.0" + console.print(f"[red]Failed to get arcade-mcp version: {e}[/red]") + ARCADE_MCP_MIN_VERSION = "1.0.0rc1" # Default version if unable to fetch + ARCADE_MCP_MAX_VERSION = "4.0.0" -ARCADE_TDK_MIN_VERSION = "2.0.0" +ARCADE_TDK_MIN_VERSION = "2.6.0rc1" ARCADE_TDK_MAX_VERSION = "3.0.0" -ARCADE_SERVE_MIN_VERSION = "2.0.0" +ARCADE_SERVE_MIN_VERSION = "2.2.0rc1" ARCADE_SERVE_MAX_VERSION = "3.0.0" +ARCADE_MCP_SERVER_MIN_VERSION = "1.0.0rc1" +ARCADE_MCP_SERVER_MAX_VERSION = "3.0.0" def ask_question(question: str, default: Optional[str] = None) -> str: @@ -181,10 +181,10 @@ def create_new_toolkit(output_directory: str, toolkit_name: str) -> None: # TODO: this detection mechanism works only for people that didn't change the # name of the repo, a better detection method is required here is_community_toolkit = False - if cwd.name == "toolkits" and cwd.parent.name == "arcade-ai": + if cwd.name == "toolkits" and cwd.parent.name == "arcade-mcp": prompt = ( "Is your toolkit a community contribution (to be merged into " - "\x1b]8;;https://github.com/ArcadeAI/arcade-ai\x1b\\ArcadeAI/arcade-ai\x1b]8;;\x1b\\ repo)?" + "\x1b]8;;https://github.com/ArcadeAI/arcade-mcp\x1b\\ArcadeAI/arcade-mcp\x1b]8;;\x1b\\ repo)?" ) is_community_toolkit = ask_yes_no_question(prompt, default=True) @@ -200,13 +200,14 @@ def create_new_toolkit(output_directory: str, toolkit_name: str) -> None: "arcade_tdk_max_version": ARCADE_TDK_MAX_VERSION, "arcade_serve_min_version": ARCADE_SERVE_MIN_VERSION, "arcade_serve_max_version": ARCADE_SERVE_MAX_VERSION, - "arcade_ai_min_version": ARCADE_AI_MIN_VERSION, - "arcade_ai_max_version": ARCADE_AI_MAX_VERSION, + "arcade_mcp_min_version": ARCADE_MCP_MIN_VERSION, + "arcade_mcp_max_version": ARCADE_MCP_MAX_VERSION, "creation_year": datetime.now().year, "is_community_toolkit": is_community_toolkit, "is_official_toolkit": is_official_toolkit, } - template_directory = Path(__file__).parent / "templates" / "{{ toolkit_name }}" + + template_directory = get_full_template_directory() / "{{ toolkit_name }}" env = Environment( loader=FileSystemLoader(str(template_directory)), @@ -230,10 +231,49 @@ def create_new_toolkit(output_directory: str, toolkit_name: str) -> None: def create_deployment(toolkit_directory: Path, toolkit_name: str) -> None: - worker_toml = toolkit_directory / "worker.toml" - if not worker_toml.exists(): - create_demo_deployment(worker_toml, toolkit_name) + # No longer create worker.toml for MCP servers + # The server.py file handles all configuration + pass + + +def create_new_toolkit_minimal(output_directory: str, toolkit_name: str) -> None: + """Create a new toolkit from a template with user input.""" + toolkit_directory = Path(output_directory) + + # Check for illegal characters in the toolkit name + if re.match(r"^[a-z0-9_]+$", toolkit_name): + if (toolkit_directory / toolkit_name).exists(): + console.print(f"[red]Toolkit '{toolkit_name}' already exists.[/red]") + exit(1) else: - pass - # Disabled pending bug fix - # update_deployment_with_local_packages(worker_toml, toolkit_name) + console.print( + "[red]Toolkit name contains illegal characters. " + "Only lowercase alphanumeric characters and underscores are allowed. " + "Please try again.[/red]" + ) + exit(1) + + context = { + "toolkit_name": toolkit_name, + "arcade_mcp_min_version": ARCADE_MCP_MIN_VERSION, + "arcade_mcp_max_version": ARCADE_MCP_MAX_VERSION, + "arcade_mcp_server_min_version": ARCADE_MCP_SERVER_MIN_VERSION, + "arcade_mcp_server_max_version": ARCADE_MCP_SERVER_MAX_VERSION, + } + template_directory = get_minimal_template_directory() / "{{ toolkit_name }}" + + env = Environment( + loader=FileSystemLoader(str(template_directory)), + autoescape=select_autoescape(["html", "xml"]), + ) + + ignore_pattern = create_ignore_pattern(False, False) + + try: + create_package(env, template_directory, toolkit_directory, context, ignore_pattern) + console.print( + f"[green]Toolkit '{toolkit_name}' created successfully at '{toolkit_directory}'.[/green]" + ) + except Exception: + remove_toolkit(toolkit_directory, toolkit_name) + raise diff --git a/libs/arcade-cli/arcade_cli/secret.py b/libs/arcade-cli/arcade_cli/secret.py new file mode 100644 index 00000000..53c48ddc --- /dev/null +++ b/libs/arcade-cli/arcade_cli/secret.py @@ -0,0 +1,286 @@ +import httpx +import typer +from rich.console import Console +from rich.table import Table + +from arcade_cli.constants import ( + PROD_ENGINE_HOST, +) +from arcade_cli.utils import ( + OrderCommands, + compute_base_url, + validate_and_get_config, +) + +console = Console() + + +app = typer.Typer( + cls=OrderCommands, + add_completion=False, + no_args_is_help=True, + pretty_exceptions_enable=False, + pretty_exceptions_show_locals=False, + pretty_exceptions_short=True, +) + +state = { + "engine_url": compute_base_url( + host=PROD_ENGINE_HOST, port=None, force_tls=False, force_no_tls=False + ) +} + + +@app.callback() +def main( + host: str = typer.Option( + PROD_ENGINE_HOST, + "--host", + "-h", + help="The Arcade Engine host.", + ), + port: int = typer.Option( + None, + "--port", + "-p", + help="The port of the Arcade Engine host.", + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine.", + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + ), +) -> None: + """ + Manage tool secrets in Arcade Cloud. + + Usage: + arcade secret set KEY1=value1 KEY2="value 2" + arcade secret set --from-env + arcade secret set -from-env --env-file /path/to/.env + arcade secret list + arcade secret unset KEY1 KEY2 KEY3 + """ + engine_url = compute_base_url(force_tls, force_no_tls, host, port) + state["engine_url"] = engine_url + + +@app.command("set", help="Set tool secret(s) using KEY=VALUE pairs or from .env file") +def set_secret( + key_value_pairs: list[str] = typer.Argument( + None, + help="Key-value pairs in the format KEY=VALUE", + ), + from_env: bool = typer.Option( + False, + "--from-env", + help="Load all secrets from local .env file", + ), + env_file: str = typer.Option( + ".env", + "--env-file", + "-f", + help="Path to .env file (default: .env)", + ), +) -> None: + """Set secrets either from .env file or KEY=VALUE pairs.""" + if not from_env and not key_value_pairs: + raise typer.BadParameter( + "Either provide KEY=VALUE pairs or use --from-env to load from .env file." + ) + if from_env and key_value_pairs: + raise typer.BadParameter("Cannot use both KEY=VALUE pairs and --from-env at the same time.") + + config = validate_and_get_config() + + if from_env: + secrets = load_env_file(env_file) + else: + secrets = {} + for pair in key_value_pairs: + if ( + "=" not in pair + or pair.split("=", 1)[0].strip() == "" + or pair.split("=", 1)[1].strip() == "" + ): + raise typer.BadParameter(f"Invalid format '{pair}'. Expected KEY=VALUE") + key, value = pair.split("=", 1) + key = key.strip() + if " " in key: + raise typer.BadParameter(f"Secret key '{key}' cannot contain spaces") + value = value # keep the value as is, including the whitespace + secrets[key] = value + + engine_url = state["engine_url"] + + for secret_key, secret_value in secrets.items(): + try: + _upsert_secret_to_engine(engine_url, config.api.key, secret_key, secret_value) + except Exception as e: + console.print(f"Error setting secret '{secret_key}': {e}", style="bold red") + continue + console.print( + f"Secret '{secret_key}' with value ending in ...{secret_value[-4:]} set successfully" + ) + + +@app.command("list", help="List all tool secrets in Arcade Cloud") +def list_secrets() -> None: + """List all secrets (keys only, values are masked).""" + config = validate_and_get_config() + engine_url = state["engine_url"] + + secrets = _get_secrets_from_engine(engine_url, config.api.key) + print_secret_table(secrets) + + +@app.command("unset", help="Delete tool secret(s) by key names") +def unset_secret( + keys: list[str] = typer.Argument( + ..., + help="Secret keys to delete", + ), +) -> None: + """Delete tool secrets.""" + config = validate_and_get_config() + engine_url = state["engine_url"] + secrets = _get_secrets_from_engine(engine_url, config.api.key) + + key_to_id = {secret["key"]: secret["id"] for secret in secrets} + + for key in set(keys): + secret_id = key_to_id.get(key) + if not secret_id: + console.print(f"Warning: Secret with key '{key}' not found, skipping", style="yellow") + continue + + try: + _delete_secret_from_engine(engine_url, config.api.key, secret_id) + console.print(f"Secret '{key}' deleted successfully") + except Exception: + console.print( + f"Failed to delete secret '{key}'. Do you have permission to delete this secret?", + style="bold red", + ) + continue + + +def print_secret_table(secrets: list[dict]) -> None: + """Print a table of tool secrets (with masked values).""" + table = Table(title="Tool Secrets") + table.add_column("Key", style="cyan") + table.add_column("Type", style="green") + table.add_column("Description", style="green") + table.add_column("Hint", style="green") + table.add_column("Last Accessed", style="green") + table.add_column("Created At", style="green") + for secret in secrets: + table.add_row( + secret["key"], + secret["binding"]["type"], + secret["description"], + "..." + secret["hint"] if secret["hint"] else "-", + secret["last_accessed_at"] if secret["last_accessed_at"] else "Never", + secret["created_at"], + ) + console.print(table) + + +def load_env_file(env_file_path: str) -> dict[str, str]: + """Load tool secrets from a .env file.""" + secrets = {} + with open(env_file_path) as file: + for line in file: + line = line.strip() + if line.startswith("#") or not line: + continue + + # Split on first '=' to handle values that contain '=' + if "=" not in line: + continue + + key, value = line.split("=", 1) + key = key.strip() + + # Remove inline comments, but respect quoted values + value = _remove_inline_comment(value) + value = value.strip() + + # Skip entries with empty keys or empty values + if not key or not value: + continue + + secrets[key] = value + return secrets + + +def _remove_inline_comment(value: str) -> str: + """Remove inline comments from env value, respecting quoted strings.""" + value = value.strip() + + # Check if value starts with a quote + if value.startswith('"') or value.startswith("'"): + quote_char = value[0] + + # Find the matching closing quote (not escaped) + i = 1 + while i < len(value): + if value[i] == quote_char: + # Found potential closing quote + # Check if there's anything after it + remaining = value[i + 1 :] + comment_idx = remaining.find(" #") + if comment_idx != -1: + # Remove the comment part and strip quotes + quoted_value = value[: i + 1] + return quoted_value[1:-1] # Remove surrounding quotes + else: + # No comment after closing quote, strip quotes + quoted_value = value[: i + 1] + return quoted_value[1:-1] # Remove surrounding quotes + i += 1 + + # No closing quote, treat as unquoted + comment_idx = value.find(" #") + if comment_idx != -1: + return value[:comment_idx] + return value + else: + # For unquoted values, remove everything after ' #' + comment_idx = value.find(" #") + if comment_idx != -1: + return value[:comment_idx] + return value + + +def _upsert_secret_to_engine( + engine_url: str, api_key: str, secret_id: str, secret_value: str +) -> None: + response = httpx.put( + f"{engine_url}/v1/admin/secrets/{secret_id}", + headers={"Authorization": f"Bearer {api_key}"}, + json={"description": "Secret set via CLI", "value": secret_value}, + ) + response.raise_for_status() + + +def _get_secrets_from_engine(engine_url: str, api_key: str) -> list[dict]: + response = httpx.get( + f"{engine_url}/v1/admin/secrets", + headers={"Authorization": f"Bearer {api_key}"}, + ) + response.raise_for_status() + return response.json()["items"] # type: ignore[no-any-return] + + +def _delete_secret_from_engine(engine_url: str, api_key: str, secret_id: str) -> None: + response = httpx.delete( + f"{engine_url}/v1/admin/secrets/{secret_id}", + headers={"Authorization": f"Bearer {api_key}"}, + ) + response.raise_for_status() diff --git a/libs/arcade-cli/arcade_cli/serve.py b/libs/arcade-cli/arcade_cli/serve.py index 65ad4fa1..593929d8 100644 --- a/libs/arcade-cli/arcade_cli/serve.py +++ b/libs/arcade-cli/arcade_cli/serve.py @@ -15,8 +15,8 @@ import uvicorn # Watchfiles is used under the hood by Uvicorn's reload feature. # Importing watchfiles here is an explicit acknowledgement that it needs to be installed import watchfiles # noqa: F401 -from arcade_core.telemetry import OTELHandler from arcade_core.toolkit import Toolkit, get_package_directory +from arcade_serve.fastapi.telemetry import OTELHandler from arcade_serve.fastapi.worker import FastAPIWorker from loguru import logger from rich.console import Console @@ -45,7 +45,7 @@ def create_arcade_app() -> fastapi.FastAPI: setup_logging(log_level=logging.DEBUG if debug_mode else logging.INFO, mcp_mode=False) logger.info(f"Debug: {debug_mode}, OTEL: {otel_enabled}, Auth Disabled: {auth_for_reload}") - version = get_pkg_version("arcade-ai") + version = get_pkg_version("arcade-mcp") toolkits = discover_toolkits() logger.info("Registered toolkits:") diff --git a/libs/arcade-cli/arcade_cli/show.py b/libs/arcade-cli/arcade_cli/show.py index 69d94455..352bce0b 100644 --- a/libs/arcade-cli/arcade_cli/show.py +++ b/libs/arcade-cli/arcade_cli/show.py @@ -5,7 +5,11 @@ from rich.console import Console from rich.markup import escape from arcade_cli.display import display_tool_details, display_tools_table -from arcade_cli.utils import create_cli_catalog, get_tools_from_engine +from arcade_cli.utils import ( + create_cli_catalog, + create_cli_catalog_local, + get_tools_from_engine, +) def show_logic( @@ -25,7 +29,7 @@ def show_logic( console = Console() try: if local: - catalog = create_cli_catalog(toolkit=toolkit) + catalog = create_cli_catalog() if toolkit else create_cli_catalog_local() tools = [t.definition for t in list(catalog)] else: tools = get_tools_from_engine(host, port, force_tls, force_no_tls, toolkit) diff --git a/libs/arcade-cli/arcade_cli/templates/__init__.py b/libs/arcade-cli/arcade_cli/templates/__init__.py new file mode 100644 index 00000000..2527e1b3 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/templates/__init__.py @@ -0,0 +1,10 @@ +from pathlib import Path + + +def get_minimal_template_directory() -> Path: + """Get the path to the templates directory.""" + return Path(__file__).parent / "minimal" + +def get_full_template_directory() -> Path: + """Get the path to the templates directory.""" + return Path(__file__).parent / "full" diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/.pre-commit-config.yaml b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/.pre-commit-config.yaml similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/.pre-commit-config.yaml rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/.pre-commit-config.yaml diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/.ruff.toml b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/.ruff.toml similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/.ruff.toml rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/.ruff.toml diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/LICENSE b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/LICENSE similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/LICENSE rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/LICENSE diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/Makefile b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/Makefile similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/Makefile rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/Makefile diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/README.md b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/README.md similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/README.md rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/README.md diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/evals/eval_{{ toolkit_name }}.py b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/evals/eval_{{ toolkit_name }}.py similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/evals/eval_{{ toolkit_name }}.py rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/evals/eval_{{ toolkit_name }}.py diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/pyproject.toml b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/pyproject.toml similarity index 84% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/pyproject.toml rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/pyproject.toml index 25565559..a82505f4 100644 --- a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/pyproject.toml +++ b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/pyproject.toml @@ -24,7 +24,7 @@ email = "{{ toolkit_author_email }}" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>={{ arcade_ai_min_version }},<{{ arcade_ai_max_version }}", + "arcade-mcp[evals]>={{ arcade_mcp_min_version }},<{{ arcade_mcp_max_version }}", "arcade-serve>={{ arcade_serve_min_version }},<{{ arcade_serve_max_version }}", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", @@ -43,16 +43,16 @@ toolkit_name = "{{ package_name }}" {% if is_community_toolkit -%} # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../", editable = true } +arcade-mcp = { path = "../../", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } {% endif -%} {% if is_official_toolkit -%} # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../../arcade-ai", editable = true } -arcade-serve = { path = "../../../arcade-ai/libs/arcade-serve/", editable = true } -arcade-tdk = { path = "../../../arcade-ai/libs/arcade-tdk/", editable = true } +arcade-mcp = { path = "../../../arcade-mcp", editable = true } +arcade-serve = { path = "../../../arcade-mcp/libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../../arcade-mcp/libs/arcade-tdk/", editable = true } {% endif -%} [tool.mypy] diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/tests/__init__.py b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/tests/__init__.py similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/tests/__init__.py rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/tests/__init__.py diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/tests/test_{{ toolkit_name }}.py b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/tests/test_{{ toolkit_name }}.py similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/tests/test_{{ toolkit_name }}.py rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/tests/test_{{ toolkit_name }}.py diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/{{ package_name }}/__init__.py b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/{{ package_name }}/__init__.py similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/{{ package_name }}/__init__.py rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/{{ package_name }}/__init__.py diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/{{ package_name }}/tools/__init__.py b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/{{ package_name }}/tools/__init__.py similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/{{ package_name }}/tools/__init__.py rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/{{ package_name }}/tools/__init__.py diff --git a/libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/{{ package_name }}/tools/hello.py b/libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/{{ package_name }}/tools/hello.py similarity index 100% rename from libs/arcade-cli/arcade_cli/templates/{{ toolkit_name }}/{{ package_name }}/tools/hello.py rename to libs/arcade-cli/arcade_cli/templates/full/{{ toolkit_name }}/{{ package_name }}/tools/hello.py diff --git a/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/.env.example b/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/.env.example new file mode 100644 index 00000000..fe5a7446 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/.env.example @@ -0,0 +1 @@ +MY_SECRET_KEY="Your tools can have secrets injected at runtime!" diff --git a/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/pyproject.toml b/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/pyproject.toml new file mode 100644 index 00000000..5501a074 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/pyproject.toml @@ -0,0 +1,37 @@ +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "{{ toolkit_name }}" +version = "0.1.0" +description = "MCP Server created with Arcade.dev" +requires-python = ">=3.10" +dependencies = [ + "arcade-mcp-server>={{ arcade_mcp_server_min_version }},<{{ arcade_mcp_server_max_version }}", +] + +[project.optional-dependencies] +dev = [ + "arcade-mcp[all]>={{ arcade_mcp_min_version}},<{{ arcade_mcp_max_version}}", + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "mypy>=1.0.0", + "ruff>=0.1.0", +] + +# Tell Arcade.dev that this package has Arcade tools +[project.entry-points.arcade_toolkits] +toolkit_name = "{{ toolkit_name }}" + +[tool.setuptools.packages.find] +include = ["{{ toolkit_name }}*"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.mypy] +python_version = "3.10" +warn_unused_configs = true +disallow_untyped_defs = false diff --git a/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/server.py b/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/server.py new file mode 100644 index 00000000..b9d54fa8 --- /dev/null +++ b/libs/arcade-cli/arcade_cli/templates/minimal/{{ toolkit_name }}/server.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +"""{{ toolkit_name }} MCP server""" + +import sys +from typing import Annotated + +from arcade_mcp_server import Context, MCPApp + +app = MCPApp(name="{{ toolkit_name }}", version="1.0.0", log_level="DEBUG") + + +@app.tool +def greet(name: Annotated[str, "The name of the person to greet"]) -> str: + """Greet a person by name.""" + return f"Hello, {name}!" + + +@app.tool(requires_secrets=["MY_SECRET_KEY"]) +def whisper_secret(context: Context) -> Annotated[str, "The last 4 characters of the secret"]: + """Reveal the last 4 characters of a secret""" + # Secrets are injected into the tool context at runtime. + # This means that LLMs and MCP clients cannot see or access your secrets + # You can define secrets in a .env file. + try: + secret = context.get_secret("MY_SECRET_KEY") + except Exception as e: + return str(e) + + return "The last 4 characters of the secret are: " + secret[-4:] + + +# Run with specific transport +if __name__ == "__main__": + # Get transport from command line argument, default to "stream" + transport = sys.argv[1] if len(sys.argv) > 1 else "http" + + # Run the server + # - "https" (default): HTTPS streaming for Claude Desktop, Claude Code, Cursor + # - "stdio": Standard I/O for VS Code and CLI tools + app.run(transport=transport, host="127.0.0.1", port=8000) diff --git a/libs/arcade-cli/arcade_cli/toolkit_docs/templates.py b/libs/arcade-cli/arcade_cli/toolkit_docs/templates.py index 199a5a37..49436347 100644 --- a/libs/arcade-cli/arcade_cli/toolkit_docs/templates.py +++ b/libs/arcade-cli/arcade_cli/toolkit_docs/templates.py @@ -97,10 +97,7 @@ const USER_ID = "{{arcade_user_id}}"; const TOOL_NAME = "{tool_fully_qualified_name}"; // Start the authorization process -const authResponse = await client.tools.authorize({{ - tool_name: TOOL_NAME, - user_id: USER_ID -}}); +const authResponse = await client.tools.authorize({{tool_name: TOOL_NAME}}); if (authResponse.status !== "completed") {{ console.log(`Click this link to authorize: ${{authResponse.url}}`); diff --git a/libs/arcade-cli/arcade_cli/utils.py b/libs/arcade-cli/arcade_cli/utils.py index 811c29f3..da72905f 100644 --- a/libs/arcade-cli/arcade_cli/utils.py +++ b/libs/arcade-cli/arcade_cli/utils.py @@ -2,6 +2,7 @@ import importlib.util import ipaddress import os import shlex +import sys import webbrowser from dataclasses import dataclass from datetime import datetime @@ -16,6 +17,12 @@ import idna import typer from arcade_core import ToolCatalog, Toolkit from arcade_core.config_model import Config +from arcade_core.discovery import ( + analyze_files_for_tools, + build_minimal_toolkit, + collect_tools_from_modules, + find_candidate_tool_files, +) from arcade_core.errors import ToolkitLoadError from arcade_core.schema import ToolDefinition from arcadepy import ( @@ -65,6 +72,12 @@ class ChatCommand(str, Enum): EXIT = "/exit" +class Provider(str, Enum): + """Supported model providers for evaluations.""" + + OPENAI = "openai" + + def create_cli_catalog( toolkit: str | None = None, show_toolkits: bool = False, @@ -98,6 +111,59 @@ def create_cli_catalog( return catalog +def _discover_installed_toolkits(catalog: ToolCatalog) -> ToolCatalog: + for tk in Toolkit.find_all_arcade_toolkits(): + catalog.add_toolkit(tk) + return catalog + + +def create_cli_catalog_local() -> ToolCatalog: + """ + Load a local toolkit from the current working directory if a pyproject.toml is present. + Fallback to environment discovery if not present. + """ + cwd = Path.cwd() + catalog = ToolCatalog() + + if not (cwd / "pyproject.toml").is_file(): + return _discover_installed_toolkits(catalog) + + try: + files = find_candidate_tool_files(cwd) + if not files: + return _discover_installed_toolkits(catalog) + + files_with_tools = analyze_files_for_tools(files) + if not files_with_tools: + return _discover_installed_toolkits(catalog) + + discovered_tools = collect_tools_from_modules(files_with_tools) + if not discovered_tools: + return _discover_installed_toolkits(catalog) + + toolkit = build_minimal_toolkit( + server_name=cwd.name, + server_version="0.1.0dev", + description=f"Local toolkit from {cwd.name}", + ) + # Add tools directly to catalog using the discovery approach + for tool_func, module in discovered_tools: + # Register module in sys.modules so it can be found + if module.__name__ not in sys.modules: + sys.modules[module.__name__] = module + catalog.add_tool(tool_func, toolkit, module) + except Exception as e: + console.log( + f"Local file discovery failed: {e}; falling back to installed toolkits", + style="dim", + ) + else: + return catalog + + # Fallback: discover installed toolkits + return _discover_installed_toolkits(catalog) + + def compute_base_url( force_tls: bool, force_no_tls: bool, @@ -530,7 +596,30 @@ def get_eval_files(directory: str) -> list[Path]: directory_path = Path(directory).resolve() if directory_path.is_dir(): - eval_files = [f for f in directory_path.rglob("eval_*.py") if f.is_file()] + # Directories to exclude from recursive search + exclude_dirs = { + ".venv", + "venv", + ".env", + "env", + "node_modules", + "__pycache__", + ".git", + "build", + "dist", + ".tox", + "htmlcov", + "site-packages", + ".pytest_cache", + } + + eval_files = [] + for f in directory_path.rglob("eval_*.py"): + if f.is_file(): + # Check if any parent directory is in exclude_dirs + should_exclude = any(part in exclude_dirs for part in f.parts) + if not should_exclude: + eval_files.append(f) elif directory_path.is_file(): eval_files = ( [directory_path] @@ -555,48 +644,59 @@ def load_eval_suites(eval_files: list[Path]) -> list[Callable]: """ Load evaluation suites from the given eval_files by importing the modules and extracting functions decorated with `@tool_eval`. - Args: eval_files: A list of Paths to evaluation files. - Returns: A list of callable evaluation suite functions. """ eval_suites = [] for eval_file_path in eval_files: module_name = eval_file_path.stem # filename without extension - # Now we need to load the module from eval_file_path file_path_str = str(eval_file_path) module_name_str = module_name - # Load using importlib - spec = importlib.util.spec_from_file_location(module_name_str, file_path_str) - if spec is None: - console.print(f"Failed to load {eval_file_path}", style="bold red") + # Add the directory containing the eval file to sys.path temporarily + # so that the eval file can import other modules in the same directory + eval_dir = str(eval_file_path.parent) + original_path = sys.path.copy() + if eval_dir not in sys.path: + sys.path.insert(0, eval_dir) + + try: + # Load using importlib + spec = importlib.util.spec_from_file_location(module_name_str, file_path_str) + if spec is None: + console.print(f"Failed to load {eval_file_path}", style="bold red") + continue + + module = importlib.util.module_from_spec(spec) + if spec.loader is not None: + spec.loader.exec_module(module) + else: + console.print(f"Failed to load module: {module_name}", style="bold red") + continue + + eval_suite_funcs = [ + obj + for name, obj in module.__dict__.items() + if callable(obj) and hasattr(obj, "__tool_eval__") + ] + + if not eval_suite_funcs: + console.print( + f"No @tool_eval functions found in {eval_file_path}", + style="bold yellow", + ) + continue + + eval_suites.extend(eval_suite_funcs) + except Exception as e: + console.print(f"Failed to load {eval_file_path}: {e}", style="bold red") continue - - module = importlib.util.module_from_spec(spec) - if spec.loader is not None: - spec.loader.exec_module(module) - else: - console.print(f"Failed to load module: {module_name}", style="bold red") - continue - - eval_suite_funcs = [ - obj - for name, obj in module.__dict__.items() - if callable(obj) and hasattr(obj, "__tool_eval__") - ] - - if not eval_suite_funcs: - console.print( - f"No @tool_eval functions found in {eval_file_path}", - style="bold yellow", - ) - continue - - eval_suites.extend(eval_suite_funcs) + finally: + # Restore the original sys.path + sys.path[:] = original_path return eval_suites @@ -698,7 +798,7 @@ def version_callback(value: bool) -> None: Prints the version of Arcade and exit. """ if value: - version = metadata.version("arcade-ai") + version = metadata.version("arcade-mcp") console.print(f"[bold]Arcade CLI[/bold] (version {version})") exit() @@ -787,6 +887,45 @@ def load_dotenv(path: str | Path, *, override: bool = False) -> dict[str, str]: return loaded +def resolve_provider_api_key(provider: Provider, provider_api_key: str | None = None) -> str | None: + """ + Resolve the API key for a given provider for evals. + + Args: + provider: The model provider + provider_api_key: API key provided via CLI argument + + Returns: + The resolved API key or None if not found + """ + if provider_api_key: + return provider_api_key + + # Map providers to their environment variable names + provider_env_vars = { + Provider.OPENAI: "OPENAI_API_KEY", + } + + env_var_name = provider_env_vars.get(provider) + if not env_var_name: + return None + + # First check current environment + api_key = os.getenv(env_var_name) + if api_key: + return api_key + + # Then check .env file in current working directory + env_file_path = Path.cwd() / ".env" + if env_file_path.exists(): + load_dotenv(env_file_path, override=False) + api_key = os.getenv(env_var_name) + if api_key: + return api_key + + return None + + def require_dependency( package_name: str, command_name: str, @@ -798,7 +937,7 @@ def require_dependency( Args: package_name: The name of the package to import (e.g., 'arcade_serve') command_name: The command that requires the package (e.g., 'serve') - install_command: The command to install the package (e.g., "pip install 'arcade-ai[evals]'") + install_command: The command to install the package (e.g., "pip install 'arcade-mcp[evals]'") """ try: importlib.import_module(package_name.replace("-", "_")) diff --git a/libs/arcade-core/arcade_core/catalog.py b/libs/arcade-core/arcade_core/catalog.py index 0323da09..d27f14a2 100644 --- a/libs/arcade-core/arcade_core/catalog.py +++ b/libs/arcade-core/arcade_core/catalog.py @@ -405,7 +405,9 @@ class ToolCatalog(BaseModel): # Hard requirement: tools must have descriptions tool_description = getattr(tool, "__tool_description__", None) if not tool_description: - raise ToolDefinitionError(f"Tool '{raw_tool_name}' is missing a description") + raise ToolDefinitionError( + f"Tool '{raw_tool_name}' is missing a description. Tool descriptions are specified as docstrings for the tool function." + ) # If the function returns a value, it must have a type annotation if does_function_return_value(tool) and tool.__annotations__.get("return") is None: @@ -449,7 +451,9 @@ def create_input_definition(func: Callable) -> ToolInput: tool_context_param_name: str | None = None for _, param in inspect.signature(func, follow_wrapped=True).parameters.items(): - if param.annotation is ToolContext: + ann = param.annotation + if isinstance(ann, type) and issubclass(ann, ToolContext): + # Soft guidance for developers using legacy ToolContext if tool_context_param_name is not None: raise ToolInputSchemaError( f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple." @@ -690,7 +694,9 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: # Final reality check if param_info.description is None: - raise ToolInputSchemaError(f"Parameter '{param_info.name}' is missing a description") + raise ToolInputSchemaError( + f"Parameter '{param_info.name}' is missing a description. Parameter descriptions are specified as string annotations using the typing.Annotated class." + ) if wire_type_info.wire_type is None: raise ToolInputSchemaError(f"Unknown parameter type: {param_info.field_type}") @@ -983,8 +989,9 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel] if asyncio.iscoroutinefunction(func) and hasattr(func, "__wrapped__"): func = func.__wrapped__ for name, param in inspect.signature(func, follow_wrapped=True).parameters.items(): - # Skip ToolContext parameters - if param.annotation is ToolContext: + # Skip ToolContext parameters (including subclasses like arcade_mcp_server.Context) + ann = param.annotation + if isinstance(ann, type) and issubclass(ann, ToolContext): continue # TODO make this cleaner @@ -1004,7 +1011,7 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel] return input_model, output_model -def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 +def determine_output_model(func: Callable) -> type[BaseModel]: """ Determine the output model for a function based on its return annotation. """ @@ -1149,9 +1156,13 @@ def create_model_from_typeddict(typeddict_class: type, model_name: str) -> type[ def to_tool_secret_requirements( secrets_requirement: list[str], ) -> list[ToolSecretRequirement]: - # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement - unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values() - return [ToolSecretRequirement(key=name) for name in unique_secrets] + # De-dupe case-insensitively but preserve the original casing for env var lookup + unique_map: dict[str, str] = {} + for name in secrets_requirement: + lowered = str(name).lower() + if lowered not in unique_map: + unique_map[lowered] = str(name) + return [ToolSecretRequirement(key=orig_name) for orig_name in unique_map.values()] def to_tool_metadata_requirements( diff --git a/libs/arcade-core/arcade_core/context.py b/libs/arcade-core/arcade_core/context.py new file mode 100644 index 00000000..0a592e74 --- /dev/null +++ b/libs/arcade-core/arcade_core/context.py @@ -0,0 +1,128 @@ +""" +Arcade Core Runtime Context Protocols + +Defines the developer-facing, transport-agnostic runtime context interfaces +(namespaced APIs: logs, progress, resources, tools, prompts, sampling, UI, +notifications) and the top-level ModelContext Protocol that aggregates them. + +Implementations live in runtime packages (e.g., arcade_mcp_server); tool authors should +use `arcade_mcp_server.Context` for concrete usage. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from pydantic import BaseModel + + +class LogsContext(Protocol): + async def debug(self, message: str, **kwargs: dict[str, Any]) -> None: ... + + async def info(self, message: str, **kwargs: dict[str, Any]) -> None: ... + + async def warning(self, message: str, **kwargs: dict[str, Any]) -> None: ... + + async def error(self, message: str, **kwargs: dict[str, Any]) -> None: ... + + +class ProgressContext(Protocol): + async def report( + self, progress: float, total: float | None = None, message: str | None = None + ) -> None: ... + + +class ResourcesContext(Protocol): + async def list_(self) -> list[Any]: ... + + async def get(self, uri: str) -> Any: ... + + async def read(self, uri: str) -> list[Any]: ... + + async def list_roots(self) -> list[Any]: ... + + async def list_templates(self) -> list[Any]: ... + + +class ToolsContext(Protocol): + async def list_(self) -> list[Any]: ... + + async def call_raw(self, name: str, params: dict[str, Any]) -> BaseModel: ... + + +class PromptsContext(Protocol): + async def list_(self) -> list[Any]: ... + + async def get(self, name: str, arguments: dict[str, str] | None = None) -> Any: ... + + +class SamplingContext(Protocol): + async def create_message( + self, + messages: str | list[str | Any], + system_prompt: str | None = None, + include_context: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + model_preferences: Any | None = None, + ) -> Any: ... + + +class UIContext(Protocol): + async def elicit(self, message: str, schema: dict[str, Any] | None = None) -> Any: ... + + +class NotificationsToolsContext(Protocol): + async def list_changed(self) -> None: ... + + +class NotificationsResourcesContext(Protocol): + async def list_changed(self) -> None: ... + + +class NotificationsPromptsContext(Protocol): + async def list_changed(self) -> None: ... + + +class NotificationsContext(Protocol): + @property + def tools(self) -> NotificationsToolsContext: ... + + @property + def resources(self) -> NotificationsResourcesContext: ... + + @property + def prompts(self) -> NotificationsPromptsContext: ... + + +@runtime_checkable +class ModelContext(Protocol): + @property + def log(self) -> LogsContext: ... + + @property + def progress(self) -> ProgressContext: ... + + @property + def resources(self) -> ResourcesContext: ... + + @property + def tools(self) -> ToolsContext: ... + + @property + def prompts(self) -> PromptsContext: ... + + @property + def sampling(self) -> SamplingContext: ... + + @property + def ui(self) -> UIContext: ... + + @property + def notifications(self) -> NotificationsContext: ... + + @property + def request_id(self) -> str | None: ... + + @property + def session_id(self) -> str | None: ... diff --git a/libs/arcade-core/arcade_core/converters/openai.py b/libs/arcade-core/arcade_core/converters/openai.py new file mode 100644 index 00000000..c4fb317c --- /dev/null +++ b/libs/arcade-core/arcade_core/converters/openai.py @@ -0,0 +1,220 @@ +"""Converter for converting Arcade ToolDefinition to OpenAI tool schema.""" + +from typing import Any, Literal, TypedDict + +from arcade_core.catalog import MaterializedTool +from arcade_core.schema import InputParameter, ValueSchema + +# ---------------------------------------------------------------------------- +# Type definitions for JSON tool schemas used by OpenAI APIs. +# Defines the proper types for tool schemas to ensure +# compatibility with OpenAI's Responses and Chat Completions APIs. +# ---------------------------------------------------------------------------- + + +class OpenAIFunctionParameterProperty(TypedDict, total=False): + """Type definition for a property within OpenAI function parameters schema.""" + + type: str | list[str] + """The JSON Schema type(s) for this property. Can be a single type or list for unions (e.g., ["string", "null"]).""" + + description: str + """Description of the property.""" + + enum: list[Any] + """Allowed values for enum properties.""" + + items: dict[str, Any] + """Schema for array items when type is 'array'.""" + + properties: dict[str, "OpenAIFunctionParameterProperty"] + """Nested properties when type is 'object'.""" + + required: list[str] + """Required fields for nested objects.""" + + additionalProperties: Literal[False] + """Must be False for strict mode compliance.""" + + +class OpenAIFunctionParameters(TypedDict, total=False): + """Type definition for OpenAI function parameters schema.""" + + type: Literal["object"] + """Must be 'object' for function parameters.""" + + properties: dict[str, OpenAIFunctionParameterProperty] + """The properties of the function parameters.""" + + required: list[str] + """List of required parameter names. In strict mode, all properties should be listed here.""" + + additionalProperties: Literal[False] + """Must be False for strict mode compliance.""" + + +class OpenAIFunctionSchema(TypedDict, total=False): + """Type definition for a function tool parameter matching OpenAI's API.""" + + name: str + """The name of the function to call.""" + + parameters: OpenAIFunctionParameters | None + """A JSON schema object describing the parameters of the function.""" + + strict: Literal[True] + """Always enforce strict parameter validation. Default `true`.""" + + description: str | None + """A description of the function. + Used by the model to determine whether or not to call the function. + """ + + +class OpenAIToolSchema(TypedDict): + """ + Schema for a tool definition passed to OpenAI's `tools` parameter. + A tool wraps a callable function for function-calling. Each tool + includes a type (always 'function') and a `function` payload that + specifies the callable via `OpenAIFunctionSchema`. + """ + + type: Literal["function"] + """The type field, always 'function'.""" + + function: OpenAIFunctionSchema + """The function definition.""" + + +# Type alias for a list of openai tool schemas +OpenAIToolList = list[OpenAIToolSchema] + + +# ---------------------------------------------------------------------------- +# Converters +# ---------------------------------------------------------------------------- +def to_openai(tool: MaterializedTool) -> OpenAIToolSchema: + """Convert a MaterializedTool to OpenAI JsonToolSchema format. + + Args: + tool: The MaterializedTool to convert + Returns: + The OpenAI JsonToolSchema format (what is passed to the OpenAI API) + """ + name = tool.definition.fully_qualified_name.replace(".", "_") + description = tool.description + parameters_schema = _convert_input_parameters_to_json_schema(tool.definition.input.parameters) + return _create_tool_schema(name, description, parameters_schema) + + +def _create_tool_schema( + name: str, description: str, parameters: OpenAIFunctionParameters +) -> OpenAIToolSchema: + """Create a properly typed tool schema. + Args: + name: The name of the function + description: Description of what the function does + parameters: JSON schema for the function parameters + strict: Whether to enforce strict validation (default: True for reliable function calls) + Returns: + A properly typed OpenAIToolSchema + """ + + function: OpenAIFunctionSchema = { + "name": name, + "description": description, + "parameters": parameters, + "strict": True, + } + + tool: OpenAIToolSchema = { + "type": "function", + "function": function, + } + + return tool + + +def _convert_value_schema_to_json_schema( + value_schema: ValueSchema, +) -> OpenAIFunctionParameterProperty: + """Convert Arcade ValueSchema to JSON Schema format.""" + type_mapping = { + "string": "string", + "integer": "integer", + "number": "number", + "boolean": "boolean", + "json": "object", + "array": "array", + } + + schema: OpenAIFunctionParameterProperty = {"type": type_mapping[value_schema.val_type]} + + if value_schema.val_type == "array" and value_schema.inner_val_type: + items_schema: dict[str, Any] = {"type": type_mapping[value_schema.inner_val_type]} + + # For arrays, enum should be applied to the items, not the array itself + if value_schema.enum: + items_schema["enum"] = value_schema.enum + + schema["items"] = items_schema + else: + # Handle enum for non-array types + if value_schema.enum: + schema["enum"] = value_schema.enum + + # Handle object properties + if value_schema.val_type == "json" and value_schema.properties: + schema["properties"] = { + name: _convert_value_schema_to_json_schema(nested_schema) + for name, nested_schema in value_schema.properties.items() + } + + return schema + + +def _convert_input_parameters_to_json_schema( + parameters: list[InputParameter], +) -> OpenAIFunctionParameters: + """Convert list of InputParameter to JSON schema parameters object.""" + if not parameters: + # Minimal JSON schema for a tool with no input parameters + return { + "type": "object", + "properties": {}, + "additionalProperties": False, + } + + properties = {} + required = [] + + for parameter in parameters: + param_schema = _convert_value_schema_to_json_schema(parameter.value_schema) + + # For optional parameters in strict mode, we need to add "null" as a type option + if not parameter.required: + param_type = param_schema.get("type") + if isinstance(param_type, str): + # Convert single type to union with null + param_schema["type"] = [param_type, "null"] + elif isinstance(param_type, list) and "null" not in param_type: + param_schema["type"] = [*param_type, "null"] + + if parameter.description: + param_schema["description"] = parameter.description + properties[parameter.name] = param_schema + + # In strict mode, all parameters (including optional ones) go in required array + # Optional parameters are handled by adding "null" to their type + required.append(parameter.name) + + json_schema: OpenAIFunctionParameters = { + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": False, + } + if not required: + del json_schema["required"] + + return json_schema diff --git a/libs/arcade-core/arcade_core/discovery.py b/libs/arcade-core/arcade_core/discovery.py new file mode 100644 index 00000000..112c055f --- /dev/null +++ b/libs/arcade-core/arcade_core/discovery.py @@ -0,0 +1,253 @@ +""" +Discovery utilities for Arcade Tools. + +Provides modular, testable functions to discover toolkits and local tool files, +load modules, collect tools, and build a ToolCatalog. +""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path +from types import ModuleType +from typing import Any + +from loguru import logger + +from arcade_core.catalog import ToolCatalog +from arcade_core.parse import get_tools_from_file +from arcade_core.toolkit import Toolkit, ToolkitLoadError + +DISCOVERY_PATTERNS = ["*.py", "tools/*.py", "arcade_tools/*.py", "tools/**/*.py"] +FILTER_PATTERNS = ["_test.py", "test_*.py", "__pycache__", "*.lock", "*.egg-info", "*.pyc"] + + +def normalize_package_name(package_name: str) -> str: + """Normalize a package name for import resolution.""" + return package_name.lower().replace("-", "_") + + +def load_toolkit_from_package(package_name: str, show_packages: bool = False) -> Toolkit: + """Attempt to load a Toolkit from an installed package name.""" + toolkit = Toolkit.from_package(package_name) + if show_packages: + logger.info(f"Loading package: {toolkit.name}") + return toolkit + + +def load_package(package_name: str, show_packages: bool = False) -> Toolkit: + """Load a toolkit for a specific package name. + + Raises ToolkitLoadError if the package is not found. + """ + normalized = normalize_package_name(package_name) + try: + return load_toolkit_from_package(normalized, show_packages) + except ToolkitLoadError: + return load_toolkit_from_package(f"arcade_{normalized}", show_packages) + + +def find_candidate_tool_files(root: Path | None = None) -> list[Path]: + """Find candidate Python files for auto-discovery in common locations.""" + cwd = root or Path.cwd() + + candidates: list[Path] = [] + for pattern in DISCOVERY_PATTERNS: + candidates.extend(cwd.glob(pattern)) + # Deduplicate candidates (same file might match multiple patterns) + unique_candidates = list(set(candidates)) + # Filter out private, cache, and tests + return [ + p for p in unique_candidates if not any(p.match(pattern) for pattern in FILTER_PATTERNS) + ] + + +def analyze_files_for_tools(files: list[Path]) -> list[tuple[Path, list[str]]]: + """Parse files with a fast AST pass to find declared @tool function names.""" + results: list[tuple[Path, list[str]]] = [] + for file_path in files: + try: + names = get_tools_from_file(file_path) + if names: + logger.info(f"Found {len(names)} tool(s) in {file_path.name}: {', '.join(names)}") + results.append((file_path, names)) + except Exception: + logger.exception(f"Could not parse {file_path}") + return results + + +def load_module_from_path(file_path: Path) -> ModuleType: + """Dynamically import a Python module from a file path.""" + import sys + + # Add the directory containing the file to sys.path temporarily + # This allows local imports to work + file_dir = str(file_path.parent) + path_added = False + if file_dir not in sys.path: + sys.path.insert(0, file_dir) + path_added = True + + try: + spec = importlib.util.spec_from_file_location( + f"_tools_{file_path.stem}", + file_path, + ) + if not spec or not spec.loader: + raise ToolkitLoadError(f"Unable to create import spec for {file_path}") + + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception: + logger.exception(f"Failed to load {file_path}") + raise ToolkitLoadError(f"Failed to load {file_path}") + + return module + finally: + # Remove the path we added + if path_added and file_dir in sys.path: + sys.path.remove(file_dir) + + +def collect_tools_from_modules( + files_with_tools: list[tuple[Path, list[str]]], +) -> list[tuple[Any, ModuleType]]: + """Load modules and collect the expected tool callables. + + Returns a list of (callable, module) pairs. + """ + discovered: list[tuple[Any, ModuleType]] = [] + + for file_path, expected_names in files_with_tools: + logger.debug(f"Loading tools from {file_path}...") + try: + module = load_module_from_path(file_path) + except ToolkitLoadError: + continue + + for name in expected_names: + if hasattr(module, name): + attr = getattr(module, name) + if callable(attr) and hasattr(attr, "__tool_name__"): + discovered.append((attr, module)) + else: + logger.warning( + f"Expected {name} to be a tool but it wasn't (missing __tool_name__)\n\n" + ) + return discovered + + +def build_minimal_toolkit( + server_name: str | None, + server_version: str | None, + description: str | None = None, +) -> Toolkit: + """Create a minimal Toolkit to host locally discovered tools.""" + name = server_name or "ArcadeMCP" + version = server_version or "0.1.0dev" + pkg = f"{name}.{Path.cwd().name}" + desc = description or f"MCP Server for {name} version {version}" + return Toolkit(name=name, package_name=pkg, version=version, description=desc) + + +def build_catalog_from_toolkits(toolkits: list[Toolkit]) -> ToolCatalog: + """Create a ToolCatalog and add the provided toolkits.""" + catalog = ToolCatalog() + for tk in toolkits: + catalog.add_toolkit(tk) + return catalog + + +def add_discovered_tools( + catalog: ToolCatalog, + toolkit: Toolkit, + tools: list[tuple[Any, ModuleType]], +) -> None: + """Add discovered local tools to the catalog, preserving module context.""" + for tool_func, module in tools: + if module.__name__ not in __import__("sys").modules: + __import__("sys").modules[module.__name__] = module + catalog.add_tool(tool_func, toolkit, module) + + +def load_toolkits_for_option(tool_package: str, show_packages: bool = False) -> list[Toolkit]: + """ + Load toolkits for a given package option. + + Args: + tool_package: Package name or comma-separated list of package names + show_packages: Whether to log loaded packages + + Returns: + List of loaded toolkits + """ + toolkits = [] + packages = [p.strip() for p in tool_package.split(",")] + + for package in packages: + try: + toolkit = load_package(package, show_packages) + toolkits.append(toolkit) + except ToolkitLoadError as e: + logger.warning(f"Failed to load package '{package}': {e}") + + return toolkits + + +def load_all_installed_toolkits(show_packages: bool = False) -> list[Toolkit]: + """ + Discover and load all installed arcade toolkits. + + Args: + show_packages: Whether to log loaded packages + + Returns: + List of all installed toolkits + """ + toolkits = Toolkit.find_all_arcade_toolkits() + + if show_packages: + for toolkit in toolkits: + logger.info(f"Loading package: {toolkit.name}") + + return toolkits + + +def discover_tools( + tool_package: str | None = None, + show_packages: bool = False, + discover_installed: bool = False, + server_name: str | None = None, + server_version: str | None = None, +) -> ToolCatalog: + """High-level discovery that returns a ToolCatalog. + + This function is pure (does not sys.exit); callers should handle errors. + """ + # 1) Package-based discovery + if tool_package: + toolkits = load_toolkits_for_option(tool_package, show_packages) + return build_catalog_from_toolkits(toolkits) + + # 2) Discover all installed packages + if discover_installed: + toolkits = load_all_installed_toolkits(show_packages) + return build_catalog_from_toolkits(toolkits) + + # 3) Local file discovery + logger.info("Auto-discovering tools from current directory") + files = find_candidate_tool_files() + if not files: + # Return empty catalog; caller can decide how to handle + return ToolCatalog() + + files_with_tools = analyze_files_for_tools(files) + if not files_with_tools: + return ToolCatalog() + + discovered = collect_tools_from_modules(files_with_tools) + catalog = ToolCatalog() + toolkit = build_minimal_toolkit(server_name, server_version) + add_discovered_tools(catalog, toolkit, discovered) + return catalog diff --git a/libs/arcade-core/arcade_core/parse.py b/libs/arcade-core/arcade_core/parse.py index 148397fa..eca54707 100644 --- a/libs/arcade-core/arcade_core/parse.py +++ b/libs/arcade-core/arcade_core/parse.py @@ -36,6 +36,18 @@ def get_function_name_if_decorated( and isinstance(decorator.func, ast.Name) and decorator.func.id in decorator_ids ) + # Support MCPApp tools. e.g., @app.tool or @app.tool(...) + or ( + isinstance(decorator, ast.Attribute) + and decorator.attr == "tool" + and isinstance(decorator.value, ast.Name) + ) + or ( + isinstance(decorator, ast.Call) + and isinstance(decorator.func, ast.Attribute) + and decorator.func.attr == "tool" + and isinstance(decorator.func.value, ast.Name) + ) ): return node.name return None diff --git a/libs/arcade-core/arcade_core/schema.py b/libs/arcade-core/arcade_core/schema.py index 64e12738..ea6e79f2 100644 --- a/libs/arcade-core/arcade_core/schema.py +++ b/libs/arcade-core/arcade_core/schema.py @@ -1,3 +1,21 @@ +""" +Arcade Core Schema + +Defines transport-agnostic tool schemas and runtime context protocols used +across Arcade libraries. This includes: + +- Tool and toolkit specifications (parameters, outputs, requirements) +- Transport-agnostic ToolContext carrying authorization, secrets, metadata +- Runtime ModelContext Protocol and its namespaced sub-protocols for logs, + progress, resources, tools, prompts, sampling, UI, and notifications + +Note: ToolContext does not embed runtime capabilities; those are provided by +implementations of ModelContext (e.g., in arcade-mcp-server) that subclasses ToolContext +to expose the namespaced APIs to tools without changing function signatures. +""" + +from __future__ import annotations + import os from dataclasses import dataclass from enum import Enum @@ -23,10 +41,10 @@ class ValueSchema(BaseModel): enum: list[str] | None = None """The list of possible values for the value, if it is a closed list.""" - properties: dict[str, "ValueSchema"] | None = None + properties: dict[str, ValueSchema] | None = None """For object types (json), the schema of nested properties.""" - inner_properties: dict[str, "ValueSchema"] | None = None + inner_properties: dict[str, ValueSchema] | None = None """For array types with json items, the schema of properties for each array item.""" description: str | None = None @@ -100,7 +118,7 @@ class ToolAuthRequirement(BaseModel): # or # client.auth.authorize(provider=AuthProvider.google, scopes=["profile", "email"]) # - # The Arcade SDK translates these into the appropriate provider ID (Google) and type (OAuth2). + # The Arcade TDK translates these into the appropriate provider ID (Google) and type (OAuth2). # The only time the developer will set these is if they are using a custom auth provider. provider_id: str | None = None """The provider ID configured in Arcade that acts as an alias to well-known configuration.""" @@ -200,7 +218,7 @@ class FullyQualifiedName: (self.toolkit_version or "").lower(), )) - def equals_ignoring_version(self, other: "FullyQualifiedName") -> bool: + def equals_ignoring_version(self, other: FullyQualifiedName) -> bool: """Check if two fully-qualified tool names are equal, ignoring the version.""" return ( self.name.lower() == other.name.lower() @@ -208,7 +226,7 @@ class FullyQualifiedName: ) @staticmethod - def from_toolkit(tool_name: str, toolkit: ToolkitDefinition) -> "FullyQualifiedName": + def from_toolkit(tool_name: str, toolkit: ToolkitDefinition) -> FullyQualifiedName: """Creates a fully-qualified tool name from a tool name and a ToolkitDefinition.""" return FullyQualifiedName(tool_name, toolkit.name, toolkit.version) @@ -298,7 +316,16 @@ class ToolMetadataItem(BaseModel): class ToolContext(BaseModel): - """The context for a tool invocation.""" + """The context for a tool invocation. + + This type is transport-agnostic and contains only authorization, + secret, and metadata information needed by the tool. Runtime-specific + capabilities (logging, resources, etc.) are provided by a separate + runtime context that wraps this object. + + Recommendation: For new tools, annotate the parameter as + `arcade_mcp_server.Context` to access namespaced runtime APIs directly. + """ authorization: ToolAuthorizationContext | None = None """The authorization context for the tool invocation that requires authorization.""" @@ -312,16 +339,35 @@ class ToolContext(BaseModel): user_id: str | None = None """The user ID for the tool invocation (if any).""" + model_config = {"arbitrary_types_allowed": True} + + def set_secret(self, key: str, value: str) -> None: + """Add or update a secret to the tool context.""" + if self.secrets is None: + self.secrets = [] + # Update existing or add new + for secret in self.secrets: + if secret.key == key: + secret.value = value + return + self.secrets.append(ToolSecretItem(key=key, value=value)) + def get_auth_token_or_empty(self) -> str: """Retrieve the authorization token, or return an empty string if not available.""" return self.authorization.token if self.authorization and self.authorization.token else "" def get_secret(self, key: str) -> str: - """Retrieve the secret for the tool invocation.""" + """Retrieve the secret for the tool invocation. + + Raises a ValueError if the secret is not found. + """ return self._get_item(key, self.secrets, "secret") def get_metadata(self, key: str) -> str: - """Retrieve the metadata for the tool invocation.""" + """Retrieve the metadata for the tool invocation. + + Raises a ValueError if the metadata is not found. + """ return self._get_item(key, self.metadata, "metadata") def _get_item( @@ -335,21 +381,14 @@ class ToolContext(BaseModel): f"{item_name.capitalize()} key passed to get_{item_name} cannot be empty." ) if not items: - raise ValueError(f"{item_name.capitalize()}s not found in context.") + raise ValueError(f"{item_name.capitalize()} '{key}' not found in context.") normalized_key = key.lower() for item in items: if item.key.lower() == normalized_key: return item.value - raise ValueError(f"{item_name.capitalize()} {key} not found in context.") - - def set_secret(self, key: str, value: str) -> None: - """Set a secret for the tool invocation.""" - if not self.secrets: - self.secrets = [] - secret = ToolSecretItem(key=str(key), value=str(value)) - self.secrets.append(secret) + raise ValueError(f"{item_name.capitalize()} '{key}' not found in context.") class ToolCallRequest(BaseModel): diff --git a/libs/arcade-core/arcade_core/toolkit.py b/libs/arcade-core/arcade_core/toolkit.py index 0b780410..b976e574 100644 --- a/libs/arcade-core/arcade_core/toolkit.py +++ b/libs/arcade-core/arcade_core/toolkit.py @@ -6,6 +6,7 @@ import types from collections import defaultdict from pathlib import Path, PurePosixPath, PureWindowsPath +import toml from pydantic import BaseModel, ConfigDict, field_validator from arcade_core.errors import ToolkitLoadError @@ -59,6 +60,71 @@ class Toolkit(BaseModel): """ return cls.from_package(module.__name__) + @classmethod + def from_directory(cls, directory: Path) -> "Toolkit": + """ + Load a Toolkit from a directory. + """ + pyproject_path = directory / "pyproject.toml" + if not pyproject_path.exists(): + raise ToolkitLoadError(f"pyproject.toml not found in {directory}") + + try: + with open(pyproject_path) as f: + pyproject_data = toml.load(f) + + project_data = pyproject_data.get("project", {}) + name = project_data.get("name") + if not name: + + def _missing_name_error() -> ToolkitLoadError: + return ToolkitLoadError("name not found in pyproject.toml") + + raise _missing_name_error() # noqa: TRY301 + + package_name = name + version = project_data.get("version", "0.0.0") + description = project_data.get("description", "") + authors = project_data.get("authors", []) + author_names = [author.get("name", "") for author in authors] + + # For homepage and repository, you might need to look under project.urls + urls = project_data.get("urls", {}) + homepage = urls.get("Homepage") + repo = urls.get("Repository") + + except Exception as e: + raise ToolkitLoadError(f"Failed to load metadata from {pyproject_path}: {e}") + + # Determine the actual package directory (supports src/ layout and flat layout) + package_dir = directory + try: + src_candidate = directory / "src" / package_name + flat_candidate = directory / package_name + if src_candidate.is_dir(): + package_dir = src_candidate + elif flat_candidate.is_dir(): + package_dir = flat_candidate + else: + # Fallback to the provided directory; tools_from_directory will de-duplicate prefixes + package_dir = directory + except Exception: + package_dir = directory + + toolkit = cls( + name=name, + package_name=package_name, + version=version, + description=description, + author=author_names, + homepage=homepage, + repository=repo, + ) + + toolkit.tools = cls.tools_from_directory(package_dir, package_name) + + return toolkit + @classmethod def from_package(cls, package: str) -> "Toolkit": """ @@ -232,9 +298,14 @@ class Toolkit(BaseModel): for module_path in modules: relative_path = module_path.relative_to(package_dir) cls.validate_file(module_path) - import_path = ".".join(relative_path.with_suffix("").parts) - import_path = f"{package_name}.{import_path}" - tools[import_path] = get_tools_from_file(str(module_path)) + # Build import path and avoid duplicating the package prefix if it already exists + relative_parts = relative_path.with_suffix("").parts + import_path = ".".join(relative_parts) + if relative_parts and relative_parts[0] == package_name: + full_import_path = import_path + else: + full_import_path = f"{package_name}.{import_path}" if import_path else package_name + tools[full_import_path] = get_tools_from_file(str(module_path)) if not tools: raise ToolkitLoadError(f"No tools found in package {package_name}") diff --git a/libs/arcade-core/arcade_core/utils.py b/libs/arcade-core/arcade_core/utils.py index 12bbd3df..7117cf04 100644 --- a/libs/arcade-core/arcade_core/utils.py +++ b/libs/arcade-core/arcade_core/utils.py @@ -4,6 +4,7 @@ import ast import inspect import re from collections.abc import Callable, Iterable +from textwrap import dedent from types import UnionType from typing import Any, Literal, TypeVar, Union, get_args, get_origin @@ -75,7 +76,9 @@ def does_function_return_value(func: Callable) -> bool: if source is None: raise ValueError("Source code not found") - tree = ast.parse(source) + # dedent in case the function is an inner function + dedented_source = dedent(source) + tree = ast.parse(dedented_source) class ReturnVisitor(ast.NodeVisitor): def __init__(self) -> None: diff --git a/libs/arcade-core/pyproject.toml b/libs/arcade-core/pyproject.toml index 2814e854..caba8fdd 100644 --- a/libs/arcade-core/pyproject.toml +++ b/libs/arcade-core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-core" -version = "2.4.0" +version = "2.5.0rc1" description = "Arcade Core - Core library for Arcade platform" readme = "README.md" license = {text = "MIT"} @@ -28,9 +28,6 @@ dependencies = [ "types-python-dateutil==2.9.0.20241003", "types-pytz==2024.2.0.20241003", "types-toml==0.10.8.20240310", - "opentelemetry-instrumentation-fastapi==0.49b2", - "opentelemetry-exporter-otlp-proto-http==1.28.2", - "opentelemetry-exporter-otlp-proto-common==1.28.2", ] [project.optional-dependencies] diff --git a/libs/arcade-evals/README.md b/libs/arcade-evals/README.md index 4d8aa48d..97ec57c4 100644 --- a/libs/arcade-evals/README.md +++ b/libs/arcade-evals/README.md @@ -14,7 +14,7 @@ Arcade Evals provides comprehensive evaluation capabilities for Arcade tools: ## Installation ```bash -pip install 'arcade-ai[evals]' +pip install 'arcade-mcp[evals]' ``` ## Usage diff --git a/libs/arcade-evals/arcade_evals/eval.py b/libs/arcade-evals/arcade_evals/eval.py index e8e126db..02bc00d0 100644 --- a/libs/arcade-evals/arcade_evals/eval.py +++ b/libs/arcade-evals/arcade_evals/eval.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable import numpy as np -from arcade_core.config_model import Config +from arcade_core.converters.openai import OpenAIToolList, to_openai from arcade_core.schema import TOOL_NAME_SEPARATOR from openai import AsyncOpenAI from scipy.optimize import linear_sum_assignment @@ -613,14 +613,12 @@ class EvalSuite: Args: client: The AsyncOpenAI client instance. model: The model to evaluate. - Returns: A dictionary containing the evaluation results. """ results: dict[str, Any] = {"model": model, "rubric": self.rubric, "cases": []} semaphore = asyncio.Semaphore(self.max_concurrent) - tool_names = list(self.catalog.get_tool_names()) async def sem_task(case: EvalCase) -> dict[str, Any]: async with semaphore: @@ -629,12 +627,14 @@ class EvalSuite: messages.extend(case.additional_messages) messages.append({"role": "user", "content": case.user_message}) + tools = get_formatted_tools(self.catalog, tool_format="openai") + # Get the model response response = await client.chat.completions.create( # type: ignore[call-overload] model=model, messages=messages, tool_choice="auto", - tools=(str(name) for name in tool_names), + tools=tools, user="eval_user", seed=42, stream=False, @@ -675,6 +675,23 @@ class EvalSuite: return results +def get_formatted_tools(catalog: "ToolCatalog", tool_format: str = "openai") -> OpenAIToolList: + """Get the formatted tools from the catalog. + + Args: + catalog: The catalog of Arcade tools. + tool_format: The format of the tools to return + + Returns: + The formatted tools. + """ + if tool_format == "openai": + tools = [to_openai(tool) for tool in catalog] + return tools + else: + raise ValueError(f"Tool format for '{tool_format}' is not supported") + + def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]: """ Returns the tool arguments from the chat completion object. @@ -729,8 +746,7 @@ def tool_eval() -> Callable[[Callable], Callable]: def decorator(func: Callable) -> Callable: @functools.wraps(func) async def wrapper( - config: Config, - base_url: str, + provider_api_key: str, model: str, max_concurrency: int = 1, ) -> list[dict[str, Any]]: @@ -740,8 +756,7 @@ def tool_eval() -> Callable[[Callable], Callable]: suite.max_concurrent = max_concurrency results = [] async with AsyncOpenAI( - api_key=config.api.key, - base_url=base_url + "/v1", + api_key=provider_api_key, ) as client: result = await suite.run(client, model) results.append(result) diff --git a/libs/arcade-mcp-server/Makefile b/libs/arcade-mcp-server/Makefile new file mode 100644 index 00000000..05e3a63e --- /dev/null +++ b/libs/arcade-mcp-server/Makefile @@ -0,0 +1,40 @@ + +.PHONY: help +help: + @echo "๐Ÿ› ๏ธ Arcade MCP Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + + + +.PHONY: sync +sync: ## Sync dependencies + uv sync --all-extras --all-packages --group dev + +.PHONY: format +format: ## Run ruff format + uv run ruff format + +.PHONY: lint +lint: ## Run ruff lint + uv run ruff check + +.PHONY: mypy +mypy: ## Run mypy + uv run mypy . + +.PHONY: test +test: ## Run tests + uv run pytest --cov=arcade_mcp_server --cov-report=term-missing ../tests/arcade_mcp_server + +.PHONY: docs +docs: ## Build docs + uv run mkdocs build + uv run mkdocs serve + +.PHONY: serve-docs +serve-docs: ## Serve docs locally + uv run mkdocs serve + +.PHONY: deploy-docs +deploy-docs: ## Deploy docs to GitHub Pages + uv run mkdocs gh-deploy --force --verbose diff --git a/libs/arcade-mcp-server/README.md b/libs/arcade-mcp-server/README.md new file mode 100644 index 00000000..7a80c56b --- /dev/null +++ b/libs/arcade-mcp-server/README.md @@ -0,0 +1,72 @@ +# Arcade MCP Server + +

+ Arcade Logo +

+ +Arcade MCP (Model Context Protocol) Server enables AI assistants and development tools to interact with your Arcade tools through a standardized protocol. Build, deploy, and integrate MCP servers seamlessly across different AI platforms. + +## Quick Links + +- **[Quickstart Guide](getting-started/quickstart.md)** - Get up and running in minutes +- **[Walkthrough](examples/README.md)** - Learn by example +- **[API Reference](api/app.md)** - MCPApp API documentation + +## Features + +- ๐Ÿš€ **FastAPI-like Interface** - Simple, intuitive API with `MCPApp` +- ๐Ÿ”ง **Tool Discovery** - Automatic discovery of tools in your project +- ๐Ÿ”Œ **Multiple Transports** - Support for stdio and HTTP/SSE +- ๐Ÿค– **Multi-Client Support** - Works with Claude, Cursor, and more +- ๐Ÿ“ฆ **Package Integration** - Load installed Arcade packages +- ๐Ÿ” **Built-in Security** - Environment-based configuration and secrets +- ๐Ÿ”„ **Hot Reload** - Development mode with automatic reloading +- ๐Ÿ“Š **Production Ready** - Deploy with Docker, systemd, PM2, or cloud platforms + +## Getting Started + +### Installation + +```bash +pip install arcade-mcp-server +``` + +### Create Your First Server + +```python +from arcade_mcp_server import MCPApp +from typing import Annotated + +app = MCPApp(name="my-tools", version="1.0.0") + +@app.tool +def greet(name: Annotated[str, "Name to greet"]) -> str: + """Greet someone by name.""" + return f"Hello, {name}!" + +if __name__ == "__main__": + app.run() +``` + +### Run Your Server + +```bash +# For development +python my_tools.py + +# For Claude Desktop +python -m arcade_mcp_server stdio + +# For HTTP clients +python -m arcade_mcp_server --host 0.0.0.0 --port 8080 +``` + +## Community + +- [GitHub Repository](https://github.com/ArcadeAI/arcade-mcp) +- [Discord Community](https://discord.gg/arcade-mcp) +- [Documentation](https://docs.arcade.dev) + +## License + +Arcade MCP Server is open source software licensed under the MIT license. diff --git a/libs/arcade-mcp-server/arcade_mcp_server/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/__init__.py new file mode 100644 index 00000000..81214296 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/__init__.py @@ -0,0 +1,44 @@ +""" +MCP (Model Context Protocol) support for Arcade. + +This package provides: +- MCP server implementation for serving Arcade tools +- Multiple transport options (stdio, HTTP/SSE) +- Integration with Arcade workers with factory and runner functions +- Context system for tool execution with MCP methods + +A FastAPI-like interface for building MCP servers. +- Add tools with decorators or explicitly +- Run the server with a single function call +- Supports HTTP transport only + +`arcade_mcp` for running stdio directly from the command line. +- auto discovery of tools and construction of the server +- supports stdio transport only +- run with uv or `python -m arcade_mcp` +""" + +from arcade_tdk import tool + +from arcade_mcp_server.context import Context +from arcade_mcp_server.mcp_app import MCPApp +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.settings import MCPSettings +from arcade_mcp_server.worker import create_arcade_mcp, run_arcade_mcp + +__all__ = [ + "Context", + # FastAPI-like interface + "MCPApp", + # MCP Server implementation + "MCPServer", + "MCPSettings", + # Integrated Factory and Runner + "create_arcade_mcp", + "run_arcade_mcp", + # Re-exported TDK functionality + "tool", +] + +# Package metadata +__version__ = "0.1.0" diff --git a/libs/arcade-mcp-server/arcade_mcp_server/__main__.py b/libs/arcade-mcp-server/arcade_mcp_server/__main__.py new file mode 100644 index 00000000..c32f2f84 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/__main__.py @@ -0,0 +1,353 @@ +""" +Arcade MCP Server Runner + +Provides a unified interface for running MCP servers with either: +- stdio transport for direct client connections +- HTTP/SSE transport with FastAPI for web-based connections + +Usage: + # Run with stdio transport + python -m arcade_mcp_server stdio + + # Run with HTTP transport (default) + python -m arcade_mcp_server + + # Run with specific toolkit + python -m arcade_mcp_server --toolkit my_toolkit + + # Run in development mode with hot reload + python -m arcade_mcp_server --reload --debug +""" + +import asyncio +import logging +import os +import sys +from pathlib import Path +from typing import Any + +from arcade_core.catalog import ToolCatalog +from arcade_core.discovery import discover_tools +from arcade_core.toolkit import ToolkitLoadError +from dotenv import load_dotenv +from loguru import logger + +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.settings import MCPSettings + + +# Logging setup with Loguru +class LoguruInterceptHandler(logging.Handler): + """Intercept standard logging and route to Loguru.""" + + def emit(self, record: logging.LogRecord) -> None: + try: + level = logger.level(record.levelname).name + except ValueError: + level = str(record.levelno) + + logger.opt(exception=record.exc_info).log(level, record.getMessage()) + + +def setup_logging(level: str = "INFO", stdio_mode: bool = False) -> None: + """Configure logging with Loguru.""" + # Remove existing handlers + logger.remove() + + # Configure output destination + sink = sys.stderr if stdio_mode else sys.stdout + + # Add handler with appropriate format + if level == "DEBUG": + format_str = "{level: <8} | {time:HH:mm:ss} | {name}:{line} | {message}" + else: + format_str = ( + "{level: <8} | {time:HH:mm:ss} | {message}" + ) + + logger.add( + sink, + format=format_str, + level=level, + colorize=True, + diagnose=(level == "DEBUG"), + ) + + # Intercept standard logging + logging.basicConfig(handlers=[LoguruInterceptHandler()], level=0, force=True) + + +def initialize_tool_catalog( + tool_package: str | None = None, + show_packages: bool = False, + discover_installed: bool = False, + server_name: str | None = None, + server_version: str | None = None, +) -> ToolCatalog: + """ + Discover and load tools from various sources. + + Returns a ToolCatalog or exits with a friendly error if nothing found. + """ + try: + catalog = discover_tools( + tool_package=tool_package, + show_packages=show_packages, + discover_installed=discover_installed, + server_name=server_name, + server_version=server_version, + ) + except ToolkitLoadError as exc: + logger.error(str(exc)) + sys.exit(1) + + total_tools = len(catalog) + if total_tools == 0: + logger.error("No tools found. Create Python files with @tool decorated functions.") + sys.exit(1) + + logger.info(f"Total tools loaded: {total_tools}") + return catalog + + +async def run_stdio_server( + catalog: ToolCatalog, + debug: bool = False, + env_file: str | None = None, + **kwargs: Any, +) -> None: + """Run MCP server with stdio transport.""" + from arcade_mcp_server.transports.stdio import StdioTransport + + # Load settings + # Ensure env from provided .env is loaded for stdio runs as well + if env_file: + load_dotenv(env_file) + logger.debug(f"Loaded environment variables from --env-file={env_file}") + settings = MCPSettings.from_env() + if debug: + settings.debug = True + settings.middleware.enable_logging = True + settings.middleware.log_level = "DEBUG" + + # Debug log settings and env var names (without values) + try: + tool_env_keys = sorted(settings.tool_secrets().keys()) + logger.debug( + f"Arcade settings: \n\ + ARCADE_ENVIRONMENT={settings.arcade.environment} \n\ + ARCADE_API_URL={settings.arcade.api_url}, \n\ + ARCADE_USER_ID={settings.arcade.user_id}, \n\ + api_key_present - {bool(settings.arcade.api_key)}" + ) + logger.debug(f"Tool environment variable names available to tools: {tool_env_keys}") + except Exception as e: + logger.debug(f"Unable to log settings/tool env keys: {e}") + + # Create server + server = MCPServer( + catalog=catalog, + settings=settings, + **kwargs, + ) + + # Create transport + transport = StdioTransport() + + try: + # Start server and transport + await server.start() + await transport.start() + + # Run connection + async with transport.connect_session() as session: + await server.run_connection( + session.read_stream, + session.write_stream, + session.init_options, + ) + except KeyboardInterrupt: + logger.info("Server stopped by user") + except Exception as e: + logger.exception(f"Server error: {e}") + raise + finally: + # Stop transport and server + try: + await transport.stop() + finally: + await server.stop() + + +def main() -> None: + """Main entry point for arcade_mcp_server module.""" + import argparse + + parser = argparse.ArgumentParser( + description="Run Arcade MCP Server", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Auto-discover tools from current directory + python -m arcade_mcp_server + + # Run with stdio transport for Claude Desktop + python -m arcade_mcp_server stdio + + # Load specific arcade package + python -m arcade_mcp_server --tool-package github + python -m arcade_mcp_server -p slack + + # Discover all installed arcade packages + python -m arcade_mcp_server --discover-installed --show-packages + + # Development mode with hot reload + python -m arcade_mcp_server --debug --reload + + # Run from a different directory + python -m arcade_mcp_server --cwd /path/to/project + python -m arcade_mcp_server --cwd ~/my-tools stdio + +Auto-discovery looks for Python files with @tool decorated functions in: + - Current directory (*.py) + - tools/ subdirectory + - arcade_tools/ subdirectory + """, + ) + + # Transport selection (positional for backwards compatibility) + parser.add_argument( + "transport", + nargs="?", + default="http", + choices=["stdio", "http", "streamable-http"], + help="Transport type (default: http)", + ) + + # Optional arguments + parser.add_argument( + "--host", + default="127.0.0.1", + help="Host to bind to (HTTP mode only)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind to (HTTP mode only)", + ) + parser.add_argument( + "--tool-package", + "--package", + "-p", + dest="tool_package", + help="Specific tool package to load (e.g., 'github' for arcade-github)", + ) + parser.add_argument( + "--discover-installed", + "--all", + action="store_true", + help="Discover all installed arcade tool packages", + ) + parser.add_argument( + "--show-packages", + action="store_true", + help="Show loaded packages during discovery", + ) + parser.add_argument( + "--reload", + action="store_true", + help="Enable auto-reload on code changes (HTTP mode only)", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug mode with verbose logging", + ) + parser.add_argument( + "--env-file", + help="Path to environment file", + ) + parser.add_argument( + "--name", + help="Server name", + ) + parser.add_argument( + "--version", + help="Server version", + ) + parser.add_argument( + "--cwd", + help="Directory to change to before running (for tool discovery)", + ) + + args = parser.parse_args() + + # Change working directory if specified + if args.cwd: + cwd_path = Path(args.cwd).resolve() + if not cwd_path.exists(): + print(f"Error: Directory does not exist: {args.cwd}", file=sys.stderr) + sys.exit(1) + if not cwd_path.is_dir(): + print(f"Error: Path is not a directory: {args.cwd}", file=sys.stderr) + sys.exit(1) + os.chdir(cwd_path) + # Update logging to show the new directory + + # Load environment variables + if args.env_file: + load_dotenv(args.env_file) + + # Setup logging + log_level = "DEBUG" if args.debug else "INFO" + setup_logging(level=log_level, stdio_mode=(args.transport == "stdio")) + + # Build kwargs for server + server_kwargs = {} + if args.name: + server_kwargs["name"] = args.name + if args.version: + server_kwargs["version"] = args.version + + # Discover tools + catalog = initialize_tool_catalog( + tool_package=args.tool_package, + show_packages=args.show_packages, + discover_installed=args.discover_installed, + server_name=server_kwargs.get("name"), + server_version=server_kwargs.get("version"), + ) + + # Run appropriate server + try: + if args.transport == "stdio": + logger.info("Starting MCP server with stdio transport") + asyncio.run( + run_stdio_server(catalog, debug=args.debug, env_file=args.env_file, **server_kwargs) + ) + else: + logger.info(f"Starting MCP server with HTTP transport on {args.host}:{args.port}") + from arcade_mcp_server.worker import run_arcade_mcp + + run_arcade_mcp( + catalog=catalog, + host=args.host, + port=args.port, + reload=args.reload, + debug=args.debug, + tool_package=args.tool_package, + discover_installed=args.discover_installed, + show_packages=args.show_packages, + **server_kwargs, + ) + except (KeyboardInterrupt, asyncio.CancelledError): + logger.info("Server stopped") + sys.exit(0) + except Exception as e: + logger.error(f"Server error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/libs/arcade-mcp-server/arcade_mcp_server/context.py b/libs/arcade-mcp-server/arcade_mcp_server/context.py new file mode 100644 index 00000000..ba3b6fc3 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/context.py @@ -0,0 +1,697 @@ +""" +MCP Context System + +Provides the primary Context class for MCP tool development. This module contains +the Context class that tools should use for both runtime capabilities and +tool-specific data access. + +The Context class combines: +- Runtime capabilities: logging, resources, prompts, sampling, UI, notifications +- Tool-specific data: secrets, user_id, authorization, metadata +- Session management: request/session IDs and MCP protocol handling + +Key responsibilities: +- Manage per-request state and the current model context using a ContextVar +- Expose namespaced runtime capabilities (log, resources, etc.) +- Provide access to tool-specific data (secrets, user_id, etc.) +- Delegate to the underlying MCP session and server managers +- Handle MCP protocol communication and lifecycle management + +Note: Context instances are automatically created and managed by the MCP server. +Tools receive a populated context instance as a parameter and should not +create Context instances directly. +""" + +from __future__ import annotations + +import asyncio +import logging +import weakref +from builtins import list as builtins_list +from contextvars import ContextVar, Token +from typing import Any, cast + +from arcade_core.context import ModelContext as ModelContextProtocol +from arcade_core.schema import ( + ToolCallOutput, + ToolContext, +) + +from arcade_mcp_server.types import ( + ClientCapabilities, + ElicitResult, + LoggingLevel, + ModelHint, + ModelPreferences, + ResourceContents, + Root, + SamplingMessage, + TextContent, +) + +# Context variable for current model context +_current_model_context: ContextVar[Context | None] = ContextVar("model_context", default=None) +_flush_lock = asyncio.Lock() + + +class _ContextComponent: + def __init__(self, ctx: Context) -> None: + self._ctx = ctx + + @property + def server(self) -> Any: + return self._ctx.server + + def _require_session(self) -> Any: + session = self._ctx._session + if session is None: + raise ValueError("Session not available") + return session + + +class Context(ToolContext): + """Primary context interface for MCP tools. + + This class provides both runtime capabilities (logging, resources, prompts, etc.) + and tool-specific data (secrets, user_id, authorization) in a single interface. + Tools should annotate their context parameter with this class. + + Runtime Capabilities: + - log: Logging interface (context.log.info(), context.log.error(), etc.) + - progress: Progress reporting for long-running operations + - resources: Access to MCP resources (files, data sources, etc.) + - tools: Call other tools programmatically + - prompts: Access to MCP prompts and templates + - sampling: Create messages using the client's model + - ui: User interaction (elicit input from user) + - notifications: Send notifications to the client + + Tool-Specific Data (inherited from ToolContext): + - user_id: The user ID for this tool execution + - secrets: List of secrets available to this tool + - authorization: Authorization context if required + - metadata: Additional metadata for the tool execution + + Example: + ```python + from arcade_mcp_server import Context, tool + + @tool + async def my_tool(context: Context) -> str: + '''Example tool''' + # Runtime capabilities + await context.log.info("Processing request") + + return "result" + ``` + + Note: Instances are automatically created and managed by the MCP server. + Tools receive a populated context instance as a parameter. + """ + + # Mark as implementing the protocol + __protocols__ = (ModelContextProtocol,) if ModelContextProtocol is not object else () + + def __init__( + self, + server: Any, + session: Any | None = None, + request_id: str | None = None, + ): + """Initialize context with server reference.""" + super().__init__() + self._server: weakref.ref[Any] = weakref.ref(server) + self._session: Any | None = session + self._tokens: list[Token] = [] + self._notification_queue: set[str] = set() + self._request_id: str | None = request_id + + # Namespaced adapters + self._log = Logs(self) + self._progress = Progress(self) + self._resources = Resources(self) + self._tools = Tools(self) + self._prompts = Prompts(self) + self._sampling = Sampling(self) + self._ui = UI(self) + self._notifications = Notifications(self) + + @property + def server(self) -> Any: + """Get the server instance.""" + server = self._server() + if server is None: + raise RuntimeError("Server instance is no longer available") + return server + + def set_session(self, session: Any) -> None: + """Set the session for this context.""" + self._session = session + + def set_request_id(self, request_id: str) -> None: + """Set the request ID for this context.""" + self._request_id = request_id + + def set_tool_context( + self, + toolContext: ToolContext, + ) -> None: + """Populate the tool context fields for this model context.""" + self.authorization = toolContext.authorization + self.secrets = toolContext.secrets + self.metadata = toolContext.metadata + self.user_id = toolContext.user_id + + async def __aenter__(self) -> Context: + """Enter the context manager and set as current model context.""" + # Set this as current model context + token = _current_model_context.set(self) + self._tokens.append(token) + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit the context manager and clear current model context.""" + # Flush any pending notifications + await self._flush_notifications() + + # Reset context + if self._tokens: + token = self._tokens.pop() + _current_model_context.reset(token) + + # ============ ModelContext protocol properties ============ + @property + def log(self) -> Logs: + """Logging interface for the tool. + + Provides methods for different log levels: + - log.debug(message): Debug-level logging + - log.info(message): Info-level logging + - log.warning(message): Warning-level logging + - log.error(message): Error-level logging + - log.log(level, message): Log at a specific level + + Example: + ```python + await context.log.info("Processing started") + await context.log.error("Something went wrong") + ``` + """ + return self._log + + @property + def progress(self) -> Progress: + """Progress reporting for long-running operations. + + Use this to report progress back to the client during lengthy operations. + + Example: + ```python + await context.progress.report(0.5, total=1.0, message="Halfway done") + ``` + """ + return self._progress + + @property + def resources(self) -> Resources: + """Interface for accessing MCP resources""" + return self._resources + + @property + def tools(self) -> Tools: + """Interface for calling other tools programmatically. + + Allows tools to call other tools within the same session. + + Example: + ```python + result = await context.tools.call_raw("other_tool", {"param": "value"}) + ``` + """ + return self._tools + + @property + def prompts(self) -> Prompts: + """Interface for accessing MCP prompts and templates""" + return self._prompts + + @property + def sampling(self) -> Sampling: + """Create messages using the client's model. + + Allows tools to generate text using the connected model. + + Example: + ```python + response = await context.sampling.create_message( + "Summarize this text: " + text, + temperature=0.7 + ) + ``` + """ + return self._sampling + + @property + def ui(self) -> UI: + """User interaction (elicitation) capabilities. + + Provides methods for interacting with the user, such as eliciting input. + + Example: + ```python + result = await context.ui.elicit( + "Please provide your name", + schema={"type": "object", "properties": {"name": {"type": "string"}}} + ) + ``` + """ + return self._ui + + @property + def notifications(self) -> Notifications: + """ + Interface for sending notifications to the client + such as tool list changes. + + Example: + ```python + await context.notifications.tools.list_changed() + ``` + """ + return self._notifications + + @property + def request_id(self) -> str | None: + """Get the current request ID. + + Returns: + The unique identifier for this MCP request, or None if not available. + """ + return self._request_id + + @property + def session_id(self) -> str | None: + """Get the current session ID. + + Returns: + The unique identifier for this MCP session, or None if not available. + """ + if self._session is None: + return None + return getattr(self._session, "session_id", None) + + # Private helpers + def _check_client_capability(self, capability: ClientCapabilities) -> bool: + """Check if client has a capability.""" + if self._session is None: + return False + return cast(bool, self._session.check_client_capability(capability)) + + def _parse_model_preferences( + self, prefs: ModelPreferences | str | list[str] | None + ) -> ModelPreferences | None: + """Parse model preferences into standard format.""" + if prefs is None: + return None + elif isinstance(prefs, ModelPreferences): + return prefs + elif isinstance(prefs, str): + return ModelPreferences(hints=[ModelHint(name=prefs)]) + elif isinstance(prefs, list): + return ModelPreferences(hints=[ModelHint(name=h) for h in prefs]) + else: + raise ValueError(f"Invalid model preferences type: {type(prefs)}") + + def _try_flush_notifications(self) -> None: + """Try to flush notifications if in async context.""" + try: + loop = asyncio.get_running_loop() + if loop and not loop.is_running(): + return + flush_task = asyncio.create_task(self._flush_notifications()) + flush_task.add_done_callback(lambda _: self._notification_queue.clear()) + except RuntimeError: + # No event loop + pass + + async def _flush_notifications(self) -> None: + """Send all queued notifications.""" + async with _flush_lock: + if not self._notification_queue or self._session is None: + return + + nm = getattr(self.server, "notification_manager", None) + if nm is None: + return + + try: + client_ids = [] + if ( + self._session + and hasattr(self._session, "session_id") + and self._session.session_id + ): + client_ids = [self._session.session_id] + + if "notifications/tools/list_changed" in self._notification_queue: + await nm.notify_tool_list_changed(client_ids) + if "notifications/resources/list_changed" in self._notification_queue: + await nm.notify_resource_list_changed(client_ids) + if "notifications/prompts/list_changed" in self._notification_queue: + pass + + self._notification_queue.clear() + except Exception: + # Don't let notification failures break the request + logging.debug("Failed to send notifications", exc_info=True) + + +# ===================== +# Namespaced adapters +# ===================== +# These thin, per-domain facades (log, progress, resources, tools, prompts, +# sampling, ui, notifications) expose a stable, developer-friendly API on +# Context (e.g., context.log.info(...), context.resources.list()). +# +# They delegate all work to the active MCP session and server managers, keeping +# transport and server-specific details encapsulated in one place. +# This design: +# - avoids leaking MCP internals into the developer surface +# - preserves a cohesive, testable Context API with clear async boundaries +# - allows runtime implementations to evolve without breaking tool code +# +# In short: adapters provide the ergonomics tools rely on, while the underlying +# implementation remains decoupled and replaceable. + + +class Logs(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def log( + self, + level: str, + message: str, + logger_name: str | None = None, + extra: dict[str, Any] | None = None, + ) -> None: + session = self._ctx._session + if session is None: + return + level_typed = cast(LoggingLevel, level) + data = {"msg": message, "extra": extra} + await session.send_log_message( + level=level_typed, + data=data, + logger=logger_name, + ) + + async def __call__( + self, + level: str, + message: str, + logger_name: str | None = None, + extra: dict[str, Any] | None = None, + ) -> None: # compatibility shim + await self.log(level, message, logger_name=logger_name, extra=extra) + + async def debug(self, message: str, **kwargs: Any) -> None: + await self.log("debug", message, **kwargs) + + async def info(self, message: str, **kwargs: Any) -> None: + await self.log("info", message, **kwargs) + + async def warning(self, message: str, **kwargs: Any) -> None: + await self.log("warning", message, **kwargs) + + async def error(self, message: str, **kwargs: Any) -> None: + await self.log("error", message, **kwargs) + + +class Progress(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def report( + self, progress: float, total: float | None = None, message: str | None = None + ) -> None: + session = self._ctx._session + if session is None: + return + progress_token = None + if hasattr(session, "_request_meta"): + progress_token = getattr(session._request_meta, "progressToken", None) + if progress_token is None: + return + await session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ) + + +class Resources(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def read(self, uri: str) -> list[ResourceContents]: + if self._ctx.server is None: + raise ValueError("Context is not available outside of a request") + result = await self._ctx.server._mcp_read_resource(uri) + return cast(list[ResourceContents], result) + + async def get(self, uri: str) -> ResourceContents: + contents = await self.read(uri) + if not contents: + raise ValueError(f"Resource not found: {uri}") + return contents[0] + + async def list_roots(self) -> list[Root]: + if self._ctx._session is None: + return [] + result = await self._ctx._session.list_roots() + return result.roots if hasattr(result, "roots") else [] + + async def list(self) -> list[Root]: + # Convert Resource objects to Root objects + resources = await self._ctx.server._resource_manager.list_resources() + # Resources have uri and name which map to Root + return [Root(uri=r.uri, name=r.name) for r in resources] + + async def list_templates(self) -> builtins_list[Any]: + templates = await self._ctx.server._resource_manager.list_resource_templates() + return cast(builtins_list[Any], templates) + + +class Tools(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def list(self) -> list[Any]: + tools = await self._ctx.server._tool_manager.list_tools() + return cast(list[Any], tools) + + async def call_raw(self, name: str, params: dict[str, Any]) -> ToolCallOutput: + tool = await self._ctx.server._tool_manager.get_tool(name) + tool_context = await self._ctx.server._create_tool_context(tool, self._ctx._session) + # Attach to current model context for the duration of this call + self._ctx.set_tool_context(tool_context) + func = tool.tool + if asyncio.iscoroutinefunction(func): + + async def async_func(**kw: Any) -> Any: + return await func(**kw) + + else: + + async def async_func(**kw: Any) -> Any: + return func(**kw) + + result = await self._ctx.server.executor.run( + func=async_func, + definition=tool.definition, + input_model=tool.input_model, + output_model=tool.output_model, + context=tool_context, + **params, + ) + return cast(ToolCallOutput, result) + + +class Prompts(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def list(self) -> list[Any]: + prompts = await self._ctx.server._prompt_manager.list_prompts() + return cast(list[Any], prompts) + + async def get(self, name: str, arguments: dict[str, str] | None = None) -> Any: + return await self._ctx.server._prompt_manager.get_prompt(name, arguments) + + +class Sampling(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def create_message( + self, + messages: str | list[str | SamplingMessage], + system_prompt: str | None = None, + include_context: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + model_preferences: ModelPreferences | str | list[str] | None = None, + ) -> Any: + if self._ctx._session is None: + raise ValueError("Session not available") + + # Convert messages to proper format + if isinstance(messages, str): + sampling_messages = [ + SamplingMessage(content=TextContent(text=messages, type="text"), role="user") + ] + elif isinstance(messages, list): + sampling_messages = [] + for m in messages: + if isinstance(m, str): + sampling_messages.append( + SamplingMessage(content=TextContent(text=m, type="text"), role="user") + ) + else: + sampling_messages.append(m) + else: + sampling_messages = messages + + # Parse model preferences + parsed_prefs = self._ctx._parse_model_preferences(model_preferences) + + # Check client capabilities + if not self._ctx._check_client_capability(ClientCapabilities(sampling={})): + raise ValueError("Client does not support sampling") + + result = await self._ctx._session.create_message( + messages=sampling_messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens or 512, + model_preferences=parsed_prefs, + ) + + return result.content if hasattr(result, "content") else result + + +class UI(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + def _validate_elicitation_schema(self, schema: dict[str, Any]) -> None: + """Validate that the schema conforms to MCP elicitation restrictions.""" + if not isinstance(schema, dict): + raise TypeError("Schema must be a dictionary") + + if schema.get("type") != "object": + raise ValueError("Schema must have type 'object'") + + properties = schema.get("properties", {}) + if not isinstance(properties, dict): + raise TypeError("Schema properties must be a dictionary") + + # Validate each property + for prop_name, prop_schema in properties.items(): + if not isinstance(prop_schema, dict): + raise TypeError(f"Property '{prop_name}' schema must be a dictionary") + + prop_type = prop_schema.get("type") + if prop_type not in ["string", "number", "integer", "boolean"]: + raise ValueError( + f"Property '{prop_name}' has unsupported type '{prop_type}'. Only primitive types are allowed." + ) + + # Validate string formats + if prop_type == "string" and "format" in prop_schema: + allowed_formats = ["email", "uri", "date", "date-time"] + if prop_schema["format"] not in allowed_formats: + raise ValueError( + f"Property '{prop_name}' has unsupported format '{prop_schema['format']}'. Allowed: {allowed_formats}" + ) + + async def elicit( + self, message: str, schema: dict[str, Any] | None = None, timeout: float = 300.0 + ) -> ElicitResult: + if self._ctx._session is None: + raise ValueError("Session not available") + if schema is None: + schema = {"type": "object", "properties": {}} + + # Validate schema conforms to MCP restrictions + self._validate_elicitation_schema(schema) + + result = await self._ctx._session.elicit( + message=message, + requested_schema=schema, + timeout=timeout, + ) + return cast(ElicitResult, result) + + +class _NotificationsTools(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def list_changed(self) -> None: + self._ctx._notification_queue.add("notifications/tools/list_changed") + self._ctx._try_flush_notifications() + + +class _NotificationsResources(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def list_changed(self) -> None: + self._ctx._notification_queue.add("notifications/resources/list_changed") + self._ctx._try_flush_notifications() + + +class _NotificationsPrompts(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + + async def list_changed(self) -> None: + self._ctx._notification_queue.add("notifications/prompts/list_changed") + self._ctx._try_flush_notifications() + + +class Notifications(_ContextComponent): + def __init__(self, ctx: Context) -> None: + super().__init__(ctx) + self._tools = _NotificationsTools(ctx) + self._resources = _NotificationsResources(ctx) + self._prompts = _NotificationsPrompts(ctx) + + @property + def tools(self) -> _NotificationsTools: + return self._tools + + @property + def resources(self) -> _NotificationsResources: + return self._resources + + @property + def prompts(self) -> _NotificationsPrompts: + return self._prompts + + +def get_current_model_context() -> Context | None: + """Get the current model context if available.""" + return _current_model_context.get() + + +def set_current_model_context(context: Context | None, token: Token | None = None) -> Token: + """Set the current model context and return a token to reset it.""" + if token is not None: + _current_model_context.reset(token) + return token + return _current_model_context.set(context) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/convert.py b/libs/arcade-mcp-server/arcade_mcp_server/convert.py new file mode 100644 index 00000000..e0afda5c --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/convert.py @@ -0,0 +1,370 @@ +import base64 +import json +import logging +from enum import Enum +from typing import Any, get_args, get_origin + +from arcade_core.catalog import MaterializedTool +from arcade_core.schema import ToolDefinition + +from arcade_mcp_server.types import MCPContent, MCPTool, TextContent, ToolAnnotations + +logger = logging.getLogger("arcade.mcp") + + +def create_mcp_tool(tool: MaterializedTool) -> MCPTool | None: + """ + Create an MCP-compatible tool definition from an Arcade tool. + + Args: + tool: An Arcade tool object + + Returns: + An MCP tool definition or None if the tool cannot be converted + """ + try: + # Get the tool name from the definition + tool_name = getattr(tool.definition, "name", "unknown") + fully_qualified_name = getattr(tool.definition, "fully_qualified_name", None) + + # Use fully qualified name for MCP tool name (replacing dots with underscores) + name = fully_qualified_name.replace(".", "_") if fully_qualified_name else tool_name + + description = getattr(tool.definition, "description", "No description available") + + # Check for deprecation + deprecation_msg = getattr(tool.definition, "deprecation_message", None) + if deprecation_msg: + description = f"[DEPRECATED: {deprecation_msg}] {description}" + + # Build input schema using authoritative ToolDefinition when available + try: + if getattr(tool.definition, "input", None): + input_schema = build_input_schema_from_definition(tool.definition) + else: + # Fallback to input_model if definition input is missing + input_schema = _build_input_schema_from_model(tool) + except Exception: + logger.exception("Error while constructing input schema; proceeding with empty schema") + input_schema = {"type": "object", "properties": {}, "additionalProperties": False} + + # Create output schema if available + output_schema = None + try: + if hasattr(tool.definition, "output") and tool.definition.output: + output_def = tool.definition.output + if getattr(output_def, "value_schema", None): + output_schema = _build_value_schema_json(output_def.value_schema) + except Exception: + logger.exception("Error while constructing output schema; omitting output schema") + + requirements = tool.definition.requirements + + # Build annotations using model for stricter typing + annotations = ToolAnnotations( + readOnlyHint=not ( + requirements.authorization or requirements.secrets or requirements.metadata + ), + openWorldHint=requirements.authorization is not None, + ) + + # Instantiate MCPTool model to ensure shape correctness + return MCPTool( + name=name, + title=tool.definition.toolkit.name + "_" + tool_name, + description=str(description), + inputSchema=input_schema, + outputSchema=output_schema if output_schema else None, + annotations=annotations, + ) + + except Exception: + logger.exception( + f"Error creating MCP tool definition for {getattr(tool, 'name', str(tool))}" + ) + try: + # Fallback minimal tool to avoid None in callers + fallback_name = getattr(tool.definition, "fully_qualified_name", "unknown").replace( + ".", "_" + ) + return MCPTool( + name=fallback_name, + title=fallback_name, + description="", + inputSchema={"type": "object", "properties": {}, "additionalProperties": False}, + ) + except Exception: + return None + + +def convert_to_mcp_content(value: Any) -> list[MCPContent]: + """ + Convert a Python value to MCP-compatible content. + """ + if value is None: + return [] + + if isinstance(value, (str, bool, int, float)): + return [TextContent(type="text", text=str(value))] + + if isinstance(value, (dict, list)): + try: + return [TextContent(type="text", text=json.dumps(value, ensure_ascii=False))] + except Exception as exc: + raise ValueError("Failed to serialize value to JSON for MCP content") from exc + + if isinstance(value, (bytes, bytearray, memoryview)): + # Encode bytes as base64 text so it can be transmitted safely + b = bytes(value) + encoded = base64.b64encode(b).decode("ascii") + return [TextContent(type="text", text=encoded)] + + # Default fallback + return [TextContent(type="text", text=str(value))] + + +def convert_content_to_structured_content(value: Any) -> dict[str, Any] | None: + """ + Convert a Python value to MCP-compatible structured content (JSON object). + + According to the MCP specification, structuredContent should be a JSON object + that represents the structured result of the tool call. + + Args: + value: The value to convert to structured content + + Returns: + A dictionary representing the structured content, or None if value is None + """ + if value is None: + return None + + if isinstance(value, dict): + # Already a dictionary - use as-is + return value + elif isinstance(value, list): + # List - wrap in a result object + return {"result": value} + elif isinstance(value, (str, int, float, bool)): + # Primitive types - wrap in a result object + return {"result": value} + else: + # For other types, convert to string and wrap + return {"result": str(value)} + + +def _map_type_to_json_schema_type(val_type: str | None) -> str: + """ + Map Arcade value types to JSON schema types. + + Args: + val_type: The Arcade value type as a string. + + Returns: + The corresponding JSON schema type as a string. + """ + if val_type is None: + return "string" + + mapping: dict[str, str] = { + "string": "string", + "integer": "integer", + "number": "number", + "boolean": "boolean", + "json": "object", + "array": "array", + } + return mapping.get(val_type, "string") + + +def build_input_schema_from_definition(definition: ToolDefinition) -> dict[str, Any]: + """Build a JSON schema object for tool inputs from a ToolDefinition. + + Returns a dict with keys: type, properties, and optional required. + """ + properties: dict[str, Any] = {} + required: list[str] = [] + + if getattr(definition, "input", None) and getattr(definition.input, "parameters", None): + for param in definition.input.parameters: + val_schema = getattr(param, "value_schema", None) + schema: dict[str, Any] = { + "type": _map_type_to_json_schema_type(getattr(val_schema, "val_type", None)), + } + + if getattr(param, "description", None): + schema["description"] = param.description + + if val_schema and getattr(val_schema, "enum", None): + schema["enum"] = list(val_schema.enum) + + if ( + val_schema + and val_schema.val_type == "array" + and getattr(val_schema, "inner_val_type", None) + ): + schema["items"] = {"type": _map_type_to_json_schema_type(val_schema.inner_val_type)} + + if ( + val_schema + and val_schema.val_type == "json" + and getattr(val_schema, "properties", None) + ): + schema["type"] = "object" + schema["properties"] = {} + for prop_name, prop_schema in val_schema.properties.items(): + schema["properties"][prop_name] = { + "type": _map_type_to_json_schema_type( + getattr(prop_schema, "val_type", None) + ), + } + if getattr(prop_schema, "description", None): + schema["properties"][prop_name]["description"] = prop_schema.description + + properties[param.name] = schema + if getattr(param, "required", False): + required.append(param.name) + + input_schema: dict[str, Any] = { + "type": "object", + "properties": properties, + "additionalProperties": False, + } + if required: + input_schema["required"] = required + return input_schema + + +def _build_input_schema_from_model(tool: MaterializedTool) -> dict[str, Any]: + """Build input schema from a tool's input_model as a fallback.""" + properties: dict[str, Any] = {} + required: list[str] = [] + + context_param_name = None + tool_input = getattr(tool.definition, "input", None) + if tool_input is not None: + context_param_name = getattr(tool_input, "tool_context_parameter_name", None) + + if ( + hasattr(tool, "input_model") + and tool.input_model is not None + and hasattr(tool.input_model, "model_fields") + ): + for field_name, field in tool.input_model.model_fields.items(): + if field_name == context_param_name: + continue + + field_type = getattr(field, "annotation", None) + field_type_name = "string" # default + + if field_type is int: + field_type_name = "integer" + elif field_type is float: + field_type_name = "number" + elif field_type is bool: + field_type_name = "boolean" + elif field_type is list or (getattr(field_type, "__origin__", None) is list): + field_type_name = "array" + elif field_type is dict or (getattr(field_type, "__origin__", None) is dict): + field_type_name = "object" + + field_description = getattr(field, "description", None) or f"Parameter: {field_name}" + + param_def: dict[str, Any] = { + "type": field_type_name, + "description": field_description, + } + + # Enum support: Enum classes or typing.Annotated[...] with Enum + enum_type = None + ann = getattr(field, "annotation", None) + if ann is not None: + origin = get_origin(ann) + args = get_args(ann) + # typing.Annotated[Enum, ...] + if origin is not None and args: + for arg in args: + if isinstance(arg, type) and issubclass(arg, Enum): + enum_type = arg + break + elif isinstance(ann, type) and issubclass(ann, Enum): + enum_type = ann + if enum_type is not None: + param_def["enum"] = [e.value for e in enum_type] + + # Literal[...] support for enum-like constraints + if ann is not None and get_origin(ann) is None: + pass # no-op, handled above + elif ann is not None and get_origin(ann) is Any: + pass + else: + if get_origin(ann) is None: + ... + + # Attempt to infer inner list item types for list[T] + if field_type_name == "array": + inner = None + if get_origin(field_type) is list and get_args(field_type): + inner = get_args(field_type)[0] + if inner is int: + param_def["items"] = {"type": "integer"} + elif inner is float: + param_def["items"] = {"type": "number"} + elif inner is bool: + param_def["items"] = {"type": "boolean"} + elif inner is str: + param_def["items"] = {"type": "string"} + + properties[field_name] = param_def + + # Required detection with multiple strategies + is_required_attr = getattr(field, "is_required", None) + try: + if callable(is_required_attr): + if is_required_attr(): + required.append(field_name) + elif isinstance(is_required_attr, bool) and is_required_attr: + required.append(field_name) + else: + has_default = getattr(field, "default", None) is not None + has_factory = getattr(field, "default_factory", None) is not None + if not (has_default or has_factory): + required.append(field_name) + except Exception: + logger.debug( + f"Could not determine if field {field_name} is required, assuming optional" + ) + + input_schema: dict[str, Any] = { + "type": "object", + "properties": properties, + "additionalProperties": False, + } + if required: + input_schema["required"] = required + return input_schema + + +def _build_value_schema_json(value_schema: Any) -> dict[str, Any]: + """Map a ValueSchema to a JSON schema fragment for outputSchema.""" + schema: dict[str, Any] = { + "type": _map_type_to_json_schema_type(getattr(value_schema, "val_type", None)), + } + if getattr(value_schema, "enum", None): + schema["enum"] = list(value_schema.enum) + if getattr(value_schema, "val_type", None) == "array" and getattr( + value_schema, "inner_val_type", None + ): + schema["items"] = {"type": _map_type_to_json_schema_type(value_schema.inner_val_type)} + if getattr(value_schema, "val_type", None) == "json" and getattr( + value_schema, "properties", None + ): + schema["type"] = "object" + schema["properties"] = {} + for prop_name, prop_schema in value_schema.properties.items(): + schema["properties"][prop_name] = { + "type": _map_type_to_json_schema_type(getattr(prop_schema, "val_type", None)) + } + if getattr(prop_schema, "description", None): + schema["properties"][prop_name]["description"] = prop_schema.description + return schema diff --git a/libs/arcade-mcp-server/arcade_mcp_server/exceptions.py b/libs/arcade-mcp-server/arcade_mcp_server/exceptions.py new file mode 100644 index 00000000..ac8462da --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/exceptions.py @@ -0,0 +1,93 @@ +""" +MCP Exception Hierarchy + +Provides domain-specific exceptions for better error handling and debugging. +""" + +from arcade_core.errors import ToolRuntimeError # Re-export for convenience + +__all__ = [ + # Re-exports + "ToolRuntimeError", + # Base exceptions + "MCPError", + "MCPRuntimeError", + # Server exceptions + "ServerError", + "SessionError", + "RequestError", + "ResponseError", + "ServerRequestError", + "LifespanError", + # Context exceptions + "MCPContextError", + "NotFoundError", + "AuthorizationError", + "PromptError", + "ResourceError", + "TransportError", + "ProtocolError", +] + + +class MCPError(Exception): + """Base error for all MCP-related exceptions.""" + + +class MCPRuntimeError(MCPError): + """Runtime error for all MCP-related exceptions.""" + + +class ServerError(MCPRuntimeError): + """Error in server operations.""" + + +class SessionError(ServerError): + """Error in session management""" + + +class RequestError(ServerError): + """Error in request processing from client to server""" + + +class ResponseError(ServerError): + """Error in request processing from server -> client""" + + +class ServerRequestError(RequestError): + """Error in sending request from server -> client initiated by the server""" + + +class LifespanError(ServerError): + """Error in lifespan management.""" + + +class MCPContextError(MCPError): + """Error in context management.""" + + +class NotFoundError(MCPContextError): + """Requested entity not found.""" + + +class AuthorizationError(MCPContextError): + """Authorization failure.""" + + +class PromptError(MCPContextError): + """Error in prompt management.""" + + +class ResourceError(MCPContextError): + """Error in resource management.""" + + +# Transport and Protocol Errors + + +class TransportError(MCPRuntimeError): + """Error in transport layer (stdio, HTTP, etc).""" + + +class ProtocolError(MCPRuntimeError): + """Error in MCP protocol handling.""" diff --git a/libs/arcade-mcp-server/arcade_mcp_server/fastapi/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/fastapi/__init__.py new file mode 100644 index 00000000..84c9a008 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/fastapi/__init__.py @@ -0,0 +1 @@ +"""FastAPI integration for MCP server.""" diff --git a/libs/arcade-mcp-server/arcade_mcp_server/fastapi/routes.py b/libs/arcade-mcp-server/arcade_mcp_server/fastapi/routes.py new file mode 100644 index 00000000..3a357e7f --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/fastapi/routes.py @@ -0,0 +1,335 @@ +""" +FastAPI OpenAPI Documentation Routes + +This module provides FastAPI route definitions solely for generating OpenAPI/Swagger +documentation. These routes describe the HTTP endpoints and their request/response +schemas but do not contain actual implementation logic. + +The routes documented here are: +- POST /mcp - Send JSON-RPC messages +- GET /mcp - Establish Server-Sent Events (SSE) stream +- DELETE /mcp - Terminate active session + +Note: These are documentation-only routes. The actual protocol implementation +is handled separately through the underlying transport layer. +""" + +from typing import Any, Optional + +from fastapi import APIRouter, Header, HTTPException, Request, status +from pydantic import BaseModel, Field + +from arcade_mcp_server.transports.http_streamable import MCP_SESSION_ID_HEADER +from arcade_mcp_server.types import JSONRPC_VERSION, LATEST_PROTOCOL_VERSION + + +# Pydantic models for OpenAPI documentation +class MCPRequest(BaseModel): + """JSON-RPC request message for MCP protocol.""" + + jsonrpc: str = Field(default=JSONRPC_VERSION, description="JSON-RPC version") + method: str = Field(..., description="Method name to invoke") + params: Optional[dict[str, Any]] = Field(None, description="Method parameters") + id: Optional[str | int] = Field(None, description="Request ID for correlating responses") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "jsonrpc": JSONRPC_VERSION, + "method": "initialize", + "params": { + "protocolVersion": LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": {"name": "example-client", "version": "1.0.0"}, + }, + "id": 1, + } + ] + } + } + + +class MCPResponse(BaseModel): + """JSON-RPC response message for MCP protocol.""" + + jsonrpc: str = Field(default=JSONRPC_VERSION, description="JSON-RPC version") + result: Optional[dict[str, Any]] = Field(None, description="Successful response data") + error: Optional[dict[str, Any]] = Field(None, description="Error information if request failed") + id: str | int = Field(..., description="Request ID this response corresponds to") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "jsonrpc": JSONRPC_VERSION, + "result": { + "protocolVersion": LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "serverInfo": {"name": "arcade-server", "version": "1.0.0"}, + }, + "id": 1, + } + ] + } + } + + +class MCPError(BaseModel): + """Error response for MCP protocol.""" + + code: int = Field(..., description="Error code") + message: str = Field(..., description="Human-readable error message") + data: Optional[Any] = Field(None, description="Additional error data") + + +def get_openapi_routes() -> list[dict]: + """Get OpenAPI route definitions for MCP endpoints.""" + return [ + { + "path": "/mcp/", + "post": { + "tags": ["MCP Protocol"], + "summary": "Send MCP message", + "description": "Send a JSON-RPC message to the MCP server. This endpoint handles:\n" + "- Method requests (with id) - returns a JSON response\n" + "- Notifications (without id) - returns 202 Accepted\n\n" + "For SSE mode, set Accept: text/event-stream header.\n" + "For JSON mode, set Accept: application/json header.", + "operationId": "send_mcp_message", + "parameters": [ + { + "name": "accept", + "in": "header", + "required": False, + "schema": {"type": "string"}, + }, + { + "name": "content-type", + "in": "header", + "required": False, + "schema": {"type": "string"}, + }, + { + "name": MCP_SESSION_ID_HEADER, + "in": "header", + "required": False, + "schema": {"type": "string"}, + }, + ], + "requestBody": { + "content": { + "application/json": {"schema": {"$ref": "#/components/schemas/MCPRequest"}} + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/MCPResponse"} + } + }, + }, + "202": {"description": "Notification accepted (no response expected)"}, + "400": {"description": "Bad Request - Invalid JSON or missing required fields"}, + "404": {"description": "Not Found - Invalid or expired session ID"}, + "406": { + "description": "Not Acceptable - Client must accept required content types" + }, + "415": { + "description": "Unsupported Media Type - Content-Type must be application/json" + }, + "500": {"description": "Internal Server Error"}, + }, + }, + "get": { + "tags": ["MCP Protocol"], + "summary": "Establish SSE stream", + "description": "Establish a Server-Sent Events (SSE) stream for receiving server-initiated messages.\n\n" + "Only one SSE stream is allowed per session. The stream will remain open until:\n" + "- The client closes the connection\n" + "- The session is terminated\n" + "- An error occurs\n\n" + "Requires Accept: text/event-stream header.", + "operationId": "establish_sse_stream", + "parameters": [ + { + "name": "accept", + "in": "header", + "required": False, + "schema": {"type": "string"}, + }, + { + "name": MCP_SESSION_ID_HEADER, + "in": "header", + "required": False, + "schema": {"type": "string"}, + }, + { + "name": "Last-Event-ID", + "in": "header", + "required": False, + "schema": {"type": "string"}, + }, + ], + "responses": { + "200": { + "description": "SSE stream established", + "content": { + "text/event-stream": {"example": 'data: {"jsonrpc":"2.0",...}\\n\\n'} + }, + }, + "409": {"description": "Conflict - Only one SSE stream allowed per session"}, + "400": {"description": "Bad Request - Invalid JSON or missing required fields"}, + "404": {"description": "Not Found - Invalid or expired session ID"}, + "406": { + "description": "Not Acceptable - Client must accept required content types" + }, + "500": {"description": "Internal Server Error"}, + }, + }, + "delete": { + "tags": ["MCP Protocol"], + "summary": "Terminate session", + "description": "Terminate the current MCP session. This will:\n" + "- Close all active streams\n" + "- Clean up session resources\n" + "- Return 200 OK on successful termination\n\n" + "Only available in stateful mode (when session IDs are used).", + "operationId": "terminate_mcp_session", + "parameters": [ + { + "name": MCP_SESSION_ID_HEADER, + "in": "header", + "required": False, + "schema": {"type": "string"}, + } + ], + "responses": { + "200": {"description": "Session terminated successfully"}, + "405": { + "description": "Method Not Allowed - Session termination not supported in stateless mode" + }, + "400": {"description": "Bad Request - Invalid JSON or missing required fields"}, + "404": {"description": "Not Found - Invalid or expired session ID"}, + "500": {"description": "Internal Server Error"}, + }, + }, + } + ] + + +def create_mcp_router() -> APIRouter: + """Create FastAPI router with MCP endpoint documentation.""" + router = APIRouter( + prefix="", + tags=["MCP Protocol"], + responses={ + 400: {"description": "Bad Request - Invalid JSON or missing required fields"}, + 404: {"description": "Not Found - Invalid or expired session ID"}, + 406: {"description": "Not Acceptable - Client must accept required content types"}, + 415: {"description": "Unsupported Media Type - Content-Type must be application/json"}, + 500: {"description": "Internal Server Error"}, + }, + ) + + @router.post( + "/", + response_model=MCPResponse, + summary="Send MCP message", + description=""" + Send a JSON-RPC message to the MCP server. This endpoint handles: + - Method requests (with id) - returns a JSON response + - Notifications (without id) - returns 202 Accepted + + For SSE mode, set Accept: text/event-stream header. + For JSON mode, set Accept: application/json header. + """, + responses={ + 200: {"description": "Successful response", "model": MCPResponse}, + 202: {"description": "Notification accepted (no response expected)"}, + }, + ) + async def send_message( + request: Request, + body: MCPRequest, + accept: str = Header(None), + content_type: str = Header(None), + mcp_session_id: Optional[str] = Header(None, alias=MCP_SESSION_ID_HEADER), + ) -> None: + """ + Documentation-only endpoint definition. + """ + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Documentation endpoint only", + ) + + @router.get( + "/", + summary="Establish SSE stream", + description=""" + Establish a Server-Sent Events (SSE) stream for receiving server-initiated messages. + + Only one SSE stream is allowed per session. The stream will remain open until: + - The client closes the connection + - The session is terminated + - An error occurs + + Requires Accept: text/event-stream header. + """, + responses={ + 200: { + "description": "SSE stream established", + "content": {"text/event-stream": {"example": 'data: {"jsonrpc":"2.0",...}\\n\\n'}}, + }, + 409: {"description": "Conflict - Only one SSE stream allowed per session"}, + }, + ) + async def establish_sse( + request: Request, + accept: str = Header(None), + mcp_session_id: Optional[str] = Header(None, alias=MCP_SESSION_ID_HEADER), + last_event_id: Optional[str] = Header(None, alias="Last-Event-ID"), + ) -> None: + """ + Documentation-only endpoint definition. + """ + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Documentation endpoint only", + ) + + @router.delete( + "/", + summary="Terminate session", + description=""" + Terminate the current MCP session. This will: + - Close all active streams + - Clean up session resources + - Return 200 OK on successful termination + + Only available in stateful mode (when session IDs are used). + """, + responses={ + 200: {"description": "Session terminated successfully"}, + 405: { + "description": "Method Not Allowed - Session termination not supported in stateless mode" + }, + }, + ) + async def terminate_session( + request: Request, + mcp_session_id: Optional[str] = Header(None, alias=MCP_SESSION_ID_HEADER), + ) -> None: + """ + Documentation-only endpoint definition. + """ + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Documentation endpoint only", + ) + + return router diff --git a/libs/arcade-mcp-server/arcade_mcp_server/lifespan.py b/libs/arcade-mcp-server/arcade_mcp_server/lifespan.py new file mode 100644 index 00000000..0770289f --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/lifespan.py @@ -0,0 +1,161 @@ +"""Lifespan management for MCP server. + +Provides a clean interface for managing server lifecycle with proper +resource initialization and cleanup. +""" + +import asyncio +import logging +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Callable + +from arcade_mcp_server.exceptions import LifespanError + +logger = logging.getLogger("arcade.mcp") + +LifespanResult = dict[str, Any] + + +@asynccontextmanager +async def default_lifespan(server: Any) -> AsyncIterator[LifespanResult]: + """Default lifespan that does basic startup/shutdown logging.""" + logger.info(f"Starting MCP server: {getattr(server, 'name', 'unknown')}") + + # Startup + try: + yield {} + finally: + # Shutdown + logger.info(f"Stopping MCP server: {getattr(server, 'name', 'unknown')}") + + +class LifespanManager: + """Manages server lifecycle with proper resource management. + + This class wraps a lifespan context manager and provides a clean + interface for server startup and shutdown operations. + """ + + def __init__( + self, + server: Any, + lifespan: Callable[[Any], AbstractAsyncContextManager[LifespanResult]] | None = None, + ): + """Initialize lifespan manager. + + Args: + server: The server instance + lifespan: Optional custom lifespan function + """ + self.server = server + self.lifespan = lifespan or default_lifespan + self._stack: Any | None = None + self._context: LifespanResult | None = None + self._started = False + + async def startup(self) -> LifespanResult: + """Run startup phase of lifespan.""" + if self._started: + raise LifespanError("Lifespan already started") + + self._started = True + + self._stack = asyncio.create_task(self._run_lifespan()) + + # Wait for startup to complete + while self._context is None and not self._stack.done(): + await asyncio.sleep(0.01) + + if self._stack.done() and self._context is None: + # Lifespan failed during startup + try: + await self._stack + except Exception as e: + raise LifespanError(f"Lifespan startup failed: {e}") from e + + if self._context is None: + raise LifespanError("Lifespan startup failed") + return self._context + + async def shutdown(self) -> None: + """Run shutdown phase of lifespan.""" + if not self._started: + return + + self._started = False + + if self._stack and not self._stack.done(): + # Trigger shutdown by cancelling the lifespan task + self._stack.cancel() + try: + await self._stack + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Error during lifespan shutdown") + + self._context = None + self._stack = None + + async def _run_lifespan(self) -> None: + """Run the lifespan context manager.""" + try: + async with self.lifespan(self.server) as context: + self._context = context + # Keep running until cancelled + while True: + await asyncio.sleep(1) + except asyncio.CancelledError: + # Normal shutdown + self._context = None + raise + except Exception: + # Abnormal shutdown + self._context = None + logger.exception("Error in lifespan") + raise + + async def __aenter__(self) -> LifespanResult: + """Async context manager entry.""" + return await self.startup() + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.shutdown() + + +def compose_lifespans( + *lifespans: Callable[[Any], AbstractAsyncContextManager[LifespanResult]], +) -> Callable[[Any], AbstractAsyncContextManager[LifespanResult]]: + """Compose multiple lifespan functions into one. + + Each lifespan's context is merged into a single dict. + Lifespans are started in order and stopped in reverse order. + """ + + @asynccontextmanager + async def composed(server: Any) -> AsyncIterator[LifespanResult]: + contexts: list[tuple[AbstractAsyncContextManager[LifespanResult], LifespanResult]] = [] + merged: LifespanResult = {} + + # Start lifespans in order (sequential for compatibility) + for lifespan in lifespans: + ctx_mgr = lifespan(server) + context = await ctx_mgr.__aenter__() + contexts.append((ctx_mgr, context)) + + # Merge context if it's a dict + merged.update(context) + + try: + yield merged + finally: + # Stop lifespans in reverse order + for ctx_mgr, _ in reversed(contexts): + try: + await ctx_mgr.__aexit__(None, None, None) + except Exception: + logger.exception("Error stopping lifespan") + + return composed diff --git a/libs/arcade-mcp-server/arcade_mcp_server/managers/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/managers/__init__.py new file mode 100644 index 00000000..914cd3c5 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/managers/__init__.py @@ -0,0 +1,7 @@ +"""MCP Component Managers.""" + +from arcade_mcp_server.managers.prompt import PromptManager +from arcade_mcp_server.managers.resource import ResourceManager +from arcade_mcp_server.managers.tool import ToolManager + +__all__ = ["PromptManager", "ResourceManager", "ToolManager"] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/managers/base.py b/libs/arcade-mcp-server/arcade_mcp_server/managers/base.py new file mode 100644 index 00000000..7cf5e588 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/managers/base.py @@ -0,0 +1,146 @@ +""" +Base Async Managers + +Provides async-safe registries with RW locking, versioning, and subscriptions. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Iterable +from types import TracebackType +from typing import Any, Generic, TypeVar, cast + +K = TypeVar("K") +V = TypeVar("V") + + +class AsyncRWLock: + """Simple async RW lock allowing concurrent readers and exclusive writers.""" + + def __init__(self) -> None: + self._reader_count = 0 + self._reader_lock = asyncio.Lock() + self._gate = asyncio.Lock() + + async def read(self) -> Any: + class _ReadCtx: + async def __aenter__(_self) -> None: + async with self._reader_lock: + self._reader_count += 1 + if self._reader_count == 1: + await self._gate.acquire() + + async def __aexit__( + _self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + async with self._reader_lock: + self._reader_count -= 1 + if self._reader_count == 0: + self._gate.release() + + return _ReadCtx() + + async def write(self) -> Any: + class _WriteCtx: + async def __aenter__(_self) -> None: + await self._gate.acquire() + + async def __aexit__( + _self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self._gate.release() + + return _WriteCtx() + + +class AsyncRegistry(Generic[K, V]): + """Async-safe registry with deterministic listing and change notifications.""" + + def __init__(self, component: str) -> None: + self.component = component + self._items: dict[K, V] = {} + self._lock = AsyncRWLock() + self._version = 0 + self._subscribers: list[Callable[[str, K | None, V | None, V | None, int], None]] = [] + + def subscribe(self, fn: Callable[[str, K | None, V | None, V | None, int], None]) -> None: + self._subscribers.append(fn) + + async def get(self, key: K) -> V: + async with await self._lock.read(): + if key not in self._items: + raise KeyError(f"{self.component.title()} '{key}' not found") + return self._items[key] + + async def keys(self) -> list[K]: + async with await self._lock.read(): + return sorted(self._items.keys(), key=lambda k: str(k)) + + async def list(self) -> list[V]: + async with await self._lock.read(): + return [self._items[k] for k in sorted(self._items.keys(), key=lambda k: str(k))] + + async def upsert(self, key: K, value: V) -> None: + async with await self._lock.write(): + old = self._items.get(key) + self._items[key] = value + self._version += 1 + version = self._version + for fn in self._subscribers: + fn("upsert", key, old, value, version) + + async def remove(self, key: K) -> V: + async with await self._lock.write(): + if key not in self._items: + raise KeyError(f"{self.component.title()} '{key}' not found") + old = self._items.pop(key) + self._version += 1 + version = self._version + for fn in self._subscribers: + fn("remove", key, old, None, version) + return old + + async def bulk_load(self, items: Iterable[tuple[K, V]]) -> None: + async with await self._lock.write(): + for k, v in items: + self._items[k] = v + self._version += 1 + version = self._version + for fn in self._subscribers: + fn("bulk_load", cast(K, None), None, None, version) + + @property + def version(self) -> int: + return self._version + + +class ComponentManager(Generic[K, V]): + """Base component manager with lifecycle and async registry.""" + + def __init__(self, component: str) -> None: + self.registry: AsyncRegistry[K, V] = AsyncRegistry(component) + self._started = False + + async def start(self) -> None: + if self._started: + return + self._started = True + + async def stop(self) -> None: + if not self._started: + return + self._started = False + + def subscribe(self, fn: Callable[[str, K | None, V | None, V | None, int], None]) -> None: + self.registry.subscribe(fn) + + @property + def version(self) -> int: + return self.registry.version diff --git a/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py b/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py new file mode 100644 index 00000000..b0c6d024 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/managers/prompt.py @@ -0,0 +1,122 @@ +""" +Prompt Manager + +Async-safe prompts with registry-based storage and deterministic listing. +""" + +from __future__ import annotations + +import logging +from typing import Callable + +from arcade_mcp_server.exceptions import NotFoundError, PromptError +from arcade_mcp_server.managers.base import ComponentManager +from arcade_mcp_server.types import GetPromptResult, Prompt, PromptMessage + +logger = logging.getLogger("arcade.mcp.managers.prompt") + + +class PromptHandler: + """Handler for generating prompt messages.""" + + def __init__( + self, + prompt: Prompt, + handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None, + ) -> None: + self.prompt = prompt + self.handler = handler or self._default_handler + + def __eq__(self, other: object) -> bool: # pragma: no cover - simple comparison + if not isinstance(other, PromptHandler): + return False + return self.prompt == other.prompt and self.handler == other.handler + + def _default_handler(self, arguments: dict[str, str]) -> list[PromptMessage]: + return [ + PromptMessage( + role="user", + content={ + "type": "text", + "text": self.prompt.description or f"Prompt: {self.prompt.name}", + }, + ) + ] + + async def get_messages(self, arguments: dict[str, str] | None = None) -> list[PromptMessage]: + args = arguments or {} + + # Validate required arguments + if self.prompt.arguments: + for arg in self.prompt.arguments: + if arg.required and arg.name not in args: + raise PromptError(f"Required argument '{arg.name}' not provided") + + result = self.handler(args) + if hasattr(result, "__await__"): + result = await result + + return result + + +class PromptManager(ComponentManager[str, PromptHandler]): + """ + Manages prompts for the MCP server. + """ + + def __init__(self) -> None: + super().__init__("prompt") + + async def list_prompts(self) -> list[Prompt]: + handlers = await self.registry.list() + return [h.prompt for h in handlers] + + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> GetPromptResult: + try: + handler = await self.registry.get(name) + except KeyError: + raise NotFoundError(f"Prompt '{name}' not found") + + try: + messages = await handler.get_messages(arguments) + return GetPromptResult( + description=handler.prompt.description, + messages=messages, + ) + except Exception as e: + if isinstance(e, PromptError): + raise + raise PromptError(f"Error generating prompt: {e}") from e + + async def add_prompt( + self, + prompt: Prompt, + handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None, + ) -> None: + prompt_handler = PromptHandler(prompt, handler) + await self.registry.upsert(prompt.name, prompt_handler) + + async def remove_prompt(self, name: str) -> Prompt: + try: + handler = await self.registry.remove(name) + except KeyError: + raise NotFoundError(f"Prompt '{name}' not found") + return handler.prompt + + async def update_prompt( + self, + name: str, + prompt: Prompt, + handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None, + ) -> Prompt: + # Ensure exists + try: + _ = await self.registry.get(name) + except KeyError: + raise NotFoundError(f"Prompt '{name}' not found") + + prompt_handler = PromptHandler(prompt, handler) + await self.registry.upsert(prompt.name, prompt_handler) + return prompt diff --git a/libs/arcade-mcp-server/arcade_mcp_server/managers/resource.py b/libs/arcade-mcp-server/arcade_mcp_server/managers/resource.py new file mode 100644 index 00000000..8cfda7a0 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/managers/resource.py @@ -0,0 +1,102 @@ +""" +Resource Manager + +Async-safe resources with registry-based storage and deterministic listing. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable + +from arcade_mcp_server.exceptions import NotFoundError +from arcade_mcp_server.managers.base import ComponentManager +from arcade_mcp_server.types import ( + BlobResourceContents, + Resource, + ResourceContents, + ResourceTemplate, + TextResourceContents, +) + +logger = logging.getLogger("arcade.mcp.managers.resource") + + +class ResourceManager(ComponentManager[str, Resource]): + """ + Manages resources for the MCP server. + """ + + def __init__( + self, + ) -> None: + super().__init__("resource") + self._templates: dict[str, ResourceTemplate] = {} + self._resource_handlers: dict[str, Callable[[str], Any]] = {} + + async def list_resources(self) -> list[Resource]: + return await self.registry.list() + + async def list_resource_templates(self) -> list[ResourceTemplate]: + return [self._templates[k] for k in sorted(self._templates.keys())] + + async def read_resource(self, uri: str) -> list[ResourceContents]: + handler = self._resource_handlers.get(uri) + if handler: + result = handler(uri) + if hasattr(result, "__await__"): + result = await result + if isinstance(result, str): + return [TextResourceContents(uri=uri, text=result)] + elif isinstance(result, dict): + if "text" in result: + return [TextResourceContents(uri=uri, text=result["text"])] + if "blob" in result: + return [BlobResourceContents(uri=uri, blob=result["blob"])] + return [ResourceContents(uri=uri)] + elif isinstance(result, list): + return result + else: + return [TextResourceContents(uri=uri, text=str(result))] + + try: + _ = await self.registry.get(uri) + except KeyError as _e: + raise NotFoundError(f"Resource '{uri}' not found") + + return [TextResourceContents(uri=uri, text="")] # static placeholder + + async def add_resource( + self, resource: Resource, handler: Callable[[str], Any] | None = None + ) -> None: + await self.registry.upsert(resource.uri, resource) + if handler: + self._resource_handlers[resource.uri] = handler + + async def remove_resource(self, uri: str) -> Resource: + try: + removed = await self.registry.remove(uri) + except KeyError as _e: + raise NotFoundError(f"Resource '{uri}' not found") + self._resource_handlers.pop(uri, None) + return removed + + async def update_resource( + self, uri: str, resource: Resource, handler: Callable[[str], Any] | None = None + ) -> Resource: + try: + await self.registry.remove(uri) + except KeyError: + raise NotFoundError(f"Resource '{uri}' not found") + await self.registry.upsert(resource.uri, resource) + if handler: + self._resource_handlers[resource.uri] = handler + return resource + + async def add_template(self, template: ResourceTemplate) -> None: + self._templates[template.uriTemplate] = template + + async def remove_template(self, uri_template: str) -> ResourceTemplate: + if uri_template not in self._templates: + raise NotFoundError(f"Resource template '{uri_template}' not found") + return self._templates.pop(uri_template) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/managers/tool.py b/libs/arcade-mcp-server/arcade_mcp_server/managers/tool.py new file mode 100644 index 00000000..9ab6cd60 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/managers/tool.py @@ -0,0 +1,94 @@ +""" +Tool Manager + +Async-safe tool management with pre-converted MCPTool DTOs and executable materials. +""" + +from __future__ import annotations + +from typing import TypedDict + +from arcade_core.catalog import MaterializedTool, ToolCatalog + +from arcade_mcp_server.convert import build_input_schema_from_definition +from arcade_mcp_server.exceptions import NotFoundError +from arcade_mcp_server.managers.base import ComponentManager +from arcade_mcp_server.types import MCPTool + + +class ManagedTool(TypedDict): + dto: MCPTool + materialized: MaterializedTool + + +Key = str # fully qualified tool name + + +class ToolManager(ComponentManager[Key, ManagedTool]): + """Tool manager storing both DTO and materialized artifacts.""" + + def __init__(self) -> None: + super().__init__("tool") + self._sanitized_to_key: dict[str, str] = {} + + @staticmethod + def _sanitize_name(name: str) -> str: + return name.replace(".", "_") + + def _to_dto(self, tool: MaterializedTool) -> MCPTool: + return MCPTool( + name=self._sanitize_name(tool.definition.fully_qualified_name), + title=f"{tool.definition.toolkit.name}_{tool.definition.name}", + description=tool.definition.description, + inputSchema=build_input_schema_from_definition(tool.definition), + ) + + async def load_from_catalog(self, catalog: ToolCatalog) -> None: + pairs: list[tuple[Key, ManagedTool]] = [] + for t in catalog: + fq = t.definition.fully_qualified_name + pairs.append((fq, {"dto": self._to_dto(t), "materialized": t})) + self._sanitized_to_key[self._sanitize_name(fq)] = fq + await self.registry.bulk_load(pairs) + + async def list_tools(self) -> list[MCPTool]: + records = await self.registry.list() + return [r["dto"] for r in records] + + async def get_tool(self, name: str) -> MaterializedTool: + # Try exact key first (dotted FQN) + try: + rec = await self.registry.get(name) + return rec["materialized"] + except KeyError: + # Fallback: resolve sanitized name + key = self._sanitized_to_key.get(name) + if key is None: + raise NotFoundError(f"Tool {name} not found") + rec = await self.registry.get(key) + return rec["materialized"] + + async def add_tool(self, tool: MaterializedTool) -> None: + key = tool.definition.fully_qualified_name + await self.registry.upsert(key, {"dto": self._to_dto(tool), "materialized": tool}) + self._sanitized_to_key[self._sanitize_name(key)] = key + + async def update_tool(self, tool: MaterializedTool) -> None: + key = tool.definition.fully_qualified_name + await self.registry.upsert(key, {"dto": self._to_dto(tool), "materialized": tool}) + self._sanitized_to_key[self._sanitize_name(key)] = key + + async def remove_tool(self, name: str) -> MaterializedTool: + # Accept either exact or sanitized name + key = name + if key not in (await self.registry.keys()): + key = self._sanitized_to_key.get(name, name) + try: + rec = await self.registry.remove(key) + except KeyError as _e: + raise NotFoundError(f"Tool {name} not found") + # Clean mapping if present + sanitized = self._sanitize_name(key) + if sanitized in self._sanitized_to_key: + del self._sanitized_to_key[sanitized] + return rec["materialized"] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py b/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py new file mode 100644 index 00000000..922de5c4 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py @@ -0,0 +1,316 @@ +""" +MCPApp - A FastAPI-like interface for MCP servers. + +Provides a clean, minimal API for building MCP servers with lazy initialization. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any, Callable, Literal, ParamSpec, TypeVar + +from arcade_core.catalog import MaterializedTool, ToolCatalog, ToolDefinitionError +from arcade_tdk.auth import ToolAuthorization +from arcade_tdk.error_adapters import ErrorAdapter +from arcade_tdk.tool import tool as tool_decorator +from dotenv import load_dotenv +from loguru import logger + +from arcade_mcp_server.exceptions import ServerError +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.types import Prompt, PromptMessage, Resource +from arcade_mcp_server.worker import run_arcade_mcp + +P = ParamSpec("P") +T = TypeVar("T") + +TransportType = Literal["http", "stdio"] + + +class MCPApp: + """ + A FastAPI-like interface for building MCP servers. + + The app collects tools and configuration, then lazily creates the server + and transport when run() is called. + + Example: + ```python + from arcade_mcp_server import MCPApp + + app = MCPApp(name="my_server", version="1.0.0") + + @app.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + # Runtime CRUD once you have a server bound to the app: + # app.server = mcp_server + # await app.tools.add(materialized_tool) + # await app.prompts.add(prompt, handler) + # await app.resources.add(resource) + + app.run(host="127.0.0.1", port=7777) + ``` + """ + + def __init__( + self, + name: str = "ArcadeMCP", + version: str = "1.0.0dev", + title: str | None = None, + instructions: str | None = None, + log_level: str = "INFO", + transport: TransportType = "http", + host: str = "127.0.0.1", + port: int = 7777, + reload: bool = False, + **kwargs: Any, + ): + """ + Initialize the MCP app. + + Args: + name: Server name + version: Server version + title: Server title for display + instructions: Server instructions + log_level: Logging level (DEBUG, INFO, WARNING, ERROR) + transport: Transport type ("http") + host: Host for transport + port: Port for transport + reload: Enable auto-reload for development + **kwargs: Additional server configuration + """ + self.name = name + self.version = version + self.title = title or name + self.instructions = instructions + self.log_level = log_level + self.server_kwargs = kwargs + self.transport = transport + self.host = host + self.port = port + self.reload = reload + + # Tool collection (build-time) + self._catalog = ToolCatalog() + self._toolkit_name = name + + # Public handle to the MCPServer (set by caller for runtime ops) + self.server: MCPServer | None = None + + self._load_env() + self._setup_logging() + + # Properties (exposed below initializer) + @property + def tools(self) -> _ToolsAPI: + """Runtime and build-time tools API: add/update/remove/list.""" + return _ToolsAPI(self) + + @property + def prompts(self) -> _PromptsAPI: + """Runtime prompts API: add/remove/list.""" + return _PromptsAPI(self) + + @property + def resources(self) -> _ResourcesAPI: + """Runtime resources API: add/remove/list.""" + return _ResourcesAPI(self) + + def _load_env(self) -> None: + """Load .env file from the current directory.""" + env_path = Path.cwd() / ".env" + if env_path.exists(): + load_dotenv(env_path, override=False) + logger.info(f"Loaded environment from {env_path}") + + def _setup_logging(self) -> None: + logger.remove() + if self.log_level == "DEBUG": + format_str = "{level: <8} | {time:HH:mm:ss} | {name}:{line} | {message}" + else: + format_str = "{level: <8} | {time:HH:mm:ss} | {message}" + logger.add( + sys.stdout, + format=format_str, + level=self.log_level, + colorize=True, + diagnose=(self.log_level == "DEBUG"), + ) + + def add_tool( + self, + func: Callable[P, T], + desc: str | None = None, + name: str | None = None, + requires_auth: ToolAuthorization | None = None, + requires_secrets: list[str] | None = None, + requires_metadata: list[str] | None = None, + adapters: list[ErrorAdapter] | None = None, + ) -> Callable[P, T]: + """Add a tool for build-time materialization (pre-server).""" + if not hasattr(func, "__tool_name__"): + func = tool_decorator( + func, + desc=desc, + name=name, + requires_auth=requires_auth, + requires_secrets=requires_secrets, + requires_metadata=requires_metadata, + adapters=adapters, + ) + try: + self._catalog.add_tool(func, self._toolkit_name) + except ToolDefinitionError as e: + raise e.with_context(func.__name__) from e + logger.debug(f"Added tool: {func.__name__}") + return func + + def tool( + self, + func: Callable[P, T] | None = None, + desc: str | None = None, + name: str | None = None, + requires_auth: ToolAuthorization | None = None, + requires_secrets: list[str] | None = None, + requires_metadata: list[str] | None = None, + adapters: list[ErrorAdapter] | None = None, + ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]: + """Decorator for adding tools with optional parameters.""" + + def decorator(f: Callable[P, T]) -> Callable[P, T]: + return self.add_tool( + f, + desc=desc, + name=name, + requires_auth=requires_auth, + requires_secrets=requires_secrets, + requires_metadata=requires_metadata, + adapters=adapters, + ) + + if func is not None: + return decorator(func) + return decorator + + def run( + self, + host: str = "127.0.0.1", + port: int = 7777, + reload: bool = False, + transport: TransportType = "http", + **kwargs: Any, + ) -> None: + if len(self._catalog) == 0: + logger.error("No tools added to the server. Use @app.tool decorator or app.add_tool().") + sys.exit(1) + + logger.info(f"Starting {self.name} v{self.version} with {len(self._catalog)} tools") + + if transport in ["http", "streamable-http", "streamable"]: + run_arcade_mcp( + catalog=self._catalog, + host=host, + port=port, + reload=reload, + **self.server_kwargs, + ) + elif transport == "stdio": + import asyncio + + from arcade_mcp_server.__main__ import run_stdio_server + + asyncio.run( + run_stdio_server( + catalog=self._catalog, + host=host, + port=port, + reload=reload, + **self.server_kwargs, + ) + ) + else: + raise ServerError(f"Invalid transport: {transport}") + + +class _ToolsAPI: + """Unified tools API for MCPApp (build-time and runtime).""" + + def __init__(self, app: MCPApp) -> None: + self._app = app + + async def add(self, tool: MaterializedTool) -> None: + """Add or update a tool at runtime if server is bound; otherwise queue via app.add_tool decorator.""" + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime tools API.") + await self._app.server.tools.add_tool(tool) + + async def update(self, tool: MaterializedTool) -> None: + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime tools API.") + await self._app.server.tools.update_tool(tool) + + async def remove(self, name: str) -> MaterializedTool: + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime tools API.") + return await self._app.server.tools.remove_tool(name) + + async def list(self) -> list[Any]: + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime tools API.") + return await self._app.server.tools.list_tools() + + +class _PromptsAPI: + """Unified prompts API for MCPApp (runtime).""" + + def __init__(self, app: MCPApp) -> None: + self._app = app + + async def add( + self, prompt: Prompt, handler: Callable[[dict[str, str]], list[PromptMessage]] | None = None + ) -> None: + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime prompts API.") + await self._app.server.prompts.add_prompt(prompt, handler) + + async def remove(self, name: str) -> Prompt: + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime prompts API.") + return await self._app.server.prompts.remove_prompt(name) + + async def list(self) -> list[Prompt]: + if self._app.server is None: + raise ServerError("No server bound to app. Set app.server to use runtime prompts API.") + return await self._app.server.prompts.list_prompts() + + +class _ResourcesAPI: + """Unified resources API for MCPApp (runtime).""" + + def __init__(self, app: MCPApp) -> None: + self._app = app + + async def add(self, resource: Resource, handler: Callable[[str], Any] | None = None) -> None: + if self._app.server is None: + raise ServerError( + "No server bound to app. Set app.server to use runtime resources API." + ) + await self._app.server.resources.add_resource(resource, handler) + + async def remove(self, uri: str) -> Resource: + if self._app.server is None: + raise ServerError( + "No server bound to app. Set app.server to use runtime resources API." + ) + return await self._app.server.resources.remove_resource(uri) + + async def list(self) -> list[Resource]: + if self._app.server is None: + raise ServerError( + "No server bound to app. Set app.server to use runtime resources API." + ) + return await self._app.server.resources.list_resources() diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py new file mode 100644 index 00000000..858b815a --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py @@ -0,0 +1,17 @@ +"""MCP Middleware System""" + +from arcade_mcp_server.middleware.base import ( + CallNext, + Middleware, + MiddlewareContext, +) +from arcade_mcp_server.middleware.error_handling import ErrorHandlingMiddleware +from arcade_mcp_server.middleware.logging import LoggingMiddleware + +__all__ = [ + "CallNext", + "ErrorHandlingMiddleware", + "LoggingMiddleware", + "Middleware", + "MiddlewareContext", +] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py new file mode 100644 index 00000000..93765d95 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py @@ -0,0 +1,241 @@ +"""Base middleware classes for MCP server.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field, replace +from datetime import datetime, timezone +from functools import partial +from typing import Any, Generic, Literal, Protocol, TypeVar, cast, runtime_checkable + +from arcade_mcp_server.types import ( + CallToolParams, + CallToolResult, + GetPromptParams, + GetPromptResult, + JSONRPCMessage, + ListPromptsRequest, + ListResourcesRequest, + ListResourceTemplatesRequest, + ListToolsRequest, + MCPTool, + Prompt, + ReadResourceParams, + ReadResourceResult, + Resource, + ResourceTemplate, +) + +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +@runtime_checkable +class CallNext(Protocol[T, R]): + """Protocol for the next handler in the middleware chain.""" + + def __call__(self, context: "MiddlewareContext[T]") -> Awaitable[R]: ... + + +@dataclass(kw_only=True) +class MiddlewareContext(Generic[T]): + """Context passed through the middleware chain. + + Contains the message being processed and metadata about the request. + """ + + # The message being processed + message: T + + # The MCP context (optional, set when in request context) + mcp_context: Any | None = None + + # Metadata + source: Literal["client", "server"] = "client" + type: Literal["request", "notification"] = "request" + method: str | None = None + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + # Request-specific metadata + request_id: str | None = None + session_id: str | None = None + + # Additional metadata that can be added by middleware + metadata: dict[str, Any] = field(default_factory=dict) + + def copy(self, **kwargs: Any) -> "MiddlewareContext[T]": + """Create a copy with updated fields.""" + return replace(self, **kwargs) + + +class Middleware: + """Base class for MCP middleware with typed handlers for each method. + + Middleware can intercept and modify requests and responses at various + stages of processing. Each handler receives the context and a call_next + function to invoke the next handler in the chain. + """ + + async def __call__( + self, + context: MiddlewareContext[T], + call_next: CallNext[T, Any], + ) -> Any: + """Main entry point that orchestrates the middleware chain.""" + # Build handler chain based on message type + handler = await self._build_handler_chain(context, call_next) + return await handler(context) + + async def _build_handler_chain( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> CallNext[Any, Any]: + """Build the handler chain for the specific message type.""" + handler = call_next + + # Method-specific handlers + if context.method: + match context.method: + case "tools/call": + handler = partial(self.on_call_tool, call_next=handler) + case "tools/list": + handler = partial(self.on_list_tools, call_next=handler) + case "resources/read": + handler = partial(self.on_read_resource, call_next=handler) + case "resources/list": + handler = partial(self.on_list_resources, call_next=handler) + case "resources/templates/list": + handler = partial(self.on_list_resource_templates, call_next=handler) + case "prompts/get": + handler = partial(self.on_get_prompt, call_next=handler) + case "prompts/list": + handler = partial(self.on_list_prompts, call_next=handler) + + # Type-specific handlers + match context.type: + case "request": + handler = partial(self.on_request, call_next=handler) + case "notification": + handler = partial(self.on_notification, call_next=handler) + + # Generic message handler (always runs) + handler = partial(self.on_message, call_next=handler) + + return handler + + # Generic handlers + async def on_message( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Handle any message. Override to add generic processing.""" + return await call_next(context) + + async def on_request( + self, + context: MiddlewareContext[JSONRPCMessage], + call_next: CallNext[JSONRPCMessage, Any], + ) -> Any: + """Handle request messages. Override to add request processing.""" + return await call_next(context) + + async def on_notification( + self, + context: MiddlewareContext[JSONRPCMessage], + call_next: CallNext[JSONRPCMessage, Any], + ) -> Any: + """Handle notification messages. Override to add notification processing.""" + return await call_next(context) + + # Tool handlers + async def on_call_tool( + self, + context: MiddlewareContext[CallToolParams], + call_next: CallNext[CallToolParams, CallToolResult], + ) -> CallToolResult: + """Handle tool calls. Override to add tool-specific processing.""" + return await call_next(context) + + async def on_list_tools( + self, + context: MiddlewareContext[ListToolsRequest], + call_next: CallNext[ListToolsRequest, list[MCPTool]], + ) -> list[MCPTool]: + """Handle tool listing. Override to filter or modify tool list.""" + return await call_next(context) + + # Resource handlers + async def on_read_resource( + self, + context: MiddlewareContext[ReadResourceParams], + call_next: CallNext[ReadResourceParams, ReadResourceResult], + ) -> ReadResourceResult: + """Handle resource reading. Override to add resource processing.""" + return await call_next(context) + + async def on_list_resources( + self, + context: MiddlewareContext[ListResourcesRequest], + call_next: CallNext[ListResourcesRequest, list[Resource]], + ) -> list[Resource]: + """Handle resource listing. Override to filter or modify resource list.""" + return await call_next(context) + + async def on_list_resource_templates( + self, + context: MiddlewareContext[ListResourceTemplatesRequest], + call_next: CallNext[ListResourceTemplatesRequest, list[ResourceTemplate]], + ) -> list[ResourceTemplate]: + """Handle resource template listing. Override to filter or modify template list.""" + return await call_next(context) + + # Prompt handlers + async def on_get_prompt( + self, + context: MiddlewareContext[GetPromptParams], + call_next: CallNext[GetPromptParams, GetPromptResult], + ) -> GetPromptResult: + """Handle prompt retrieval. Override to add prompt processing.""" + return await call_next(context) + + async def on_list_prompts( + self, + context: MiddlewareContext[ListPromptsRequest], + call_next: CallNext[ListPromptsRequest, list[Prompt]], + ) -> list[Prompt]: + """Handle prompt listing. Override to filter or modify prompt list.""" + return await call_next(context) + + +def compose_middleware( + *middleware: Middleware, +) -> Callable[[MiddlewareContext[T], CallNext[T, R]], Awaitable[R]]: + """Compose multiple middleware into a single handler. + + The middleware are applied in reverse order, so the first middleware + in the list is the outermost (runs first on request, last on response). + """ + + async def composed( + context: MiddlewareContext[T], + call_next: CallNext[T, R], + ) -> R: + # Build the chain in reverse order into a CallNext[T, R] + current: CallNext[T, R] = call_next + + for mw in reversed(middleware): + + async def wrapper( + ctx: MiddlewareContext[T], + next_handler: CallNext[T, R] = current, + m: Middleware = mw, + ) -> R: + result = await m(ctx, next_handler) + return cast(R, result) + + # wrapper conforms to CallNext[T, R] + current = wrapper # type: ignore[assignment] + + return await current(context) + + return composed diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/error_handling.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/error_handling.py new file mode 100644 index 00000000..640c8b93 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/error_handling.py @@ -0,0 +1,107 @@ +"""Error handling middleware for MCP server.""" + +import logging +from typing import Any + +from arcade_mcp_server.convert import convert_content_to_structured_content, convert_to_mcp_content +from arcade_mcp_server.middleware.base import CallNext, Middleware, MiddlewareContext +from arcade_mcp_server.types import CallToolResult, JSONRPCError + +logger = logging.getLogger("arcade.mcp") + + +class ErrorHandlingMiddleware(Middleware): + """Middleware that handles errors and converts them to appropriate responses.""" + + def __init__(self, mask_error_details: bool = True): + """Initialize error handling middleware. + + Args: + mask_error_details: Whether to mask error details in responses + """ + self.mask_error_details = mask_error_details + + async def on_message( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Wrap all messages with error handling.""" + try: + return await call_next(context) + except Exception as e: + return self._handle_error(context, e) + + async def on_call_tool( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Handle tool call errors specially.""" + try: + return await call_next(context) + except Exception as e: + # For tool calls, return error as CallToolResult + error_message = self._get_error_message(e) + logger.exception(f"Error calling tool: {error_message}") + + content = convert_to_mcp_content(error_message) + structured_content = convert_content_to_structured_content({"error": error_message}) + + return CallToolResult( + content=content, + structuredContent=structured_content, + isError=True, + ) + + def _handle_error(self, context: MiddlewareContext[Any], error: Exception) -> Any: + """Convert exception to appropriate error response.""" + error_message = self._get_error_message(error) + + # Log the full error + logger.exception(f"Error processing {context.method}: {error}") + + # Get request ID if available + request_id = context.request_id + if not request_id and hasattr(context.message, "id"): + request_id = str(getattr(context.message, "id", "unknown")) + + # Return JSON-RPC error + return JSONRPCError( + id=request_id or "unknown", + error={ + "code": self._get_error_code(error), + "message": error_message, + }, + ) + + def _get_error_message(self, error: Exception) -> str: + """Get appropriate error message based on configuration.""" + if self.mask_error_details: + # Return generic message for security + error_type = type(error).__name__ + if error_type in ["ValueError", "TypeError", "KeyError"]: + return "Invalid request parameters" + elif error_type in ["NotFoundError", "FileNotFoundError"]: + return "Resource not found" + elif error_type in ["PermissionError", "AuthorizationError"]: + return "Permission denied" + else: + return "Internal server error" + else: + # Return actual error message for debugging + return str(error) + + def _get_error_code(self, error: Exception) -> int: + """Get JSON-RPC error code for exception.""" + error_type = type(error).__name__ + + # Map common errors to JSON-RPC codes + if error_type in ["ValueError", "TypeError", "KeyError"]: + return -32602 # Invalid params + elif error_type in ["NotFoundError", "FileNotFoundError"]: + return -32601 # Method not found + elif error_type in ["PermissionError", "AuthorizationError"]: + return -32603 # Internal error (used for auth) + else: + return -32603 # Generic internal error diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/logging.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/logging.py new file mode 100644 index 00000000..da8d3469 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/logging.py @@ -0,0 +1,121 @@ +"""Logging middleware for MCP server.""" + +import logging +import time +from typing import Any + +from arcade_mcp_server.middleware.base import CallNext, Middleware, MiddlewareContext + +logger = logging.getLogger("arcade.mcp") + + +class LoggingMiddleware(Middleware): + """Middleware that logs all MCP messages and timing information.""" + + def __init__(self, log_level: str = "INFO"): + """Initialize logging middleware. + + Args: + log_level: The log level to use for message logging + """ + self.log_level = getattr(logging, log_level.upper(), logging.INFO) + + async def on_message( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Log all messages with timing information.""" + start_time = time.time() + + # Log the incoming message + self._log_request(context) + + try: + # Process the message + result = await call_next(context) + except Exception as e: + # Log error + elapsed = time.time() - start_time + self._log_error(context, e, elapsed) + raise + else: + # Log success + elapsed = time.time() - start_time + self._log_response(context, result, elapsed) + return result + + def _log_request(self, context: MiddlewareContext[Any]) -> None: + """Log incoming request.""" + if not logger.isEnabledFor(self.log_level): + return + + method = context.method or "unknown" + msg_type = context.type + + # Build log message + parts = [f"[{msg_type.upper()}]", f"method={method}"] + + if context.request_id: + parts.append(f"request_id={context.request_id}") + if context.session_id: + parts.append(f"session_id={context.session_id}") + + # Log message details based on method + if hasattr(context.message, "params"): + params = getattr(context.message, "params", None) + if params: + if hasattr(params, "name"): + parts.append(f"name={params.name}") + elif hasattr(params, "uri"): + parts.append(f"uri={params.uri}") + + logger.log(self.log_level, " ".join(parts)) + + def _log_response( + self, + context: MiddlewareContext[Any], + result: Any, + elapsed: float, + ) -> None: + """Log response with timing.""" + if not logger.isEnabledFor(self.log_level): + return + + method = context.method or "unknown" + elapsed_ms = int(elapsed * 1000) + + # Build log message + parts = ["[RESPONSE]", f"method={method}", f"elapsed={elapsed_ms}ms"] + + if context.request_id: + parts.append(f"request_id={context.request_id}") + + # Add result info based on type + if isinstance(result, list): + parts.append(f"count={len(result)}") + elif hasattr(result, "content"): + content = getattr(result, "content", []) + if isinstance(content, list): + parts.append(f"content_blocks={len(content)}") + + logger.log(self.log_level, " ".join(parts)) + + def _log_error( + self, + context: MiddlewareContext[Any], + error: Exception, + elapsed: float, + ) -> None: + """Log error with timing.""" + method = context.method or "unknown" + elapsed_ms = int(elapsed * 1000) + + parts = ["[ERROR]", f"method={method}", f"elapsed={elapsed_ms}ms"] + + if context.request_id: + parts.append(f"request_id={context.request_id}") + + parts.append(f"error={type(error).__name__}: {error!s}") + + logger.error(" ".join(parts)) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/server.py b/libs/arcade-mcp-server/arcade_mcp_server/server.py new file mode 100644 index 00000000..6887c50a --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/server.py @@ -0,0 +1,898 @@ +""" +MCP Server Implementation + +Provides request handling, middleware orchestration, and manager-backed +operations for tools, resources, prompts, sampling, logging, and roots. + +Key notes: +- For every incoming request, a new MCP ModelContext is created and set as + current via a ContextVar for the request lifetime +- Tool invocations receive a ToolContext (wrapped by TDK as needed) and are + executed via ToolExecutor +- Managers (tool, resource, prompt) back the namespaced operations +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Any, Callable, cast + +from arcade_core.catalog import MaterializedTool, ToolCatalog +from arcade_core.executor import ToolExecutor +from arcade_core.schema import ToolAuthRequirement as CoreToolAuthRequirement +from arcade_core.schema import ToolContext +from arcadepy import ArcadeError, AsyncArcade +from arcadepy.types.auth_authorize_params import AuthRequirement, AuthRequirementOauth2 + +from arcade_mcp_server.context import Context, get_current_model_context, set_current_model_context +from arcade_mcp_server.convert import convert_content_to_structured_content, convert_to_mcp_content +from arcade_mcp_server.exceptions import NotFoundError, ToolRuntimeError +from arcade_mcp_server.lifespan import LifespanManager +from arcade_mcp_server.managers import PromptManager, ResourceManager, ToolManager +from arcade_mcp_server.middleware import ( + CallNext, + ErrorHandlingMiddleware, + LoggingMiddleware, + Middleware, + MiddlewareContext, +) +from arcade_mcp_server.session import InitializationState, NotificationManager, ServerSession +from arcade_mcp_server.settings import MCPSettings +from arcade_mcp_server.types import ( + LATEST_PROTOCOL_VERSION, + BlobResourceContents, + CallToolRequest, + CallToolResult, + CompleteRequest, + CreateMessageRequest, + ElicitRequest, + GetPromptRequest, + GetPromptResult, + Implementation, + InitializeRequest, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListResourceTemplatesRequest, + ListResourceTemplatesResult, + ListRootsRequest, + ListToolsRequest, + ListToolsResult, + MCPMessage, + PingRequest, + ReadResourceRequest, + ReadResourceResult, + ServerCapabilities, + SetLevelRequest, + SubscribeRequest, + TextResourceContents, + UnsubscribeRequest, +) + +logger = logging.getLogger("arcade.mcp") + + +class MCPServer: + """ + MCP Server with middleware and context support. + + This server provides: + - Middleware chain for extensible request processing + - Context injection for tools + - Component managers for tools, resources, and prompts + - Bidirectional communication support to MCP clients + """ + + # Public manager properties near top + @property + def tools(self) -> ToolManager: + """Access the ToolManager for runtime tool operations.""" + return self._tool_manager + + @property + def resources(self) -> ResourceManager: + """Access the ResourceManager for runtime resource operations.""" + return self._resource_manager + + @property + def prompts(self) -> PromptManager: + """Access the PromptManager for runtime prompt operations.""" + return self._prompt_manager + + def __init__( + self, + catalog: ToolCatalog, + *, + name: str = "ArcadeMCP", + version: str = "0.1.0", + title: str | None = None, + instructions: str | None = None, + settings: MCPSettings | None = None, + middleware: list[Middleware] | None = None, + lifespan: Callable[[Any], Any] | None = None, + auth_disabled: bool = False, + arcade_api_key: str | None = None, + arcade_api_url: str | None = None, + ): + """ + Initialize MCP server. + + Args: + catalog: Tool catalog + name: Server name + version: Server version + title: Server title for display + instructions: Server instructions + settings: MCP settings (uses env if not provided) + middleware: List of middleware to apply + lifespan: Lifespan manager function + auth_disabled: Disable authentication + arcade_api_key: Arcade API key (overrides settings) + arcade_api_url: Arcade API URL (overrides settings) + """ + self.name = name or self.__class__.__name__ + self._started = False + self._lock = asyncio.Lock() + + # Server identity + self.version = version + self.title = title or name + self.instructions = instructions or self._default_instructions() + + # Settings + self.settings = settings or MCPSettings.from_env() + self.auth_disabled = auth_disabled or self.settings.arcade.auth_disabled + + # Initialize Arcade client + # Fallback to API key in ~/.arcade/credentials.yaml if not provided + self._init_arcade_client( + arcade_api_key or self.settings.arcade.api_key, + arcade_api_url or self.settings.arcade.api_url, + ) + + # Component managers (passive) + self._tool_manager = ToolManager() + self._resource_manager = ResourceManager() + self._prompt_manager = PromptManager() + + # Centralized notifications + self.notification_manager = NotificationManager(self) + + # Subscribe to changes -> broadcast + self._tool_manager.subscribe( + lambda *_: asyncio.get_event_loop().create_task( # type: ignore[arg-type] + self.notification_manager.notify_tool_list_changed() + ) + ) + self._resource_manager.subscribe( + lambda *_: asyncio.get_event_loop().create_task( # type: ignore[arg-type] + self.notification_manager.notify_resource_list_changed() + ) + ) + self._prompt_manager.subscribe( + lambda *_: asyncio.get_event_loop().create_task( # type: ignore[arg-type] + self.notification_manager.notify_prompt_list_changed() + ) + ) + + # Defer loading tools from catalog to server start to ensure readiness + self._initial_catalog = catalog + + # Middleware chain + self.middleware: list[Middleware] = [] + self._init_middleware(middleware) + + # Lifespan management + self.lifespan_manager = LifespanManager(self, lifespan) + + # Session management + self._sessions: dict[str, ServerSession] = {} + self._sessions_lock = asyncio.Lock() + + # Handler registration + self._handlers = self._register_handlers() + + def _init_arcade_client(self, api_key: str | None, api_url: str | None) -> None: + """Initialize Arcade client for runtime authorization.""" + self.arcade: AsyncArcade | None = None + + if not api_url: + api_url = os.environ.get("ARCADE_API_URL", "https://api.arcade.dev") + + final_api_key = api_key + + # If no API key provided, try to load from credentials file + if not final_api_key: + try: + from arcade_core.config import get_config + + config = get_config() + final_api_key = config.api.key + if final_api_key: + logger.info("Loaded Arcade API key from ~/.arcade/credentials.yaml") + except Exception as e: + logger.debug(f"Could not load credentials from file: {e}") + + if final_api_key: + logger.info(f"Using Arcade client with API URL: {api_url}") + self.arcade = AsyncArcade(api_key=final_api_key, base_url=api_url) + else: + logger.warning( + "Arcade API key not configured. Tools requiring auth will return a login instruction." + ) + + def _init_middleware(self, custom_middleware: list[Middleware] | None) -> None: + """Initialize middleware chain.""" + # Always add error handling first (innermost) + self.middleware.append( + ErrorHandlingMiddleware(mask_error_details=self.settings.middleware.mask_error_details) + ) + + # Add logging if enabled + if self.settings.middleware.enable_logging: + self.middleware.append(LoggingMiddleware(log_level=self.settings.middleware.log_level)) + + # Add custom middleware + if custom_middleware: + self.middleware.extend(custom_middleware) + + def _register_handlers(self) -> dict[str, Callable]: + """Register method handlers.""" + return { + "ping": self._handle_ping, + "initialize": self._handle_initialize, + "tools/list": self._handle_list_tools, + "tools/call": self._handle_call_tool, + "resources/list": self._handle_list_resources, + "resources/templates/list": self._handle_list_resource_templates, + "resources/read": self._handle_read_resource, + "prompts/list": self._handle_list_prompts, + "prompts/get": self._handle_get_prompt, + "logging/setLevel": self._handle_set_log_level, + } + + def _default_instructions(self) -> str: + """Get default server instructions.""" + return ( + "The Arcade MCP Server provides access to tools defined in Arcade toolkits. " + "Use 'tools/list' to see available tools and 'tools/call' to execute them." + ) + + async def _start(self) -> None: + """Start server components (called by MCPComponent.start).""" + await self._tool_manager.start() + # Load initial catalog now that manager is started + try: + await self._tool_manager.load_from_catalog(self._initial_catalog) + except Exception: + logger.exception("Failed to load tools from initial catalog") + await self._resource_manager.start() + await self._prompt_manager.start() + await self.lifespan_manager.startup() + + async def _stop(self) -> None: + """Stop server components (called by MCPComponent.stop).""" + # Stop all sessions + async with self._sessions_lock: + sessions = list(self._sessions.values()) + for _session in sessions: + # Sessions should handle their own cleanup + pass + + await self._prompt_manager.stop() + await self._resource_manager.stop() + await self._tool_manager.stop() + + # Stop lifespan + await self.lifespan_manager.shutdown() + + async def start(self) -> None: + async with self._lock: + if self._started: + logger.debug(f"{self.name} already started") + return + logger.info(f"Starting {self.name}") + try: + await self._start() + self._started = True + logger.info(f"{self.name} started successfully") + except Exception: + logger.exception(f"Failed to start {self.name}") + raise + + async def stop(self) -> None: + async with self._lock: + if not self._started: + logger.debug(f"{self.name} not started") + return + logger.info(f"Stopping {self.name}") + try: + await self._stop() + self._started = False + logger.info(f"{self.name} stopped successfully") + except Exception: + logger.exception(f"Failed to stop {self.name}") + # best-effort on stop + + async def run_connection( + self, + read_stream: Any, + write_stream: Any, + init_options: Any = None, + ) -> None: + """ + Run a single MCP connection. + + Args: + read_stream: Stream for reading messages + write_stream: Stream for writing messages + init_options: Connection initialization options + """ + + # Create session + session = ServerSession( + server=self, + read_stream=read_stream, + write_stream=write_stream, + init_options=init_options, + ) + + # Register session + async with self._sessions_lock: + self._sessions[session.session_id] = session + + try: + logger.info(f"Starting session {session.session_id}") + await session.run() + except Exception: + logger.exception("Session error") + raise + finally: + # Unregister session + async with self._sessions_lock: + self._sessions.pop(session.session_id, None) + logger.info(f"Session {session.session_id} ended") + + async def handle_message( + self, + message: Any, + session: ServerSession | None = None, + ) -> MCPMessage | None: + """ + Handle an incoming message. + + Args: + message: Message to handle + session: Server session + + Returns: + Response message or None + """ + # Validate message + if ( + not isinstance(message, dict) + or not message.get("method") + or not isinstance(message["method"], str) + ): + return JSONRPCError( + id="null", + error={"code": -32600, "message": "Invalid request"}, + ) + + method = message["method"] + msg_id = message.get("id") + + # Handle notifications (no response needed) + if method and method.startswith("notifications/"): + if method == "notifications/initialized" and session: + session.mark_initialized() + return None + + # Check if this is a response to a server-initiated request + if "id" in message and "method" not in message: + # This is handled in the session's message processing + return None + + # Check initialization state + if ( + session + and session.initialization_state != InitializationState.INITIALIZED + and method not in ["initialize", "ping"] + ): + return JSONRPCError( + id=str(msg_id or "null"), + error={ + "code": -32600, + "message": "Request not allowed before initialization", + }, + ) + + # Find handler + handler = self._handlers.get(method) + if not handler: + return JSONRPCError( + id=str(msg_id or "null"), + error={"code": -32601, "message": f"Method not found: {method}"}, + ) + + # Create context and apply middleware + try: + # Create request context + context = ( + await session.create_request_context() + if session + else Context(self, request_id=str(msg_id) if msg_id else None) + ) + + # Set as current model context + token = set_current_model_context(context) + + try: + # Create middleware context + middleware_context = MiddlewareContext( + message=message, + mcp_context=context, + source="client", + type="request", + method=method, + request_id=str(msg_id) if msg_id else None, + session_id=session.session_id if session else None, + ) + + # Parse message based on method + parsed_message = self._parse_message(message, method or "") + + # Apply middleware chain + async def final_handler(_: MiddlewareContext[Any]) -> Any: + return await handler(parsed_message, session=session) + + result = await self._apply_middleware(middleware_context, final_handler) + + from typing import cast + + return cast(MCPMessage | None, result) + + finally: + # Clean up context + set_current_model_context(None, token) + if session: + await session.cleanup_request_context(context) + + except Exception: + logger.exception("Error handling message") + return JSONRPCError( + id=str(msg_id or "null"), + error={"code": -32603, "message": "Internal error"}, + ) + + def _parse_message(self, message: dict[str, Any], method: str) -> Any: + """Parse raw message dict into typed message based on method.""" + message_types = { + "ping": PingRequest, + "initialize": InitializeRequest, + "tools/list": ListToolsRequest, + "tools/call": CallToolRequest, + "resources/list": ListResourcesRequest, + "resources/read": ReadResourceRequest, + "resources/subscribe": SubscribeRequest, + "resources/unsubscribe": UnsubscribeRequest, + "resources/templates/list": ListResourceTemplatesRequest, + "prompts/list": ListPromptsRequest, + "prompts/get": GetPromptRequest, + "logging/setLevel": SetLevelRequest, + "sampling/createMessage": CreateMessageRequest, + "completion/complete": CompleteRequest, + "roots/list": ListRootsRequest, + "elicitation/create": ElicitRequest, + } + + message_type = message_types.get(method) + if message_type is not None: + # Use constructor for compatibility across Pydantic versions + return message_type(**message) + return message + + async def _apply_middleware( + self, + context: MiddlewareContext[Any], + final_handler: Callable[[MiddlewareContext[Any]], Any] | CallNext[Any, Any], + ) -> Any: + """Apply middleware chain to a request.""" + + # Build chain from outside in + async def chain_fn(ctx: MiddlewareContext[Any]) -> Any: + return await final_handler(ctx) + + chain: CallNext[Any, Any] = cast(CallNext[Any, Any], chain_fn) + + for middleware in reversed(self.middleware): + + async def make_handler( + ctx: MiddlewareContext[Any], + next_handler: CallNext[Any, Any] = chain, + mw: Middleware = middleware, + ) -> Any: + return await mw(ctx, next_handler) + + chain = make_handler # type: ignore[assignment] + + # Execute chain + return await chain(context) + + # Handler methods + async def _handle_ping( + self, + message: PingRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[Any]: + """Handle ping request.""" + return JSONRPCResponse(id=message.id, result={}) + + async def _handle_initialize( + self, + message: InitializeRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[InitializeResult]: + """Handle initialize request.""" + if session: + session.set_client_params(message.params) + + result = InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + tools={"listChanged": True}, + logging={}, + prompts={"listChanged": True}, + resources={"subscribe": True, "listChanged": True}, + ), + serverInfo=Implementation( + name=self.name, + version=self.version, + title=self.title, + ), + instructions=self.instructions, + ) + + return JSONRPCResponse(id=message.id, result=result) + + async def _handle_list_tools( + self, + message: ListToolsRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[ListToolsResult] | JSONRPCError: + """Handle list tools request.""" + try: + tools = await self._tool_manager.list_tools() + return JSONRPCResponse(id=message.id, result=ListToolsResult(tools=tools)) + except Exception: + logger.exception("Error listing tools") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error listing tools"}, + ) + + async def _create_tool_context( + self, tool: MaterializedTool, session: ServerSession | None = None + ) -> ToolContext: + """Create a tool context from a tool definition and session""" + tool_context = ToolContext() + + # secrets + if tool.definition.requirements and tool.definition.requirements.secrets: + for secret in tool.definition.requirements.secrets: + if secret.key in self.settings.tool_secrets(): + tool_context.set_secret(secret.key, self.settings.tool_secrets()[secret.key]) + elif secret.key in os.environ: + tool_context.set_secret(secret.key, os.environ[secret.key]) + + # user_id selection + env = (self.settings.arcade.environment or "").lower() + user_id = self.settings.arcade.user_id + + # If no user_id from env, try config file (like we do for API key) + if not user_id: + try: + from arcade_core.config import get_config + + config = get_config() + if config.user and config.user.email: + user_id = config.user.email + logger.debug(f"Context user_id set from config file: {user_id}") + except Exception: + logger.debug("Could not load user_id from config file") + + if user_id: + tool_context.user_id = user_id + logger.debug(f"Context user_id set: {user_id}") + elif env in ("development", "dev", "local"): + tool_context.user_id = session.session_id if session else None + logger.debug(f"Context user_id set from session (dev env={env})") + else: + tool_context.user_id = session.session_id if session else None + logger.debug("Context user_id set from session (non-dev env)") + + return tool_context + + async def _handle_call_tool( + self, + message: CallToolRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[CallToolResult] | JSONRPCError: + """Handle tool call request.""" + tool_name = message.params.name + input_params = message.params.arguments or {} + + try: + # Get tool + tool = await self._tool_manager.get_tool(tool_name) + + # Create tool context + tool_context = await self._create_tool_context(tool, session) + + # Attach tool_context to current model context for this request + mctx = get_current_model_context() + if mctx is not None: + mctx.set_tool_context(tool_context) + + # Handle authorization if required + if tool.definition.requirements and tool.definition.requirements.authorization: + auth_result = await self._check_authorization(tool, tool_context.user_id) + if auth_result.status != "completed": + tool_response = { + "message": "The tool was not executed because it requires authorization. This is not an error, but the end user must click the link to complete the OAuth2 flow before the tool can be executed.", + "llm_instructions": f"Please show the following link to the end user formatted as markdown: {auth_result.url} \nInform the end user that the tool requires their authorization to be completed before the tool can be executed.", + "authorization_url": auth_result.url, + } + content = convert_to_mcp_content(tool_response) + structured_content = convert_content_to_structured_content(tool_response) + return JSONRPCResponse( + id=message.id, + result=CallToolResult( + content=content, + structuredContent=structured_content, + isError=False, + ), + ) + + # Execute tool + result = await ToolExecutor.run( + func=tool.tool, + definition=tool.definition, + input_model=tool.input_model, + output_model=tool.output_model, + context=tool_context, + **input_params, + ) + + # Convert result + if result.value is not None: + content = convert_to_mcp_content(result.value) + + # structuredContent should be the raw result value as a JSON object + structured_content = convert_content_to_structured_content(result.value) + + return JSONRPCResponse( + id=message.id, + result=CallToolResult( + content=content, + structuredContent=structured_content, + isError=False, + ), + ) + else: + error = result.error or "Error calling tool" + content = convert_to_mcp_content(str(error)) + + # structuredContent should be the error as a JSON object + structured_content = convert_content_to_structured_content({"error": str(error)}) + + return JSONRPCResponse( + id=message.id, + result=CallToolResult( + content=content, + structuredContent=structured_content, + isError=True, + ), + ) + except NotFoundError: + # Match test expectation: return a normal response with isError=True + error_message = f"Unknown tool: {tool_name}" + content = convert_to_mcp_content(error_message) + + # structuredContent should be the error as a JSON object + structured_content = convert_content_to_structured_content({"error": error_message}) + + return JSONRPCResponse( + id=message.id, + result=CallToolResult( + content=content, + structuredContent=structured_content, + isError=True, + ), + ) + except Exception: + logger.exception("Error calling tool") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error calling tool"}, + ) + + async def _check_authorization( + self, + tool: MaterializedTool, + user_id: str | None = None, + ) -> Any: + """Check tool authorization.""" + if not self.arcade: + raise ToolRuntimeError( + "Authorization required but Arcade API Key is not configured. " + "Set ARCADE_API_KEY as environment variable or run 'arcade login'." + ) + + req = tool.definition.requirements.authorization + provider_id = str(getattr(req, "provider_id", "")) + provider_type = str(getattr(req, "provider_type", "")) + # TypedDict requires concrete type; supply empty scopes if absent when oauth2 provider + oauth2_req = ( + AuthRequirementOauth2( + scopes=(req.oauth2.scopes or []) if req.oauth2 is not None else [] + ) + if isinstance(req, CoreToolAuthRequirement) and provider_type.lower() == "oauth2" + else AuthRequirementOauth2() + ) + auth_req = AuthRequirement( + provider_id=provider_id, + provider_type=provider_type, + oauth2=oauth2_req, + ) + + # Log a warning if user_id is not set + final_user_id = user_id or "anonymous" + if final_user_id == "anonymous": + logger.warning( + "No user_id available for authorization, defaulting to 'anonymous'. " + "Set ARCADE_USER_ID as environment variable or run 'arcade login'." + ) + + try: + response = await self.arcade.auth.authorize( + auth_requirement=auth_req, + user_id=final_user_id, + ) + except ArcadeError as e: + logger.exception("Error authorizing tool") + raise ToolRuntimeError(f"Authorization failed: {e}") from e + else: + return response + + async def _handle_list_resources( + self, + message: ListResourcesRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[ListResourcesResult] | JSONRPCError: + """Handle list resources request.""" + try: + resources = await self._resource_manager.list_resources() + return JSONRPCResponse(id=message.id, result=ListResourcesResult(resources=resources)) + except Exception: + logger.exception("Error listing resources") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error listing resources"}, + ) + + async def _handle_list_resource_templates( + self, + message: ListResourceTemplatesRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[ListResourceTemplatesResult] | JSONRPCError: + """Handle list resource templates request.""" + try: + templates = await self._resource_manager.list_resource_templates() + return JSONRPCResponse( + id=message.id, + result=ListResourceTemplatesResult(resourceTemplates=templates), + ) + except Exception: + logger.exception("Error listing resource templates") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error listing resource templates"}, + ) + + async def _handle_read_resource( + self, + message: ReadResourceRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[ReadResourceResult] | JSONRPCError: + """Handle read resource request.""" + try: + contents = await self._resource_manager.read_resource(message.params.uri) + # Narrow to allowed types for ReadResourceResult + allowed_contents = [ + c for c in contents if isinstance(c, (TextResourceContents, BlobResourceContents)) + ] + return JSONRPCResponse( + id=message.id, + result=ReadResourceResult(contents=allowed_contents), + ) + except NotFoundError: + return JSONRPCError( + id=message.id, + error={"code": -32002, "message": f"Resource not found: {message.params.uri}"}, + ) + except Exception: + logger.exception(f"Error reading resource: {message.params.uri}") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error reading resource"}, + ) + + async def _handle_list_prompts( + self, + message: ListPromptsRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[ListPromptsResult] | JSONRPCError: + """Handle list prompts request.""" + try: + prompts = await self._prompt_manager.list_prompts() + return JSONRPCResponse(id=message.id, result=ListPromptsResult(prompts=prompts)) + except Exception: + logger.exception("Error listing prompts") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error listing prompts"}, + ) + + async def _handle_get_prompt( + self, + message: GetPromptRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[GetPromptResult] | JSONRPCError: + """Handle get prompt request.""" + try: + result = await self._prompt_manager.get_prompt( + message.params.name, + message.params.arguments if hasattr(message.params, "arguments") else None, + ) + return JSONRPCResponse(id=message.id, result=result) + except NotFoundError: + return JSONRPCError( + id=message.id, + error={"code": -32002, "message": f"Prompt not found: {message.params.name}"}, + ) + except Exception: + logger.exception(f"Error getting prompt: {message.params.name}") + return JSONRPCError( + id=message.id, + error={"code": -32603, "message": "Internal error getting prompt"}, + ) + + async def _handle_set_log_level( + self, + message: SetLevelRequest, + session: ServerSession | None = None, + ) -> JSONRPCResponse[Any] | JSONRPCError: + """Handle set log level request.""" + try: + level_name = str( + message.params.level.value + if hasattr(message.params.level, "value") + else message.params.level + ) + logger.setLevel(getattr(logging, level_name.upper(), logging.INFO)) + except Exception: + logger.setLevel(logging.INFO) + + return JSONRPCResponse(id=message.id, result={}) + + # Resource support for Context + async def _mcp_read_resource(self, uri: str) -> list[Any]: + """Read a resource (for Context.read_resource).""" + return await self._resource_manager.read_resource(uri) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/session.py b/libs/arcade-mcp-server/arcade_mcp_server/session.py new file mode 100644 index 00000000..0c4757c9 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/session.py @@ -0,0 +1,637 @@ +""" +MCP Server Session + +Manages per-session state and provides session-level operations. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from enum import Enum +from typing import Any + +from arcade_mcp_server.context import Context +from arcade_mcp_server.exceptions import RequestError, SessionError +from arcade_mcp_server.types import ( + CancelledNotification, + CancelledParams, + ClientCapabilities, + CompleteResult, + CreateMessageResult, + ElicitResult, + InitializeParams, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + ListRootsResult, + LoggingLevel, + LoggingMessageNotification, + LoggingMessageParams, + ProgressNotification, + ProgressNotificationParams, + PromptListChangedNotification, + ResourceListChangedNotification, + ToolListChangedNotification, +) + +logger = logging.getLogger(__name__) + + +class InitializationState(Enum): + """Session initialization states.""" + + NOT_INITIALIZED = 1 + INITIALIZING = 2 + INITIALIZED = 3 + + +class RequestManager: + """ + Manages server-initiated requests to the client. + + Handles request/response correlation for bidirectional communication. + """ + + def __init__(self, write_stream: Any): + """Initialize request manager.""" + self._write_stream = write_stream + self._pending_requests: dict[str, asyncio.Future[Any]] = {} + self._lock = asyncio.Lock() + self._closed = asyncio.Event() + + def is_closed(self) -> bool: + """Return True if the manager has been closed/cancelled.""" + return self._closed.is_set() + + async def send_request( + self, + method: str, + params: dict[str, Any] | None = None, + timeout: float = 400.0, + ) -> Any: + """ + Send a request to the client and wait for response. + + Args: + method: Request method + params: Request parameters + timeout: Request timeout in seconds + + Returns: + Response result + + Raises: + MCPTimeoutError: If request times out + ProtocolError: If response is an error + """ + if self._closed.is_set(): + raise SessionError("Session closed") + request_id = str(uuid.uuid4()) + + # Create request + request = JSONRPCRequest( + id=request_id, + method=method, + params=params or {}, + ) + + # Create future for response + future: asyncio.Future[Any] = asyncio.Future() + async with self._lock: + if self._closed.is_set(): + raise SessionError("Session closed") + self._pending_requests[request_id] = future + + try: + # Send request + message = request.model_dump_json(exclude_none=True) + "\n" + logger.debug(f"Sending server->client request method={method} id={request_id}") + await self._write_stream.send(message) + + # Wait for response + result = await asyncio.wait_for(future, timeout=timeout) + logger.debug(f"Received response for id={request_id} method={method}") + return result + + finally: + # Clean up + async with self._lock: + self._pending_requests.pop(request_id, None) + + async def handle_response(self, message: dict[str, Any]) -> None: + """ + Handle a response message from the client. + + Args: + message: Response message + """ + if self._closed.is_set(): + # Drop any late responses after closure + return + request_id = message.get("id") + if not request_id: + logger.debug("Received response without id; ignoring") + return + + async with self._lock: + future = self._pending_requests.get(str(request_id)) + + if future and not future.done(): + if "error" in message: + logger.debug(f"Response id={request_id} contains error; propagating") + future.set_exception(RequestError(f"Request failed: {message['error']}")) + else: + logger.debug(f"Correlated response id={request_id} -> completing future") + future.set_result(message.get("result")) + else: + logger.debug( + f"No pending future for response id={request_id}; possibly late or mismatched" + ) + + async def cancel_all(self, reason: str | None = None) -> None: + """Cancel all pending requests and notify the client. + + Sends a CancelledNotification for each in-flight request and + completes their futures with SessionError so awaiters unblock. + """ + # Mark closed first to prevent new requests + if not self._closed.is_set(): + self._closed.set() + # Snapshot current pending ids and futures + async with self._lock: + pending_items = list(self._pending_requests.items()) + # Clear the map eagerly to prevent races with late responses + self._pending_requests.clear() + + if not pending_items: + return + + # Best-effort notify client of cancellations + notifications = [] + for request_id, _future in pending_items: + notification = CancelledNotification( + params=CancelledParams(requestId=request_id, reason=reason) + ) + notifications.append(notification) + + try: + for note in notifications: + message = note.model_dump_json(exclude_none=True) + "\n" + await self._write_stream.send(message) + except Exception: + # Swallow transport errors during shutdown; proceed to cancel futures + logging.debug( + "Failed to send cancellation notifications during shutdown", exc_info=True + ) + + # Cancel futures so any waiters are released + for _request_id, future in pending_items: + if not future.done(): + future.set_exception(SessionError("Session closed")) + + +class NotificationManager: + """Broadcasts server-initiated listChanged notifications to sessions.""" + + def __init__(self, server: Any): + self._server = server + + async def _broadcast( + self, notification: JSONRPCMessage, session_ids: list[str] | None = None + ) -> None: + # Do not broadcast before server is started + if not getattr(self._server, "_started", False): + return + async with self._server._sessions_lock: + if session_ids is None: + sessions = list(self._server._sessions.values()) + else: + sessions = [ + self._server._sessions.get(sid) + for sid in session_ids + if sid in self._server._sessions + ] + for s in sessions: + if s is None: + continue + try: + await s.send_notification(notification) + except Exception: + logger.debug("Failed to notify a session", exc_info=True) + + async def notify_tool_list_changed(self, session_ids: list[str] | None = None) -> None: + await self._broadcast(ToolListChangedNotification(), session_ids) + + async def notify_resource_list_changed(self, session_ids: list[str] | None = None) -> None: + await self._broadcast(ResourceListChangedNotification(), session_ids) + + async def notify_prompt_list_changed(self, session_ids: list[str] | None = None) -> None: + await self._broadcast(PromptListChangedNotification(), session_ids) + + +class ServerSession: + """ + MCP server session handling a single client connection. + + Manages: + - Session state and lifecycle + - Client capabilities + - Request/response handling + - Notification sending + """ + + def __init__( + self, + server: Any, + session_id: str | None = None, + read_stream: Any | None = None, + write_stream: Any | None = None, + init_options: Any | None = None, + stateless: bool = False, + ): + """ + Initialize server session. + + Args: + server: Parent server instance + session_id: Session identifier (generated if not provided) + read_stream: Stream for reading messages + write_stream: Stream for writing messages + init_options: Initialization options + stateless: Whether session is stateless + """ + self.server = server + self.session_id = session_id or str(uuid.uuid4()) + self.read_stream = read_stream + self.write_stream = write_stream + self.init_options = init_options or {} + self.stateless = stateless + + # Session state + self.initialization_state = InitializationState.NOT_INITIALIZED + self.client_params: InitializeParams | None = None + self._session_data: dict[str, Any] = {} + + # Request management + self._request_manager = RequestManager(write_stream) if write_stream else None + + # Context for current request + self._current_context: Context | None = None + + def set_client_params(self, params: InitializeParams) -> None: + """Set client initialization parameters.""" + self.client_params = params + self.initialization_state = InitializationState.INITIALIZING + + def mark_initialized(self) -> None: + """Mark session as initialized.""" + self.initialization_state = InitializationState.INITIALIZED + + def check_client_capability(self, capability: ClientCapabilities) -> bool: + """ + Check if client has a specific capability. + + Args: + capability: Capability to check + + Returns: + True if client has capability + """ + if not self.client_params or not self.client_params.capabilities: + return False + + client_caps = self.client_params.capabilities + + # Check specific capabilities + # Use hasattr to check for attributes that might be in extra fields + if ( + hasattr(capability, "tools") + and capability.tools + and not (hasattr(client_caps, "tools") and client_caps.tools) + ): + return False + if ( + hasattr(capability, "resources") + and capability.resources + and not (hasattr(client_caps, "resources") and client_caps.resources) + ): + return False + if ( + hasattr(capability, "prompts") + and capability.prompts + and not (hasattr(client_caps, "prompts") and client_caps.prompts) + ): + return False + return not ( + hasattr(capability, "logging") + and capability.logging + and not (hasattr(client_caps, "logging") and client_caps.logging) + ) + + async def run(self) -> None: + """ + Run the session message loop. + + Reads messages from the stream and processes them. + """ + if not self.read_stream: + raise SessionError("No read stream available") + + try: + async for message in self.read_stream: + if message: + await self._process_message(message) + except asyncio.CancelledError: + pass + except Exception as e: + await self.server.logger.exception("Session error") + raise SessionError(f"Session error: {e}") from e + finally: + # Cleanup + if self._request_manager: + # Cancel any pending requests + await self._cleanup_pending_requests() + + async def _process_message(self, message: str) -> None: + """Process a single message.""" + try: + # Parse message + data = json.loads(message) + + # Check if it's a response to our request + if "id" in data and "method" not in data: + if self._request_manager: + logger.debug( + f"Session received response message id={data.get('id')} -> routing to RequestManager" + ) + await self._request_manager.handle_response(data) + return + + # Otherwise, process as incoming request + response = await self.server.handle_message(data, self) + + # Send response if any + if response and self.write_stream: + if hasattr(response, "model_dump_json"): + response_data = response.model_dump_json(exclude_none=True) + else: + response_data = json.dumps(response) + + if not response_data.endswith("\n"): + response_data += "\n" + + await self.write_stream.send(response_data) + + except json.JSONDecodeError: + await self._send_error_response( + None, + -32700, + "Parse error", + ) + except Exception as e: + await self._send_error_response( + None, + -32603, + f"Internal error: {e!s}", + ) + + async def _send_error_response( + self, + request_id: Any, + code: int, + message: str, + ) -> None: + """Send an error response.""" + if not self.write_stream: + return + + error_response = JSONRPCError( + id=str(request_id) if request_id else "null", + error={"code": code, "message": message}, + ) + + response_data = error_response.model_dump_json() + "\n" + await self.write_stream.send(response_data) + + async def _cleanup_pending_requests(self) -> None: + """Clean up any pending requests.""" + if self._request_manager: + # Cancel all pending futures and notify client + await self._request_manager.cancel_all(reason="Session closed") + + # Notification methods + async def send_notification(self, notification: JSONRPCMessage) -> None: + """Send a notification to the client.""" + if not self.write_stream: + return + + message = notification.model_dump_json(exclude_none=True) + "\n" + await self.write_stream.send(message) + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Send a progress notification.""" + notification = ProgressNotification( + params=ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + message=message, + ) + ) + await self.send_notification(notification) + + async def send_log_message( + self, + level: LoggingLevel, + data: Any, + logger: str | None = None, + ) -> None: + """Send a log message notification.""" + notification = LoggingMessageNotification( + params=LoggingMessageParams( + level=level, + data=data, + logger=logger, + ) + ) + await self.send_notification(notification) + + async def send_tool_list_changed(self) -> None: + """Send tool list changed notification.""" + await self.send_notification(ToolListChangedNotification()) + + async def send_resource_list_changed(self) -> None: + """Send resource list changed notification.""" + await self.send_notification(ResourceListChangedNotification()) + + async def send_prompt_list_changed(self) -> None: + """Send prompt list changed notification.""" + await self.send_notification(PromptListChangedNotification()) + + # Server-initiated requests + async def create_message( + self, + messages: list[dict[str, Any]], + max_tokens: int, + system_prompt: str | None = None, + include_context: str | None = None, + temperature: float | None = None, + model_preferences: dict[str, Any] | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + timeout: float = 60.0, + ) -> CreateMessageResult: + """ + Send a sampling request to the client. + + Args: + messages: Messages to sample + max_tokens: Maximum tokens to generate + system_prompt: System prompt + include_context: Context to include + temperature: Sampling temperature + model_preferences: Model preferences + stop_sequences: Stop sequences + metadata: Request metadata + timeout: Request timeout + + Returns: + Sampling result + """ + if not self._request_manager: + raise SessionError("Cannot send requests without request manager") + + params = { + "messages": messages, + "maxTokens": max_tokens, + } + + # Add optional parameters + if system_prompt is not None: + params["systemPrompt"] = system_prompt + if include_context is not None: + params["includeContext"] = include_context + if temperature is not None: + params["temperature"] = temperature + if model_preferences is not None: + params["modelPreferences"] = model_preferences + if stop_sequences is not None: + params["stopSequences"] = stop_sequences + if metadata is not None: + params["metadata"] = metadata + + result = await self._request_manager.send_request( + "sampling/createMessage", + params, + timeout, + ) + + return CreateMessageResult(**result) + + async def list_roots(self, timeout: float = 60.0) -> ListRootsResult: + """ + Request roots list from the client. + + Args: + timeout: Request timeout + + Returns: + Roots list result + """ + if not self._request_manager: + raise SessionError("Cannot send requests without request manager") + + result = await self._request_manager.send_request( + "roots/list", + None, + timeout, + ) + + return ListRootsResult(**result) + + async def complete( + self, + ref: dict[str, Any], + argument: dict[str, Any], + timeout: float = 60.0, + ) -> CompleteResult: + """ + Request completion from the client. + + Args: + ref: Completion reference + argument: Completion argument + timeout: Request timeout + + Returns: + Completion result + """ + if not self._request_manager: + raise SessionError("Cannot send requests without request manager") + + result = await self._request_manager.send_request( + "completion/complete", + {"ref": ref, "argument": argument}, + timeout, + ) + + return CompleteResult(**result) + + async def elicit( + self, + message: str, + requested_schema: dict[str, Any] | None = None, + timeout: float = 300.0, + ) -> ElicitResult: + """ + Send an elicitation request to the client. + + Args: + message: Elicitation message to display + requested_schema: JSON schema for the requested response + timeout: Request timeout + + Returns: + Elicitation result + """ + if not self._request_manager: + raise SessionError("Cannot send requests without request manager") + + params: dict[str, Any] = { + "message": message, + } + + # Add schema if provided + if requested_schema is not None: + params["requestedSchema"] = requested_schema + + result = await self._request_manager.send_request( + "elicitation/create", + params, + timeout, + ) + + return ElicitResult(**result) + + # Context management + async def create_request_context(self) -> Context: + """Create a context for the current request.""" + context = Context(self.server) + context.set_session(self) + self._current_context = context + return context + + async def cleanup_request_context(self, context: Context) -> None: + """Clean up request context.""" + # Flush any pending notifications + await context._flush_notifications() + self._current_context = None diff --git a/libs/arcade-mcp-server/arcade_mcp_server/settings.py b/libs/arcade-mcp-server/arcade_mcp_server/settings.py new file mode 100644 index 00000000..53fe543d --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/settings.py @@ -0,0 +1,252 @@ +""" +MCP Settings Management + +Provides Pydantic-based settings with validation and environment variable support. +""" + +import os +from typing import Any + +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings + + +class NotificationSettings(BaseSettings): + """Notification-related settings.""" + + rate_limit_per_minute: int = Field( + default=60, + description="Maximum notifications per minute per client", + ge=1, + le=1000, + ) + default_debounce_ms: int = Field( + default=100, + description="Default debounce time in milliseconds", + ge=0, + le=10000, + ) + max_queued_notifications: int = Field( + default=1000, + description="Maximum queued notifications per client", + ge=10, + le=10000, + ) + + model_config = {"env_prefix": "MCP_NOTIFICATION_"} + + +class TransportSettings(BaseSettings): + """Transport-related settings.""" + + session_timeout_seconds: int = Field( + default=300, + description="Session timeout in seconds", + ge=30, + le=3600, + ) + cleanup_interval_seconds: int = Field( + default=10, + description="Cleanup interval in seconds", + ge=1, + le=60, + ) + max_sessions: int = Field( + default=1000, + description="Maximum concurrent sessions", + ge=1, + le=10000, + ) + max_queue_size: int = Field( + default=1000, + description="Maximum queue size per session", + ge=10, + le=10000, + ) + + model_config = {"env_prefix": "MCP_TRANSPORT_"} + + +class ServerSettings(BaseSettings): + """Server-related settings.""" + + name: str = Field( + default="ArcadeMCP", + description="Server name", + ) + version: str = Field( + default="0.1.0dev", + description="Server version", + ) + title: str | None = Field( + default="Arcade MCP", + description="Server title for display", + ) + instructions: str | None = Field( + default=( + "ArcadeMCP provides access to a wide range of tools and toolkits." + "Use 'tools/list' to see available tools and 'tools/call' to execute them." + ), + description="Server instructions for clients", + ) + + model_config = {"env_prefix": "MCP_SERVER_"} + + +class MiddlewareSettings(BaseSettings): + """Middleware-related settings.""" + + enable_logging: bool = Field( + default=True, + description="Enable logging middleware", + ) + log_level: str = Field( + default="INFO", + description="Log level", + ) + enable_error_handling: bool = Field( + default=True, + description="Enable error handling middleware", + ) + mask_error_details: bool = Field( + default=False, + description="Mask error details in production", + ) + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + """Validate log level.""" + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + v = v.upper() + if v not in valid_levels: + raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}") + return v + + model_config = {"env_prefix": "MCP_MIDDLEWARE_"} + + +class ArcadeSettings(BaseSettings): + """Arcade-specific settings.""" + + api_key: str | None = Field( + default=None, + description="Arcade API key", + ) + api_url: str = Field( + default="https://api.arcade.dev", + description="Arcade API URL", + ) + auth_disabled: bool = Field( + default=False, + description="Disable authentication", + ) + server_secret: str | None = Field( + default="dev", + description="Server secret", + validation_alias="ARCADE_WORKER_SECRET", + ) + environment: str = Field( + default="dev", + description="Environment (dev or prod.)", + ) + user_id: str | None = Field( + default=None, + description="User ID for Arcade environment", + ) + + model_config = {"env_prefix": "ARCADE_"} + + +class ToolEnvironmentSettings(BaseSettings): + """Tool environment settings. + + Every environment variable that is not prefixed + with one of the prefixes for the other settings + will be added to the tool environment as an + available tool secret in the ToolContext + """ + + tool_environment: dict[str, Any] = Field( + default_factory=dict, + description="Tool environment", + ) + + def model_post_init(self, __context: Any) -> None: + """Populate tool_environment from process env if not provided.""" + if not self.tool_environment: + excluded_prefixes = ("MCP_", "_") + self.tool_environment = { + key: value + for key, value in os.environ.items() + if not any(key.startswith(prefix) for prefix in excluded_prefixes) + } + + model_config = { + "env_prefix": "", + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "allow", + } + + +class MCPSettings(BaseSettings): + """Main MCP settings container.""" + + # Sub-settings + notification: NotificationSettings = Field( + default_factory=NotificationSettings, + description="Notification settings", + ) + transport: TransportSettings = Field( + default_factory=TransportSettings, + description="Transport settings", + ) + server: ServerSettings = Field( + default_factory=ServerSettings, + description="Server settings", + ) + middleware: MiddlewareSettings = Field( + default_factory=MiddlewareSettings, + description="Middleware settings", + ) + arcade: ArcadeSettings = Field( + default_factory=ArcadeSettings, + description="Arcade integration settings", + ) + tool_environment: ToolEnvironmentSettings = Field( + default_factory=ToolEnvironmentSettings, + description="Tool environment settings", + ) + + # Global settings + debug: bool = Field( + default=False, + description="Enable debug mode", + ) + + model_config = { + "env_prefix": "MCP_", + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "allow", + } + + @classmethod + def from_env(cls) -> "MCPSettings": + """Create settings from environment variables.""" + return cls() + + def tool_secrets(self) -> dict[str, Any]: + """Get tool secrets.""" + return self.tool_environment.tool_environment + + def to_dict(self) -> dict[str, Any]: + """Convert settings to dictionary.""" + return self.model_dump(exclude_unset=True) + + +# Global settings instance +settings = MCPSettings.from_env() diff --git a/libs/arcade-mcp-server/arcade_mcp_server/transports/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/transports/__init__.py new file mode 100644 index 00000000..bcc95635 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/transports/__init__.py @@ -0,0 +1,12 @@ +"""MCP Transport implementations.""" + +from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager +from arcade_mcp_server.transports.http_streamable import EventStore, HTTPStreamableTransport +from arcade_mcp_server.transports.stdio import StdioTransport + +__all__ = [ + "EventStore", + "HTTPSessionManager", + "HTTPStreamableTransport", + "StdioTransport", +] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/transports/http_session_manager.py b/libs/arcade-mcp-server/arcade_mcp_server/transports/http_session_manager.py new file mode 100644 index 00000000..22127ca8 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/transports/http_session_manager.py @@ -0,0 +1,262 @@ +"""HTTP Session Manager for MCP servers. + +Manages HTTP streaming sessions with optional resumability via event store. +""" + +import contextlib +import logging +from collections.abc import AsyncIterator +from http import HTTPStatus +from typing import Optional +from uuid import uuid4 + +import anyio +from anyio.abc import TaskStatus +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.session import ServerSession +from arcade_mcp_server.transports.http_streamable import ( + MCP_SESSION_ID_HEADER, + EventStore, + HTTPStreamableTransport, +) + +logger = logging.getLogger(__name__) + + +class HTTPSessionManager: + """Manages HTTP streaming sessions with optional resumability. + + This class abstracts session management, event storage, and request handling + for HTTP streaming transports. It handles: + + 1. Session tracking for clients + 2. Resumability via optional event store + 3. Connection management and lifecycle + 4. Request handling and transport setup + + Important: Only one HTTPSessionManager instance should be created per application. + The instance cannot be reused after its run() context has completed. + """ + + def __init__( + self, + server: MCPServer, + event_store: Optional[EventStore] = None, + json_response: bool = False, + stateless: bool = False, + ): + """Initialize HTTP session manager. + + Args: + server: The MCP server instance + event_store: Optional event store for resumability + json_response: Whether to use JSON responses instead of SSE + stateless: If True, creates fresh transport for each request + """ + self.server = server + self.event_store = event_store + self.json_response = json_response + self.stateless = stateless + + # Session tracking (only used if not stateless) + self._session_creation_lock = anyio.Lock() + self._server_instances: dict[str, HTTPStreamableTransport] = {} + + # Task group will be set during lifespan + self._task_group: Optional[anyio.abc.TaskGroup] = None + + # Thread-safe tracking of run() calls + self._run_lock = anyio.Lock() + self._has_started = False + + @contextlib.asynccontextmanager + async def run(self) -> AsyncIterator[None]: + """Run the session manager with lifecycle management. + + This creates and manages the task group for all session operations. + + Important: This method can only be called once per instance. + Create a new instance if you need to restart. + """ + async with self._run_lock: + if self._has_started: + raise RuntimeError( + "HTTPSessionManager.run() can only be called once per instance. " + "Create a new instance if you need to run again." + ) + self._has_started = True + + async with anyio.create_task_group() as tg: + self._task_group = tg + logger.info("HTTP session manager started") + try: + yield + finally: + logger.info("HTTP session manager shutting down") + tg.cancel_scope.cancel() + self._task_group = None + self._server_instances.clear() + + async def handle_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """Process ASGI request with proper session handling. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + if self._task_group is None: + raise RuntimeError("Task group is not initialized. Make sure to use run().") + + if self.stateless: + await self._handle_stateless_request(scope, receive, send) + else: + await self._handle_stateful_request(scope, receive, send) + + async def _handle_stateless_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """Process request in stateless mode - new transport per request.""" + logger.debug("Stateless mode: Creating new transport for this request") + + # Create transport without session ID in stateless mode + http_transport = HTTPStreamableTransport( + mcp_session_id=None, + is_json_response_enabled=self.json_response, + event_store=None, # No event store in stateless mode + ) + + # Start server in a new task + async def run_stateless_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + try: + # Create a new session for this request + session = ServerSession( + server=self.server, + read_stream=read_stream, + write_stream=write_stream, + ) + + # Set the session on the transport + http_transport.session = session + + # Run the session (start + loop until closed) + await session.run() + + # Brief yield to allow cleanup + await anyio.sleep(0) + except Exception: + logger.exception("Stateless session crashed") + + if self._task_group is None: + raise RuntimeError("Task group not initialized") + await self._task_group.start(run_stateless_server) + + # Handle the HTTP request + await http_transport.handle_request(scope, receive, send) + + # Terminate the transport + await http_transport.terminate() + + async def _handle_stateful_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """Process request in stateful mode - maintain session state.""" + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Existing session case + if request_mcp_session_id and request_mcp_session_id in self._server_instances: + transport = self._server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + return + + if request_mcp_session_id is None: + # New session case + logger.debug("Creating new transport") + async with self._session_creation_lock: + new_session_id = uuid4().hex + http_transport = HTTPStreamableTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, + ) + + if http_transport.mcp_session_id is None: + raise RuntimeError("MCP session ID not set") + self._server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") + + # Define the server runner + async def run_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + try: + # Create a session for this connection + session = ServerSession( + server=self.server, + read_stream=read_stream, + write_stream=write_stream, + ) + + # Set the session on the transport + http_transport.session = session + + # Run the session (start + loop until closed) + await session.run() + + # Brief yield to allow cleanup + await anyio.sleep(0) + except Exception as e: + logger.error( + f"Session {http_transport.mcp_session_id} crashed: {e}", + exc_info=True, + ) + finally: + # Clean up on crash + if ( + http_transport.mcp_session_id + and http_transport.mcp_session_id in self._server_instances + and not http_transport.is_terminated + ): + logger.info( + f"Cleaning up crashed session {http_transport.mcp_session_id}" + ) + del self._server_instances[http_transport.mcp_session_id] + + if self._task_group is None: + raise RuntimeError("Task group not initialized") + await self._task_group.start(run_server) + + # Handle the HTTP request + await http_transport.handle_request(scope, receive, send) + else: + # Invalid session ID + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py b/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py new file mode 100644 index 00000000..6bd895cb --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/transports/http_streamable.py @@ -0,0 +1,834 @@ +"""HTTP Streamable Transport for MCP servers. + +This module implements HTTP transport with Server-Sent Events (SSE) streaming support, +following the patterns from the sample library. + +Design overview +- The transport provides a duplex, in-process message channel between the HTTP layer + and the MCP session using anyio memory streams: + - read side (transport -> session): + - `_read_stream_writer` (SendStream[SessionMessage | Exception]) + - `_read_stream` (ReceiveStream[SessionMessage | Exception]) + - write side (session -> transport): + - `_write_stream` (SendStream[SessionMessage]) + - `_write_stream_reader` (ReceiveStream[SessionMessage]) + +- The transport writes inbound client messages (parsed from HTTP requests) to + `_read_stream_writer`; the session consumes them from `_read_stream`. + +- The session writes outbound server messages to `_write_stream`; the transport's + `message_router` task consumes them from `_write_stream_reader` and fans them out + to the correct per-request stream maintained in `_request_streams[request_id]`. + +- Response modes: + - JSON response mode: a single HTTP JSON response is returned by awaiting the + first terminal message (JSONRPCResponse or JSONRPCError) for the request. + - SSE response mode: a long-lived stream of events is sent as SSE; the stream + is closed when a terminal message is observed for the request. + +- A standalone GET SSE stream uses the special key `GET_STREAM_KEY` to deliver + server-initiated events without a preceding POST. + +- Optional resumability can be enabled by providing an `EventStore` implementation. +""" + +import json +import logging +import re +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass +from http import HTTPStatus +from typing import cast + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel, TypeAdapter +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from arcade_mcp_server.session import ServerSession +from arcade_mcp_server.types import ( + INTERNAL_ERROR, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + MCPMessage, + RequestId, + SessionMessage, +) + +logger = logging.getLogger(__name__) + +# Header names +MCP_SESSION_ID_HEADER = "Mcp-Session-Id" +MCP_PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version" +LAST_EVENT_ID_HEADER = "Last-Event-ID" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + +# Special key for the standalone GET stream +GET_STREAM_KEY = "_GET_stream" + +# Session ID validation pattern (visible ASCII characters) +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + +# Type aliases +StreamId = str +EventId = str + + +@dataclass +class EventMessage: + """A JSONRPCMessage with an optional event ID for stream resumability.""" + + message: MCPMessage + event_id: str | None = None + + +EventCallback = Callable[[EventMessage], Awaitable[None]] + + +class EventStore: + """Interface for resumability support via event storage.""" + + async def store_event(self, stream_id: StreamId, message: MCPMessage) -> EventId: + """Store an event for later retrieval.""" + raise NotImplementedError + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replay events after the specified event ID.""" + raise NotImplementedError + + +class HTTPStreamableTransport: + """HTTP transport with SSE streaming support for MCP. + + Responsibilities + - Parse HTTP requests into JSON-RPC messages and enqueue them on the + transportโ†’session read stream (via `_read_stream_writer`). + - Consume sessionโ†’transport messages from `_write_stream_reader` in a + background `message_router`, routing them to per-request streams in + `_request_streams` keyed by the JSON-RPC request id (or `GET_STREAM_KEY` + for the standalone GET SSE stream). + - Serve responses back to the HTTP client: + - JSON response mode: wait for the first terminal response and return a + single `application/json` body. + - SSE mode: stream each outbound `SessionMessage` as an SSE event with + appropriate headers and close on terminal response. + + Streams created in `connect()` + - `_read_stream_writer` / `_read_stream`: transportโ†’session channel for inbound + client messages. + - `_write_stream` / `_write_stream_reader`: sessionโ†’transport channel for outbound + server messages, consumed by the `message_router`. + + These in-memory channels provide backpressure and decouple HTTP from the session + loop while keeping the implementation fully async. + """ + + def __init__( + self, + mcp_session_id: str | None, + session: ServerSession | None = None, + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + ): + """Initialize HTTP streamable transport. + + Args: + mcp_session_id: Session identifier (must be visible ASCII) + session: Server session for handling requests + is_json_response_enabled: If True, return JSON responses instead of SSE + event_store: Optional event store for resumability + """ + if mcp_session_id and not SESSION_ID_PATTERN.fullmatch(mcp_session_id): + raise ValueError("Session ID must only contain visible ASCII characters") + + self.mcp_session_id = mcp_session_id + self.session = session + self.is_json_response_enabled = is_json_response_enabled + self._event_store = event_store + self._request_streams: dict[ + RequestId, + tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]], + ] = {} + self._terminated = False + + # Streams for connection + self._read_stream_writer: MemoryObjectSendStream[str | Exception] | None = None + self._read_stream: MemoryObjectReceiveStream[str | Exception] | None = None + self._write_stream: MemoryObjectSendStream[str | SessionMessage] | None = None + self._write_stream_reader: MemoryObjectReceiveStream[str | SessionMessage] | None = None + + def _parse_mcp_message(self, obj: str | dict[str, object] | MCPMessage) -> MCPMessage: + """Parse incoming data into a typed MCPMessage. + + Accepts a raw JSON string, already-parsed dict, or an existing MCPMessage. + """ + if isinstance(obj, BaseModel): + # Already a pydantic model; trust caller and cast to MCPMessage + return cast(MCPMessage, obj) + + parsed: dict[str, object] + if isinstance(obj, str): + try: + maybe = json.loads(obj) + except Exception as exc: # parse error - treat as invalid request + raise ValueError(f"Invalid JSON: {exc}") + if not isinstance(maybe, dict): + raise TypeError("JSON must be an object") + parsed = maybe + elif isinstance(obj, dict): + parsed = obj + else: + raise TypeError("Unsupported message type") + + try: + return TypeAdapter(MCPMessage).validate_python(parsed) + except Exception: + # Fallback: treat as error + return JSONRPCError( + id=str(parsed.get("id", "null")), + error={"code": -32600, "message": "Invalid message"}, + ) + + @property + def is_terminated(self) -> bool: + """Check if transport has been terminated.""" + return self._terminated + + def _create_error_response( + self, + error_message: str, + status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, + headers: dict[str, str] | None = None, + ) -> Response: + """Create an error response.""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", + error=ErrorData(code=error_code, message=error_message).model_dump(exclude_none=True), + ) + + return Response( + error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + + def _create_json_response( + self, + response_message: JSONRPCMessage | None, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """Create a JSON response.""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True) + if response_message + else None, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """Extract session ID from request headers.""" + return request.headers.get(MCP_SESSION_ID_HEADER) + + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + """Create event data dictionary from EventMessage.""" + event_data = { + "event": "message", + "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), + } + + if event_message.event_id: + event_data["id"] = event_message.event_id + + return event_data + + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: + """Clean up memory streams for a request.""" + if request_id in self._request_streams: + try: + await self._request_streams[request_id][0].aclose() + await self._request_streams[request_id][1].aclose() + except Exception: + logger.debug("Error closing memory streams - may already be closed") + finally: + self._request_streams.pop(request_id, None) + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Handle incoming HTTP requests.""" + request = Request(scope, receive) + + if self._terminated: + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """Check if request accepts required media types.""" + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) + has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """Check if request has correct Content-Type.""" + content_type = request.headers.get("content-type", "") + content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) + + async def _handle_post_request( + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: + """Handle POST requests containing JSON-RPC messages.""" + writer = self._read_stream_writer + if writer is None: + raise ValueError("No read stream writer available. Ensure connect() is called first.") + + try: + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if self.is_json_response_enabled: + if not has_json: + response = self._create_error_response( + "Not Acceptable: Client must accept application/json", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, receive, send) + return + else: + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, receive, send) + return + + # Validate Content-Type for POST payloads only when JSON mode + if self.is_json_response_enabled and not self._check_content_type(request): + response = self._create_error_response( + "Unsupported Media Type: Content-Type must be application/json", + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + ) + await response(scope, receive, send) + return + + # Parse the body + body = await request.body() + body_str = body.decode("utf-8") if isinstance(body, (bytes, bytearray)) else str(body) + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = self._create_error_response( + f"Parse error: {e!s}", HTTPStatus.BAD_REQUEST, PARSE_ERROR + ) + await response(scope, receive, send) + return + + # Accept either well-typed messages or raw dicts + message_dict = raw_message if isinstance(raw_message, dict) else {} + try: + message = self._parse_mcp_message(message_dict or body_str) + except Exception as exc: + response = self._create_error_response( + f"Invalid request: {exc}", + HTTPStatus.BAD_REQUEST, + INVALID_REQUEST, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + # Determine initialization by dict method when validation fallback used + is_initialization_request = ( + isinstance(message, JSONRPCRequest) and message.method == "initialize" + ) + + if is_initialization_request: + if self.mcp_session_id: + request_session_id = self._get_session_id(request) + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + elif not await self._validate_request_headers(request, send): + return + + # For notifications and responses, return 202 Accepted + if not isinstance(message, JSONRPCRequest): + response = self._create_json_response(None, HTTPStatus.ACCEPTED) + await response(scope, receive, send) + + # Process the message + await writer.send(body_str if body_str.endswith("\n") else body_str + "\n") + return + + # Handle requests + request_id = str(message.id) + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) + request_stream_reader = self._request_streams[request_id][1] + + if self.is_json_response_enabled: + # JSON response mode + await writer.send(body_str if body_str.endswith("\n") else body_str + "\n") + + try: + response_message = None + async for event_message in request_stream_reader: + if isinstance(event_message.message, (JSONRPCResponse, JSONRPCError)): + response_message = event_message.message + break + + if response_message: + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + logger.error("No response received before stream closed") + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + except Exception: + logger.exception("Error processing JSON response") + response = self._create_error_response( + "Error processing request", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + finally: + await self._clean_up_memory_streams(request_id) + else: + # SSE response mode + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, str] + ](0) + + async def sse_writer() -> None: + try: + async with sse_stream_writer, request_stream_reader: + async for event_message in request_stream_reader: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + if isinstance( + event_message.message, (JSONRPCResponse, JSONRPCError) + ): + break + except Exception: + logger.exception("Error in SSE writer") + finally: + logger.debug("Closing SSE writer") + await self._clean_up_memory_streams(request_id) + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), + } + + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + try: + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + await writer.send(body_str if body_str.endswith("\n") else body_str + "\n") + except Exception: + logger.exception("SSE response error") + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(request_id) + + except Exception as err: + logger.exception("Error handling POST request") + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + if writer: + await writer.send(Exception(err)) + + async def _handle_get_request(self, request: Request, send: Send) -> None: + """Handle GET request to establish SSE.""" + writer = self._read_stream_writer + if writer is None: + raise ValueError("No read stream writer available. Ensure connect() is called first.") + + # Validate Accept header + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + error_response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await error_response(request.scope, request.receive, send) + return + + if not await self._validate_request_headers(request, send): + return + + # Handle resumability + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): + await self._replay_events(last_event_id, request, send) + return + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Check if we already have an active GET stream + if GET_STREAM_KEY in self._request_streams: + error_response = self._create_error_response( + "Conflict: Only one SSE stream is allowed per session", + HTTPStatus.CONFLICT, + ) + await error_response(request.scope, request.receive, send) + return + + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + + async def standalone_sse_writer() -> None: + try: + self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[ + EventMessage + ](0) + standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] + + async with sse_stream_writer, standalone_stream_reader: + async for event_message in standalone_stream_reader: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + except Exception: + logger.exception("Error in standalone SSE writer") + finally: + logger.debug("Closing standalone SSE writer") + await self._clean_up_memory_streams(GET_STREAM_KEY) + + sse_response: EventSourceResponse = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=standalone_sse_writer, + headers=headers, + ) + + try: + await sse_response(request.scope, request.receive, send) + except Exception: + logger.exception("Error in standalone SSE response") + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(GET_STREAM_KEY) + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + """Handle DELETE requests for session termination.""" + if not self.mcp_session_id: + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_request_headers(request, send): + return + + await self.terminate() + + response = self._create_json_response(None, HTTPStatus.OK) + await response(request.scope, request.receive, send) + + async def terminate(self) -> None: + """Terminate the current session.""" + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # Close all request streams + request_stream_keys = list(self._request_streams.keys()) + for key in request_stream_keys: + await self._clean_up_memory_streams(key) + self._request_streams.clear() + + try: + if self._read_stream_writer: + await self._read_stream_writer.aclose() + if self._read_stream: + await self._read_stream.aclose() + if self._write_stream_reader: + await self._write_stream_reader.aclose() + if self._write_stream: + await self._write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") + + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """Handle unsupported HTTP methods.""" + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = self._create_error_response( + "Method Not Allowed", + HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) + + async def _validate_request_headers(self, request: Request, send: Send) -> bool: + """Validate request headers.""" + return await self._validate_session(request, send) + + async def _validate_session(self, request: Request, send: Send) -> bool: + """Validate session ID in request.""" + if not self.mcp_session_id: + return True + + request_session_id = self._get_session_id(request) + + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: + """Replay events after the specified event ID.""" + event_store = self._event_store + if not event_store: + return + + try: + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, str] + ](0) + + async def replay_sender() -> None: + try: + async with sse_stream_writer: + + async def send_event(event_message: EventMessage) -> None: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + stream_id = await event_store.replay_events_after(last_event_id, send_event) + + if stream_id and stream_id not in self._request_streams: + self._request_streams[stream_id] = anyio.create_memory_object_stream[ + EventMessage + ](0) + msg_reader = self._request_streams[stream_id][1] + + async with msg_reader: + async for event_message in msg_reader: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + except Exception: + logger.exception("Error in replay sender") + + sse_response: EventSourceResponse = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=replay_sender, + headers=headers, + ) + + try: + await sse_response(request.scope, request.receive, send) + except Exception: + logger.exception("Error in replay response") + finally: + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + + except Exception: + logger.exception("Error replaying events") + error_response = self._create_error_response( + "Error replaying events", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await error_response(request.scope, request.receive, send) + + @asynccontextmanager + async def connect( + self, + ) -> AsyncIterator[ + tuple[ + MemoryObjectReceiveStream[str | Exception], + MemoryObjectSendStream[str | SessionMessage], + ] + ]: + """Context manager providing read and write streams for connection. + + Creates the in-memory channels used by the transport and starts the + `message_router` task responsible for routing outbound messages from + the session to the correct per-request stream (or the standalone GET + stream identified by `GET_STREAM_KEY`). + """ + # Create memory streams + read_stream_writer, read_stream = anyio.create_memory_object_stream[str | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[str | SessionMessage]( + 0 + ) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._read_stream = read_stream + self._write_stream_reader = write_stream_reader + self._write_stream = write_stream + + # Start message router + async with anyio.create_task_group() as tg: + + async def message_router() -> None: + try: + async for session_message in write_stream_reader: + # Accept either a SessionMessage wrapper or a raw JSON string + try: + if isinstance(session_message, SessionMessage): + message = session_message.message + elif isinstance(session_message, str): + message = self._parse_mcp_message(session_message) + elif isinstance(session_message, BaseModel): + message = cast(JSONRPCMessage, session_message) + else: + logger.error( + f"Unsupported outbound message type: {type(session_message)}" + ) + continue + except Exception: + logger.exception("Failed to parse outbound message from session") + continue + target_request_id = None + + # Check if this is a response + if isinstance(message, (JSONRPCResponse, JSONRPCError)): + target_request_id = str(message.id) + + request_stream_id = ( + target_request_id if target_request_id else GET_STREAM_KEY + ) + + # Store event if we have an event store + event_id = None + if self._event_store: + event_id = await self._event_store.store_event( + request_stream_id, + message, # type: ignore[arg-type] + ) + logger.debug(f"Stored {event_id} from {request_stream_id}") + + if request_stream_id in self._request_streams: + try: + await self._request_streams[request_stream_id][0].send( + EventMessage(message, event_id) # type: ignore[arg-type] + ) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + self._request_streams.pop(request_stream_id, None) + except Exception: + logger.exception("Error in message router") + + tg.start_soon(message_router) + + try: + yield read_stream, write_stream + finally: + for stream_id in list(self._request_streams.keys()): + await self._clean_up_memory_streams(stream_id) + self._request_streams.clear() + + try: + await read_stream_writer.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() + await write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") diff --git a/libs/arcade-mcp-server/arcade_mcp_server/transports/stdio.py b/libs/arcade-mcp-server/arcade_mcp_server/transports/stdio.py new file mode 100644 index 00000000..8ed830c0 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/transports/stdio.py @@ -0,0 +1,209 @@ +""" +Stdio Transport + +Provides stdio (stdin/stdout) transport for MCP communication. +""" + +import asyncio +import contextlib +import logging +import queue +import signal +import sys +import threading +import uuid +from collections.abc import AsyncIterator +from typing import Any + +from arcade_mcp_server.exceptions import TransportError +from arcade_mcp_server.session import ServerSession + +logger = logging.getLogger("arcade.mcp.transports.stdio") + + +class StdioWriteStream: + """Write stream implementation for stdio.""" + + def __init__(self, write_queue: queue.Queue[str | None]): + self.write_queue = write_queue + + async def send(self, data: str) -> None: + """Send data to stdout.""" + if not data.endswith("\n"): + data += "\n" + await asyncio.to_thread(self.write_queue.put, data) + + +class StdioReadStream: + """Read stream implementation for stdio.""" + + def __init__(self, read_queue: queue.Queue[str | None]): + self.read_queue = read_queue + self._running = True + + def stop(self) -> None: + """Stop the read stream.""" + self._running = False + + def __aiter__(self) -> AsyncIterator[str]: + return self + + async def __anext__(self) -> str: + if not self._running: + raise StopAsyncIteration + try: + line = await asyncio.to_thread(self.read_queue.get) + except asyncio.CancelledError: + raise StopAsyncIteration + except Exception as e: + logger.exception("Error reading from stdin") + raise TransportError(f"Read error: {e}") from e + if line is None or not self._running: + raise StopAsyncIteration + return line + + +class StdioTransport: + """ + Stdio transport implementation for stdio communication. + + This transport uses stdin/stdout for MCP communication, + suitable for command-line tools and scripts. + """ + + def __init__(self, name: str = "stdio"): + """Initialize stdio transport.""" + self.name = name + self.read_queue: queue.Queue[str | None] = queue.Queue() + self.write_queue: queue.Queue[str | None] = queue.Queue() + self.reader_thread: threading.Thread | None = None + self.writer_thread: threading.Thread | None = None + self._shutdown_event = asyncio.Event() + self._running = False + self._sessions: dict[str, ServerSession] = {} + + async def start(self) -> None: + """Start the transport.""" + # Component start is handled here directly + + # Start I/O threads + self._running = True + self.reader_thread = threading.Thread( + target=self._reader_loop, + daemon=True, + name=f"{self.name}-reader", + ) + self.writer_thread = threading.Thread( + target=self._writer_loop, + daemon=True, + name=f"{self.name}-writer", + ) + self.reader_thread.start() + self.writer_thread.start() + + # Set up signal handlers + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, lambda: asyncio.create_task(self.stop())) + except NotImplementedError: + # Windows doesn't support POSIX signals + if sys.platform == "win32": + logger.warning("Signal handling not fully supported on Windows") + else: + logger.warning(f"Failed to set up signal handler for {sig}") + + async def stop(self) -> None: + """Stop the transport.""" + if not self._running: + return + + logger.info("Stopping stdio transport") + self._running = False + + # Signal threads to stop + self.read_queue.put(None) + self.write_queue.put(None) + + # Wait for threads to finish + if self.reader_thread and self.reader_thread.is_alive(): + self.reader_thread.join(timeout=1.0) + if self.writer_thread and self.writer_thread.is_alive(): + self.writer_thread.join(timeout=1.0) + + # Set shutdown event + self._shutdown_event.set() + + def _reader_loop(self) -> None: + """Reader thread loop.""" + try: + for line in sys.stdin: + if not self._running: + break + self.read_queue.put(line.strip()) + except Exception: + logger.exception("Error in reader thread") + finally: + self.read_queue.put(None) # Signal EOF + + def _writer_loop(self) -> None: + """Writer thread loop.""" + try: + while self._running: + msg = self.write_queue.get() + if msg is None: + break + sys.stdout.write(msg) + sys.stdout.flush() + except Exception: + logger.exception("Error in writer thread") + + @contextlib.asynccontextmanager + async def connect_session(self, **options: Any) -> AsyncIterator[ServerSession]: + """ + Create a stdio session. + + Since stdio is inherently single-session, this will fail + if a session is already active. + """ + # Check if already have a session + sessions = await self.list_sessions() + if sessions: + raise TransportError("Stdio transport only supports one session") + + # Create session + session_id = str(uuid.uuid4()) + read_stream = StdioReadStream(self.read_queue) + write_stream = StdioWriteStream(self.write_queue) + session = ServerSession( + server=None, # set by the caller using run_connection; not used here + session_id=session_id, + read_stream=read_stream, + write_stream=write_stream, + init_options=options, + stateless=True, + ) + + # Register session + await self.register_session(session) + + try: + yield session + finally: + # Cleanup + read_stream.stop() + await self.unregister_session(session_id) + + async def wait_for_shutdown(self) -> None: + """Wait for the transport to shut down.""" + await self._shutdown_event.wait() + + # Minimal session registry to support connect_session lifecycle + async def list_sessions(self) -> list[str]: + return list(self._sessions.keys()) + + async def register_session(self, session: ServerSession) -> None: + self._sessions[session.session_id] = session + + async def unregister_session(self, session_id: str) -> None: + self._sessions.pop(session_id, None) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/types.py b/libs/arcade-mcp-server/arcade_mcp_server/types.py new file mode 100644 index 00000000..ddb38e82 --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/types.py @@ -0,0 +1,666 @@ +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from typing import Any, Generic, Literal, TypeAlias, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + +# ----------------------------------------------------------------------------- +# JSON-RPC constants +# ----------------------------------------------------------------------------- + +JSONRPC_VERSION: Literal["2.0"] = "2.0" +LATEST_PROTOCOL_VERSION: str = "2025-06-18" + +# ----------------------------------------------------------------------------- +# Basic types +# ----------------------------------------------------------------------------- + +ProgressToken = str | int +Cursor = str +RequestId = str | int +AnyFunction: TypeAlias = Callable[..., Any] + + +# ----------------------------------------------------------------------------- +# Base JSON-RPC shapes +# ----------------------------------------------------------------------------- + + +class Request(BaseModel): + method: str + params: Any = None + + model_config = ConfigDict(extra="allow", populate_by_name=True) + + +class Notification(BaseModel): + method: str + params: Any = None + + model_config = ConfigDict(extra="allow", populate_by_name=True) + + +class Result(BaseModel): + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + model_config = ConfigDict(extra="allow", populate_by_name=True) + + +class JSONRPCMessage(BaseModel): + jsonrpc: Literal["2.0"] = Field(default=JSONRPC_VERSION, frozen=True) + + model_config = ConfigDict(extra="allow") + + +class JSONRPCRequest(JSONRPCMessage, Request): + id: RequestId + + +T = TypeVar("T", bound=Result) + + +class JSONRPCResponse(JSONRPCMessage, Generic[T]): + id: RequestId + result: T | dict[str, Any] + + +# Standard JSON-RPC error codes +PARSE_ERROR = -32700 +INVALID_REQUEST = -32600 +METHOD_NOT_FOUND = -32601 +INVALID_PARAMS = -32602 +INTERNAL_ERROR = -32603 + + +class ErrorData(BaseModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCError(JSONRPCMessage): + id: RequestId + error: dict[str, Any] + + +# ----------------------------------------------------------------------------- +# Transport types +# ----------------------------------------------------------------------------- + + +@dataclass +class SessionMessage: + """Wrapper for messages in transport sessions.""" + + message: JSONRPCMessage + + +# ----------------------------------------------------------------------------- +# Initialization +# ----------------------------------------------------------------------------- + + +class BaseMetadata(BaseModel): + name: str + title: str | None = None + + model_config = ConfigDict(extra="allow") + + +class Implementation(BaseMetadata): + version: str + + +class ClientCapabilities(BaseModel): + experimental: dict[str, object] | None = None + roots: dict[str, Any] | None = None + sampling: dict[str, Any] | None = None + elicitation: dict[str, Any] | None = None + + model_config = ConfigDict(extra="allow") + + +class ServerCapabilities(BaseModel): + experimental: dict[str, object] | None = None + logging: dict[str, Any] | None = None + completions: dict[str, Any] | None = None + prompts: dict[str, Any] | None = None + resources: dict[str, Any] | None = None + tools: dict[str, Any] | None = None + + model_config = ConfigDict(extra="allow") + + +class InitializeParams(BaseModel): + protocolVersion: str + capabilities: ClientCapabilities = Field(default_factory=ClientCapabilities) + clientInfo: Implementation + + +class InitializeRequest(JSONRPCRequest): + method: Literal["initialize"] = Field(default="initialize", frozen=True) + params: InitializeParams + + +class InitializeResult(Result): + protocolVersion: str + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + + +class InitializedNotification(JSONRPCMessage, Notification): + method: Literal["notifications/initialized"] = Field( + default="notifications/initialized", frozen=True + ) + + +# ----------------------------------------------------------------------------- +# Ping +# ----------------------------------------------------------------------------- + + +class PingRequest(JSONRPCRequest): + method: Literal["ping"] = Field(default="ping", frozen=True) + + +# ----------------------------------------------------------------------------- +# Progress notifications +# ----------------------------------------------------------------------------- + + +class ProgressNotificationParams(BaseModel): + progressToken: ProgressToken + progress: float + total: float | None = None + message: str | None = None + + +class ProgressNotification(JSONRPCMessage, Notification): + method: Literal["notifications/progress"] = Field(default="notifications/progress", frozen=True) + params: ProgressNotificationParams + + +# ----------------------------------------------------------------------------- +# Pagination +# ----------------------------------------------------------------------------- + + +class PaginatedRequest(JSONRPCRequest): + params: dict[str, Any] | None = None + + +class PaginatedResult(Result): + nextCursor: Cursor | None = None + + +# ----------------------------------------------------------------------------- +# Annotations (used across resources, content, etc.) +# ----------------------------------------------------------------------------- + +Role = Literal["user", "assistant"] + + +class Annotations(BaseModel): + audience: list[Role] | None = None + priority: float | None = None + lastModified: str | None = None + + model_config = ConfigDict(extra="allow") + + +# ----------------------------------------------------------------------------- +# Resources +# ----------------------------------------------------------------------------- + + +class Resource(BaseMetadata): + uri: str + description: str | None = None + mimeType: str | None = None + annotations: Annotations | None = None + size: int | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ListResourcesRequest(PaginatedRequest): + method: Literal["resources/list"] = Field(default="resources/list", frozen=True) + + +class ListResourcesResult(PaginatedResult): + resources: list[Resource] = Field(default_factory=list) + + +class ListResourceTemplatesRequest(PaginatedRequest): + method: Literal["resources/templates/list"] = Field( + default="resources/templates/list", frozen=True + ) + + +class ResourceTemplate(BaseMetadata): + uriTemplate: str + description: str | None = None + mimeType: str | None = None + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ListResourceTemplatesResult(PaginatedResult): + resourceTemplates: list[ResourceTemplate] = Field(default_factory=list) + + +class ReadResourceParams(BaseModel): + uri: str + + +class ReadResourceRequest(JSONRPCRequest): + method: Literal["resources/read"] = Field(default="resources/read", frozen=True) + params: ReadResourceParams + + +class ResourceContents(BaseModel): + uri: str + mimeType: str | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class TextResourceContents(ResourceContents): + text: str + + +class BlobResourceContents(ResourceContents): + blob: str + + +class ReadResourceResult(Result): + contents: list[TextResourceContents | BlobResourceContents] + + +class ResourceListChangedNotification(JSONRPCMessage, Notification): + method: Literal["notifications/resources/list_changed"] = Field( + default="notifications/resources/list_changed", frozen=True + ) + + +class ResourceUpdatedNotificationParams(BaseModel): + uri: str + + +class ResourceUpdatedNotification(JSONRPCMessage, Notification): + method: Literal["notifications/resources/updated"] = Field( + default="notifications/resources/updated", frozen=True + ) + params: ResourceUpdatedNotificationParams + + +class SubscribeParams(BaseModel): + uri: str + + +class SubscribeRequest(JSONRPCRequest): + method: Literal["resources/subscribe"] = Field(default="resources/subscribe", frozen=True) + params: SubscribeParams + + +class UnsubscribeParams(BaseModel): + uri: str + + +class UnsubscribeRequest(JSONRPCRequest): + method: Literal["resources/unsubscribe"] = Field(default="resources/unsubscribe", frozen=True) + params: UnsubscribeParams + + +# ----------------------------------------------------------------------------- +# Prompts +# ----------------------------------------------------------------------------- + + +class PromptArgument(BaseMetadata): + description: str | None = None + required: bool | None = None + + +class Prompt(BaseMetadata): + description: str | None = None + arguments: list[PromptArgument] | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ListPromptsRequest(PaginatedRequest): + method: Literal["prompts/list"] = Field(default="prompts/list", frozen=True) + + +class ListPromptsResult(PaginatedResult): + prompts: list[Prompt] = Field(default_factory=list) + + +class PromptMessage(BaseModel): + role: Role + content: dict[str, Any] + + +class GetPromptParams(BaseModel): + name: str + arguments: dict[str, str] | None = None + + +class GetPromptRequest(JSONRPCRequest): + method: Literal["prompts/get"] = Field(default="prompts/get", frozen=True) + params: GetPromptParams + + +class GetPromptResult(Result): + description: str | None = None + messages: list[PromptMessage] + + +class PromptListChangedNotification(JSONRPCMessage, Notification): + method: Literal["notifications/prompts/list_changed"] = Field( + default="notifications/prompts/list_changed", frozen=True + ) + + +# ----------------------------------------------------------------------------- +# Tools +# ----------------------------------------------------------------------------- + + +class ToolAnnotations(BaseModel): + title: str | None = None + readOnlyHint: bool | None = None + destructiveHint: bool | None = None + idempotentHint: bool | None = None + openWorldHint: bool | None = None + + model_config = ConfigDict(extra="allow") + + +class MCPTool(BaseModel): + name: str + description: str | None = None + inputSchema: dict[str, Any] + outputSchema: dict[str, Any] | None = None + annotations: ToolAnnotations | None = None + title: str | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ListToolsRequest(PaginatedRequest): + method: Literal["tools/list"] = Field(default="tools/list", frozen=True) + + +class ListToolsResult(PaginatedResult): + tools: list[MCPTool] + + +class ToolListChangedNotification(JSONRPCMessage, Notification): + method: Literal["notifications/tools/list_changed"] = Field( + default="notifications/tools/list_changed", frozen=True + ) + + +class CallToolParams(BaseModel): + name: str + arguments: dict[str, Any] | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + model_config = ConfigDict(extra="allow", populate_by_name=True) + + +class CallToolRequest(JSONRPCRequest): + method: Literal["tools/call"] = Field(default="tools/call", frozen=True) + params: CallToolParams + + +class TextContent(BaseModel): + type: Literal["text"] + text: str + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ImageContent(BaseModel): + type: Literal["image"] + data: str + mimeType: str + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class AudioContent(BaseModel): + type: Literal["audio"] + data: str + mimeType: str + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ResourceLink(Resource): + type: Literal["resource_link"] = Field(default="resource_link", frozen=True) + + +class EmbeddedResource(BaseModel): + type: Literal["resource"] = Field(default="resource", frozen=True) + resource: TextResourceContents | BlobResourceContents + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +MCPContent = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource + + +class CallToolResult(Result): + """ + A list of content objects that represent the unstructured result of the tool call. + """ + + content: list[MCPContent] + + """ + An optional JSON object that represents the structured result of the tool call. + """ + structuredContent: dict[str, Any] | None = None + + isError: bool | None = None + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- + + +class LoggingLevel(str, Enum): + DEBUG = "debug" + INFO = "info" + NOTICE = "notice" + WARNING = "warning" + ERROR = "error" + CRITICAL = "critical" + ALERT = "alert" + EMERGENCY = "emergency" + + +class SetLevelParams(BaseModel): + level: LoggingLevel + + +class SetLevelRequest(JSONRPCRequest): + method: Literal["logging/setLevel"] = Field(default="logging/setLevel", frozen=True) + params: SetLevelParams + + +class LoggingMessageParams(BaseModel): + level: LoggingLevel + logger: str | None = None + data: Any + + +class LoggingMessageNotification(JSONRPCMessage, Notification): + method: Literal["notifications/message"] = Field(default="notifications/message", frozen=True) + params: LoggingMessageParams + + +# ----------------------------------------------------------------------------- +# Cancellation (notification-only) +# ----------------------------------------------------------------------------- + + +class CancelledParams(BaseModel): + requestId: RequestId + reason: str | None = None + + +class CancelledNotification(JSONRPCMessage, Notification): + method: Literal["notifications/cancelled"] = Field( + default="notifications/cancelled", frozen=True + ) + params: CancelledParams + + +# ----------------------------------------------------------------------------- +# Sampling (server -> client) +# ----------------------------------------------------------------------------- + + +class SamplingMessage(BaseModel): + role: Role + content: TextContent | ImageContent | AudioContent + + +class ModelHint(BaseModel): + name: str | None = None + + +class ModelPreferences(BaseModel): + hints: list[ModelHint] | None = None + costPriority: float | None = None + speedPriority: float | None = None + intelligencePriority: float | None = None + + +class CreateMessageParams(BaseModel): + messages: list[SamplingMessage] + modelPreferences: ModelPreferences | None = None + systemPrompt: str | None = None + includeContext: Literal["none", "thisServer", "allServers"] | None = None + temperature: float | None = None + maxTokens: int + stopSequences: list[str] | None = None + metadata: dict[str, Any] | None = None + + +class CreateMessageRequest(JSONRPCRequest): + method: Literal["sampling/createMessage"] = Field(default="sampling/createMessage", frozen=True) + params: CreateMessageParams + + +class CreateMessageResult(Result, SamplingMessage): + model: str + stopReason: Literal["endTurn", "stopSequence", "maxTokens"] | str | None = None + + +# ----------------------------------------------------------------------------- +# Completion (client -> server) +# ----------------------------------------------------------------------------- + + +class ResourceTemplateReference(BaseModel): + type: Literal["ref/resource"] + uri: str + + +class PromptReference(BaseMetadata): + type: Literal["ref/prompt"] + + +class CompletionArgument(BaseModel): + name: str + value: str + + +class CompletionContext(BaseModel): + arguments: dict[str, str] | None = None + + +class CompleteParams(BaseModel): + ref: ResourceTemplateReference | PromptReference + argument: CompletionArgument + context: CompletionContext | None = None + + +class CompleteRequest(JSONRPCRequest): + method: Literal["completion/complete"] = Field(default="completion/complete", frozen=True) + params: CompleteParams + + +class Completion(BaseModel): + values: list[str] + total: int | None = None + hasMore: bool | None = None + + +class CompleteResult(Result): + completion: Completion + + +# ----------------------------------------------------------------------------- +# Roots +# ----------------------------------------------------------------------------- + + +class ListRootsRequest(JSONRPCRequest): + method: Literal["roots/list"] = Field(default="roots/list", frozen=True) + + +class Root(BaseModel): + uri: str + name: str | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + + +class ListRootsResult(Result): + roots: list[Root] + + +class RootsListChangedNotification(JSONRPCMessage, Notification): + method: Literal["notifications/roots/list_changed"] = Field( + default="notifications/roots/list_changed", frozen=True + ) + + +# ----------------------------------------------------------------------------- +# Elicitation (server -> client) +# ----------------------------------------------------------------------------- + +ElicitRequestedSchema = dict[str, Any] + + +class ElicitParams(BaseModel): + message: str + requestedSchema: ElicitRequestedSchema + + +class ElicitRequest(JSONRPCRequest): + method: Literal["elicitation/create"] = Field(default="elicitation/create", frozen=True) + params: ElicitParams + + +class ElicitResult(Result): + action: Literal["accept", "decline", "cancel"] + content: dict[str, str | int | float | bool | None] | None = None + + +# ----------------------------------------------------------------------------- +# Union for middleware typing and convenience +# ----------------------------------------------------------------------------- + +MCPMessage = ( + JSONRPCRequest + | JSONRPCResponse[Any] + | JSONRPCError + | CancelledNotification + | ProgressNotification + | LoggingMessageNotification +) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/worker.py b/libs/arcade-mcp-server/arcade_mcp_server/worker.py new file mode 100644 index 00000000..1def203d --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/worker.py @@ -0,0 +1,291 @@ +""" +Arcade MCP Server (Integrated Worker + MCP HTTP) + +Creates a FastAPI application that exposes both Arcade Worker endpoints and +MCP Server endpoints over HTTP/SSE. MCP is always enabled in this integrated mode. +""" + +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import uvicorn +from arcade_core.catalog import ToolCatalog +from arcade_serve.fastapi.worker import FastAPIWorker +from fastapi import FastAPI +from loguru import logger +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.settings import MCPSettings +from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager + + +@asynccontextmanager +async def create_lifespan( + catalog: ToolCatalog, + mcp_settings: MCPSettings | None = None, + **kwargs: Any, +) -> AsyncGenerator[dict[str, Any], None]: + """ + Create lifespan context for the MCP server components. + + Yields a dict with `mcp_server`, and `session_manager`. + """ + if mcp_settings is None: + mcp_settings = MCPSettings.from_env() + + try: + tool_env_keys = sorted(mcp_settings.tool_secrets().keys()) + logger.debug( + f"Arcade settings: \n\ + ARCADE_ENVIRONMENT={mcp_settings.arcade.environment} \n\ + ARCADE_API_URL={mcp_settings.arcade.api_url}, \n\ + ARCADE_USER_ID={mcp_settings.arcade.user_id}, \n\ + api_key_present - {bool(mcp_settings.arcade.api_key)}" + ) + logger.debug(f"Tool environment variable names available to tools: {tool_env_keys}") + except Exception as e: + logger.debug(f"Unable to log settings/tool env keys: {e}") + + mcp_server = MCPServer( + catalog, + settings=mcp_settings, + **kwargs, + ) + + session_manager = HTTPSessionManager( + server=mcp_server, + json_response=True, + ) + + await mcp_server.start() + async with session_manager.run(): + logger.info("MCP server started and ready for connections") + yield { + "mcp_server": mcp_server, + "session_manager": session_manager, + } + await mcp_server.stop() + + +def create_arcade_mcp( + catalog: ToolCatalog, + mcp_settings: MCPSettings | None = None, + debug: bool = False, + **kwargs: Any, +) -> FastAPI: + """ + Create a FastAPI app exposing Arcade Worker and MCP HTTP endpoints. + + MCP is always enabled in this integrated application. + """ + if mcp_settings is None: + mcp_settings = MCPSettings.from_env() + secret = mcp_settings.arcade.server_secret + if secret is None: + secret = "dev" # noqa: S105 + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[None]: + async with create_lifespan(catalog, mcp_settings, **kwargs) as components: + app.state.mcp_server = components["mcp_server"] + app.state.session_manager = components["session_manager"] + yield + + app = FastAPI( + title=(mcp_settings.server.title or mcp_settings.server.name), + description=(mcp_settings.server.instructions or ""), + version=mcp_settings.server.version, + docs_url="/docs" if not mcp_settings.arcade.auth_disabled else None, + redoc_url="/redoc" if not mcp_settings.arcade.auth_disabled else None, + lifespan=lifespan, + **kwargs, + ) + + # Worker endpoints + worker = FastAPIWorker( + app=app, + secret=secret, + disable_auth=mcp_settings.arcade.auth_disabled, + ) + worker.catalog = catalog + + class _MCPASGIProxy: + def __init__(self, parent_app: FastAPI): + self._app = parent_app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + session_manager = getattr(self._app.state, "session_manager", None) + if session_manager is None: + resp = Response("MCP server not initialized", status_code=503) + await resp(scope, receive, send) + return + await session_manager.handle_request(scope, receive, send) + + # Mount the actual ASGI proxy to handle all /mcp requests + app.mount("/mcp", _MCPASGIProxy(app), name="mcp-proxy") + + # Customize OpenAPI to include MCP documentation + def custom_openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + + # Get the default OpenAPI schema + from fastapi.openapi.utils import get_openapi + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + + # Add MCP routes to the schema + from arcade_mcp_server.fastapi.routes import ( + MCPError, + MCPRequest, + MCPResponse, + get_openapi_routes, + ) + + # Add MCP schemas + if "components" not in openapi_schema: + openapi_schema["components"] = {} + if "schemas" not in openapi_schema["components"]: + openapi_schema["components"]["schemas"] = {} + + # Add schema definitions + openapi_schema["components"]["schemas"]["MCPRequest"] = MCPRequest.model_json_schema() + openapi_schema["components"]["schemas"]["MCPResponse"] = MCPResponse.model_json_schema() + openapi_schema["components"]["schemas"]["MCPError"] = MCPError.model_json_schema() + + # Add MCP paths + if "paths" not in openapi_schema: + openapi_schema["paths"] = {} + + for route_def in get_openapi_routes(): + path = route_def["path"] + openapi_schema["paths"][path] = {k: v for k, v in route_def.items() if k != "path"} + + app.openapi_schema = openapi_schema + return app.openapi_schema + + app.openapi = custom_openapi # type: ignore[method-assign] + + return app + + +def create_arcade_mcp_factory() -> FastAPI: + """ + App factory for uvicorn reload support. + + This function is called by uvicorn when using reload mode with an import string. + It rediscovers the catalog and reads configuration from environment variables. + """ + import os + + from arcade_core.discovery import discover_tools + from arcade_core.toolkit import ToolkitLoadError + + # Read configuration from env vars that were set before running the server + debug = os.environ.get("ARCADE_MCP_DEBUG", "false").lower() == "true" + tool_package = os.environ.get("ARCADE_MCP_TOOL_PACKAGE") + discover_installed = os.environ.get("ARCADE_MCP_DISCOVER_INSTALLED", "false").lower() == "true" + show_packages = os.environ.get("ARCADE_MCP_SHOW_PACKAGES", "false").lower() == "true" + server_name = os.environ.get("ARCADE_MCP_SERVER_NAME") + server_version = os.environ.get("ARCADE_MCP_SERVER_VERSION") + + # Rediscover tools since there have been changes + try: + catalog = discover_tools( + tool_package=tool_package, + show_packages=show_packages, + discover_installed=discover_installed, + server_name=server_name, + server_version=server_version, + ) + except ToolkitLoadError as exc: + logger.error(str(exc)) + raise RuntimeError(f"Failed to discover tools: {exc}") from exc + + total_tools = len(catalog) + if total_tools == 0: + logger.error("No tools found. Create Python files with @tool decorated functions.") + raise RuntimeError("No tools found") + + logger.info(f"Total tools loaded: {total_tools}") + + # Build kwargs for server creation + kwargs = {} + if server_name: + kwargs["name"] = server_name + if server_version: + kwargs["version"] = server_version + + return create_arcade_mcp( + catalog=catalog, + mcp_settings=None, + debug=debug, + **kwargs, + ) + + +def run_arcade_mcp( + catalog: ToolCatalog, + host: str = "127.0.0.1", + port: int = 7777, + reload: bool = False, + debug: bool = False, + tool_package: str | None = None, + discover_installed: bool = False, + show_packages: bool = False, + **kwargs: Any, +) -> None: + """ + Run the integrated Arcade MCP server with uvicorn. + """ + import os + + log_level = "debug" if debug else "info" + + if reload: + # Set env vars for the app factory to read later + os.environ["ARCADE_MCP_DEBUG"] = str(debug) + if tool_package: + os.environ["ARCADE_MCP_TOOL_PACKAGE"] = tool_package + os.environ["ARCADE_MCP_DISCOVER_INSTALLED"] = str(discover_installed) + os.environ["ARCADE_MCP_SHOW_PACKAGES"] = str(show_packages) + if kwargs.get("name"): + os.environ["ARCADE_MCP_SERVER_NAME"] = kwargs["name"] + if kwargs.get("version"): + os.environ["ARCADE_MCP_SERVER_VERSION"] = kwargs["version"] + + # import string is required for reload mode + app_import_string = "arcade_mcp_server.worker:create_arcade_mcp_factory" + + uvicorn.run( + app_import_string, + factory=True, + host=host, + port=port, + log_level=log_level, + reload=reload, + lifespan="on", + ) + else: + app = create_arcade_mcp( + catalog=catalog, + debug=debug, + **kwargs, + ) + + uvicorn.run( + app, + host=host, + port=port, + log_level=log_level, + reload=reload, + lifespan="on", + ) diff --git a/libs/arcade-mcp-server/docs/advanced/transports.md b/libs/arcade-mcp-server/docs/advanced/transports.md new file mode 100644 index 00000000..99f2cfc0 --- /dev/null +++ b/libs/arcade-mcp-server/docs/advanced/transports.md @@ -0,0 +1,189 @@ +# Transport Modes + +MCP servers can communicate with clients through different transport mechanisms. Each transport is optimized for specific use cases and client types. + +## stdio Transport + +The stdio (standard input/output) transport is used for direct client connections. + +### Characteristics +- Communicates via standard input/output streams +- Logs go to stderr to avoid interfering with protocol messages +- Ideal for desktop applications and command-line tools +- Used by Claude Desktop and similar clients + +### Usage + +```bash +# Run with stdio transport +python -m arcade_mcp_server stdio + +# Or with MCPApp +app.run(transport="stdio") +``` + +### Client Configuration + +For Claude Desktop, configure in `~/Library/Application Support/Claude/claude_desktop_config.json`: + +```json +{ + "mcpServers": { + "my-tools": { + "command": "python", + "args": ["-m", "arcade_mcp_server", "stdio"], + "cwd": "/path/to/your/tools" + } + } +} +``` + +## HTTP Transport + +The HTTP transport provides REST/SSE endpoints for web-based clients. + +### Characteristics +- RESTful API with Server-Sent Events (SSE) for streaming +- Supports hot reload for development +- Includes health checks and API documentation +- Can be deployed behind reverse proxies +- Suitable for web applications and services + +### Usage + +```bash +# Run with HTTP transport (default) +python -m arcade_mcp_server + +# With specific host and port +python -m arcade_mcp_server --host 0.0.0.0 --port 8080 + +# Or with MCPApp +app.run(transport="http", host="0.0.0.0", port=8080) +``` + +### Endpoints + +When running in HTTP mode, the server provides: + +- `GET /health` - Health check endpoint +- `GET /mcp` - SSE endpoint for MCP protocol +- `GET /docs` - Swagger UI documentation (debug mode) +- `GET /redoc` - ReDoc documentation (debug mode) + +### Development Features + +```bash +# Enable hot reload and debug mode +python -m arcade_mcp_server --reload --debug + +# This enables: +# - Automatic restart on code changes +# - Detailed error messages +# - API documentation endpoints +# - Verbose logging +``` + +## Choosing a Transport + +### Use stdio when: +- Integrating with desktop applications (Claude Desktop, VS Code) +- Building command-line tools +- You need simple, direct communication +- Running in environments without network access + +### Use HTTP when: +- Building web applications +- Deploying to cloud environments +- You need to support multiple concurrent clients +- Integrating with existing web services +- You want API documentation and testing tools + +## Transport Configuration + +### Environment Variables + +Both transports respect common environment variables: + +```bash +# Server identification +MCP_SERVER_NAME="My MCP Server" +MCP_SERVER_VERSION="1.0.0" + +# Logging +MCP_DEBUG=true +MCP_LOG_LEVEL=DEBUG + +# HTTP-specific +MCP_HTTP_HOST=0.0.0.0 +MCP_HTTP_PORT=8080 +``` + +### Programmatic Configuration + +When using MCPApp: + +```python +from arcade_mcp_server import MCPApp + +app = MCPApp( + name="my-server", + version="1.0.0", + log_level="DEBUG" +) + +# Run with specific transport +if __name__ == "__main__": + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "stdio": + app.run(transport="stdio") + else: + app.run(transport="http", host="0.0.0.0", port=8080) +``` + +## Security Considerations + +### stdio Transport +- Inherits security context of the parent process +- No network exposure +- Suitable for trusted environments + +### HTTP Transport +- Exposes network endpoints +- Should use authentication in production +- Consider using HTTPS with reverse proxy +- Implement rate limiting for public deployments + +## Advanced Transport Features + +### Custom Middleware (HTTP) + +Add custom middleware to HTTP transports: + +```python +from arcade_mcp_server import MCPApp + +app = MCPApp(name="my-server") + +# Add custom middleware +@app.middleware("http") +async def add_custom_headers(request, call_next): + response = await call_next(request) + response.headers["X-Custom-Header"] = "value" + return response +``` + +### Transport Events + +Listen to transport lifecycle events: + +```python +@app.on_event("startup") +async def startup_handler(): + print("Server starting up...") + +@app.on_event("shutdown") +async def shutdown_handler(): + print("Server shutting down...") +``` diff --git a/libs/arcade-mcp-server/docs/api/cli.md b/libs/arcade-mcp-server/docs/api/cli.md new file mode 100644 index 00000000..21a3ac46 --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/cli.md @@ -0,0 +1,105 @@ +# CLI + +The `arcade_mcp_server` CLI is a simple tool for running MCP servers. + +It is used to discover tools and run the server. + + + +## Command Line Options + +``` +usage: python -m arcade_mcp_server [-h] [--host HOST] [--port PORT] + [--tool-package PACKAGE] [--discover-installed] + [--show-packages] [--reload] [--debug] + [--env-file ENV_FILE] [--name NAME] [--version VERSION] + [transport] + +Run Arcade MCP Server + +positional arguments: + transport Transport type: stdio, http, streamable-http (default: http) + +optional arguments: + -h, --help show this help message and exit + --host HOST Host to bind to (HTTP mode only, default: 127.0.0.1) + --port PORT Port to bind to (HTTP mode only, default: 8000) + --tool-package PACKAGE, --package PACKAGE, -p PACKAGE + Specific tool package to load (e.g., 'github' for arcade-github) + --discover-installed, --all + Discover all installed arcade tool packages + --show-packages Show loaded packages during discovery + --reload Enable auto-reload on code changes (HTTP mode only) + --debug Enable debug mode with verbose logging + --env-file ENV_FILE Path to environment file + --name NAME Server name + --version VERSION Server version +``` + +## Tool Discovery + +The CLI discovers tools in three ways: + +### 1. Auto-Discovery (Default) + +Automatically finds Python files with `@tool` decorated functions in: +- Current directory (`*.py`) +- `tools/` subdirectory +- `arcade_tools/` subdirectory + +Example file structure: +``` +my_project/ +โ”œโ”€โ”€ hello.py # Contains @tool functions +โ”œโ”€โ”€ tools/ +โ”‚ โ””โ”€โ”€ math.py # More @tool functions +โ””โ”€โ”€ arcade_tools/ + โ””โ”€โ”€ utils.py # Even more @tool functions +``` + +### 2. Package Loading + +Load specific arcade packages installed in your environment: + +```bash +# Load arcade-github package +python -m arcade_mcp_server --tool-package github + +# Load custom package (tries arcade_ prefix first) +python -m arcade_mcp_server -p mycompany_tools +``` + +### 3. Discover All Installed + +Find and load all arcade packages in your Python environment: + +```bash +# Load all arcade packages +python -m arcade_mcp_server --discover-installed + +# Show what's being loaded +python -m arcade_mcp_server --discover-installed --show-packages +``` + +### Example Tool File + +Create any Python file with `@tool` decorated functions: + +```python +from arcade_mcp_server import tool + +@tool +def hello(name: str) -> str: + """Say hello to someone.""" + return f"Hello, {name}!" + +@tool +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b +``` + +Then run: +```bash +python -m arcade_mcp_server # Auto-discovers and loads these tools +``` diff --git a/libs/arcade-mcp-server/docs/api/mcp_app.md b/libs/arcade-mcp-server/docs/api/mcp_app.md new file mode 100644 index 00000000..e06e04f1 --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/mcp_app.md @@ -0,0 +1,47 @@ +### MCPApp + +A FastAPI-like interface for building MCP servers with lazy initialization. + +MCPApp provides a clean, minimal API for building MCP servers programmatically. It handles tool collection, server configuration, and transport setup with a developer-friendly interface. + +#### Basic Usage + +```python +from arcade_mcp_server import MCPApp + +app = MCPApp(name="my_server", version="1.0.0") + +@app.tool +def greet(name: str) -> str: + return f"Hello, {name}!" + +app.run(host="127.0.0.1", port=7777) +``` + +#### Class Reference + +::: arcade_mcp_server.mcp_app.MCPApp + +#### Examples + +```python +# --- server.py --- +# Programmatic server creation with a simple tool and HTTP transport + +from arcade_mcp_server import MCPApp + +app = MCPApp(name="example_server", version="1.0.0") + +@app.tool +def echo(text: str) -> str: + return f"Echo: {text}" + +if __name__ == "__main__": + # Start an HTTP server (good for local development/testing) + app.run(host="0.0.0.0", port=7777, reload=False, debug=True) +``` + +```bash +# then run the server +python server.py +``` diff --git a/libs/arcade-mcp-server/docs/api/server/errors.md b/libs/arcade-mcp-server/docs/api/server/errors.md new file mode 100644 index 00000000..b6c78666 --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/server/errors.md @@ -0,0 +1,36 @@ +### Exceptions + +Domain-specific error types raised by the MCP server and components. + +::: arcade_mcp_server.exceptions + +#### Examples + +```python +from arcade_mcp_server.exceptions import ( + MCPError, + NotFoundError, + DuplicateError, + ValidationError, + ToolError, +) + +# Raising a not-found when a resource is missing +async def read_resource_or_fail(uri: str) -> str: + if not await exists(uri): + raise NotFoundError(f"Resource not found: {uri}") + return await read(uri) + +# Validating input +def validate_age(age: int) -> None: + if age < 0: + raise ValidationError("age must be non-negative") + +# Handling tool execution errors in middleware or handlers +async def call_tool_safely(call): + try: + return await call() + except ToolError as e: + # Convert to an error result or re-raise + raise MCPError(f"Tool failed: {e}") +``` diff --git a/libs/arcade-mcp-server/docs/api/server/middleware.md b/libs/arcade-mcp-server/docs/api/server/middleware.md new file mode 100644 index 00000000..aa7ca327 --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/server/middleware.md @@ -0,0 +1,50 @@ +### Middleware + +Base interfaces and built-in middleware. + +::: arcade_mcp_server.middleware.base.Middleware + +::: arcade_mcp_server.middleware.base.MiddlewareContext + +::: arcade_mcp_server.middleware.base.compose_middleware + +#### Built-ins + +::: arcade_mcp_server.middleware.logging.LoggingMiddleware + +::: arcade_mcp_server.middleware.error_handling.ErrorHandlingMiddleware + +#### Examples + +```python +# Implement a custom middleware +from arcade_mcp_server.middleware.base import Middleware, MiddlewareContext + +class TimingMiddleware(Middleware): + async def __call__(self, context: MiddlewareContext, call_next): + import time + start = time.perf_counter() + try: + return await call_next(context) + finally: + elapsed_ms = (time.perf_counter() - start) * 1000 + # Attach timing info to context metadata + context.metadata["elapsed_ms"] = round(elapsed_ms, 2) +``` + +```python +# Compose middleware and create a server +from arcade_mcp_server.middleware.base import compose_middleware +from arcade_mcp_server.middleware.logging import LoggingMiddleware +from arcade_mcp_server.middleware.error_handling import ErrorHandlingMiddleware +from arcade_mcp_server.server import MCPServer +from arcade_core.catalog import ToolCatalog + +middleware = compose_middleware([ + ErrorHandlingMiddleware(mask_error_details=False), + LoggingMiddleware(log_level="INFO"), + TimingMiddleware(), +]) + +server = MCPServer(catalog=ToolCatalog(), middleware=[middleware]) +``` diff --git a/libs/arcade-mcp-server/docs/api/server/server.md b/libs/arcade-mcp-server/docs/api/server/server.md new file mode 100644 index 00000000..8baabf20 --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/server/server.md @@ -0,0 +1,53 @@ + + +# Server + +### Low-level Server + +Low-level server for hosting Arcade tools over MCP. + +::: arcade_mcp_server.server.MCPServer + +#### Examples + +```python +# Basic server with tool catalog and stdio transport +import asyncio +from arcade_mcp_server.server import MCPServer +from arcade_core.catalog import ToolCatalog +from arcade_mcp_server.transports.stdio import StdioTransport + +async def main(): + catalog = ToolCatalog() + server = MCPServer(catalog=catalog, name="example", version="1.0.0") + await server._start() + try: + # Run stdio transport loop + transport = StdioTransport() + await transport.run(server) + finally: + await server._stop() + +if __name__ == "__main__": + asyncio.run(main()) +``` + +```python +# Handling a single HTTP streamable connection +import asyncio +from arcade_mcp_server.server import MCPServer +from arcade_core.catalog import ToolCatalog +from arcade_mcp_server.transports.http_streamable import HTTPStreamableTransport + +async def run_http(): + catalog = ToolCatalog() + server = MCPServer(catalog=catalog) + await server._start() + try: + transport = HTTPStreamableTransport(host="0.0.0.0", port=7777) + await transport.run(server) + finally: + await server._stop() + +asyncio.run(run_http()) +``` diff --git a/libs/arcade-mcp-server/docs/api/server/settings.md b/libs/arcade-mcp-server/docs/api/server/settings.md new file mode 100644 index 00000000..5d7b828c --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/server/settings.md @@ -0,0 +1,49 @@ +### Settings + +Global configuration and environment-driven settings. + +::: arcade_mcp_server.settings.MCPSettings + +#### Sub-settings + +::: arcade_mcp_server.settings.ServerSettings + +::: arcade_mcp_server.settings.MiddlewareSettings + +::: arcade_mcp_server.settings.NotificationSettings + +::: arcade_mcp_server.settings.TransportSettings + +::: arcade_mcp_server.settings.ArcadeSettings + +::: arcade_mcp_server.settings.ToolEnvironmentSettings + +#### Examples + +```python +from arcade_mcp_server.settings import MCPSettings + +settings = MCPSettings( + debug=True, + middleware=MCPSettings.middleware.__class__( + enable_logging=True, + mask_error_details=False, + ), + server=MCPSettings.server.__class__( + title="My MCP Server", + instructions="Use responsibly", + ), + transport=MCPSettings.transport.__class__( + http_host="0.0.0.0", + http_port=8000, + ), +) +``` + +```python +# Loading from environment +from arcade_mcp_server.settings import MCPSettings + +# Values like ARCADE_MCP_DEBUG, ARCADE_MCP_HTTP_PORT, etc. are parsed +settings = MCPSettings() +``` diff --git a/libs/arcade-mcp-server/docs/api/server/types.md b/libs/arcade-mcp-server/docs/api/server/types.md new file mode 100644 index 00000000..062a02ed --- /dev/null +++ b/libs/arcade-mcp-server/docs/api/server/types.md @@ -0,0 +1,37 @@ +### Types + +Core Pydantic models and enums for the MCP protocol shapes. + +::: _server.types + +#### Examples + +```python +# Constructing a JSON-RPC request and response model +from arcade_mcp_server.types import JSONRPCRequest, JSONRPCResponse + +req = JSONRPCRequest(id=1, method="ping", params={}) +res = JSONRPCResponse(id=req.id, result={}) +print(req.model_dump_json()) +print(res.model_dump_json()) +``` + +```python +# Building a tools/call request and examining result shape +from arcade_mcp_server.types import CallToolRequest, CallToolResult, TextContent + +call = CallToolRequest( + id=2, + method="tools/call", + params={ + "name": "Toolkit.tool", + "arguments": {"text": "hello"}, + }, +) +# Result would typically be produced by the server: +result = CallToolResult( + content=[TextContent(type="text", text="Echo: hello")], + structuredContent={"result": "Echo: hello"}, + isError=False +) +``` diff --git a/libs/arcade-mcp-server/docs/clients/claude.md b/libs/arcade-mcp-server/docs/clients/claude.md new file mode 100644 index 00000000..07f22748 Binary files /dev/null and b/libs/arcade-mcp-server/docs/clients/claude.md differ diff --git a/libs/arcade-mcp-server/docs/clients/cursor.md b/libs/arcade-mcp-server/docs/clients/cursor.md new file mode 100644 index 00000000..a89058ec Binary files /dev/null and b/libs/arcade-mcp-server/docs/clients/cursor.md differ diff --git a/libs/arcade-mcp-server/docs/clients/inspector.md b/libs/arcade-mcp-server/docs/clients/inspector.md new file mode 100644 index 00000000..d6bbf61b --- /dev/null +++ b/libs/arcade-mcp-server/docs/clients/inspector.md @@ -0,0 +1,354 @@ +# MCP Inspector + +The MCP Inspector is a powerful debugging and testing tool for MCP servers. It provides a web-based interface to interact with your Arcade MCP server, test tools, and monitor protocol messages. + +## Installation + +Install the MCP Inspector globally: + +```bash +npm install -g @modelcontextprotocol/inspector +``` + +Or use npx to run without installing: + +```bash +npx @modelcontextprotocol/inspector +``` + +## Basic Usage + +### Connecting to HTTP Servers + +For MCP servers running over HTTP: + +```bash +# Start your MCP server +python -m arcade_mcp_server --host 0.0.0.0 --port 8000 + +# In another terminal, start the inspector +mcp-inspector http://localhost:8000/mcp +``` + +### Connecting to stdio Servers + +For stdio-based servers: + +```bash +# Start the inspector with your server command +mcp-inspector "python -m arcade_mcp_server stdio" + +# With additional arguments +mcp-inspector "python -m arcade_mcp_server stdio --tool-package github" +``` + +## Inspector Features + +### Tool Explorer + +The Tool Explorer shows all available tools with: + +- Tool names and descriptions +- Parameter schemas +- Return type information +- Example invocations + +### Interactive Testing + +Test tools directly from the interface: + +1. Select a tool from the explorer +2. Fill in parameter values +3. Click "Execute" to run the tool +4. View results and execution time + +### Protocol Monitor + +Monitor all MCP protocol messages: + +- Request/response pairs +- Message timing +- Protocol errors +- Raw JSON data + +### Resource Browser + +If your server provides resources: + +- Browse available resources +- View resource contents +- Test resource operations + +### Prompt Templates + +Test prompt templates if supported: + +- View available prompts +- Fill template parameters +- Preview rendered prompts + +## Advanced Usage + +### Custom Environment + +Pass environment variables to your server: + +```bash +# Using env command +env ARCADE_API_KEY=your-key mcp-inspector "python -m arcade_mcp_server stdio" + +# Using inspector's env option +mcp-inspector --env ARCADE_API_KEY=your-key "python -m arcade_mcp_server stdio" +``` + +### Working Directory + +Set the working directory for your server: + +```bash +mcp-inspector --cwd /path/to/project "python -m arcade_mcp_server stdio" +``` + +### Debug Mode + +Enable verbose logging: + +```bash +# Debug the MCP server +mcp-inspector "python -m arcade_mcp_server stdio --debug" + +# Debug the inspector itself +mcp-inspector --debug "python -m arcade_mcp_server stdio" +``` + +## Testing Workflows + +### Tool Development + +1. **Start your server with hot reload**: + ```bash + python -m arcade_mcp_server --reload --debug + ``` + +2. **Connect the inspector**: + ```bash + mcp-inspector http://localhost:8000/mcp + ``` + +3. **Develop and test**: + - Modify your tool code + - Server auto-reloads + - Test immediately in inspector + +### Performance Testing + +Use the inspector to measure tool performance: + +1. Enable timing in the Protocol Monitor +2. Execute tools multiple times +3. Analyze response times +4. Identify bottlenecks + +### Error Debugging + +Debug tool errors effectively: + +1. Enable debug mode on your server +2. Execute the failing tool +3. Check Protocol Monitor for error details +4. View server logs in terminal + +## Integration Testing + +### Test Suites + +Create test suites using the inspector: + +```javascript +// test-tools.js +const tests = [ + { + tool: "greet", + params: { name: "World" }, + expected: "Hello, World!" + }, + { + tool: "calculate", + params: { expression: "2 + 2" }, + expected: 4 + } +]; + +// Run tests via inspector API +``` + +### Automated Testing + +Combine with testing frameworks: + +```python +# test_mcp_tools.py +import subprocess +import json +import pytest + +def test_tool_via_inspector(): + # Start server + server = subprocess.Popen( + ["python", "-m", "arcade_mcp_server"], + stdout=subprocess.PIPE + ) + + # Use inspector's API to test tools + # ... +``` + +## Best Practices + +### Development Setup + +1. **Use Split Terminal**: + - Terminal 1: MCP server with reload + - Terminal 2: Inspector + - Terminal 3: Code editor + +2. **Enable All Debugging**: + ```bash + python -m arcade_mcp_server --reload --debug --env-file .env.dev + ``` + +3. **Save Test Cases**: + - Export successful tool calls + - Build regression test suite + - Document edge cases + +### Production Testing + +1. **Test Against Production Config**: + ```bash + mcp-inspector "python -m arcade_mcp_server stdio --env-file .env.prod" + ``` + +2. **Verify Security**: + - Test with limited permissions + - Verify API key handling + - Check error messages don't leak secrets + +3. **Load Testing**: + - Execute tools rapidly + - Monitor memory usage + - Check for resource leaks + +## Troubleshooting + +### Connection Issues + +#### "Failed to connect" + +1. Verify server is running +2. Check correct URL/command +3. Ensure ports aren't blocked +4. Try with `--debug` flag + +#### "Protocol error" + +1. Ensure server implements MCP correctly +2. Check for version compatibility +3. Review server logs +4. Verify transport type + +### Tool Issues + +#### "Tool not found" + +1. Verify tool is decorated with `@tool` +2. Check tool discovery in server +3. Ensure no import errors +4. Restart server and inspector + +#### "Parameter validation failed" + +1. Check parameter types match schema +2. Verify required parameters +3. Test with simpler values +4. Review tool documentation + +## Examples + +### Quick Test Session + +```bash +# 1. Start a simple MCP server +cat > test_tools.py << 'EOF' +from arcade_mcp_server import tool +from typing import Annotated + +@tool +def echo(message: Annotated[str, "Message to echo"]) -> str: + """Echo the message back.""" + return message + +@tool +def add( + a: Annotated[int, "First number"], + b: Annotated[int, "Second number"] +) -> Annotated[int, "Sum"]: + """Add two numbers.""" + return a + b +EOF + +# 2. Start inspector +mcp-inspector "python -m arcade_mcp_server stdio" + +# 3. Test tools in the web interface +``` + +### HTTP Server Testing + +```bash +# 1. Create an MCPApp server +cat > app.py << 'EOF' +from arcade_mcp_server import MCPApp +from typing import Annotated + +app = MCPApp(name="test-server", version="1.0.0") + +@app.tool +def get_time() -> Annotated[str, "Current time"]: + """Get the current time.""" + from datetime import datetime + return datetime.now().isoformat() + +if __name__ == "__main__": + app.run(port=9000, reload=True) +EOF + +# 2. Run the server +python app.py + +# 3. Connect inspector +mcp-inspector http://localhost:9000/mcp +``` + +### Debugging Session + +```bash +# 1. Enable all debugging +export DEBUG=* +export MCP_DEBUG=true + +# 2. Start server with verbose logging +python -m arcade_mcp_server stdio --debug 2>server.log + +# 3. Start inspector with debugging +mcp-inspector --debug "tail -f server.log" & +mcp-inspector --debug "python -m arcade_mcp_server stdio --debug" +``` + +## Tips and Tricks + +1. **Bookmark Tool URLs**: Save frequently tested tools +2. **Export Test Data**: Save successful requests for documentation +3. **Use Browser DevTools**: Inspect network requests +4. **Create Tool Shortcuts**: Bookmark specific tool tests +5. **Monitor Resources**: Keep an eye on server resources during testing diff --git a/libs/arcade-mcp-server/docs/clients/vscode.md b/libs/arcade-mcp-server/docs/clients/vscode.md new file mode 100644 index 00000000..ecc814ab --- /dev/null +++ b/libs/arcade-mcp-server/docs/clients/vscode.md @@ -0,0 +1,485 @@ +# Visual Studio Code + +While VSCode doesn't have native MCP support yet, you can integrate Arcade MCP servers with VSCode through extensions and custom configurations. This guide shows various integration approaches. + +## Prerequisites + +- Visual Studio Code installed +- Python 3.10+ installed +- `arcade-mcp` package installed (`pip install arcade-mcp`) +- Python extension for VSCode + +## Integration Methods + +### Method 1: Terminal Integration + +Use VSCode's integrated terminal to run MCP servers: + +1. Open integrated terminal (`Ctrl/Cmd + ` `) +2. Start your MCP server: + ```bash + python -m arcade_mcp_server --reload --debug + ``` +3. Use split terminals for multiple servers + +### Method 2: Task Runner + +Create tasks to manage MCP servers: + +#### Create `.vscode/tasks.json`: + +```json +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Start MCP Server", + "type": "shell", + "command": "python", + "args": ["-m", "arcade_mcp_server", "--reload", "--debug"], + "isBackground": true, + "problemMatcher": { + "pattern": { + "regexp": "^(ERROR|WARNING):\\s+(.+)$", + "severity": 1, + "message": 2 + }, + "background": { + "activeOnStart": true, + "beginsPattern": "^Starting.*", + "endsPattern": "^.*Server ready.*" + } + }, + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "Start MCP (HTTP)", + "type": "shell", + "command": "python", + "args": [ + "-m", "arcade_mcp_server", + "--host", "0.0.0.0", + "--port", "8000", + "--reload" + ], + "isBackground": true, + "problemMatcher": [] + }, + { + "label": "Test Tools", + "type": "shell", + "command": "python", + "args": ["${workspaceFolder}/test_tools.py"], + "problemMatcher": "$python" + } + ] +} +``` + +Run tasks via: +- Command Palette: `Tasks: Run Task` +- Terminal menu: `Terminal > Run Task` + +### Method 3: Launch Configurations + +Debug your MCP tools with VSCode's debugger: + +#### Create `.vscode/launch.json`: + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug MCP Server", + "type": "python", + "request": "launch", + "module": "arcade_mcp_server", + "args": ["--debug", "--reload"], + "cwd": "${workspaceFolder}", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "ARCADE_API_KEY": "${env:ARCADE_API_KEY}" + }, + "console": "integratedTerminal" + }, + { + "name": "Debug Specific Tool", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tools/my_tool.py", + "args": ["--test"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal" + }, + { + "name": "Debug with Package", + "type": "python", + "request": "launch", + "module": "arcade_mcp_server", + "args": [ + "--tool-package", "github", + "--debug" + ], + "env": { + "GITHUB_TOKEN": "${input:githubToken}" + } + } + ], + "inputs": [ + { + "id": "githubToken", + "type": "promptString", + "description": "Enter your GitHub token", + "password": true + } + ] +} +``` + +## Development Workflow + +### Project Setup + +Recommended project structure: + +``` +my-mcp-project/ +โ”œโ”€โ”€ .vscode/ +โ”‚ โ”œโ”€โ”€ launch.json # Debug configurations +โ”‚ โ”œโ”€โ”€ tasks.json # Task definitions +โ”‚ โ”œโ”€โ”€ settings.json # Workspace settings +โ”‚ โ””โ”€โ”€ extensions.json # Recommended extensions +โ”œโ”€โ”€ .env # Environment variables +โ”œโ”€โ”€ .env.example +โ”œโ”€โ”€ tools/ +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ””โ”€โ”€ my_tools.py +โ”œโ”€โ”€ tests/ +โ”‚ โ””โ”€โ”€ test_tools.py +โ”œโ”€โ”€ requirements.txt +โ””โ”€โ”€ pyproject.toml +``` + +### Workspace Settings + +Configure `.vscode/settings.json`: + +```json +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", + "python.terminal.activateEnvironment": true, + "python.linting.enabled": true, + "python.linting.pylintEnabled": true, + "python.formatting.provider": "black", + "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": ["tests"], + "files.exclude": { + "**/__pycache__": true, + "**/*.pyc": true + }, + "terminal.integrated.env.linux": { + "PYTHONPATH": "${workspaceFolder}" + }, + "terminal.integrated.env.osx": { + "PYTHONPATH": "${workspaceFolder}" + }, + "terminal.integrated.env.windows": { + "PYTHONPATH": "${workspaceFolder}" + } +} +``` + +### Recommended Extensions + +Create `.vscode/extensions.json`: + +```json +{ + "recommendations": [ + "ms-python.python", + "ms-python.vscode-pylance", + "ms-vscode.live-server", + "humao.rest-client", + "redhat.vscode-yaml", + "ms-azuretools.vscode-docker" + ] +} +``` + +## Testing Tools + +### REST Client Extension + +Test HTTP MCP servers using REST Client: + +Create `test-mcp.http`: + +```http +### Get Server Info +GET http://localhost:8000/health + +### List Tools +POST http://localhost:8000/catalog +Content-Type: application/json +Authorization: Bearer {{$env ARCADE_API_KEY}} + +{} + +### Call Tool +POST http://localhost:8000/call_tool +Content-Type: application/json +Authorization: Bearer {{$env ARCADE_API_KEY}} + +{ + "tool_name": "greet", + "tool_arguments": { + "name": "World" + } +} +``` + +### Python Test Scripts + +Create test scripts for your tools: + +```python +# test_tools.py +import asyncio +from arcade_core.catalog import ToolCatalog + +async def test_tools(): + # Import your tools + from tools import my_tools + + # Create catalog + catalog = ToolCatalog() + catalog.add_tool(my_tools.greet, "test") + + # Test tool + result = await catalog.call_tool( + "test.greet", + {"name": "Test"} + ) + print(f"Result: {result}") + +if __name__ == "__main__": + asyncio.run(test_tools()) +``` + +## Debugging Tips + +### Breakpoint Debugging + +1. Set breakpoints in your tool code +2. Launch debugger with "Debug MCP Server" +3. Trigger tool execution +4. Step through code execution + +### Logging Configuration + +Enhanced logging for debugging: + +```python +# tools/__init__.py +import logging +from loguru import logger + +# Configure loguru +logger.add( + "debug.log", + rotation="10 MB", + level="DEBUG", + format="{time} {level} {message}" +) + +# Intercept standard logging +class InterceptHandler(logging.Handler): + def emit(self, record): + logger_opt = logger.opt(depth=6, exception=record.exc_info) + logger_opt.log(record.levelname, record.getMessage()) + +logging.basicConfig(handlers=[InterceptHandler()], level=0) +``` + +### Performance Profiling + +Profile your tools: + +```json +{ + "name": "Profile MCP Server", + "type": "python", + "request": "launch", + "module": "cProfile", + "args": [ + "-o", "profile.stats", + "-m", "arcade_mcp_server", + "--debug" + ], + "cwd": "${workspaceFolder}" +} +``` + +## Snippets + +Create useful code snippets in `.vscode/python.code-snippets`: + +```json +{ + "Arcade Tool": { + "prefix": "atool", + "body": [ + "from arcade_tdk import tool", + "from typing import Annotated", + "", + "@tool", + "def ${1:tool_name}(", + " ${2:param}: Annotated[${3:str}, \"${4:Parameter description}\"]", + ") -> Annotated[${5:str}, \"${6:Return description}\"]:", + " \"\"\"${7:Tool description}.\"\"\"", + " ${8:# Implementation}", + " return ${9:result}" + ], + "description": "Create an Arcade tool" + }, + "Async Tool": { + "prefix": "atoolasync", + "body": [ + "from arcade_tdk import tool", + "from typing import Annotated", + "", + "@tool", + "async def ${1:tool_name}(", + " ${2:param}: Annotated[${3:str}, \"${4:Parameter description}\"]", + ") -> Annotated[${5:str}, \"${6:Return description}\"]:", + " \"\"\"${7:Tool description}.\"\"\"", + " ${8:# Async implementation}", + " return ${9:result}" + ], + "description": "Create an async Arcade tool" + } +} +``` + +## Integration Examples + +### Multi-Server Setup + +Run multiple MCP servers for different purposes: + +```json +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Start All Servers", + "dependsOn": [ + "Start API Tools", + "Start Data Tools", + "Start Utility Tools" + ], + "problemMatcher": [] + }, + { + "label": "Start API Tools", + "type": "shell", + "command": "python -m arcade_mcp_server --port 8001", + "options": { + "cwd": "${workspaceFolder}/api_tools" + }, + "isBackground": true + }, + { + "label": "Start Data Tools", + "type": "shell", + "command": "python -m arcade_mcp_server --port 8002", + "options": { + "cwd": "${workspaceFolder}/data_tools" + }, + "isBackground": true + }, + { + "label": "Start Utility Tools", + "type": "shell", + "command": "python -m arcade_mcp_server --port 8003", + "options": { + "cwd": "${workspaceFolder}/util_tools" + }, + "isBackground": true + } + ] +} +``` + +### Environment Management + +Handle multiple environments: + +```json +{ + "version": "2.0.0", + "tasks": [ + { + "label": "MCP Server (Dev)", + "type": "shell", + "command": "python -m arcade_mcp_server --env-file .env.dev", + "problemMatcher": [] + }, + { + "label": "MCP Server (Staging)", + "type": "shell", + "command": "python -m arcade_mcp_server --env-file .env.staging", + "problemMatcher": [] + }, + { + "label": "MCP Server (Prod)", + "type": "shell", + "command": "python -m arcade_mcp_server --env-file .env.prod", + "problemMatcher": [], + "presentation": { + "reveal": "always", + "panel": "dedicated", + "showReuseMessage": true, + "clear": true + } + } + ] +} +``` + +## Best Practices + +1. **Use Virtual Environments**: Always work in isolated environments +2. **Version Control Settings**: Include `.vscode` in your repository +3. **Environment Files**: Use `.env` files for secrets +4. **Consistent Formatting**: Configure formatters and linters +5. **Test Automation**: Set up test tasks and debug configs +6. **Documentation**: Keep README and docstrings updated +7. **Git Hooks**: Use pre-commit for code quality + +## Troubleshooting + +### Common Issues + +1. **Python interpreter not found**: + - Select interpreter: `Cmd/Ctrl + Shift + P` > "Python: Select Interpreter" + - Ensure virtual environment is activated + +2. **Module import errors**: + - Check PYTHONPATH in settings + - Verify package installation + - Restart VSCode + +3. **Debug breakpoints not working**: + - Ensure you're using the debug configuration + - Check that debugpy is installed + - Verify source maps are correct + +4. **Task execution fails**: + - Check task definition syntax + - Verify working directory + - Review terminal output for errors diff --git a/libs/arcade-mcp-server/docs/examples/00_hello_world.md b/libs/arcade-mcp-server/docs/examples/00_hello_world.md new file mode 100644 index 00000000..931cf91d --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/00_hello_world.md @@ -0,0 +1,21 @@ +# 00 - Hello World + +The simplest possible MCP server with a single tool using arcade-mcp-server. + +## Running the Example + +- **Run (HTTP default)**: `python -m arcade_mcp_server` +- **Run (stdio for Claude Desktop)**: `python -m arcade_mcp_server stdio` + +## Source Code + +```python +--8<-- "docs/examples/00_hello_world.py" +``` + +## Key Concepts + +- **Minimal Setup**: Just import `@tool` decorator and annotate your function +- **Auto-Discovery**: The CLI automatically finds tools in your current directory +- **Transport Flexibility**: Works with both stdio (for Claude Desktop) and HTTP +- **Type Annotations**: Use `Annotated` to provide descriptions for parameters and return values diff --git a/libs/arcade-mcp-server/docs/examples/00_hello_world.py b/libs/arcade-mcp-server/docs/examples/00_hello_world.py new file mode 100644 index 00000000..060f6e79 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/00_hello_world.py @@ -0,0 +1,37 @@ +""" +00_hello_world.py - The simplest possible MCP server + +This example shows the absolute minimum code needed to create an MCP server +with a single tool using arcade-mcp-server. + +To run (auto-discovery): +1. Keep this file in the current directory +2. Run: python -m arcade_mcp_server + +For Claude Desktop (stdio transport): + python -m arcade_mcp_server stdio +""" + +from typing import Annotated + +from arcade_mcp_server import tool + + +@tool +def greet(name: Annotated[str, "Name of the person to greet"]) -> Annotated[str, "Welcome message"]: + """Greet a person by name with a welcome message.""" + + return f"Hello, {name}! Welcome to Arcade MCP." + + +# That's it! The arcade_mcp_server CLI will handle everything else: +# - Creating the MCP server +# - Setting up the transport (stdio or HTTP) +# - Registering your tool +# - Handling all the protocol communication + +# When you run `python -m arcade_mcp_server`, it will: +# 1. Discover this file (if in current directory) +# 2. Find the @tool decorated function +# 3. Create an MCP server with this tool +# 4. Start listening for requests diff --git a/libs/arcade-mcp-server/docs/examples/01_tools.md b/libs/arcade-mcp-server/docs/examples/01_tools.md new file mode 100644 index 00000000..b950f67d --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/01_tools.md @@ -0,0 +1,131 @@ +# 01 - Tools + +Learn how to create tools with different parameter types and how arcade_mcp_server discovers them automatically. + +## Running the Example + +- **Run**: `python -m arcade_mcp_server` +- **Run (stdio)**: `python -m arcade_mcp_server stdio` +- **Show loaded packages**: `python -m arcade_mcp_server --show-packages` +- **Load specific package**: `python -m arcade_mcp_server --tool-package github` +- **Discover all installed**: `python -m arcade_mcp_server --discover-installed` + +## Source Code + +```python +--8<-- "docs/examples/01_tools.py" +``` + +## Creating Tools + +### 1. Simple Tools + +Basic tools with simple parameter types: + +```python +@tool +def hello(name: Annotated[str, "Name to greet"]) -> str: + """Say hello to someone.""" + return f"Hello, {name}!" + +@tool +def add( + a: Annotated[float, "First number"], + b: Annotated[float, "Second number"] +) -> Annotated[float, "Sum of the numbers"]: + """Add two numbers together.""" + return a + b +``` + +### 2. List Parameters + +Working with lists of values: + +```python +@tool +def calculate_average( + numbers: Annotated[list[float], "List of numbers to average"] +) -> Annotated[float, "Average of all numbers"]: + """Calculate the average of a list of numbers.""" + if not numbers: + return 0.0 + return sum(numbers) / len(numbers) +``` + +### 3. Complex Types with TypedDict + +Using TypedDict for structured input and output: + +```python +class PersonInfo(TypedDict): + name: str + age: int + email: str + is_active: bool + +@tool +def create_user_profile( + person: Annotated[PersonInfo, "Person's information"] +) -> Annotated[str, "Formatted user profile"]: + """Create a formatted user profile from person information.""" + # Implementation here +``` + +## Tool Discovery + +The arcade_mcp_server CLI discovers tools in multiple ways: + +### 1. Current Directory +- Scans all `*.py` files in the current directory +- Imports and checks for `@tool` decorated functions + +### 2. Standard Directories +- `tools/` directory - Common convention for organizing tools +- `arcade_tools/` directory - Alternative naming convention +- Both are recursively scanned for Python files + +### 3. Package Loading +```bash +# Load a specific package +python -m arcade_mcp_server --tool-package github + +# Discover all installed arcade packages +python -m arcade_mcp_server --discover-installed +``` + +### 4. File Organization + +Example project structure: +``` +my_project/ +โ”œโ”€โ”€ hello.py # Contains @tool functions +โ”œโ”€โ”€ tools/ +โ”‚ โ””โ”€โ”€ math.py # More @tool functions +โ””โ”€โ”€ arcade_tools/ + โ””โ”€โ”€ utils.py # Even more @tool functions +``` + +## Best Practices + +### Parameter Annotations +- **Always use `Annotated`**: Provide descriptions for all parameters +- **Clear descriptions**: Help the AI understand what each parameter does +- **Type hints**: Use proper Python type hints for validation + +### Tool Design +- **Single purpose**: Each tool should do one thing well +- **Error handling**: Add validation and helpful error messages +- **Return types**: Always annotate return types with descriptions + +### Organization +- **Group related tools**: Use directories to organize by functionality +- **Naming conventions**: Use clear, descriptive names +- **Documentation**: Write clear docstrings for each tool + +## Key Concepts + +- **Auto-Discovery**: Automatically finds tools without explicit registration +- **Type Safety**: Full type annotation support with runtime validation +- **TypedDict Support**: Use TypedDict for complex structured data +- **Flexible Organization**: Structure your tools however makes sense for your project +- **Multiple Sources**: Discover from files, directories, and packages diff --git a/libs/arcade-mcp-server/docs/examples/01_tools.py b/libs/arcade-mcp-server/docs/examples/01_tools.py new file mode 100644 index 00000000..f185cbd6 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/01_tools.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +from typing import Annotated + +from arcade_mcp_server import tool +from typing_extensions import TypedDict + +""" +01_tools.py - Tool creation, discovery, and parameter types + +This example demonstrates: +1. How to create tools with the @tool decorator +2. Different parameter types (simple, lists, TypedDict) +3. How arcade_mcp_server discovers tools automatically + +To run: + python -m arcade_mcp_server # Auto-discover all tools + python -m arcade_mcp_server --show-packages # Show what's being loaded + python -m arcade_mcp_server stdio # For Claude Desktop +""" + +# === DISCOVERY PATTERNS === + +""" +The arcade_mcp_server CLI discovers tools using these patterns: + +1. Current directory: *.py files + - Scans all Python files in the current directory + - Imports and checks for @tool decorated functions + +2. tools/ directory: + - If exists, recursively scans for Python files + - Common convention for organizing tools + +3. arcade_tools/ directory: + - Alternative directory name + - Also recursively scanned + +4. Package loading with --tool-package: + python -m arcade_mcp_server --tool-package github + - Loads arcade-github package + - Can load any installed package in the current python environment + +5. Discover all installed with --discover-installed: + python -m arcade_mcp_server --discover-installed + - Finds all arcade-* packages in the current python environment + - Loads all their tools + +Discovery tips: +- Use __init__.py in directories for proper imports +- Organize related tools in subdirectories +- Use clear, descriptive tool names +- Tools are namespaced by their toolkit name +""" + +# === SIMPLE TOOLS === + + +@tool +def hello(name: Annotated[str, "Name to greet"]) -> Annotated[str, "Greeting message"]: + """Say hello to someone.""" + return f"Hello, {name}!" + + +@tool +def add( + a: Annotated[float, "First number"], b: Annotated[float, "Second number"] +) -> Annotated[float, "Sum of the numbers"]: + """Add two numbers together.""" + return a + b + + +# === TOOLS WITH LIST PARAMETERS === + + +@tool +def calculate_average( + numbers: Annotated[list[float], "List of numbers to average"], +) -> Annotated[float, "Average of all numbers"]: + """Calculate the average of a list of numbers.""" + if not numbers: + return 0.0 + return sum(numbers) / len(numbers) + + +@tool +def factorial(n: Annotated[int, "Non-negative integer"]) -> Annotated[int, "Factorial of n"]: + """Calculate the factorial of a number.""" + if n < 0: + raise ValueError("Factorial not defined for negative numbers") + if n == 0: + return 1 + + result = 1 + for i in range(1, n + 1): + result *= i + return result + + +# === TOOLS WITH COMPLEX TYPES (TypedDict) === + + +class PersonInfo(TypedDict): + name: str + age: int + email: str + is_active: bool + + +@tool +def create_user_profile( + person: Annotated[PersonInfo, "Person's information"], +) -> Annotated[str, "Formatted user profile"]: + """Create a formatted user profile from person information.""" + status = "Active" if person["is_active"] else "Inactive" + return f""" +User Profile: +- Name: {person["name"]} +- Age: {person["age"]} +- Email: {person["email"]} +- Status: {status} +""".strip() + + +class CalculationResult(TypedDict): + sum: float + average: float + min: float + max: float + count: int + + +@tool +def analyze_numbers( + values: Annotated[list[float], "List of numbers to analyze"], +) -> Annotated[CalculationResult, "Statistical analysis of the numbers"]: + """Analyze a list of numbers and return statistics.""" + if not values: + return {"sum": 0.0, "average": 0.0, "min": 0.0, "max": 0.0, "count": 0} + + return { + "sum": sum(values), + "average": sum(values) / len(values), + "min": min(values), + "max": max(values), + "count": len(values), + } diff --git a/libs/arcade-mcp-server/docs/examples/02_building_apps.md b/libs/arcade-mcp-server/docs/examples/02_building_apps.md new file mode 100644 index 00000000..8d4c0961 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/02_building_apps.md @@ -0,0 +1,92 @@ +# 02 - Building Apps + +Build and run an MCP server programmatically using the FastAPI-like `MCPApp` interface. + +## Running the Example + +- **Run HTTP**: `python examples/02_building_apps.py` +- **Run stdio**: `python examples/02_building_apps.py stdio` + +## Source Code + +```python +--8<-- "docs/examples/02_building_apps.py" +``` + +## MCPApp Features + +### 1. Creating an App + +```python +from arcade_mcp_server import MCPApp + +app = MCPApp( + name="my_server", + version="1.0.0", + title="My MCP Server", + instructions="This server provides utility tools", + log_level="INFO" +) +``` + +### 2. Adding Tools + +Use the `@app.tool` decorator to add tools: +```python +@app.tool +def my_tool(param: Annotated[str, "Description"]) -> str: + """Tool description.""" + return f"Result: {param}" +``` + +### 3. Running the Server + +```python +# Default HTTP transport +app.run() + +# Specify options +app.run( + host="0.0.0.0", + port=8080, + reload=True, # Auto-reload on code changes + transport="http" +) + +# For stdio transport (Claude Desktop) +app.run(transport="stdio") +``` + +### 4. Using Context + +Tools can access runtime context: +```python +@app.tool +async def context_aware(context: Context, value: str) -> dict: + """Tool that uses context features.""" + # Access user info + user_id = context.user_id + + + # Use MCP features if available + if context: + await context.log.info(f"Processing for user: {user_id}") + + # Access secrets + secret_keys = list(context.secrets.keys()) + + + return { + "user": user_id, + "value": value, + "available_secrets": secret_keys + } +``` + +## Key Concepts + +- **FastAPI-like Interface**: Familiar decorator-based API design +- **Programmatic Control**: Build servers without CLI dependency +- **Transport Flexibility**: Support for both HTTP and stdio transports +- **Context Integration**: Access to user info, logging, and secrets +- **Development Features**: Hot reload, debug logging, and more diff --git a/libs/arcade-mcp-server/docs/examples/02_building_apps.py b/libs/arcade-mcp-server/docs/examples/02_building_apps.py new file mode 100644 index 00000000..b3446972 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/02_building_apps.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +""" +02_building_apps.py - Build an MCP server using MCPApp + +This example shows how to build and run an MCP server programmatically +using `MCPApp` instead of relying on the arcade_mcp_server CLI. + +To run (HTTP transport by default): + python 02_building_apps.py + +To run with stdio transport (for Claude Desktop): + python 02_building_apps.py stdio +""" + +import sys +from typing import Annotated + +from arcade_mcp_server import Context, MCPApp + +# Create the MCP application +app = MCPApp( + name="my_mcp_server", version="0.1.0", instructions="Example MCP server built with MCPApp" +) + + +@app.tool +def greet( + name: Annotated[str, "Name of the person to greet"], +) -> Annotated[str, "Greeting message"]: + """Return a friendly greeting. + + Parameters: + name: Person's name + + Returns: + Greeting message. + """ + return f"Hello, {name}!" + + +@app.tool +async def whoami(context: Context) -> Annotated[dict, "Basic server and user information"]: + """Return basic information from the tool context. + + Returns: + Dictionary with `user_id` and whether MCP features are available. + """ + user_id = context.user_id or "anonymous" + + if context: + await context.log.info(f"whoami called by: {user_id}") + + secret_keys = [secret.key for secret in context.secrets] if context.secrets else [] + return { + "user_id": user_id, + "secret_keys": secret_keys, + } + + +if __name__ == "__main__": + # Check if stdio transport was requested + if len(sys.argv) > 1 and sys.argv[1] == "stdio": + app.run(transport="stdio") + else: + # Default to HTTP transport + app.run(host="127.0.0.1", port=8001) diff --git a/libs/arcade-mcp-server/docs/examples/03_context.md b/libs/arcade-mcp-server/docs/examples/03_context.md new file mode 100644 index 00000000..69ac4ab3 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/03_context.md @@ -0,0 +1,64 @@ +# 03 - Tool Context + +Access runtime features through Context including logging, secrets, and progress reporting. + +## Running the Example + +- **Run**: `python -m arcade_mcp_server` +- **Run (stdio)**: `python -m arcade_mcp_server stdio` +- **Env**: set `API_KEY`, `DATABASE_URL` + +## Source Code + +```python +--8<-- "docs/examples/03_context.py" +``` + +## Context Features + +The Context provides access to runtime features: + +### 1. Logging +Send log messages at different levels: +```python +await context.log.debug("Debug message") +await context.log.info("Information message") +await context.log.warning("Warning message") +await context.log.error("Error message") +``` + +### 2. Secrets Management +Access environment variables securely: +```python +try: + api_key = context.get_secret("API_KEY") +except ValueError: + # Handle missing secret +``` + +### 3. User Context +Access information about the current user: +```python +user_id = context.user_id or "anonymous" +``` + +### 4. Progress Reporting +Report progress for long-running operations: +```python +await context.progress.report(current, total, "Processing...") +``` + +### 5. Tool Decorator Options +Specify required secrets: +```python +@tool(requires_secrets=["DATABASE_URL", "API_KEY"]) +async def my_tool(context: Context, ...): +``` + +## Key Concepts + +- **Context Parameter**: Tools receive a `Context` as their first parameter +- **Async Functions**: Use `async def` for tools that use context features +- **Secure Secrets**: Secrets are accessed through context, not hardcoded +- **Structured Logging**: Log at appropriate levels for debugging +- **Progress Updates**: Keep users informed during long operations diff --git a/libs/arcade-mcp-server/docs/examples/03_context.py b/libs/arcade-mcp-server/docs/examples/03_context.py new file mode 100644 index 00000000..ef1d7ee2 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/03_context.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +""" +03_context.py - Using Context with namespaced runtime APIs + +This example shows how tools can access runtime features through +Context (provided at runtime by the TDK wrapper), including logging, +secrets, and progress reporting. + +To run (auto-discovery): + python -m arcade_mcp_server + +For Claude Desktop (stdio transport): + python -m arcade_mcp_server stdio + +Set environment variables for secrets: + export API_KEY="your-secret-key" + export DATABASE_URL="postgresql://localhost/mydb" +""" + +from typing import Annotated, Any + +from arcade_mcp_server import Context, tool + + +@tool +async def secure_api_call( + context: Context, + endpoint: Annotated[str, "API endpoint to call"], + method: Annotated[str, "HTTP method (GET, POST, etc.)"] = "GET", +) -> Annotated[str, "API response or error message"]: + """Make a secure API call using secrets from context.""" + + # Access secrets from environment via Context helper + try: + api_key = context.get_secret("API_KEY") + except ValueError: + await context.log.error("API_KEY not found in environment") + return "Error: API_KEY not configured" + + # Log the API call + await context.log.info(f"Making {method} request to {endpoint}") + + # Simulate API call (in real code, use httpx or aiohttp) + return f"Successfully called {endpoint} with API key: {api_key[:4]}..." + + +@tool(requires_secrets=["DATABASE_URL"]) +async def database_info( + context: Context, table_name: Annotated[str | None, "Specific table to check"] = None +) -> Annotated[str, "Database connection info"]: + """Get database connection information from context.""" + + # Get database URL from secrets + try: + db_url = context.get_secret("DATABASE_URL") + except ValueError: + db_url = "Not configured" + + # Log at different levels + if db_url == "Not configured": + await context.log.warning("DATABASE_URL not set") + else: + await context.log.debug(f"Checking database: {db_url.split('@')[-1]}") + + # Get user info + user_info = f"User: {context.user_id or 'anonymous'}" + + if table_name: + return f"{user_info}\nDatabase: {db_url}\nChecking table: {table_name}" + else: + return f"{user_info}\nDatabase: {db_url}" + + +@tool +async def debug_context( + context: Context, + show_secrets: Annotated[bool, "Whether to show secret keys (not values)"] = False, +) -> Annotated[dict, "Current context information"]: + """Debug tool to inspect the current context.""" + + info: dict[str, Any] = { + "user_id": context.user_id, + } + + if show_secrets: + # Only show keys, not values for security + info["secret_keys"] = [s.key for s in (context.secrets or [])] + + # Log that debug info was accessed + await context.log.info(f"Debug context accessed by {context.user_id or 'unknown'}") + + return info + + +@tool +async def process_with_progress( + context: Context, + items: Annotated[list[str], "Items to process"], + delay_seconds: Annotated[float, "Delay between items"] = 0.1, +) -> Annotated[dict, "Processing results"]: + """Process items with progress notifications.""" + + results: dict[str, list] = {"processed": [], "errors": []} + + # Log start + await context.log.info(f"Starting to process {len(items)} items") + + for i, item in enumerate(items): + try: + # Simulate processing + import asyncio + + await asyncio.sleep(delay_seconds) + + # Report progress (current, total, message) + await context.progress.report(i + 1, len(items), f"Processing: {item}") + await context.log.debug(f"Processing item {i + 1}/{len(items)}: {item}") + + results["processed"].append(item.upper()) + + except Exception as e: + await context.log.error(f"Failed to process {item}: {e}") + results["errors"].append({"item": item, "error": str(e)}) + + # Log completion + await context.log.info( + f"Processing complete: {len(results['processed'])} succeeded, " + f"{len(results['errors'])} failed" + ) + + return results + + +# The Context provides at runtime (via TDK wrapper): +# - context.user_id: ID of the user making the request +# - context.get_secret(key): Retrieve a secret value (raises if missing) +# - context.log.(msg): Send log messages to the client (debug/info/warning/error) +# - context.progress.report(progress, total=None, message=None): Progress updates diff --git a/libs/arcade-mcp-server/docs/examples/04_secrets.md b/libs/arcade-mcp-server/docs/examples/04_secrets.md new file mode 100644 index 00000000..63117c6f --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/04_secrets.md @@ -0,0 +1,60 @@ +# 04 - Tool Secrets + +Read secrets from environment and `.env` files securely via Context. + +## Running the Example + +- **Run**: `python -m arcade_mcp_server` +- **Run (stdio)**: `python -m arcade_mcp_server stdio` +- **Create `.env`**: Add `API_KEY=supersecret` to a `.env` file + +## Source Code + +```python +--8<-- "docs/examples/04_tool_secrets.py" +``` + +## Working with Secrets + +### 1. Environment Variables + +Secrets can be provided via environment variables: +```bash +export API_KEY="your-secret-key" +export DATABASE_URL="postgresql://localhost/mydb" +python -m arcade_mcp_server +``` + +### 2. Using .env Files + +Create a `.env` file in your working directory: +``` +API_KEY=supersecret +DATABASE_URL=postgresql://user:pass@localhost/db +GITHUB_TOKEN=ghp_xxxxxxxxxxxx +``` + +### 3. Declaring Required Secrets + +Use the `requires_secrets` parameter to declare which secrets your tool needs: +```python +@tool(requires_secrets=["API_KEY", "DATABASE_URL"]) +def my_secure_tool(context: Context) -> str: + api_key = context.get_secret("API_KEY") + db_url = context.get_secret("DATABASE_URL") +``` + +### 4. Security Best Practices + +- **Never log secret values**: Always mask or truncate when displaying +- **Declare requirements**: Use `requires_secrets` to document dependencies +- **Handle missing secrets**: Use try/except when accessing secrets +- **Use descriptive names**: Make it clear what each secret is for + +## Key Concepts + +- **Secure Access**: Secrets are accessed through context, not imported directly +- **Environment Integration**: Works with both environment variables and .env files +- **Error Handling**: Always handle the case where a secret might be missing +- **Masking**: Never expose full secret values in logs or return values +- **Declaration**: Use `requires_secrets` to make dependencies explicit diff --git a/libs/arcade-mcp-server/docs/examples/04_secrets.py b/libs/arcade-mcp-server/docs/examples/04_secrets.py new file mode 100644 index 00000000..8b593eec --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/04_secrets.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +"""04: Read secrets from .env via Context + +Run (auto-discovery): + python -m arcade_mcp_server + +For Claude Desktop (stdio transport): + python -m arcade_mcp_server stdio + +Environment: + # Create a .env in the working directory with: + # API_KEY=supersecret +""" + +from arcade_mcp_server import Context, tool + + +@tool( + name="UseSecret", + desc="Echo a masked secret read from the context", + requires_secrets=["API_KEY"], # declare we need API_KEY +) +def use_secret(context: Context) -> str: + """Read API_KEY from context and return a masked confirmation string.""" + try: + value = context.get_secret("API_KEY") + masked = value[:2] + "***" if len(value) >= 2 else "***" + return f"Got API_KEY of length {len(value)} -> {masked}" + except Exception as e: + return f"Error getting secret: {e}" diff --git a/libs/arcade-mcp-server/docs/examples/05_logging.md b/libs/arcade-mcp-server/docs/examples/05_logging.md new file mode 100644 index 00000000..852cab97 --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/05_logging.md @@ -0,0 +1,101 @@ +# 05 - Logging + +Demonstrates MCP logging capabilities with various levels and patterns for debugging and monitoring. + +## Running the Example + +- **Run**: `python examples/05_logging.py` +- Set `log_level="DEBUG"` in `MCPApp` to see debug logs + +## Source Code + +```python +--8<-- "docs/examples/05_logging.py" +``` + +## Logging Features + +### 1. Log Levels + +MCP supports standard log levels: +```python +await context.log.debug("Detailed debugging information") +await context.log.info("General information") +await context.log.warning("Warning messages") +await context.log.error("Error messages") +``` + +### 2. Structured Logging + +Log with context and metadata: +```python +# Include user context +await context.log.info( + f"Action performed by user: {context.user_id}" +) + +# Add operation details +await context.log.debug( + f"Processing {item_count} items with options: {options}" +) +``` + +### 3. Error Logging + +Proper error handling and logging: +```python +try: + # Operation that might fail + result = risky_operation() +except Exception as e: + # Log error with type and message + await context.log.error( + f"Operation failed: {type(e).__name__}: {str(e)}" + ) + + # Log traceback at debug level + await context.log.debug( + f"Traceback:\n{traceback.format_exc()}" + ) +``` + +### 4. Progress Logging + +Track long-running operations: +```python +for i, item in enumerate(items): + # Log progress + await context.log.debug( + f"Progress: {i+1}/{len(items)} ({(i+1)/len(items)*100:.0f}%)" + ) + + # Process item + process(item) +``` + +### 5. Batch Processing + +Log batch operations effectively: +```python +# Log batch start +await context.log.info(f"Starting batch of {count} items") + +# Log individual items at debug level +for item in items: + await context.log.debug(f"Processing: {item}") + +# Log summary +await context.log.info( + f"Batch complete: {success_count} successful, {fail_count} failed" +) +``` + +## Best Practices + +1. **Use Appropriate Levels**: Debug for details, info for general flow, warning for issues, error for failures +2. **Include Context**: Always include relevant context like user ID, operation names, counts +3. **Structure Messages**: Use consistent message formats for easier parsing +4. **Handle Errors Gracefully**: Log errors with enough detail to debug but not expose sensitive data +5. **Progress Updates**: For long operations, provide regular progress updates +6. **Batch Summaries**: For batch operations, log both individual items (debug) and summaries (info) +7. **Performance Considerations**: Be mindful of log volume in production environments diff --git a/libs/arcade-mcp-server/docs/examples/05_logging.py b/libs/arcade-mcp-server/docs/examples/05_logging.py new file mode 100644 index 00000000..cfe2587e --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/05_logging.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +""" +05_logging.py - MCP logging capabilities + +This example demonstrates the various logging levels and patterns +available through the MCP protocol for debugging and monitoring. + +To run: + python 05_logging.py + +To see debug logs: + Set log_level="DEBUG" when creating MCPApp +""" + +import asyncio +import time +import traceback +from typing import Annotated, Optional + +from arcade_mcp_server import Context, MCPApp + +# Create the app with debug logging +app = MCPApp(name="logging_examples", version="0.1.0", log_level="DEBUG") + + +@app.tool +async def demonstrate_log_levels( + context: Context, message: Annotated[str, "Base message to log at different levels"] +) -> Annotated[dict, "Summary of logged messages"]: + """Demonstrate all MCP logging levels.""" + + # Log at each level + levels = ["debug", "info", "warning", "error"] + logged = {} + + for level in levels: + log_message = f"[{level.upper()}] {message}" + await context.log(level, log_message) + logged[level] = log_message + + return {"logged_messages": logged, "note": "Check your MCP client to see these messages"} + + +@app.tool +async def timed_operation( + context: Context, + operation_name: Annotated[str, "Name of the operation"], + duration_seconds: Annotated[float, "How long the operation takes"] = 2.0, +) -> Annotated[dict, "Operation timing details"]: + """Perform a timed operation with detailed logging.""" + + start_time = time.time() + + # Log operation start + await context.log.info( + f"Starting operation: {operation_name} (expected duration: {duration_seconds}s)" + ) + + # Simulate work with progress logging + steps = 5 + for i in range(steps): + await context.log.debug(f"Progress: step {i + 1}/{steps} ({(i + 1) / steps * 100:.0f}%)") + + await asyncio.sleep(duration_seconds / steps) + + # Calculate results + end_time = time.time() + actual_duration = end_time - start_time + + # Log completion + await context.log.info(f"Completed operation: {operation_name} in {actual_duration:.2f}s") + + return { + "operation": operation_name, + "expected_duration": duration_seconds, + "actual_duration": round(actual_duration, 2), + "start_time": start_time, + "end_time": end_time, + } + + +@app.tool +async def error_handling_example( + context: Context, + should_fail: Annotated[bool, "Whether to simulate an error"], + error_type: Annotated[str, "Type of error to simulate"] = "ValueError", +) -> Annotated[dict, "Result or error details"]: + """Demonstrate error logging and handling.""" + + try: + await context.log.debug(f"Error handling test: should_fail={should_fail}") + + if should_fail: + if error_type == "ValueError": + raise ValueError("This is a simulated value error") # noqa: TRY301 + elif error_type == "KeyError": + raise KeyError("missing_key") # noqa: TRY301 + elif error_type == "ZeroDivisionError": + result = 1 / 0 + return {"result": result} + else: + raise Exception(f"Generic error of type: {error_type}") # noqa: TRY002, TRY301 + + # Success case + await context.log.info("Operation completed successfully") + + except Exception as e: + # Log the error with details + await context.log.error(f"Operation failed with {type(e).__name__}: {e!s}") + + # Log traceback separately at debug level + await context.log.debug(f"Traceback:\n{traceback.format_exc()}") + + return { + "status": "error", + "error_type": type(e).__name__, + "error_message": str(e), + "handled": True, + } + else: + return {"status": "success", "message": "No errors occurred"} + + +@app.tool +async def structured_logging( + context: Context, + user_action: Annotated[str, "Action the user is performing"], + metadata: Annotated[dict | None, "Additional metadata to log"] = None, +) -> Annotated[str, "Confirmation message"]: + """Demonstrate structured logging patterns.""" + + # Log main action + await context.log.info( + f"User action: {user_action} (user_id: {context.user_id or 'anonymous'})" + ) + + # Log additional details at debug level + await context.log.debug( + f"Context details: {len(context.secrets) if context.secrets else 0} secrets available" + ) + + # Log metadata if provided + if metadata: + await context.log.debug(f"Custom metadata: {metadata}") + + return f"Logged user action: {user_action}" + + +@app.tool +async def batch_processing_logs( + context: Context, + items: Annotated[list[str], "Items to process"], + fail_on_item: Annotated[Optional[str], "Item that should fail"] = None, +) -> Annotated[dict, "Processing results with detailed logs"]: + """Process items with detailed logging for each step.""" + + results: dict[str, list] = {"successful": [], "failed": []} + + await context.log.info(f"Starting batch processing of {len(items)} items") + + for i, item in enumerate(items): + try: + # Log item start + await context.log.debug(f"Processing item {i + 1}/{len(items)}: {item}") + + # Simulate failure if requested + if item == fail_on_item: + raise ValueError(f"Simulated failure for item: {item}") # noqa: TRY301 + + # Simulate processing + await asyncio.sleep(0.1) + + results["successful"].append(item) + + except Exception as e: + await context.log.warning(f"Failed to process '{item}': {e!s}") + results["failed"].append({"item": item, "error": str(e)}) + + # Log summary + await context.log.info( + f"Batch processing complete: {len(results['successful'])} successful, " + f"{len(results['failed'])} failed", + ) + + return results + + +if __name__ == "__main__": + # Run the server + app.run(host="127.0.0.1", port=8001) diff --git a/libs/arcade-mcp-server/docs/examples/README.md b/libs/arcade-mcp-server/docs/examples/README.md new file mode 100644 index 00000000..c68ac7ad --- /dev/null +++ b/libs/arcade-mcp-server/docs/examples/README.md @@ -0,0 +1,54 @@ +# Arcade MCP Examples + +This directory contains examples demonstrating how to build MCP servers with your Arcade tools. + +## Examples Overview + +### Basic Examples + +1. **[00_hello_world.py](00_hello_world.py)** โ€“ Minimal tool example + - Single `@tool` function showing the basics + - Run: `python -m arcade_mcp_server` (or `python -m arcade_mcp_server stdio`) + +2. **[01_tools.py](01_tools.py)** โ€“ Creating tools and discovery + - Simple parameters, lists, and `TypedDict` + - How arcade_mcp_server discovers tools automatically + - Run: `python -m arcade_mcp_server` + +3. **[02_building_apps.py](02_building_apps.py)** โ€“ Building apps with MCPApp + - Create an `MCPApp`, register tools with `@app.tool` + - Run HTTP: `python 02_building_apps.py` + - Run stdio: `python 02_building_apps.py stdio` + +4. **[03_context.py](03_context.py)** โ€“ Using `Context` + - Access secrets, logging, and user context + - Run: `python -m arcade_mcp_server` + +5. **[04_tool_secrets.py](04_tool_secrets.py)** โ€“ Working with secrets + - Use `requires_secrets` and access masked values + - Run: `python -m arcade_mcp_server` + +6. **[05_logging.py](05_logging.py)** โ€“ Logging with MCP + - Demonstrates debug/info/warning/error levels and structured logs + - Run: `python 05_logging.py` + +## Running Examples + +Most examples can be run directly with the arcade_mcp_server CLI: + +```bash +# Auto-discover tools in current directory +python -m arcade_mcp_server + +# With specific transport +python -m arcade_mcp_server stdio # For Claude Desktop +python -m arcade_mcp_server # HTTP by default + +# With debugging +python -m arcade_mcp_server --debug + +# With hot reload (HTTP only) +python -m arcade_mcp_server --reload +``` + +For MCPApp examples, run the script directly to start an HTTP server. diff --git a/libs/arcade-mcp-server/docs/getting-started/quickstart.md b/libs/arcade-mcp-server/docs/getting-started/quickstart.md new file mode 100644 index 00000000..a3a02b7c --- /dev/null +++ b/libs/arcade-mcp-server/docs/getting-started/quickstart.md @@ -0,0 +1,182 @@ +# Quick Start + +The `arcade_mcp_server` package provides powerful ways to run MCP servers with your Arcade tools. + +## Getting Started + +### Install + +```bash +uv pip install arcade-mcp-server +``` + + +```bash +uv run python -m arcade_mcp_server +``` + +### Write a tool + + +```python +from arcade_mcp_server import tool + +@tool +def greet(Annotated[str, "The name to greet"]) -> Annotated[str, "The greeting"]: + return f"Hello, {name}!" +``` + +### Run MCP Server + +```bash +uv run python -m arcade_mcp_server +``` + +You should see the following output: + +```text +INFO | 03:32:05 | Auto-discovering tools from current directory +INFO | 03:32:05 | Found 1 tool(s) in 00_hello_world.py: greet +INFO: Started server process +INFO: Waiting for application startup. +INFO | 03:32:05 | Starting MCP server with HTTP transport on 127.0.0.1:7777 +INFO | 03:32:05 | Starting MCP server: ArcadeMCP +INFO | 03:32:05 | HTTP session manager started +INFO | 03:32:05 | MCP server started and ready for connections +INFO: Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:7777 (Press CTRL+C to quit) +``` + +View the docs at http://127.0.0.1:7777/docs. + +That's it! You've created an MCP server with a tool. + +Check out the [CLI](../api/cli.md) for more options and [Clients](../clients/README.md) for how to use the server with different clients like Claude Desktop, Cursor, and VSCode. + + +## Building MCP Servers + +The simplest way to create an MCP server programmatically is using `MCPApp`, which provides a FastAPI-like interface: + +```python +from arcade_mcp_server import MCPApp +from typing import Annotated + +app = MCPApp( + name="my-tools", + version="1.0.0", + instructions="Custom MCP server with specialized tools" +) + +@app.tool +def calculate( + expression: Annotated[str, "Mathematical expression to evaluate"] +) -> Annotated[float, "The result of the calculation"]: + """Safely evaluate a mathematical expression.""" + # Safe evaluation logic here + return eval(expression, {"__builtins__": {}}, {}) + +@app.tool +def fetch_data( + url: Annotated[str, "URL to fetch data from"] +) -> Annotated[dict, "The fetched data"]: + """Fetch data from an API endpoint.""" + import requests + return requests.get(url).json() + +# Run the server +if __name__ == "__main__": + app.run(host="0.0.0.0", port=8080, reload=True) +``` + +## `arcade_mcp_server` CLI + +The `arcade_mcp_server` CLI is a simple tool for running MCP servers automatically discovering tools, creating a server for you, and running it. + +This is primarily used for development, and running mcp servers locally for desktop clients with stdio. + +### Auto-Discovery Mode + +The simplest way to run is to let arcade_mcp_server discover tools in your current directory: + +```bash +# Auto-discover @tool decorated functions +python -m arcade_mcp_server + +# With stdio transport for Claude Desktop +python -m arcade_mcp_server stdio +``` + +### Loading Installed Packages + +Load specific arcade packages or discover all installed ones: + +```bash +# Load a specific arcade package +python -m arcade_mcp_server --tool-package github +python -m arcade_mcp_server -p slack + +# Discover all installed arcade packages +python -m arcade_mcp_server --discover-installed + +# Show which packages are being loaded +python -m arcade_mcp_server --discover-installed --show-packages +``` + +### Development Mode + +For active development with hot reload: + +```bash +# Run with hot reload and debug logging +python -m arcade_mcp_server --reload --debug + +# Specify host and port +python -m arcade_mcp_server --host 0.0.0.0 --port 8080 + +# Load environment variables +python -m arcade_mcp_server --env-file .env +``` + + +## Environment Variables + +Configure the server using environment variables: + +```bash +# Server settings +MCP_SERVER_NAME="My MCP Server" +MCP_SERVER_VERSION="1.0.0" + +# Arcade integration +ARCADE_API_KEY="your-api-key" +ARCADE_API_URL="https://api.arcade.dev" +ARCADE_USER_ID="user@example.com" + +# Development settings +ARCADE_AUTH_DISABLED=true +MCP_DEBUG=true + +# Tool secrets (available to tools via context) +MY_API_KEY="secret-value" +DATABASE_URL="postgresql://..." +``` + +## Development Tips + +### Hot Reload +Use `--reload --debug` for development to automatically restart on code changes: + +```bash +python -m arcade_mcp_server --reload --debug +``` + +### Logging +- Use `--debug` for verbose logging +- In stdio mode, logs go to stderr +- In HTTP mode, logs go to stdout + +### Testing Tools +With HTTP transport and debug mode, access API documentation at: +- http://localhost:8000/docs (Swagger UI) +- http://localhost:8000/redoc (ReDoc) diff --git a/libs/arcade-mcp-server/docs/index.md b/libs/arcade-mcp-server/docs/index.md new file mode 100644 index 00000000..5bb0da5b --- /dev/null +++ b/libs/arcade-mcp-server/docs/index.md @@ -0,0 +1,88 @@ +# Arcade MCP + +

+ Arcade Logo +

+ +Arcade MCP (Model Context Protocol) enables AI assistants and development tools to interact with your Arcade tools through a standardized protocol. Build, deploy, and integrate your MCP servers seamlessly across different AI platforms. + +## Quick Links + +- **[Quickstart Guide](getting-started/quickstart.md)** - Get up and running in minutes +- **[Walkthrough](examples/README.md)** - Learn by example +- **[API Reference](api/mcp_app.md)** - MCPApp API documentation + +## Features + +- ๐Ÿš€ **FastAPI-like Interface** - Simple, intuitive API with `MCPApp` +- ๐Ÿ”ง **Tool Discovery** - Automatic discovery of tools in your project +- ๐Ÿ”Œ **Multiple Transports** - Support for stdio and HTTP/SSE +- ๐Ÿค– **Multi-Client Support** - Works with Claude, Cursor, VS Code, and more +- ๐Ÿ“ฆ **Package Integration** - Load installed Arcade packages +- ๐Ÿ” **Built-in Security** - Environment-based configuration and secrets +- ๐Ÿ”„ **Hot Reload** - Development mode with automatic reloading +- ๐Ÿ“Š **Production Ready** - Deploy with Docker, systemd, PM2, or cloud platforms + +## Getting Started + +### Installation + +```bash +pip install arcade-mcp-server +``` + +### Create Your First Tool + +```python +from arcade_mcp_server import MCPApp +from typing import Annotated + +app = MCPApp(name="my-tools", version="1.0.0") + +@app.tool +def greet(name: Annotated[str, "Name to greet"]) -> str: + """Greet someone by name.""" + return f"Hello, {name}!" + +if __name__ == "__main__": + app.run() +``` + +### Run Your Server + +```bash +# For development +python my_tools.py + +# For Claude Desktop +python -m arcade_mcp_server stdio + +# For HTTP clients +python -m arcade_mcp_server --host 0.0.0.0 --port 8080 +``` + +## Client Integration + +Connect your MCP server with AI assistants and development tools: + +- **[Claude Desktop](clients/claude.md)** - Native MCP support in Claude +- **[Cursor IDE](clients/cursor.md)** - Enhanced AI coding with MCP tools +- **[VS Code](clients/vscode.md)** - Integrate with Visual Studio Code +- **[MCP Inspector](clients/inspector.md)** - Debug and test your tools + + +## Learn More + +- **[Walkthrough](examples/README.md)** - Comprehensive examples and tutorials +- **[API Reference](api/mcp_app.md)** - Detailed API documentation +- **[Transport Modes](advanced/transports.md)** - stdio and HTTP transport details + +## Community + +- [GitHub Repository](https://github.com/ArcadeAI/arcade-mcp) +- [Discord Community](https://discord.gg/arcade-mcp) +- [Documentation](https://docs.arcade.dev) + +## License + +Arcade MCP server is open source software licensed under the MIT license. diff --git a/libs/arcade-mcp-server/mkdocs.yml b/libs/arcade-mcp-server/mkdocs.yml new file mode 100644 index 00000000..c3760535 --- /dev/null +++ b/libs/arcade-mcp-server/mkdocs.yml @@ -0,0 +1,98 @@ +site_name: Arcade MCP +site_description: MCP (Model Context Protocol) support for Arcade. +site_url: https://docs.arcadeai.dev/arcade-mcp-server/ +repo_url: https://github.com/ArcadeAI/arcade-mcp +repo_name: ArcadeAI/arcade-mcp + +theme: + palette: + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + - media: "(prefers-color-scheme: light)" + scheme: default + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + toggle: + icon: material/brightness-4 + name: Switch to system preference + name: material + logo: https://docs.arcade.dev/images/logo/arcade-logo.png + favicon: https://docs.arcade.dev/images/logo/arcade-logo.png + + features: + - navigation.instant + - navigation.tracking + - navigation.expand + - navigation.indexes + - content.code.copy + - content.code.annotate + +markdown_extensions: + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences + - admonition + - pymdownx.details + - pymdownx.tabbed: + alternate_style: true + - tables + - footnotes + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + +plugins: + - search + - mkdocstrings: + handlers: + python: + paths: [../arcade_mcp_server] + options: + show_source: false + show_root_heading: true + heading_level: 3 + +exclude_docs: | + /README.md +nav: + - Home: index.md + - Getting Started: + - getting-started/quickstart.md + - Walkthrough: + - examples/README.md + - examples/00_hello_world.md + - examples/01_tools.md + - examples/02_building_apps.md + - examples/03_context.md + - examples/04_secrets.md + - examples/05_logging.md + - Clients: + - clients/claude.md + - clients/cursor.md + - clients/vscode.md + - clients/inspector.md + - API Reference: + - api/app.md + - api/cli.md + - Server: + - api/server/server.md + - api/server/middleware.md + - api/server/types.md + - api/server/errors.md + - api/server/settings.md + - Advanced: + - advanced/transports.md + +extra: + social: + - icon: fontawesome/brands/github + link: https://github.com/ArcadeAI/arcade-mcp + - icon: fontawesome/brands/python + link: https://pypi.org/project/arcade-mcp/ diff --git a/libs/arcade-mcp-server/pyproject.toml b/libs/arcade-mcp-server/pyproject.toml new file mode 100644 index 00000000..da193984 --- /dev/null +++ b/libs/arcade-mcp-server/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["hatchling>=1.25"] +build-backend = "hatchling.build" + +[project] +name = "arcade-mcp-server" +version = "1.0.0rc1" +description = "Model Context Protocol (MCP) server framework for Arcade.dev" +readme = "README.md" +authors = [{ name = "Arcade.dev" }] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">=3.10" +dependencies = [ + "arcade-core>=2.5.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", + "arcadepy>=1.5.0", + "pydantic>=2.0.0", + "fastapi>=0.100.0", + "uvicorn>=0.30.0", + "sse-starlette>=2.0.0", + "starlette>=0.37.0", + "anyio>=4.0.0", + "python-dotenv>=1.0.0", + "pydantic-settings>=2.10.1", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "mypy>=1.0.0", + "ruff>=0.1.0", + "mkdocs>=1.6.0", + "mkdocs-material>=9.6.0", + "mkdocstrings[python]>=0.28.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["arcade_mcp_server"] diff --git a/libs/arcade-serve/arcade_serve/core/__init__.py b/libs/arcade-serve/arcade_serve/core/__init__.py index e69de29b..80d76995 100644 --- a/libs/arcade-serve/arcade_serve/core/__init__.py +++ b/libs/arcade-serve/arcade_serve/core/__init__.py @@ -0,0 +1,27 @@ +"""Core components for Arcade Serve.""" + +from arcade_serve.core.base import BaseWorker +from arcade_serve.core.common import ( + RequestData, + ResponseData, + Router, + Worker, + WorkerComponent, +) +from arcade_serve.core.components import ( + CallToolComponent, + CatalogComponent, + HealthCheckComponent, +) + +__all__ = [ + "BaseWorker", + "CallToolComponent", + "CatalogComponent", + "HealthCheckComponent", + "RequestData", + "ResponseData", + "Router", + "Worker", + "WorkerComponent", +] diff --git a/libs/arcade-serve/arcade_serve/core/base.py b/libs/arcade-serve/arcade_serve/core/base.py index d24248a5..70e4faf7 100644 --- a/libs/arcade-serve/arcade_serve/core/base.py +++ b/libs/arcade-serve/arcade_serve/core/base.py @@ -40,7 +40,10 @@ class BaseWorker(Worker): ) def __init__( - self, secret: str | None = None, disable_auth: bool = False, otel_meter: Meter | None = None + self, + secret: str | None = None, + disable_auth: bool = False, + otel_meter: Meter | None = None, ) -> None: """ Initialize the BaseWorker with an empty ToolCatalog. @@ -181,5 +184,11 @@ class BaseWorker(Worker): """ Register the necessary routes to the application. """ + # Initialize components list if it doesn't exist + if not hasattr(self, "components"): + self.components = [] + for component_cls in self.default_components: - component_cls(self).register(router) + component = component_cls(self) + component.register(router) + self.components.append(component) diff --git a/libs/arcade-serve/arcade_serve/core/common.py b/libs/arcade-serve/arcade_serve/core/common.py index d165357c..9bef08c6 100644 --- a/libs/arcade-serve/arcade_serve/core/common.py +++ b/libs/arcade-serve/arcade_serve/core/common.py @@ -45,6 +45,20 @@ class Router(ABC): """ pass + @abstractmethod + def add_mount(self, path: str, app: Any, name: str | None = None) -> None: + """Mount an ASGI application at the specified path. + + This is optional for routers to implement. If not implemented, + MCPComponent will raise NotImplementedError. + + Args: + path: The URL path to mount the app at + app: The ASGI application to mount + name: Optional name for the mount + """ + raise NotImplementedError("This router does not support mounting ASGI applications") + class Worker(ABC): """ diff --git a/libs/arcade-serve/arcade_serve/core/components.py b/libs/arcade-serve/arcade_serve/core/components.py index a85d8db2..e3d93072 100644 --- a/libs/arcade-serve/arcade_serve/core/components.py +++ b/libs/arcade-serve/arcade_serve/core/components.py @@ -1,3 +1,7 @@ +from arcade_core.schema import ( + ToolCallRequest, + ToolCallResponse, +) from opentelemetry import trace from arcade_serve.core.common import ( @@ -5,8 +9,6 @@ from arcade_serve.core.common import ( HealthCheckResponse, RequestData, Router, - ToolCallRequest, - ToolCallResponse, Worker, WorkerComponent, ) diff --git a/libs/arcade-core/arcade_core/telemetry.py b/libs/arcade-serve/arcade_serve/fastapi/telemetry.py similarity index 100% rename from libs/arcade-core/arcade_core/telemetry.py rename to libs/arcade-serve/arcade_serve/fastapi/telemetry.py diff --git a/libs/arcade-serve/arcade_serve/fastapi/worker.py b/libs/arcade-serve/arcade_serve/fastapi/worker.py index f7d53a91..0d89b026 100644 --- a/libs/arcade-serve/arcade_serve/fastapi/worker.py +++ b/libs/arcade-serve/arcade_serve/fastapi/worker.py @@ -4,6 +4,7 @@ from typing import Any, Callable from fastapi import Depends, FastAPI, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from opentelemetry.metrics import Meter +from starlette.routing import Mount from arcade_serve.core.base import ( BaseWorker, @@ -26,25 +27,45 @@ class FastAPIWorker(BaseWorker): *, disable_auth: bool = False, otel_meter: Meter | None = None, + components: list[type[WorkerComponent]] | None = None, ) -> None: """ Initialize the FastAPIWorker with a FastAPI app instance. If no secret is provided, the worker will use the ARCADE_WORKER_SECRET environment variable. - Args: app: The FastAPI app to host the worker in secret: Optional secret for authorization disable_auth: Whether to disable authorization otel_meter: Optional OpenTelemetry meter + components: Optional list of components to register """ super().__init__(secret, disable_auth, otel_meter) self.app = app self.router = FastAPIRouter(app, self) - self.register_routes(self.router) - # Initialize components + # Initialize components list self.components: list[WorkerComponent] = [] + # If no components specified, register the default routes from BaseWorker + if components is None: + self.register_routes(self.router) + else: + # Register the provided components + for component_cls in components: + self.register_component(component_cls) + + def register_component(self, component_cls: type[WorkerComponent], **kwargs: Any) -> None: + """ + Register a component with the worker. + + Args: + component_cls: The component class to register + **kwargs: Additional keyword arguments to pass to the component constructor + """ + component = component_cls(self, **kwargs) + component.register(self.router) + self.components.append(component) + security = HTTPBearer() # Authorization: Bearer @@ -109,3 +130,15 @@ class FastAPIRouter(Router): # **kwargs to pass to FastAPI **kwargs, ) + + def add_mount(self, path: str, app: Any, name: str | None = None) -> None: + """Mount an ASGI application at the specified path. + + Args: + path: The URL path to mount the app at + app: The ASGI application to mount + name: Optional name for the mount + """ + # Add mount to the FastAPI app's router + mount = Mount(path, app=app, name=name) + self.app.router.routes.append(mount) diff --git a/libs/arcade-serve/arcade_serve/mcp/__init__.py b/libs/arcade-serve/arcade_serve/mcp/__init__.py deleted file mode 100644 index 979feda4..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -MCP (Model Context Protocol) support for Arcade workers. -""" - -from arcade_serve.mcp.stdio import StdioServer - -__all__ = ["StdioServer"] diff --git a/libs/arcade-serve/arcade_serve/mcp/convert.py b/libs/arcade-serve/arcade_serve/mcp/convert.py deleted file mode 100644 index 9ea7d9d9..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/convert.py +++ /dev/null @@ -1,188 +0,0 @@ -import json -import logging -from enum import Enum -from typing import Any - -from arcade_core.catalog import MaterializedTool - -# Type aliases for MCP types -MCPTool = dict[str, Any] -MCPTextContent = dict[str, Any] -MCPImageContent = dict[str, Any] -MCPEmbeddedResource = dict[str, Any] -MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource - -logger = logging.getLogger("arcade.mcp") - - -def create_mcp_tool(tool: MaterializedTool) -> dict[str, Any] | None: # noqa: C901 - """ - Create an MCP-compatible tool definition from an Arcade tool. - - Args: - tool: An Arcade tool object - - Returns: - An MCP tool definition or None if the tool cannot be converted - """ - try: - name = getattr(tool.definition, "fully_qualified_name", None) or getattr( - tool.definition, "name", "unknown" - ) - description = getattr(tool.definition, "description", "No description available") - - # Extract parameters from the input model - parameters = {} - required = [] - - if ( - hasattr(tool, "input_model") - and tool.input_model is not None - and hasattr(tool.input_model, "model_fields") - ): - for field_name, field in tool.input_model.model_fields.items(): - # Skip internal tool context parameters - if field_name == getattr( - tool.definition.input, "tool_context_parameter_name", None - ): - continue - - # Get field type information - field_type = getattr(field, "annotation", None) - field_type_name = "string" # default - - # Safety check for field_type - if field_type is int: - field_type_name = "integer" - elif field_type is float: - field_type_name = "number" - elif field_type is bool: - field_type_name = "boolean" - elif field_type is list or str(field_type).startswith("list["): - field_type_name = "array" - elif field_type is dict or str(field_type).startswith("dict["): - field_type_name = "object" - - # Get description with fallback - field_description = getattr(field, "description", None) - if not field_description: - field_description = f"Parameter: {field_name}" - - # Create parameter definition - param_def = { - "type": field_type_name, - "description": field_description, - } - - # Enum support: if the field annotation is an Enum, add allowed values - enum_type = None - if hasattr(field, "annotation"): - ann = field.annotation - # Handle typing.Annotated[Enum, ...] - if getattr(ann, "__origin__", None) is not None and hasattr(ann, "__args__"): - for arg in ann.__args__: # type: ignore[union-attr] - if isinstance(arg, type) and issubclass(arg, Enum): - enum_type = arg - break - elif isinstance(ann, type) and issubclass(ann, Enum): - enum_type = ann - if enum_type is not None: - param_def["enum"] = [e.value for e in enum_type] - - parameters[field_name] = param_def - - # In Pydantic v2, check if field is required based on default value - try: - if field.is_required(): - required.append(field_name) - except (AttributeError, TypeError): - # Fallback if is_required() doesn't exist or fails - try: - has_default = getattr(field, "default", None) is not None - has_factory = getattr(field, "default_factory", None) is not None - if not (has_default or has_factory): - required.append(field_name) - except Exception: - # Ultimate fallback - assume required if we can't determine - logger.debug( - f"Could not determine if field {field_name} is required, assuming optional" - ) - - # Create the input schema with explicit properties and required fields - input_schema = { - "type": "object", - "properties": parameters, - } - - # Only include required field if we have required parameters - if required: - input_schema["required"] = required - - # Add annotations based on tool metadata - annotations = {} - - # Use tool name as title if available - annotations["title"] = getattr(tool.definition, "title", str(name).replace(".", "_")) - - # Determine hints based on tool properties - if hasattr(tool.definition, "metadata"): - metadata = tool.definition.metadata or {} - annotations["readOnlyHint"] = metadata.get("read_only", False) - annotations["destructiveHint"] = metadata.get("destructive", False) - annotations["idempotentHint"] = metadata.get("idempotent", True) - annotations["openWorldHint"] = metadata.get("open_world", False) - - # Create the final tool definition - tool_def: MCPTool = { - "name": str(name).replace(".", "_"), - "description": str(description), - "inputSchema": input_schema, - "annotations": annotations, - } - - logger.debug(f"Created tool definition for {name}") - - except Exception: - logger.exception( - f"Error creating MCP tool definition for {getattr(tool, 'name', str(tool))}" - ) - return None - return tool_def - - -def convert_to_mcp_content(value: Any) -> list[dict[str, Any]]: - """ - Convert a Python value to MCP-compatible content. - """ - if value is None: - return [] - - if isinstance(value, (str, bool, int, float)): - return [{"type": "text", "text": str(value)}] - - if isinstance(value, (dict, list)): - return [{"type": "text", "text": json.dumps(value)}] - - # Default fallback - return [{"type": "text", "text": str(value)}] - - -def _map_type_to_json_schema_type(val_type: str) -> str: - """ - Map Arcade value types to JSON schema types. - - Args: - val_type: The Arcade value type as a string. - - Returns: - The corresponding JSON schema type as a string. - """ - mapping: dict[str, str] = { - "string": "string", - "integer": "integer", - "number": "number", - "boolean": "boolean", - "json": "object", - "array": "array", - } - return mapping.get(val_type, "string") diff --git a/libs/arcade-serve/arcade_serve/mcp/logging.py b/libs/arcade-serve/arcade_serve/mcp/logging.py deleted file mode 100644 index 667212bf..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/logging.py +++ /dev/null @@ -1,215 +0,0 @@ -import json -import logging -import sys -import time -from typing import Any - -from arcade_serve.mcp.types import ( - JSONRPCError, - JSONRPCRequest, - JSONRPCResponse, - MCPMessage, -) - -logger = logging.getLogger("arcade.mcp") - - -class MCPLoggingMiddleware: - """ - Middleware for logging MCP requests and responses. - Logs request and response details, including timing and errors. - """ - - def __init__( - self, - log_level: str = "INFO", - log_request_body: bool = False, - log_response_body: bool = False, - log_errors: bool = True, - min_duration_to_log_ms: int = 0, - stdio_mode: bool = False, - ) -> None: - """ - Initialize the MCP logging middleware. - - Args: - log_level: Logging level (default: "INFO"). - log_request_body: Whether to log full request bodies (default: False). - log_response_body: Whether to log full response bodies (default: False). - log_errors: Whether to log errors at ERROR level (default: True). - min_duration_to_log_ms: Minimum duration in ms to log (0 logs all). - stdio_mode: Whether running in stdio mode (redirects logs to stderr). - """ - self.log_level = getattr(logging, log_level.upper()) - self.log_request_body = log_request_body - self.log_response_body = log_response_body - self.log_errors = log_errors - self.min_duration_to_log_ms = min_duration_to_log_ms - self.request_log_format = "[MCP>] {method}{params_str} (id: {id})" - self.response_log_format = "[MCP<] {method} completed in {duration:.2f}ms (id: {id})" - self.error_log_format = "[MCP!] {method} error: {error} (id: {id})" - - # If in stdio mode, ensure MCP logs go to stderr - if stdio_mode: - self._redirect_logs_to_stderr() - - # Log that middleware is initialized - logger.debug(f"MCP logging middleware initialized (level: {log_level})") - - def _redirect_logs_to_stderr(self) -> None: - """Redirect MCP logs to stderr to avoid interfering with stdio communication.""" - # Remove any existing handlers - for handler in logger.handlers[:]: - logger.removeHandler(handler) - - # Add a stderr handler - stderr_handler = logging.StreamHandler(sys.stderr) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - stderr_handler.setFormatter(formatter) - logger.addHandler(stderr_handler) - - # Ensure we're not propagating to root logger which might log to stdout - logger.propagate = False - - logger.debug("MCP logs redirected to stderr for stdio mode") - - def __call__(self, message: MCPMessage, direction: str) -> MCPMessage: - """ - Process and log an MCP message. - - Args: - message: The MCP message to process. - direction: The message direction ("request" or "response"). - - Returns: - The original message (unmodified). - """ - if direction == "request": - self._log_request(message) - else: - self._log_response(message) - return message - - def _log_request(self, message: MCPMessage) -> None: - """ - Log an MCP request message. - """ - if not isinstance(message, JSONRPCRequest): - logger.debug(f"Ignoring non-request message: {type(message).__name__}") - return - - try: - # Store request start time for duration calculation - message._mcp_start_time = time.time() # type: ignore[attr-defined] - - # Format parameters for logging - params_str = "" - if self.log_request_body and hasattr(message, "params") and message.params is not None: - params_str = f": {self._format_params(message.params)}" - - log_msg = self.request_log_format.format( - method=message.method, params_str=params_str, id=getattr(message, "id", "none") - ) - - logger.log(self.log_level, log_msg) - except Exception: - logger.exception("Error logging request") - - def _log_response(self, message: MCPMessage) -> None: - """ - Log an MCP response message. - """ - if not isinstance(message, (JSONRPCResponse, JSONRPCError)): - logger.debug(f"Ignoring non-response message: {type(message).__name__}") - return - - try: - # Calculate request duration if we have the start time - duration_ms = 0 - request = getattr(message, "_request", None) - if request: - start_time = getattr(request, "_mcp_start_time", None) - if start_time: - duration_ms = (time.time() - start_time) * 1000 - else: - start_time = getattr(message, "_mcp_start_time", None) - if start_time: - duration_ms = (time.time() - start_time) * 1000 - - # Skip if below minimum duration threshold - if self.min_duration_to_log_ms > 0 and duration_ms < self.min_duration_to_log_ms: - return - - # Handle error responses - if hasattr(message, "error") and message.error is not None: - if self.log_errors: - error_msg = self.error_log_format.format( - method=getattr(message, "method", "unknown"), - error=getattr(message.error, "message", str(message.error)), - id=getattr(message, "id", "none"), - ) - logger.error(error_msg) - return - - # Log successful response - result_str = "" - if self.log_response_body and hasattr(message, "result"): - result_str = f": {self._format_result(message.result)}" - - log_msg = self.response_log_format.format( - method=getattr(message, "method", "unknown"), - duration=duration_ms, - id=getattr(message, "id", "none"), - result_str=result_str, - ) - - logger.log(self.log_level, log_msg) - except Exception: - logger.exception("Error logging response") - - def _format_params(self, params: Any) -> str: - """ - Format parameters for logging. - """ - try: - if isinstance(params, dict): - # Handle common MCP params specially - if "name" in params and "arguments" in params: - return f"{params['name']}({json.dumps(params.get('arguments', {}))})" - return json.dumps(params) - return str(params) - except Exception: - logger.debug(f"Error formatting params {params!s}") - return str(params) - - def _format_result(self, result: Any) -> str: - """ - Format result for logging. - """ - try: - if isinstance(result, dict): - return json.dumps(result) - return str(result) - except Exception as e: - logger.debug(f"Error formatting result {e!s}") - return str(result) - - -def create_mcp_logging_middleware(**config: Any) -> MCPLoggingMiddleware: - """ - Create an MCP logging middleware with the given configuration. - - Args: - **config: Configuration options. - - Returns: - An MCPLoggingMiddleware instance. - """ - return MCPLoggingMiddleware( - log_level=config.get("log_level", "INFO"), - log_request_body=config.get("log_request_body", False), - log_response_body=config.get("log_response_body", False), - log_errors=config.get("log_errors", True), - min_duration_to_log_ms=config.get("min_duration_to_log_ms", 0), - stdio_mode=config.get("stdio_mode", False), - ) diff --git a/libs/arcade-serve/arcade_serve/mcp/message_processor.py b/libs/arcade-serve/arcade_serve/mcp/message_processor.py deleted file mode 100644 index 96a1924c..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/message_processor.py +++ /dev/null @@ -1,83 +0,0 @@ -import inspect -import json -import logging -from typing import Any, Callable, TypeVar - -from arcade_serve.mcp.types import InitializeRequest, JSONRPCRequest, MCPMessage - -logger = logging.getLogger("arcade.mcp") - -T = TypeVar("T") - -# Type definition for middleware functions -MessageProcessor = Callable[[Any, str], Any] - - -class MCPMessageProcessor: - """ - Processes MCP messages through a chain of middleware. - Supports both synchronous and asynchronous middleware. - """ - - def __init__(self) -> None: - self.middleware: list[Callable[[MCPMessage, str], Any]] = [] - - def add_middleware(self, mw: Callable[[MCPMessage, str], Any]) -> None: - self.middleware.append(mw) - - async def process(self, message: Any, direction: str) -> Any: # noqa: C901 - # First, try to parse the message if it's a string - if isinstance(message, str): - # Strip any whitespace including newlines - message = message.strip() - if not message: - return None - - try: - parsed = json.loads(message) - if isinstance(parsed, dict): - method = parsed.get("method") - # Convert to appropriate message type - if method == "initialize" and "id" in parsed: - logger.debug(f"Parsed initialize request: {parsed}") - message = InitializeRequest(**parsed) - elif method and method.startswith("notifications/"): - # It's a notification, log it but pass through as dict - logger.debug(f"Received notification: {method}") - # Keep as parsed dict to avoid validation errors on unknown notifications - message = parsed - elif "method" in parsed and "id" in parsed: - # Regular method request - logger.debug(f"Parsed method request: {method}") - message = JSONRPCRequest(**parsed) - # Other message types can be handled similarly - except json.JSONDecodeError: - logger.warning(f"Failed to parse message as JSON: {message[:100]}...") - except Exception: - logger.exception("Error processing message") - - # Process through middleware chain - result = message - for mw in self.middleware: - try: - if inspect.iscoroutinefunction(mw): - result = await mw(result, direction) - else: - result = mw(result, direction) - except Exception: - logger.exception(f"Error in middleware {mw}") - return result - - async def process_request(self, message: Any) -> Any: - return await self.process(message, "request") - - async def process_response(self, message: Any) -> Any: - return await self.process(message, "response") - - -def create_message_processor(*middleware: MessageProcessor) -> MCPMessageProcessor: - processor = MCPMessageProcessor() - for m in middleware: - if m is not None: - processor.add_middleware(m) - return processor diff --git a/libs/arcade-serve/arcade_serve/mcp/server.py b/libs/arcade-serve/arcade_serve/mcp/server.py deleted file mode 100644 index 43fdf116..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/server.py +++ /dev/null @@ -1,601 +0,0 @@ -import asyncio -import logging -import os -import uuid -from enum import Enum -from typing import Any, Callable, Union - -from arcade_core.catalog import MaterializedTool, ToolCatalog -from arcade_core.executor import ToolExecutor -from arcade_core.schema import ToolAuthorizationContext, ToolContext -from arcadepy import ArcadeError, AsyncArcade -from arcadepy.types.auth_authorize_params import AuthRequirement, AuthRequirementOauth2 -from arcadepy.types.shared import AuthorizationResponse - -from arcade_serve.mcp.convert import convert_to_mcp_content, create_mcp_tool -from arcade_serve.mcp.logging import create_mcp_logging_middleware -from arcade_serve.mcp.message_processor import MCPMessageProcessor, create_message_processor -from arcade_serve.mcp.types import ( - CallToolRequest, - CallToolResponse, - CallToolResult, - CancelRequest, - Implementation, - InitializeRequest, - InitializeResponse, - InitializeResult, - JSONRPCError, - JSONRPCResponse, - ListPromptsRequest, - ListPromptsResponse, - ListResourcesRequest, - ListResourcesResponse, - ListToolsRequest, - ListToolsResponse, - ListToolsResult, - PingRequest, - PingResponse, - ProgressNotification, - ServerCapabilities, - ShutdownRequest, - ShutdownResponse, - Tool, -) - -logger = logging.getLogger("arcade.mcp") - -MCP_PROTOCOL_VERSION = "2024-11-05" - - -class MessageMethod(str, Enum): - """Enumeration of supported MCP message methods""" - - PING = "ping" - INITIALIZE = "initialize" - LIST_TOOLS = "tools/list" - CALL_TOOL = "tools/call" - PROGRESS = "progress" - CANCEL = "$/cancelRequest" - SHUTDOWN = "shutdown" - LIST_RESOURCES = "resources/list" - LIST_PROMPTS = "prompts/list" - - -class MCPServer: - """ - Unified async MCP server that manages connections, middleware, and tool invocation. - Handles protocol-level messages (ping, initialize, list_tools, call_tool, etc.). - """ - - def __init__( - self, - tool_catalog: Any, - enable_logging: bool = True, - **client_kwargs: dict[str, Any], - ) -> None: - """ - Initialize the MCP server. - - Args: - tool_catalog: Catalog of available tools - **client_kwargs: Additional arguments to pass to the AsyncArcade client - """ - self.tool_catalog: ToolCatalog = tool_catalog - self.message_processor: MCPMessageProcessor = create_message_processor() - - # Pop middleware_config from client_kwargs regardless of logging state, - # as it's internal config not meant for AsyncArcade. - middleware_config = client_kwargs.pop("middleware_config", {}) - - if enable_logging: - # Create and add the logging middleware if logging is enabled. - # Note: enable_logging must be True for this middleware (and its stdio_mode behavior) - # to be activated. - self.message_processor.add_middleware( - create_mcp_logging_middleware(**middleware_config) - ) - - self._shutdown: bool = False - # Initialize AsyncArcade with the *remaining* client_kwargs - self.arcade = AsyncArcade(**client_kwargs) # type: ignore[arg-type] - - # Initialize handler dispatch table - self._method_handlers: dict[str, Callable] = { - MessageMethod.PING: self._handle_ping, - MessageMethod.INITIALIZE: self._handle_initialize, - MessageMethod.LIST_TOOLS: self._handle_list_tools, - MessageMethod.CALL_TOOL: self._handle_call_tool, - MessageMethod.PROGRESS: self._handle_progress, - MessageMethod.CANCEL: self._handle_cancel, - MessageMethod.SHUTDOWN: self._handle_shutdown, - MessageMethod.LIST_RESOURCES: self._handle_list_resources, - MessageMethod.LIST_PROMPTS: self._handle_list_prompts, - } - - async def run_connection( - self, - read_stream: Any, - write_stream: Any, - init_options: Any, - ) -> None: - """ - Handle a single MCP connection (SSE or stdio). - - Args: - read_stream: Async iterable yielding incoming messages. - write_stream: Object with an async send(message) method. - init_options: Initialization options for the connection. - """ - # Generate a user ID if possible - user_id = self._get_user_id(init_options) - - try: - logger.info(f"Starting MCP connection for user {user_id}") - - async for message in read_stream: - # Process the message - response = await self.handle_message(message, user_id=user_id) - - # Skip sending responses for None (e.g., notifications) - if response is None: - continue - - await self._send_response(write_stream, response) - - except asyncio.CancelledError: - logger.info("Connection cancelled") - except Exception: - logger.exception("Error in connection") - - def _get_user_id(self, init_options: Any) -> str: - """ - Get the user ID for a connection. - - Args: - init_options: Initialization options for the connection - - Returns: - A user ID string - """ - try: - from arcade_core.config import config - - # Prefer config.user.email if available - if config.user and config.user.email: - return config.user.email - except ValueError: - logger.debug("No logged in user for MCP Server") - - fallback = str(uuid.uuid4()) - if os.environ.get("ARCADE_USER_ID", None): - return os.environ.get("ARCADE_USER_ID", fallback) - elif isinstance(init_options, dict): - user_id = init_options.get("user_id") - if user_id: - return str(user_id) - # Fallback to random UUID - return str(fallback) - - async def _send_response(self, write_stream: Any, response: Any) -> None: - """ - Send a response to the client. - - Args: - write_stream: Stream to write the response to - response: Response object to send - """ - # Ensure the response is properly serialized to JSON - if hasattr(response, "model_dump_json"): - # It's a Pydantic model, serialize it - json_response = response.model_dump_json() - # Ensure it ends with a newline for JSON-RPC-over-stdio - if not json_response.endswith("\n"): - json_response += "\n" - logger.debug(f"Sending response: {json_response[:200]}...") - await write_stream.send(json_response) - elif isinstance(response, dict): - # It's a dict, convert to JSON - import json - - json_response = json.dumps(response) - # Ensure it ends with a newline for JSON-RPC-over-stdio - if not json_response.endswith("\n"): - json_response += "\n" - logger.debug(f"Sending response: {json_response[:200]}...") - await write_stream.send(json_response) - else: - # It's already a string or something else - response_str = str(response) - # Ensure it ends with a newline for JSON-RPC-over-stdio - if not response_str.endswith("\n"): - response_str += "\n" - logger.debug(f"Sending raw response type: {type(response)}") - await write_stream.send(response_str) - - async def handle_message(self, message: Any, user_id: str | None = None) -> Any: - """ - Handle an incoming MCP message. Processes it through middleware and dispatches - to the appropriate handler based on the message method. - - Args: - message: The raw incoming message - user_id: Optional user ID for authentication - - Returns: - A properly formatted response message - """ - # Pre-process message through middleware - processed = await self.message_processor.process_request(message) - - # Handle special case for JSON string initialize requests - if isinstance(processed, str): - try: - import json - - parsed = json.loads(processed) - if ( - isinstance(parsed, dict) - and parsed.get("method") == MessageMethod.INITIALIZE - and "id" in parsed - ): - # This is an initialize request - init_response = await self._handle_initialize(InitializeRequest(**parsed)) - return init_response - except Exception: - logger.exception("Error processing JSON string") - # Not parseable JSON, continue with normal processing - pass - - # Check if it's a notification - if hasattr(processed, "method"): - method = getattr(processed, "method", None) - - # Handle notifications (methods starting with "notifications/") - if method and method.startswith("notifications/"): - await self._handle_notification(method, processed) - return None - - # Handle regular methods using the dispatch table - if method in self._method_handlers: - # If it's a call_tool request, we need to pass the user_id - if method == MessageMethod.CALL_TOOL: - return await self._method_handlers[method](processed, user_id=user_id) - # For other methods, just pass the processed message - return await self._method_handlers[method](processed) - - # Unknown method - return JSONRPCError( - id=getattr(processed, "id", None), - error={ - "code": -32601, - "message": f"Method not found: {method}", - }, - ) - - # If it's not a method request, just pass it through - return processed - - async def _handle_notification(self, method: str, message: Any) -> None: - """ - Handle notification messages. - - Args: - method: The notification method - message: The notification message - """ - if method == "notifications/cancelled": - logger.info(f"Request cancelled: {getattr(message, 'params', {})}") - else: - logger.debug(f"Received notification: {method}") - - async def _handle_ping(self, message: PingRequest) -> PingResponse: - """ - Handle a ping request and return a pong response. - - Args: - message: The ping request - - Returns: - A properly formatted pong response - """ - return PingResponse(id=message.id) - - async def _handle_initialize(self, message: InitializeRequest) -> InitializeResponse: - """ - Handle an initialize request and return a proper initialize response. - - Args: - message: The initialize request - - Returns: - A properly formatted initialize response - """ - # Create the result data - result = InitializeResult( - protocolVersion=MCP_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - serverInfo=Implementation(name="Arcade MCP Worker", version="0.1.0"), - instructions="Arcade MCP Worker initialized.", - ) - - # Construct proper response with result field - response = InitializeResponse(id=message.id, result=result) - - logger.debug(f"Initialize response: {response.model_dump_json()}") - return response - - async def _handle_list_tools( - self, message: ListToolsRequest - ) -> Union[ListToolsResponse, JSONRPCError]: - """ - Handle a tools/list request and return a list of available tools. - - Args: - message: The tools/list request - - Returns: - A properly formatted tools/list response or error - """ - try: - # Get all tools from the catalog - tools = [] - tool_conversion_errors = [] - - for tool in self.tool_catalog: - try: - mcp_tool = create_mcp_tool(tool) - if mcp_tool: - tools.append(mcp_tool) - except Exception: - tool_name = getattr(tool, "name", str(tool)) - logger.exception(f"Error converting tool: {tool_name}") - tool_conversion_errors.append(tool_name) - - # Log summary if we had errors - if tool_conversion_errors: - logger.warning( - f"Failed to convert {len(tool_conversion_errors)} tools: {tool_conversion_errors}" - ) - - # Create tool objects with exception handling for each one - tool_objects = [] - for t in tools: - try: - # Make input schema optional if missing - tool_dict = dict(t) - if "inputSchema" not in tool_dict: - tool_dict["inputSchema"] = {"type": "object", "properties": {}} - - tool_objects.append(Tool(**tool_dict)) - except Exception: - logger.exception(f"Error creating Tool object for {t.get('name', 'unknown')}") - - # Return successful response with the tools we were able to convert - result = ListToolsResult(tools=tool_objects) - response = ListToolsResponse(id=message.id, result=result) - - except Exception: - logger.exception("Error listing tools") - return JSONRPCError( - id=message.id, - error={ - "code": -32603, - "message": "Internal error listing tools", - }, - ) - return response - - async def _handle_call_tool( - self, message: CallToolRequest, user_id: str | None = None - ) -> CallToolResponse: - """ - Handle a tools/call request to execute a tool. - - Args: - message: The tools/call request - user_id: Optional user ID for authentication - - Returns: - A properly formatted tools/call response - """ - tool_name: str = message.params["name"] - # Extract input from the correct field - input_params: dict[str, Any] = message.params.get("input", {}) - if not input_params: - input_params = message.params.get("arguments", {}) - - logger.info(f"Handling tool call for {tool_name}") - - try: - tool = self.tool_catalog.get_tool_by_name(tool_name, separator="_") - tool_context = ToolContext() - - # Set up context with secrets - if tool.definition.requirements and tool.definition.requirements.secrets: - self._setup_tool_secrets(tool, tool_context) - - # Handle authorization if needed - requirement = self._get_auth_requirement(tool) - if requirement: - auth_result = await self._check_authorization(requirement, user_id=user_id) - if auth_result.status != "completed": - return CallToolResponse( - id=message.id, - result=CallToolResult(content=[{"type": "text", "text": auth_result.url}]), - ) - else: - tool_context.authorization = ToolAuthorizationContext( - token=auth_result.context.token if auth_result.context else None, - user_info={"user_id": user_id} if user_id else {}, - ) - - # Execute the tool - logger.debug(f"Executing tool {tool_name} with input: {input_params}") - result = await ToolExecutor.run( - func=tool.tool, - definition=tool.definition, - input_model=tool.input_model, - output_model=tool.output_model, - context=tool_context, - **input_params, - ) - logger.debug(f"Tool result: {result}") - if result.value: - return CallToolResponse( - id=message.id, - result=CallToolResult(content=convert_to_mcp_content(result.value)), - ) - else: - error = result.error or "Error calling tool" - logger.error(f"Tool {tool_name} returned error: {error}") - return CallToolResponse( - id=message.id, - result=CallToolResult( - content=[{"type": "text", "text": convert_to_mcp_content(error)}] - ), - ) - except Exception as e: - logger.exception(f"Error calling tool {tool_name}") - error = f"Error calling tool {tool_name}: {e!s}" - return CallToolResponse( - id=message.id, - result=CallToolResult( - content=[{"type": "text", "text": convert_to_mcp_content(error)}] - ), - ) - - def _setup_tool_secrets(self, tool: Any, tool_context: ToolContext) -> None: - """ - Set up tool secrets in the tool context. - - Args: - tool: The tool to set up secrets for - tool_context: The tool context to update - """ - for secret in tool.definition.requirements.secrets: - value = os.environ.get(secret.key) - if value is not None: - tool_context.set_secret(secret.key, value) - - async def _handle_progress(self, message: ProgressNotification) -> JSONRPCResponse: - """ - Handle a progress notification. - - Args: - message: The progress notification - - Returns: - A response acknowledging the notification - """ - return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True}) - - async def _handle_cancel(self, message: CancelRequest) -> JSONRPCResponse: - """ - Handle a cancel request. - - Args: - message: The cancel request - - Returns: - A response acknowledging the cancellation - """ - return JSONRPCResponse(id=getattr(message, "id", None), result={"ok": True}) - - async def _handle_shutdown(self, message: ShutdownRequest) -> ShutdownResponse: - """ - Handle a shutdown request. - - Args: - message: The shutdown request - - Returns: - A response acknowledging the shutdown request - """ - # Schedule a task to shutdown the server after sending the response - proc = asyncio.create_task(self.shutdown()) - proc.add_done_callback(lambda _: logger.info("MCP server shutdown complete")) - return ShutdownResponse(id=message.id, result={"ok": True}) - - async def _handle_list_resources(self, message: ListResourcesRequest) -> ListResourcesResponse: - """ - Handle a resources/list request. - - Args: - message: The resources/list request - - Returns: - A properly formatted resources/list response - """ - return ListResourcesResponse(id=message.id, result={"resources": []}) - - async def _handle_list_prompts(self, message: ListPromptsRequest) -> ListPromptsResponse: - """ - Handle a prompts/list request. - - Args: - message: The prompts/list request - - Returns: - A properly formatted prompts/list response - """ - return ListPromptsResponse(id=message.id, result={"prompts": []}) - - def _get_auth_requirement(self, tool: MaterializedTool) -> AuthRequirement | None: - """ - Get the authentication requirement for a tool. - - Args: - tool: The tool to get the requirement for - - Returns: - An authentication requirement or None if not required - """ - req = tool.definition.requirements.authorization - if not req: - return None - if not req.provider_id and not req.provider_type: - return None - if hasattr(req, "oauth2") and req.oauth2: - return AuthRequirement( - provider_id=str(req.provider_id), - provider_type=str(req.provider_type), - oauth2=AuthRequirementOauth2(scopes=req.oauth2.scopes or []), - ) - return AuthRequirement( - provider_id=str(req.provider_id), - provider_type=str(req.provider_type), - ) - - async def _check_authorization( - self, auth_requirement: AuthRequirement, user_id: str | None = None - ) -> AuthorizationResponse: - """ - Check if a tool is authorized for a user. - - Args: - tool: The tool to check authorization for - user_id: The user ID to check authorization for - - Returns: - An authorization response - - Raises: - RuntimeError: If the tool has no authorization requirement - Exception: If authorization fails - """ - try: - response = await self.arcade.auth.authorize( - auth_requirement=auth_requirement, - user_id=user_id or "anonymous", - ) - logger.debug(f"Authorization response: {response}") - - except ArcadeError: - logger.exception("Error authorizing tool") - raise - return response - - async def shutdown(self) -> None: - """Shutdown the server.""" - self._shutdown = True - logger.info("MCP server shutdown complete") diff --git a/libs/arcade-serve/arcade_serve/mcp/stdio.py b/libs/arcade-serve/arcade_serve/mcp/stdio.py deleted file mode 100644 index 98357e39..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/stdio.py +++ /dev/null @@ -1,185 +0,0 @@ -import asyncio -import logging -import queue -import signal -import sys -import threading -from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, TypeVar - -if TYPE_CHECKING: - pass - -from arcade_serve.mcp.server import MCPServer - -logger = logging.getLogger("arcade.mcp") - -T = TypeVar("T") - - -def stdio_reader(stdin: object, q: queue.Queue[str | None]) -> None: - """Read lines from stdin and put them into a queue.""" - for line in stdin: # type: ignore[attr-defined] - q.put(line) - q.put(None) - - -def stdio_writer(stdout: object, q: queue.Queue[str | None]) -> None: - """Write messages from a queue to stdout.""" - try: - while True: - msg = q.get() - if msg is None: - break - - # Ensure message ends with a newline for proper JSON-RPC-over-stdio - if not msg.endswith("\n"): - msg += "\n" - - stdout.write(msg) # type: ignore[attr-defined] - stdout.flush() # type: ignore[attr-defined] - except Exception: - logger.exception("Error in stdio writer") - - -class StdioServer(MCPServer): - """ - Stdio server that handles signals and cleanup. - """ - - def __init__( - self, - tool_catalog: Any, - enable_logging: bool = True, - **client_kwargs: dict[str, Any], - ): - # Set up stdio-specific middleware configuration - middleware_config = client_kwargs.get("middleware_config", {}) - middleware_config["stdio_mode"] = True - client_kwargs["middleware_config"] = middleware_config - - super().__init__(tool_catalog, enable_logging, **client_kwargs) - self.read_q: queue.Queue[str | None] = queue.Queue() - self.write_q: queue.Queue[str | None] = queue.Queue() - self.reader_thread: threading.Thread | None = None - self.writer_thread: threading.Thread | None = None - self.running = False - self.shutdown_event = asyncio.Event() - - def start_io_threads(self) -> None: - """Start stdio reader and writer threads.""" - self.reader_thread = threading.Thread( - target=self._stdio_reader, args=(sys.stdin, self.read_q), daemon=True - ) - self.writer_thread = threading.Thread( - target=self._stdio_writer, args=(sys.stdout, self.write_q), daemon=True - ) - self.reader_thread.start() - self.writer_thread.start() - - def _stdio_reader(self, stdin: object, q: queue.Queue[str | None]) -> None: - """Read lines from stdin and put them into a queue.""" - try: - for line in stdin: # type: ignore[attr-defined] - if not self.running: - break - q.put(line) - except Exception: - logger.exception("Error in stdio reader") - finally: - q.put(None) # Signal EOF - - def _stdio_writer(self, stdout: object, q: queue.Queue[str | None]) -> None: - """Write messages from a queue to stdout.""" - try: - while self.running: - msg = q.get() - if msg is None: - break - stdout.write(msg) # type: ignore[attr-defined] - stdout.flush() # type: ignore[attr-defined] - except Exception: - logger.exception("Error in stdio writer") - - async def _read_stream(self) -> AsyncGenerator[str, None]: - """Async generator that yields lines from the read queue.""" - while self.running: - try: - line = await asyncio.to_thread(self.read_q.get) - if line is None: - break - yield line - except asyncio.CancelledError: - break - except Exception: - logger.exception("Error reading from stdin") - break - - async def shutdown(self) -> None: - """Gracefully shut down the server.""" - if not self.running: - return - - logger.info("Shutting down stdio server...") - self.running = False - - # Signal shutdown to MCP server - await self.shutdown() - - # Clean up IO queues and threads - try: - if self.read_q: - self.read_q.put(None) - if self.write_q: - self.write_q.put(None) - except Exception: - logger.exception("Error during shutdown") - - # Signal completion - self.shutdown_event.set() - logger.info("Stdio server shutdown complete") - - async def run(self) -> None: - """Run the stdio server with signal handling.""" - self.running = True - - # Set up signal handlers - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - try: - loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown())) - except NotImplementedError: - # Windows doesn't support POSIX signals - if sys.platform == "win32": - logger.warning("Signal handling not fully supported on Windows") - else: - logger.warning(f"Failed to set up signal handler for {sig}") - - # Start IO threads - self.start_io_threads() - - logger.info("Starting MCP server with stdio transport") - - # Create WriteStream class for MCP server - class WriteStream: - async def send(self_, message: str) -> None: - if self.running: - await asyncio.to_thread(self.write_q.put, message) - - try: - # Run MCP server connection - await self.run_connection(self._read_stream(), WriteStream(), None) - except asyncio.CancelledError: - # Handle cancellation - logger.info("Server operation cancelled") - except KeyboardInterrupt: - # Handle keyboard interrupt - logger.info("Keyboard interrupt received") - except Exception: - # Handle unexpected errors - logger.exception("Unexpected error") - finally: - # Ensure we clean up - await self.shutdown() - # Wait for shutdown to complete - await self.shutdown_event.wait() diff --git a/libs/arcade-serve/arcade_serve/mcp/types.py b/libs/arcade-serve/arcade_serve/mcp/types.py deleted file mode 100644 index 4354241b..00000000 --- a/libs/arcade-serve/arcade_serve/mcp/types.py +++ /dev/null @@ -1,383 +0,0 @@ -import json -from collections.abc import Callable -from typing import ( - Any, - Generic, - Literal, - TypeAlias, - TypeVar, - Union, -) - -from pydantic import BaseModel, ConfigDict, Field - -ProgressToken = str | int -Cursor = str -Role = Literal["user", "assistant"] -RequestId = str | int -AnyFunction: TypeAlias = Callable[..., Any] - - -class RequestParams(BaseModel): - class Meta(BaseModel): - progressToken: ProgressToken | None = None - model_config = ConfigDict(extra="allow") - - meta: Meta | None = Field(alias="_meta", default=None) - - model_config = ConfigDict(extra="allow") - - -class NotificationParams(BaseModel): - class Meta(BaseModel): - model_config = ConfigDict(extra="allow") - - meta: Meta | None = Field(alias="_meta", default=None) - model_config = ConfigDict(extra="allow") - - -RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) -NotificationParamsT = TypeVar( - "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None -) -MethodT = TypeVar("MethodT", bound=str) - - -class Request(BaseModel, Generic[RequestParamsT, MethodT]): - method: MethodT - params: RequestParamsT - model_config = ConfigDict(extra="allow") - - -class PaginatedRequest(Request[RequestParamsT, MethodT]): - cursor: Cursor | None = None - model_config = ConfigDict(extra="allow") - - -class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): - method: MethodT - params: NotificationParamsT - model_config = ConfigDict(extra="allow") - - -class Result(BaseModel): - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - model_config = ConfigDict(extra="allow") - - -class PaginatedResult(Result): - nextCursor: Cursor | None = None - model_config = ConfigDict(extra="allow") - - -class JSONRPCMessage(BaseModel): - """Base class for all JSON-RPC messages.""" - - model_config = ConfigDict(extra="allow") - jsonrpc: str = Field(default="2.0", frozen=True) - - -class JSONRPCRequest(JSONRPCMessage): - """A JSON-RPC request message.""" - - id: str | int | None = None - method: str - params: dict[str, Any] | None = None - - -class JSONRPCResponse(JSONRPCMessage): - """A JSON-RPC response message.""" - - id: str | int | None - result: Any = None - error: dict[str, Any] | None = None - - def model_dump_json(self, **kwargs: Any) -> str: - """Convert to JSON string with proper formatting.""" - - # Convert to dict - data = { - "jsonrpc": self.jsonrpc, - "id": self.id, - } - - # Add result if present - if self.result is not None: - # Check if result is a Pydantic model - if hasattr(self.result, "model_dump"): - data["result"] = self.result.model_dump(exclude_none=True) - # Check if result is already a dict/list/primitive - elif ( - isinstance(self.result, (dict, list, str, int, float, bool)) or self.result is None - ): - data["result"] = self.result # type: ignore[assignment] - else: - # Try to convert using str() as a fallback - data["result"] = str(self.result) - - # Add error if present - if self.error is not None: - data["error"] = self.error # type: ignore[assignment] - - return json.dumps(data, ensure_ascii=False) - - -class JSONRPCError(JSONRPCMessage): - """A JSON-RPC error message.""" - - id: str | int | None - error: dict[str, Any] - - -PARSE_ERROR = -32700 -INVALID_REQUEST = -32600 -METHOD_NOT_FOUND = -32601 -INVALID_PARAMS = -32602 -INTERNAL_ERROR = -32603 - - -class ErrorData(BaseModel): - code: int - message: str - data: Any | None = None - model_config = ConfigDict(extra="allow") - - -JSONRPCMessageBaseModel = BaseModel | JSONRPCRequest | JSONRPCResponse | JSONRPCError - - -class EmptyResult(Result): - pass - - -class Implementation(BaseModel): - """Describes the server or client implementation.""" - - name: str - version: str - model_config = ConfigDict(extra="allow") - - -class RootsCapability(BaseModel): - listChanged: bool | None = None - model_config = ConfigDict(extra="allow") - - -class SamplingCapability(BaseModel): - model_config = ConfigDict(extra="allow") - - -class ClientCapabilities(BaseModel): - experimental: dict[str, dict[str, Any]] | None = None - sampling: SamplingCapability | None = None - roots: RootsCapability | None = None - model_config = ConfigDict(extra="allow") - - -class PromptsCapability(BaseModel): - listChanged: bool | None = None - model_config = ConfigDict(extra="allow") - - -class ResourcesCapability(BaseModel): - subscribe: bool | None = None - listChanged: bool | None = None - model_config = ConfigDict(extra="allow") - - -class ToolsCapability(BaseModel): - listChanged: bool | None = None - model_config = ConfigDict(extra="allow") - - -class LoggingCapability(BaseModel): - model_config = ConfigDict(extra="allow") - - -class ServerCapabilities(BaseModel): - """Describes the server's capabilities.""" - - model_config = ConfigDict(extra="allow") - tools: dict[str, Any] | None = None - resources: dict[str, Any] | None = None - prompts: dict[str, Any] | None = None - - -class InitializeRequestParams(RequestParams): - protocolVersion: str | int - capabilities: ClientCapabilities - clientInfo: Implementation - model_config = ConfigDict(extra="allow") - - -class InitializeRequest(JSONRPCRequest): - method: str = Field(default="initialize", frozen=True) - params: dict[str, Any] | None = None - - -class InitializeResult(BaseModel): - protocolVersion: str - capabilities: ServerCapabilities - serverInfo: Implementation - instructions: str | None = None - - -class InitializedNotification( - Notification[NotificationParams | None, Literal["notifications/initialized"]] -): - method: Literal["notifications/initialized"] - params: NotificationParams | None = None - model_config = ConfigDict(extra="allow") - - -class PingRequest(JSONRPCRequest): - method: str = Field(default="ping", frozen=True) - params: dict[str, Any] | None = None - - -class ProgressNotificationParams(NotificationParams): - progressToken: ProgressToken - progress: float - total: float | None = None - model_config = ConfigDict(extra="allow") - - -class ProgressNotification(JSONRPCMessage): - method: str = Field(default="progress", frozen=True) - params: dict[str, Any] - - -class PingResponse(JSONRPCResponse): - result: dict[str, Any] = Field(default_factory=lambda: {"pong": True}) - - -class ShutdownRequest(JSONRPCRequest): - method: str = Field(default="shutdown", frozen=True) - params: dict[str, Any] | None = None - - -class ShutdownResponse(JSONRPCResponse): - result: dict[str, Any] = Field(default_factory=lambda: {"ok": True}) - - -class CancelRequest(JSONRPCRequest): - method: str = Field(default="$/cancelRequest", frozen=True) - params: dict[str, Any] - - -class InitializeResponse(JSONRPCResponse): - """ - Response to an initialize request. - - Note: This must be a properly formatted JSON-RPC response with a `result` field - containing the initialization data, not another request. - """ - - result: InitializeResult - - def model_dump_json(self, **kwargs: Any) -> str: - """Convert to JSON string with proper formatting.""" - # Convert to dict - data = { - "jsonrpc": self.jsonrpc, - "id": self.id, - "result": self.result.model_dump(exclude_none=True), - } - - # Return JSON string - return json.dumps(data, ensure_ascii=False) - - -class ListToolsRequest(JSONRPCRequest): - method: str = Field(default="tools/list", frozen=True) - params: dict[str, Any] | None = None - - -class ToolAnnotations(BaseModel): - """ - Represents tool annotations for hints about behavior. - """ - - title: str | None = None - readOnlyHint: bool | None = None - destructiveHint: bool | None = None - idempotentHint: bool | None = None - openWorldHint: bool | None = None - model_config = ConfigDict(extra="allow") - - -class Tool(BaseModel): - """ - Represents an MCP tool definition. - """ - - name: str - description: str - inputSchema: dict[str, Any] | None = None - annotations: ToolAnnotations | None = None - - model_config = ConfigDict(extra="allow") - - -class ListToolsResult(BaseModel): - tools: list[Tool] - - -class ListToolsResponse(JSONRPCResponse): - result: ListToolsResult - - -class CallToolRequest(JSONRPCRequest): - method: str = Field(default="tools/call", frozen=True) - params: dict[str, Any] - - -class CallToolResult(BaseModel): - content: Any - - -class CallToolResponse(JSONRPCResponse): - result: CallToolResult - - -# Resource and Prompt protocol stubs (expand as needed) -class ListResourcesRequest(JSONRPCRequest): - method: str = Field(default="resources/list", frozen=True) - params: dict[str, Any] | None = None - - -class ListResourcesResponse(JSONRPCResponse): - result: dict[str, Any] - - -class ListPromptsRequest(JSONRPCRequest): - method: str = Field(default="prompts/list", frozen=True) - params: dict[str, Any] | None = None - - -class ListPromptsResponse(JSONRPCResponse): - result: dict[str, Any] - - -# Utility type alias for all MCP protocol messages -MCPMessage = Union[ - JSONRPCRequest, - JSONRPCResponse, - JSONRPCError, - PingRequest, - PingResponse, - InitializeRequest, - InitializeResponse, - ListToolsRequest, - ListToolsResponse, - CallToolRequest, - CallToolResponse, - ProgressNotification, - CancelRequest, - ShutdownRequest, - ShutdownResponse, - ListResourcesRequest, - ListResourcesResponse, - ListPromptsRequest, - ListPromptsResponse, -] diff --git a/libs/arcade-serve/pyproject.toml b/libs/arcade-serve/pyproject.toml index 40d0c0f2..da9eb22e 100644 --- a/libs/arcade-serve/pyproject.toml +++ b/libs/arcade-serve/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-serve" -version = "2.1.0" +version = "2.2.0rc1" description = "Arcade Serve - Serving infrastructure for Arcade tools and workers" readme = "README.md" license = {text = "MIT"} @@ -19,10 +19,14 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "arcade-core>=2.4.0,<3.0.0", + "arcade-core>=2.5.0rc1,<3.0.0", "fastapi>=0.115.3", "uvicorn>=0.30.0", "watchfiles>=1.0.5", + "sse-starlette>=2.0.0", + "opentelemetry-instrumentation-fastapi==0.49b2", + "opentelemetry-exporter-otlp-proto-http==1.28.2", + "opentelemetry-exporter-otlp-proto-common==1.28.2", ] [project.optional-dependencies] diff --git a/libs/arcade-tdk/pyproject.toml b/libs/arcade-tdk/pyproject.toml index 3dca3d10..e149be2f 100644 --- a/libs/arcade-tdk/pyproject.toml +++ b/libs/arcade-tdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-tdk" -version = "2.5.0" +version = "2.6.0rc1" description = "Arcade TDK - Toolkit Development Kit for building Arcade tools" readme = "README.md" license = {text = "MIT"} @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "arcade-core>=2.4.0,<3.0.0", + "arcade-core>=2.5.0rc1,<3.0.0", "pydantic>=2.7.0", ] diff --git a/libs/tests/arcade_mcp_server/__init__.py b/libs/tests/arcade_mcp_server/__init__.py new file mode 100644 index 00000000..6c30879c --- /dev/null +++ b/libs/tests/arcade_mcp_server/__init__.py @@ -0,0 +1 @@ +"""Tests for arcade-mcp-server package.""" diff --git a/libs/tests/arcade_mcp_server/conftest.py b/libs/tests/arcade_mcp_server/conftest.py new file mode 100644 index 00000000..056077b3 --- /dev/null +++ b/libs/tests/arcade_mcp_server/conftest.py @@ -0,0 +1,299 @@ +"""Shared fixtures and utilities for arcade-mcp-server tests.""" + +import asyncio +from collections.abc import AsyncGenerator +from typing import Annotated, Any +from unittest.mock import AsyncMock, Mock + +import pytest +import pytest_asyncio +from arcade_core.catalog import MaterializedTool, ToolCatalog, ToolMeta, create_func_models +from arcade_core.schema import ( + InputParameter, + OAuth2Requirement, + ToolAuthRequirement, + ToolDefinition, + ToolInput, + ToolkitDefinition, + ToolOutput, + ToolRequirements, + ValueSchema, +) +from arcade_mcp_server import tool +from arcade_mcp_server.context import Context +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.session import ServerSession +from arcade_mcp_server.settings import MCPSettings +from arcade_tdk.auth import OAuth2 + + +@pytest.fixture +def event_loop(): + """Create an instance of the default event loop for the test session.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def sample_tool_def() -> ToolDefinition: + """Create a sample tool definition.""" + return ToolDefinition( + name="test_tool", + fully_qualified_name="TestToolkit.test_tool", + description="A test tool", + toolkit=ToolkitDefinition(name="TestToolkit", description="Test toolkit", version="1.0.0"), + input=ToolInput( + parameters=[ + InputParameter( + name="text", + required=True, + description="Input text", + value_schema=ValueSchema(val_type="string"), + ) + ] + ), + output=ToolOutput(description="Tool output", value_schema=ValueSchema(val_type="string")), + requirements=ToolRequirements(), + ) + + +@pytest.fixture +def sample_tool_def_with_auth() -> ToolDefinition: + """Create a sample tool definition.""" + return ToolDefinition( + name="sample_tool_with_auth", + fully_qualified_name="TestToolkit.sample_tool_with_auth", + description="A test tool", + toolkit=ToolkitDefinition(name="TestToolkit", description="Test toolkit", version="1.0.0"), + input=ToolInput( + parameters=[ + InputParameter( + name="text", + required=True, + description="Input text", + value_schema=ValueSchema(val_type="string"), + ) + ] + ), + output=ToolOutput(description="Tool output", value_schema=ValueSchema(val_type="string")), + requirements=ToolRequirements( + authorization=ToolAuthRequirement( + provider_type="oauth2", + provider_id="test-provider", + id="test-provider", + oauth2=OAuth2Requirement( + scopes=["test.scope", "another.scope"], + ), + ), + ), + ) + + +@pytest.fixture +def sample_tool_func(): + """Create a sample tool function.""" + + @tool + def sample_tool( + text: Annotated[str, "Input text to echo"], + ) -> Annotated[str, "Echoed text result"]: + """Echo input text back to the caller.""" + return f"Echo: {text}" + + return sample_tool + + +@pytest.fixture +def sample_tool_func_with_auth(): + """Create a sample tool function.""" + + @tool( + requires_auth=OAuth2( + id="test-provider", + scopes=["test.scope", "another.scope"], + ), + ) + def sample_tool_with_auth( + text: Annotated[str, "Input text to echo"], + ) -> Annotated[str, "Echoed text result"]: + """Echo input text back to the caller.""" + return f"Echo: {text}" + + return sample_tool_with_auth + + +@pytest.fixture +def materialized_tool(sample_tool_func, sample_tool_def) -> MaterializedTool: + """Create a materialized tool with required models and metadata.""" + input_model, output_model = create_func_models(sample_tool_func) + meta = ToolMeta(module=sample_tool_func.__module__, toolkit=sample_tool_def.toolkit.name) + return MaterializedTool( + tool=sample_tool_func, + definition=sample_tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + +@pytest.fixture +def materialized_tool_with_auth( + sample_tool_func_with_auth, sample_tool_def_with_auth +) -> MaterializedTool: + """Create a materialized tool with required models and metadata.""" + input_model, output_model = create_func_models(sample_tool_func_with_auth) + meta = ToolMeta( + module=sample_tool_func_with_auth.__module__, toolkit=sample_tool_def_with_auth.toolkit.name + ) + return MaterializedTool( + tool=sample_tool_func_with_auth, + definition=sample_tool_def_with_auth, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + +@pytest.fixture +def tool_catalog( + materialized_tool: MaterializedTool, materialized_tool_with_auth: MaterializedTool +) -> ToolCatalog: + """Create a tool catalog with sample tools.""" + catalog = ToolCatalog() + catalog._tools[materialized_tool.definition.get_fully_qualified_name()] = materialized_tool + catalog._tools[materialized_tool_with_auth.definition.get_fully_qualified_name()] = ( + materialized_tool_with_auth + ) + return catalog + + +@pytest.fixture +def mcp_settings() -> MCPSettings: + """Create test MCP settings.""" + settings = MCPSettings() + settings.debug = True + settings.middleware.enable_logging = True + settings.middleware.mask_error_details = False + return settings + + +@pytest_asyncio.fixture +async def mcp_server(tool_catalog, mcp_settings) -> AsyncGenerator[MCPServer, None]: + """Create and start an MCP server.""" + server = MCPServer( + catalog=tool_catalog, + name="Test Server", + version="1.0.0", + settings=mcp_settings, + ) + await server.start() + yield server + await server.stop() + + +@pytest.fixture +def mock_read_stream() -> AsyncMock: + """Create a mock read stream.""" + stream = AsyncMock() + stream.read = AsyncMock() + return stream + + +@pytest.fixture +def mock_write_stream() -> AsyncMock: + """Create a mock write stream.""" + stream = AsyncMock() + stream.write = AsyncMock() + stream.send = AsyncMock() + return stream + + +@pytest_asyncio.fixture +async def server_session(mcp_server, mock_read_stream, mock_write_stream) -> ServerSession: + """Create a server session.""" + session = ServerSession( + server=mcp_server, + read_stream=mock_read_stream, + write_stream=mock_write_stream, + ) + return session + + +@pytest_asyncio.fixture +async def initialized_server_session(server_session) -> ServerSession: + """Create an initialized server session.""" + server_session.mark_initialized() + return server_session + + +@pytest.fixture +def sample_messages() -> dict[str, Any]: + """Sample MCP protocol messages for testing.""" + return { + "initialize": { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}, "sampling": {}}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + }, + "initialized": {"jsonrpc": "2.0", "method": "notifications/initialized"}, + "list_tools": {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}, + "call_tool": { + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": {"name": "TestToolkit.test_tool", "arguments": {"text": "Hello, world!"}}, + }, + "ping": {"jsonrpc": "2.0", "id": 4, "method": "ping"}, + } + + +@pytest.fixture +def mock_context() -> Context: + """Create a mock context.""" + context = Mock(spec=Context) + context.server = Mock() + context.request_id = "test-request-123" + context.session_id = "test-session-456" + context.state = {} + + # Mock async methods + context.log = AsyncMock() + context.debug = AsyncMock() + context.info = AsyncMock() + context.warning = AsyncMock() + context.error = AsyncMock() + context.report_progress = AsyncMock() + context.read_resource = AsyncMock(return_value=[]) + context.list_roots = AsyncMock(return_value=[]) + context.sample = AsyncMock() + context.elicit = AsyncMock() + context.send_tool_list_changed = AsyncMock() + context.send_resource_list_changed = AsyncMock() + context.send_prompt_list_changed = AsyncMock() + + return context + + +# Async test helpers +async def wait_for(condition, timeout=1.0): + """Wait for a condition to become true.""" + start = asyncio.get_event_loop().time() + while not condition(): + if asyncio.get_event_loop().time() - start > timeout: + raise TimeoutError("Condition not met within timeout") + await asyncio.sleep(0.01) + + +async def collect_messages(stream, count): + """Collect a specific number of messages from a stream.""" + messages = [] + for _ in range(count): + msg = await stream.read() + messages.append(msg) + return messages diff --git a/libs/tests/arcade_mcp_server/test_context.py b/libs/tests/arcade_mcp_server/test_context.py new file mode 100644 index 00000000..a2a5557e --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_context.py @@ -0,0 +1,383 @@ +"""Tests for MCP Context implementation.""" + +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest +from arcade_mcp_server.context import Context +from arcade_mcp_server.context import get_current_model_context as get_current_context +from arcade_mcp_server.context import set_current_model_context as set_current_context +from arcade_mcp_server.types import ( + ModelHint, + ModelPreferences, +) + + +class TestContext: + """Test Context class and context management.""" + + def test_context_creation(self, mcp_server): + """Test context creation with various parameters.""" + # Basic context + context = Context(server=mcp_server) + + assert context.server == mcp_server + assert context.request_id is None + assert context.session_id is None + + # Context with request ID + context2 = Context(server=mcp_server, request_id="req-123") + + assert context2.request_id == "req-123" + assert context2.session_id is None # No session set yet + + def test_context_implements_protocol(self): + """Test that Context implements MCPContext protocol.""" + # Context should expose namespaced adapters + assert hasattr(Context, "log") + assert hasattr(Context, "progress") + assert hasattr(Context, "resources") + assert hasattr(Context, "tools") + assert hasattr(Context, "prompts") + assert hasattr(Context, "sampling") + assert hasattr(Context, "ui") + assert hasattr(Context, "notifications") + + def test_context_var_management(self): + """Test context variable get/set functionality.""" + server = Mock() + context = Context(server=server) + + # Initially no current context + assert get_current_context() is None + + # Set context + token = set_current_context(context) + assert get_current_context() == context + + # Clear context + set_current_context(None, token) + assert get_current_context() is None + + @pytest.mark.asyncio + async def test_context_isolation(self): + """Test that contexts are isolated between async tasks.""" + server = Mock() + context1 = Context(server=server, request_id="req-1") + context2 = Context(server=server, request_id="req-2") + + results = [] + + async def task1(): + set_current_context(context1) + results.append(get_current_context()) + + async def task2(): + set_current_context(context2) + results.append(get_current_context()) + + # Run tasks + await asyncio.gather(task1(), task2()) + + # Each task should have its own context + assert len(results) == 2 + # Context vars are task-local, so both should see their own context + assert context1 in results + assert context2 in results + + @pytest.mark.asyncio + async def test_logging_methods(self, mcp_server): + """Test logging methods.""" + session = Mock() + session.send_log_message = AsyncMock() + + context = Context(server=mcp_server) + context.set_session(session) + + # Test all log levels + await context.log.debug("Debug message") + await context.log.info("Info message") + await context.log.warning("Warning message") + await context.log.error("Error message") + + # Verify calls + assert session.send_log_message.call_count == 4 + + # Test with extra metadata + await context.log("info", "Test message", logger_name="test.logger", extra={"key": "value"}) + + # Check the call - context passes logger_name but session expects logger + call_kwargs = session.send_log_message.call_args[1] + assert call_kwargs["level"] == "info" + assert isinstance(call_kwargs["data"], dict) + assert call_kwargs["data"]["msg"] == "Test message" + assert call_kwargs["data"]["extra"] == {"key": "value"} + assert call_kwargs["logger"] == "test.logger" + + @pytest.mark.asyncio + async def test_logging_without_session(self, mcp_server): + """Test logging when session is not available.""" + context = Context(server=mcp_server) + + # Should not raise errors + await context.log.debug("Debug message") + await context.log.info("Info message") + await context.log.warning("Warning message") + await context.log.error("Error message") + + @pytest.mark.asyncio + async def test_progress_reporting(self, mcp_server): + """Test progress reporting functionality.""" + session = Mock() + session.send_progress_notification = AsyncMock() + session._request_meta = Mock(progressToken="task-123") + + context = Context(server=mcp_server) + context.set_session(session) + + # Report progress + await context.progress.report(50, 100, "Processing...") + + session.send_progress_notification.assert_called_once_with( + progress_token="task-123", progress=50, total=100, message="Processing..." + ) + + # Without total + await context.progress.report(0.75, message="Almost done") + + assert session.send_progress_notification.call_count == 2 + + # Test without progress token - should not call send_progress_notification + session2 = Mock(spec=["send_progress_notification"]) + session2.send_progress_notification = AsyncMock() + # Without _request_meta attribute, progress won't be reported + context2 = Context(server=mcp_server) + context2.set_session(session2) + + await context2.progress.report(25, 100) + session2.send_progress_notification.assert_not_called() + + @pytest.mark.asyncio + async def test_resource_reading(self, mcp_server): + """Test resource reading through context.""" + # Mock server's resource reading + mcp_server._mcp_read_resource = AsyncMock( + return_value=[{"uri": "file://test.txt", "text": "Test content"}] + ) + + context = Context(server=mcp_server) + + resources = await context.resources.read("file://test.txt") + + assert len(resources) == 1 + assert resources[0]["text"] == "Test content" + mcp_server._mcp_read_resource.assert_called_once_with("file://test.txt") + + @pytest.mark.asyncio + async def test_list_roots(self, mcp_server): + """Test listing roots.""" + session = Mock() + # Return an object with roots attribute + result = Mock() + result.roots = [{"uri": "file:///home", "name": "Home"}] + session.list_roots = AsyncMock(return_value=result) + + context = Context(server=mcp_server) + context.set_session(session) + + roots = await context.resources.list_roots() + + assert len(roots) == 1 + assert roots[0]["name"] == "Home" + + @pytest.mark.asyncio + async def test_sampling(self, mcp_server): + """Test sampling functionality.""" + session = Mock() + # Mock the response with content attribute + result = Mock() + result.content = {"type": "text", "text": "Response"} + session.create_message = AsyncMock(return_value=result) + + context = Context(server=mcp_server) + context.set_session(session) + + # Mock client capabilities check + session.check_client_capability = Mock(return_value=True) + + # Test basic sampling + result = await context.sampling.create_message( + messages="Hello", system_prompt="Be helpful", temperature=0.7, max_tokens=100 + ) + + assert result["type"] == "text" + assert result["text"] == "Response" + + # Test with model preferences + result = await context.sampling.create_message( + messages=[{"role": "user", "content": "Hello"}], + model_preferences=ModelPreferences(hints=[ModelHint(name="claude-3")]), + ) + + assert session.create_message.call_count == 2 + + @pytest.mark.asyncio + async def test_sampling_without_capability(self, mcp_server): + """Test sampling when client doesn't support it.""" + session = Mock() + context = Context(server=mcp_server) + context.set_session(session) + + # Mock client capabilities check to return False + session.check_client_capability = Mock(return_value=False) + + with pytest.raises(ValueError) as exc_info: + await context.sampling.create_message(messages=["Hello"], max_tokens=32) + + assert "Client does not support sampling" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_elicitation(self, mcp_server): + """Test user input elicitation.""" + session = Mock() + # Mock the elicit method on session + session.elicit = AsyncMock(return_value={"value": "user input"}) + + context = Context(server=mcp_server) + context.set_session(session) + + # Test string elicitation + result = await context.ui.elicit("Enter your name:") + + assert result == {"value": "user input"} + session.elicit.assert_called_once_with( + message="Enter your name:", + requested_schema={"type": "object", "properties": {}}, + timeout=300.0, + ) + + # Test with schema + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + result = await context.ui.elicit("Enter details:", schema=schema) + + assert result == {"value": "user input"} + assert session.elicit.call_count == 2 + session.elicit.assert_called_with( + message="Enter details:", requested_schema=schema, timeout=300 + ) + + @pytest.mark.asyncio + async def test_notification_queueing(self, mcp_server): + """Test notification queueing methods.""" + session = Mock() + session.send_tool_list_changed = AsyncMock() + session.send_resource_list_changed = AsyncMock() + session.send_prompt_list_changed = AsyncMock() + + context = Context(server=mcp_server) + context.set_session(session) + + # Queue notifications - they are queued and not sent immediately + await context.notifications.tools.list_changed() + await context.notifications.resources.list_changed() + await context.notifications.prompts.list_changed() + + # Notifications should be queued, not sent immediately + assert "notifications/tools/list_changed" in context._notification_queue + assert "notifications/resources/list_changed" in context._notification_queue + assert "notifications/prompts/list_changed" in context._notification_queue + + # Mock the notification manager + nm = Mock() + nm.notify_tool_list_changed = AsyncMock() + nm.notify_resource_list_changed = AsyncMock() + mcp_server.notification_manager = nm + + # Add session_id to session + session.session_id = "test-session-123" + + # Now flush notifications + await context._flush_notifications() + + # Queue should be cleared after flush + assert len(context._notification_queue) == 0 + + # Verify notifications were sent with the session_id + nm.notify_tool_list_changed.assert_called_once_with(["test-session-123"]) + nm.notify_resource_list_changed.assert_called_once_with(["test-session-123"]) + + def test_parse_model_preferences(self, mcp_server): + """Test model preferences parsing.""" + context = Context(server=mcp_server) + + # Test with ModelPreferences object + prefs = ModelPreferences(hints=[ModelHint(name="gpt-4")]) + parsed = context._parse_model_preferences(prefs) + assert parsed == prefs + + # Test with string + parsed = context._parse_model_preferences("gpt-4") + assert isinstance(parsed, ModelPreferences) + assert len(parsed.hints) == 1 + assert parsed.hints[0].name == "gpt-4" + + # Test with list of strings + parsed = context._parse_model_preferences(["gpt-4", "claude-3"]) + assert isinstance(parsed, ModelPreferences) + assert len(parsed.hints) == 2 + assert parsed.hints[0].name == "gpt-4" + assert parsed.hints[1].name == "claude-3" + + # Test with None + parsed = context._parse_model_preferences(None) + assert parsed is None + + # Test with invalid type + with pytest.raises(ValueError, match="Invalid model preferences type"): + context._parse_model_preferences({"invalid": "dict"}) + + @pytest.mark.asyncio + async def test_context_without_server(self): + """Test operations that require server when server is None.""" + # Create a context with a server that will be garbage collected + server = Mock() + context = Context(server=server) + # Clear the strong reference to server + del server + + with pytest.raises(RuntimeError) as exc_info: + # This should raise because the weak reference is dead + _ = context.server + + assert "Server instance is no longer available" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_context_without_session(self, mcp_server): + """Test operations that require session when session is None.""" + context = Context(server=mcp_server) + + # These should return empty/None without raising + roots = await context.resources.list_roots() + assert roots == [] + + # Sampling should raise ValueError when session is None + with pytest.raises(ValueError, match="Session not available"): + await context.sampling.create_message(messages=["Hello"], max_tokens=32) + + # Elicit should also raise ValueError when session is None + with pytest.raises(ValueError, match="Session not available"): + await context.ui.elicit("Enter text") + + @pytest.mark.asyncio + async def test_context_as_context_manager(self, mcp_server): + """Test using context as an async context manager.""" + context = Context(server=mcp_server) + + # Enter context + async with context as ctx: + assert ctx == context + # Context should be set as current + assert get_current_context() == context + + # After exit, context should be reset + assert get_current_context() is None diff --git a/libs/tests/arcade_mcp_server/test_convert.py b/libs/tests/arcade_mcp_server/test_convert.py new file mode 100644 index 00000000..6e39865f --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_convert.py @@ -0,0 +1,420 @@ +"""Tests for MCP content conversion utilities.""" + +import base64 +import json +from typing import Annotated + +import pytest +from arcade_core.catalog import MaterializedTool, ToolMeta, create_func_models +from arcade_core.schema import ( + InputParameter, + ToolDefinition, + ToolInput, + ToolkitDefinition, + ToolOutput, + ToolRequirements, + ValueSchema, +) +from arcade_mcp_server import tool +from arcade_mcp_server.convert import convert_to_mcp_content, create_mcp_tool + +# Small PNG header (1x1 transparent pixel) used for byte-image param tests +PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde" + + +class TestConvertToMCPContent: + """Test convert_to_mcp_content function.""" + + @pytest.mark.parametrize( + "value, expect_empty, decode_b64, expect_text", + [ + ("Hello, world!", False, False, "Hello, world!"), + (42, False, False, "42"), + (3.14159, False, False, "3.14159"), + (1234567890, False, False, "1234567890"), + (True, False, False, "True"), + (False, False, False, "False"), + ("single", False, False, None), # covers list wrapping behavior + ("Hello\nWorld\t๐ŸŒ", False, False, "Hello\nWorld\t๐ŸŒ"), + ("", False, False, ""), + (b"Hello, binary world!", False, True, None), + (PNG_BYTES, False, True, None), + (None, True, False, None), + ({}, False, False, "{}"), + ([], False, False, "[]"), + ], + ) + def test_convert_primitives_and_bytes(self, value, expect_empty, decode_b64, expect_text): + """Parameterize primitives/bytes/empties/special cases.""" + result = convert_to_mcp_content(value) + + if expect_empty: + assert result == [] + return + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + text = result[0].text + + if decode_b64: + decoded = base64.b64decode(text) + assert decoded == value + + if expect_text is not None: + assert text == expect_text + + @pytest.mark.parametrize( + "data", + [ + {"name": "Alice", "age": 30, "active": True}, + [1, 2, "three", {"four": 4}], + { + "users": [ + {"id": 1, "name": "Alice", "tags": ["admin", "user"]}, + {"id": 2, "name": "Bob", "tags": ["user"]}, + ], + "metadata": {"version": "1.0", "count": 2}, + }, + ], + ) + def test_convert_json_roundtrip(self, data): + """Parameterize JSON-serializable structures and assert round-trip equality.""" + result = convert_to_mcp_content(data) + assert len(result) == 1 + assert result[0].type == "text" + + parsed = json.loads(result[0].text) + assert parsed == data + + def test_convert_circular_reference(self): + """Test handling circular references in objects.""" + # Create circular reference + obj = {"a": 1} + obj["self"] = obj + + # Should handle gracefully (implementation dependent) + # Most JSON encoders will raise an error + with pytest.raises(Exception): + convert_to_mcp_content(obj) + + def test_convert_custom_objects(self): + """Test converting custom objects.""" + + class CustomObject: + def __str__(self): + return "CustomObject instance" + + def __repr__(self): + return "" + + obj = CustomObject() + result = convert_to_mcp_content(obj) + + # Should use string representation + assert "CustomObject" in result[0].text + + +class TestCreateMCPTool: + """Test create_mcp_tool function.""" + + @pytest.fixture + def sample_tool_def(self): + """Create a sample tool definition.""" + return ToolDefinition( + name="calculate", + fully_qualified_name="MathToolkit.calculate", + description="Perform a calculation", + toolkit=ToolkitDefinition( + name="MathToolkit", + description="Math tools", + version="1.0.0", + ), + input=ToolInput( + parameters=[ + InputParameter( + name="expression", + required=True, + description="Math expression to evaluate", + value_schema=ValueSchema(val_type="string"), + ), + InputParameter( + name="precision", + required=False, + description="Decimal precision", + value_schema=ValueSchema(val_type="integer"), + ), + ] + ), + output=ToolOutput( + description="Calculation result", + value_schema=ValueSchema(val_type="number"), + ), + requirements=ToolRequirements(), + ) + + @pytest.fixture + def materialized_tool(self, sample_tool_def): + """Create a materialized tool.""" + + @tool + def calculate( + expression: Annotated[str, "Math expression"] = "1 + 1", + precision: Annotated[int, "Decimal precision"] = 2, + ) -> Annotated[float, "Calculation result"]: + """Perform a calculation.""" + return round(eval(expression), precision) # noqa: S307 + + input_model, output_model = create_func_models(calculate) + meta = ToolMeta(module=calculate.__module__, toolkit=sample_tool_def.toolkit.name) + return MaterializedTool( + tool=calculate, + definition=sample_tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + def test_create_basic_tool(self, materialized_tool): + """Test creating basic MCP tool.""" + mcp_tool = create_mcp_tool(materialized_tool) + + assert mcp_tool.name == "MathToolkit_calculate" + # ensure input schema present + assert isinstance(mcp_tool.inputSchema, dict) + + def test_tool_input_schema(self, materialized_tool): + """Test tool input schema generation.""" + mcp_tool = create_mcp_tool(materialized_tool) + schema = mcp_tool.inputSchema + + assert schema["type"] == "object" + assert "properties" in schema + assert "expression" in schema["properties"] + assert "precision" in schema["properties"] + + # Required may or may not be present depending on defaults + if "required" in schema: + assert "expression" in schema["required"] + + def _create_tool_def_with_type(self, param_type: str) -> ToolDefinition: + return ToolDefinition( + name="test", + fully_qualified_name="Test.test", + description="Test", + toolkit=ToolkitDefinition(name="Test"), + input=ToolInput( + parameters=[ + InputParameter( + name="param", + required=True, + description="Test param", + value_schema=ValueSchema(val_type=param_type), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + @pytest.mark.parametrize( + "arcade_type,json_type", + [ + ("string", "string"), + ("integer", "integer"), + ("number", "number"), + ("boolean", "boolean"), + ("array", "array"), + ("json", "object"), + ], + ) + def test_parameter_types(self, arcade_type, json_type): + """Test different parameter type conversions (parameterized).""" + tool_def = self._create_tool_def_with_type(arcade_type) + + @tool + def f(param: Annotated[str, "Test param"]): + return param + + input_model, output_model = create_func_models(f) + meta = ToolMeta(module=f.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=f, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + mcp_tool = create_mcp_tool(mat_tool) + param_schema = mcp_tool.inputSchema["properties"]["param"] + assert param_schema["type"] == json_type + + def test_array_parameter(self): + """Test array parameter with inner type.""" + tool_def = ToolDefinition( + name="test", + fully_qualified_name="Test.test", + description="Test", + toolkit=ToolkitDefinition(name="Test"), + input=ToolInput( + parameters=[ + InputParameter( + name="items", + required=True, + description="List of items", + value_schema=ValueSchema( + val_type="array", + inner_val_type="string", + ), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + @tool + def f(items: Annotated[list[str], "List of items"]): + return items + + input_model, output_model = create_func_models(f) + meta = ToolMeta(module=f.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=f, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + mcp_tool = create_mcp_tool(mat_tool) + param_schema = mcp_tool.inputSchema["properties"]["items"] + + assert param_schema["type"] == "array" + assert param_schema["items"]["type"] == "string" + + def test_enum_parameter(self): + """Test enum parameter values.""" + tool_def = ToolDefinition( + name="test", + fully_qualified_name="Test.test", + description="Test", + toolkit=ToolkitDefinition(name="Test"), + input=ToolInput( + parameters=[ + InputParameter( + name="color", + required=True, + description="Color choice", + value_schema=ValueSchema( + val_type="string", + enum=["red", "green", "blue"], + ), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + @tool + def f(color: Annotated[str, "Color choice"]): + return color + + input_model, output_model = create_func_models(f) + meta = ToolMeta(module=f.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=f, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + mcp_tool = create_mcp_tool(mat_tool) + param_schema = mcp_tool.inputSchema["properties"]["color"] + + assert param_schema["type"] == "string" + assert param_schema["enum"] == ["red", "green", "blue"] + + def test_no_parameters(self): + """Test tool with no parameters.""" + tool_def = ToolDefinition( + name="test", + fully_qualified_name="Test.test", + description="Test", + toolkit=ToolkitDefinition(name="Test"), + input=ToolInput(parameters=[]), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + @tool + def f() -> Annotated[str, "result"]: + return "result" + + input_model, output_model = create_func_models(f) + meta = ToolMeta(module=f.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=f, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + mcp_tool = create_mcp_tool(mat_tool) + schema = mcp_tool.inputSchema + + assert schema["type"] == "object" + assert schema["properties"] == {} + assert schema.get("required", []) in ([], None) + + def test_missing_input_attribute_fallback(self): + """Test tool with missing input attribute to trigger _build_input_schema_from_model fallback.""" + # Create a valid ToolDefinition first + tool_def = ToolDefinition( + name="test_fallback", + fully_qualified_name="Test.test_fallback", + description="Test fallback to input model", + toolkit=ToolkitDefinition(name="Test"), + input=ToolInput(parameters=[]), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + @tool + def f( + name: Annotated[str, "User name"], age: Annotated[int, "User age"] = 25 + ) -> Annotated[str, "greeting"]: + return f"Hello {name}, you are {age} years old" + + input_model, output_model = create_func_models(f) + meta = ToolMeta(module=f.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=f, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + # Remove the input attribute from the definition to simulate the missing attribute case + delattr(mat_tool.definition, "input") + + mcp_tool = create_mcp_tool(mat_tool) + schema = mcp_tool.inputSchema + + assert schema["type"] == "object" + assert "properties" in schema + assert "name" in schema["properties"] + assert "age" in schema["properties"] + + # Ensure the schema was built from the model and not the definition + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["age"]["type"] == "integer" + + if "required" in schema: + assert "name" in schema["required"] + assert "age" not in schema["required"] diff --git a/libs/tests/arcade_mcp_server/test_error_handling_middleware.py b/libs/tests/arcade_mcp_server/test_error_handling_middleware.py new file mode 100644 index 00000000..ca83dbb9 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_error_handling_middleware.py @@ -0,0 +1,255 @@ +"""Tests for Error Handling Middleware.""" + +import asyncio +from unittest.mock import Mock, patch + +import pytest +from arcade_mcp_server.exceptions import ( + AuthorizationError, + MCPError, + NotFoundError, + ServerError, + ToolRuntimeError, +) +from arcade_mcp_server.middleware.base import MiddlewareContext +from arcade_mcp_server.middleware.error_handling import ErrorHandlingMiddleware +from arcade_mcp_server.types import JSONRPCError + + +class TestErrorHandlingMiddleware: + """Test ErrorHandlingMiddleware class.""" + + @pytest.fixture + def error_middleware(self): + """Create error handling middleware (no masking).""" + return ErrorHandlingMiddleware(mask_error_details=False) + + @pytest.fixture + def error_middleware_masked(self): + """Create error handling middleware with masking.""" + return ErrorHandlingMiddleware(mask_error_details=True) + + @pytest.fixture + def context(self): + """Create a test context.""" + return MiddlewareContext( + message={"id": 1, "method": "test"}, + mcp_context=Mock(), + request_id="req-123", + ) + + @pytest.mark.asyncio + async def test_successful_request(self, error_middleware, context): + """Test that successful requests pass through.""" + + async def handler(ctx): + return {"result": "success"} + + result = await error_middleware(context, handler) + + assert result == {"result": "success"} + + @pytest.mark.asyncio + async def test_not_found_error(self, error_middleware, context): + """Test handling of NotFoundError.""" + + async def handler(ctx): + raise NotFoundError("Resource not found: test.txt") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.id == "req-123" + assert result.error["code"] == -32601 + assert "Resource not found: test.txt" in result.error["message"] + + @pytest.mark.asyncio + async def test_server_error(self, error_middleware, context): + """Test handling of ServerError.""" + + async def handler(ctx): + raise ServerError("Server operation failed") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + assert "Server operation failed" in result.error["message"] + + @pytest.mark.asyncio + async def test_authorization_error(self, error_middleware, context): + """Test handling of AuthorizationError.""" + + async def handler(ctx): + raise AuthorizationError("User not authorized") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + assert "User not authorized" in result.error["message"] + + @pytest.mark.asyncio + async def test_tool_runtime_error(self, error_middleware, context): + """Test handling of ToolError.""" + + async def handler(ctx): + raise ToolRuntimeError("Tool execution failed: API rate limit") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + assert "Tool execution failed" in result.error["message"] + + @pytest.mark.asyncio + async def test_generic_mcp_error(self, error_middleware, context): + """Test handling of generic MCPError.""" + + async def handler(ctx): + raise MCPError("Something went wrong") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + assert "Something went wrong" in result.error["message"] + + @pytest.mark.asyncio + async def test_unexpected_error(self, error_middleware, context): + """Test handling of unexpected exceptions.""" + + async def handler(ctx): + raise RuntimeError("Unexpected error occurred") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + assert "Unexpected error occurred" in result.error["message"] + + @pytest.mark.asyncio + async def test_error_masking(self, error_middleware_masked, context): + """Test error detail masking in production.""" + + async def handler(ctx): + raise RuntimeError("Sensitive internal error with secrets") + + result = await error_middleware_masked(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + assert result.error["message"] == "Internal server error" + # Should not include sensitive details + assert "data" not in result.error + + @pytest.mark.asyncio + async def test_error_with_traceback(self, error_middleware, context): + """Test that error response contains expected structure (no traceback in current impl).""" + + async def handler(ctx): + def nested(): + raise ValueError("Deep error") + + nested() + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32602 + assert "Deep error" in result.error["message"] + + @pytest.mark.asyncio + async def test_notification_error_handling(self, error_middleware): + """Test error handling for notifications (no ID).""" + # Notifications don't have an ID + context = MiddlewareContext(message={"method": "notification/test"}, mcp_context=Mock()) + + async def handler(ctx): + raise ValueError("Notification error") + + # For notifications, errors are still returned as JSONRPCError + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert hasattr(result, "id") + + @pytest.mark.asyncio + async def test_error_logging(self, error_middleware, context): + """Test that errors are logged appropriately.""" + with patch("arcade_mcp_server.middleware.error_handling.logger") as mock_logger: + + async def handler(ctx): + raise ToolRuntimeError("Tool failed") + + await error_middleware(context, handler) + + # Should log the error using exception + mock_logger.exception.assert_called() + call_args = mock_logger.exception.call_args[0][0] + assert "Tool failed" in call_args + + @pytest.mark.asyncio + async def test_preserves_error_code_mappings(self, error_middleware, context): + """Test that error codes map per implementation.""" + test_cases = [ + (NotFoundError("Not found"), -32601), + (AuthorizationError("Unauthorized"), -32603), + (ToolRuntimeError("Tool error"), -32603), + (RuntimeError("Boom"), -32603), + ] + + for error, expected_code in test_cases: + + async def handler(ctx, e=error): + raise e + + result = await error_middleware(context, handler) + assert result.error["code"] == expected_code + + @pytest.mark.asyncio + async def test_error_context_preservation(self, error_middleware): + """Test that context information is preserved in errors.""" + context = MiddlewareContext( + message={"id": 123, "method": "tools/call"}, + mcp_context=Mock(), + request_id="req-456", + session_id="sess-789", + ) + + async def handler(ctx): + assert ctx.request_id == "req-456" + raise ValueError("Context test") + + result = await error_middleware(context, handler) + + # Error response should use request_id when present + assert result.id == "req-456" + + @pytest.mark.asyncio + async def test_async_error_handling(self, error_middleware, context): + """Test handling of errors in async operations.""" + + async def handler(ctx): + await asyncio.sleep(0.01) + raise OSError("Async operation failed") + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert result.error["code"] == -32603 + + @pytest.mark.asyncio + async def test_chained_error_handling(self, error_middleware, context): + """Test error handling with chained exceptions.""" + + async def handler(ctx): + try: + raise ValueError("Original error") + except ValueError as e: + raise RuntimeError("Wrapped error") from e + + result = await error_middleware(context, handler) + + assert isinstance(result, JSONRPCError) + assert "Wrapped error" in result.error["message"] diff --git a/libs/tests/arcade_mcp_server/test_logging_middleware.py b/libs/tests/arcade_mcp_server/test_logging_middleware.py new file mode 100644 index 00000000..57044699 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_logging_middleware.py @@ -0,0 +1,331 @@ +"""Tests for Logging Middleware.""" + +import asyncio +import logging +from unittest.mock import Mock, patch + +import pytest +from arcade_mcp_server.middleware.base import MiddlewareContext +from arcade_mcp_server.middleware.logging import LoggingMiddleware +from arcade_mcp_server.types import ( + JSONRPCError, + JSONRPCResponse, +) + + +class TestLoggingMiddleware: + """Test LoggingMiddleware class.""" + + @pytest.fixture + def logging_middleware(self): + """Create logging middleware.""" + return LoggingMiddleware(log_level="INFO") + + @pytest.fixture + def debug_logging_middleware(self): + """Create debug logging middleware.""" + return LoggingMiddleware(log_level="DEBUG") + + @pytest.fixture + def context(self): + """Create a test context.""" + return MiddlewareContext( + message={"id": 1, "method": "test/method", "params": {"key": "value"}}, + mcp_context=Mock(), + method="test/method", + request_id="req-123", + session_id="sess-456", + source="client", + type="request", + ) + + @pytest.mark.asyncio + async def test_request_logging(self, logging_middleware, context): + """Test that requests are logged.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return {"result": "success"} + + await logging_middleware(context, handler) + + # Should log the request using log() method, not info() + mock_logger.log.assert_called() + # Check the first call (request log) + first_call = mock_logger.log.call_args_list[0] + call_args = first_call[0][1] # Second arg is the message + assert "req-123" in call_args + assert "test/method" in call_args + assert "REQUEST" in call_args + + @pytest.mark.asyncio + async def test_response_logging(self, logging_middleware, context): + """Test that responses are logged.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return JSONRPCResponse(id=1, result={"status": "ok"}) + + await logging_middleware(context, handler) + + # Should log both request and response + assert mock_logger.log.call_count >= 2 + + # Find response log + response_logged = False + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1 and "RESPONSE" in call[0][1]: + response_logged = True + assert "req-123" in call[0][1] + assert "elapsed=" in call[0][1] + + assert response_logged + + @pytest.mark.asyncio + async def test_error_response_logging(self, logging_middleware, context): + """Test that error responses are logged.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return JSONRPCError(id=1, error={"code": -32603, "message": "Internal error"}) + + await logging_middleware(context, handler) + + # Should log response even for error responses + response_logged = False + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1 and "RESPONSE" in call[0][1]: + response_logged = True + + assert response_logged + + @pytest.mark.asyncio + async def test_exception_logging(self, logging_middleware, context): + """Test that exceptions are logged.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + raise ValueError("Test exception") + + with pytest.raises(ValueError): + await logging_middleware(context, handler) + + # Should log the exception + mock_logger.error.assert_called() + error_args = mock_logger.error.call_args[0][0] + assert "ERROR" in error_args + assert "ValueError" in error_args + assert "Test exception" in error_args + + @pytest.mark.asyncio + async def test_timing_information(self, logging_middleware, context): + """Test that timing information is included.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + # Simulate some work + await asyncio.sleep(0.05) + return {"result": "success"} + + await logging_middleware(context, handler) + + # Find response log with timing + timing_logged = False + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1 and "RESPONSE" in call[0][1] and "elapsed=" in call[0][1]: + timing_logged = True + # Should show elapsed time in ms + assert "ms" in call[0][1] + + assert timing_logged + + @pytest.mark.asyncio + async def test_debug_level_logging(self, debug_logging_middleware, context): + """Test debug level logging includes more details.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + # Set logger level to debug + mock_logger.isEnabledFor.return_value = True + + async def handler(ctx): + return {"result": "success", "data": {"nested": "value"}} + + await debug_logging_middleware(context, handler) + + # Should have log calls at debug level + mock_logger.log.assert_called() + # First call should be with DEBUG level (logging.DEBUG = 10) + assert mock_logger.log.call_args_list[0][0][0] == logging.DEBUG + + @pytest.mark.asyncio + async def test_notification_logging(self, logging_middleware): + """Test logging of notifications (no ID).""" + context = MiddlewareContext( + message={"method": "notifications/test"}, + mcp_context=Mock(), + method="notifications/test", + type="notification", + ) + + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return None # Notifications typically return None + + await logging_middleware(context, handler) + + # Should log notification + mock_logger.log.assert_called() + notification_logged = False + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1 and "NOTIFICATION" in call[0][1]: + notification_logged = True + + assert notification_logged + + @pytest.mark.asyncio + async def test_log_filtering(self, logging_middleware): + """Test that logging respects log level.""" + # Create middleware with high log level + middleware = LoggingMiddleware(log_level="ERROR") + + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + # Configure mock logger level + mock_logger.isEnabledFor.side_effect = lambda level: level >= logging.ERROR + + context = MiddlewareContext( + message={"id": 1, "method": "test"}, mcp_context=Mock(), method="test" + ) + + async def handler(ctx): + return {"result": "success"} + + await middleware(context, handler) + + # Should not log info level messages + mock_logger.info.assert_not_called() + + @pytest.mark.asyncio + async def test_method_specific_logging(self, logging_middleware): + """Test logging includes method-specific information.""" + # Test tool call + # Create a mock object with params attribute for the middleware to access + message = Mock() + params_mock = Mock() + params_mock.name = "MyTool" # Set as attribute, not in constructor + params_mock.arguments = {"x": 1} + message.params = params_mock + + tool_context = MiddlewareContext(message=message, mcp_context=Mock(), method="tools/call") + + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return {"result": "tool result"} + + await logging_middleware(tool_context, handler) + + # Should log tool name + tool_logged = False + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1 and "name=MyTool" in call[0][1]: + tool_logged = True + + assert tool_logged + + @pytest.mark.asyncio + async def test_session_tracking(self, logging_middleware, context): + """Test that session ID is included in logs.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return {"result": "success"} + + await logging_middleware(context, handler) + + # Should include session ID + session_logged = False + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1 and "sess-456" in call[0][1]: + session_logged = True + + assert session_logged + + @pytest.mark.asyncio + async def test_large_message_truncation(self, logging_middleware): + """Test that large messages are truncated.""" + # Create a large message + large_data = "x" * 10000 + context = MiddlewareContext( + message={"id": 1, "method": "test", "params": {"data": large_data}}, + mcp_context=Mock(), + method="test", + ) + + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + + async def handler(ctx): + return {"result": large_data} + + await logging_middleware(context, handler) + + # Log messages should be reasonable size + for call in mock_logger.log.call_args_list: + if len(call[0]) > 1: + log_msg = call[0][1] + # Should not include the full large data + assert len(log_msg) < 5000 + + @pytest.mark.asyncio + async def test_concurrent_request_logging(self, logging_middleware): + """Test logging handles concurrent requests correctly.""" + with patch("arcade_mcp_server.middleware.logging.logger") as mock_logger: + # Track which requests were logged + logged_ids = set() + + def track_logs(level, msg, *args, **kwargs): + # Extract request IDs from log messages + if "req-" in msg: + import re + + match = re.search(r"req-(\d+)", msg) + if match: + logged_ids.add(match.group(1)) + + mock_logger.log.side_effect = track_logs + + # Create multiple concurrent requests + async def make_request(req_id): + ctx = MiddlewareContext( + message={"id": req_id, "method": "test"}, + mcp_context=Mock(), + method="test", + request_id=f"req-{req_id}", + ) + + async def handler(c): + await asyncio.sleep(0.01) # Simulate work + return {"result": f"result-{req_id}"} + + return await logging_middleware(ctx, handler) + + # Run concurrent requests + await asyncio.gather(*[make_request(i) for i in range(5)]) + + # All requests should be logged + assert len(logged_ids) >= 5 + + def test_log_level_configuration(self): + """Test log level configuration.""" + # Test different log levels + for level_str, level_int in [ + ("DEBUG", logging.DEBUG), + ("INFO", logging.INFO), + ("WARNING", logging.WARNING), + ("ERROR", logging.ERROR), + ]: + middleware = LoggingMiddleware(log_level=level_str) + assert middleware.log_level == level_int + + # Test case insensitive - log_level is stored as integer + middleware = LoggingMiddleware(log_level="info") + assert middleware.log_level == logging.INFO diff --git a/libs/tests/arcade_mcp_server/test_mcp_app.py b/libs/tests/arcade_mcp_server/test_mcp_app.py new file mode 100644 index 00000000..cdfab6ac --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_mcp_app.py @@ -0,0 +1,181 @@ +"""Tests for MCPApp initialization and basic functionality.""" + +from typing import Annotated + +import pytest +from arcade_core.catalog import MaterializedTool +from arcade_mcp_server import tool +from arcade_mcp_server.mcp_app import MCPApp +from arcade_mcp_server.server import MCPServer + + +class TestMCPApp: + """Test MCPApp class.""" + + @pytest.fixture + def mcp_app(self) -> MCPApp: + """Create an MCP app.""" + return MCPApp(name="TestMCPApp", version="1.0.0") + + def test_add_tool(self, mcp_app: MCPApp): + """Test adding a tool to the MCP app.""" + + def undecorated_sample_tool( + text: Annotated[str, "Input text"], + ) -> Annotated[str, "Echoed text"]: + """Echo input text back to the caller.""" + return f"Echo: {text}" + + @tool + def decorated_sample_tool( + text: Annotated[str, "Input text"], + ) -> Annotated[str, "Echoed text"]: + """Echo input text back to the caller.""" + return f"Echo: {text}" + + previous_tools = len(mcp_app._catalog) + + undecorated_tool = mcp_app.add_tool(undecorated_sample_tool) + decorated_tool = mcp_app.add_tool(decorated_sample_tool) + + assert len(mcp_app._catalog) == previous_tools + 2 + + # Verify tool has the @tool decorator applied + assert hasattr(undecorated_tool, "__tool_name__") + assert undecorated_tool.__tool_name__ == "UndecoratedSampleTool" + assert hasattr(decorated_tool, "__tool_name__") + assert decorated_tool.__tool_name__ == "DecoratedSampleTool" + + def test_tool(self, mcp_app: MCPApp): + """Test the MCPApp tool decorator.""" + + # Test decorator without parameters + @mcp_app.tool + def simple_tool(message: Annotated[str, "A message"]) -> str: + """A simple tool.""" + return f"Response: {message}" + + # Test decorator with parameters + @mcp_app.tool(name="SimpleTool2") + def simple_tool2(message: Annotated[str, "A message"]) -> str: + """A simple tool.""" + return f"Response: {message}" + + # Verify both tools were added + assert len(mcp_app._catalog) == 2 + + # Verify decorator attributes + assert hasattr(simple_tool, "__tool_name__") + assert simple_tool.__tool_name__ == "SimpleTool" + assert hasattr(simple_tool2, "__tool_name__") + assert simple_tool2.__tool_name__ == "SimpleTool2" + # Verify tools can still be called + assert simple_tool("test") == "Response: test" + assert simple_tool2("test") == "Response: test" + + @pytest.mark.asyncio + async def test_tools_api( + self, mcp_app: MCPApp, mcp_server: MCPServer, materialized_tool: MaterializedTool + ): + """Test the tools API.""" + # Test that tools API requires server binding + with pytest.raises(Exception): # noqa: B017 + await mcp_app.tools.add(materialized_tool) + + # Bind server to app (instead of calling mcp_app.run()) + mcp_app.server = mcp_server + + # Test removing a tool at runtime + removed_tool = await mcp_app.tools.remove(materialized_tool.definition.fully_qualified_name) + assert ( + removed_tool.definition.fully_qualified_name + == materialized_tool.definition.fully_qualified_name + ) + + num_tools_before_add = len(await mcp_app.tools.list()) + + # Test adding a tool at runtime + await mcp_app.tools.add(materialized_tool) + + # Test listing tools at runtime + tools = await mcp_app.tools.list() + assert len(tools) == num_tools_before_add + 1 + + # Test updating a tool at runtime + await mcp_app.tools.update(materialized_tool) + + @pytest.mark.asyncio + async def test_prompts_api(self, mcp_app: MCPApp, mcp_server): + """Test the prompts API.""" + from arcade_mcp_server.types import Prompt, PromptArgument, PromptMessage + + # Test that prompts API requires server binding + sample_prompt = Prompt( + name="test_prompt", + description="A test prompt", + arguments=[PromptArgument(name="input", description="Test input", required=True)], + ) + + with pytest.raises(Exception) as exc_info: + await mcp_app.prompts.add(sample_prompt) + assert "No server bound to app" in str(exc_info.value) + + # Bind server to app + mcp_app.server = mcp_server + + # Create a prompt handler + async def test_handler(args: dict[str, str]) -> list[PromptMessage]: + return [ + PromptMessage( + role="user", + content={"type": "text", "text": f"Hello {args.get('input', 'world')}"}, + ) + ] + + # Test adding a prompt at runtime + await mcp_app.prompts.add(sample_prompt, test_handler) + + # Test listing prompts at runtime + prompts = await mcp_app.prompts.list() + assert len(prompts) == 1 + assert any(p.name == "test_prompt" for p in prompts) + + # Test removing a prompt at runtime + removed_prompt = await mcp_app.prompts.remove("test_prompt") + assert removed_prompt.name == "test_prompt" + + @pytest.mark.asyncio + async def test_resources_api(self, mcp_app: MCPApp, mcp_server): + """Test the resources API.""" + from arcade_mcp_server.types import Resource + + # Test that resources API requires server binding + sample_resource = Resource( + uri="file:///test.txt", + name="test.txt", + description="A test text file", + mimeType="text/plain", + ) + + with pytest.raises(Exception) as exc_info: + await mcp_app.resources.add(sample_resource) + assert "No server bound to app" in str(exc_info.value) + + # Bind server to app + mcp_app.server = mcp_server + + # Create a resource handler + def test_handler(uri: str): + return {"content": f"Content for {uri}", "mimeType": "text/plain"} + + # Test adding a resource at runtime + await mcp_app.resources.add(sample_resource, test_handler) + + # Test listing resources at runtime + resources = await mcp_app.resources.list() + assert len(resources) >= 1 + assert any(r.uri == "file:///test.txt" for r in resources) + + # Test removing a resource at runtime + removed_resource = await mcp_app.resources.remove("file:///test.txt") + assert removed_resource.uri == "file:///test.txt" diff --git a/libs/tests/arcade_mcp_server/test_middleware_base.py b/libs/tests/arcade_mcp_server/test_middleware_base.py new file mode 100644 index 00000000..58c88e85 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_middleware_base.py @@ -0,0 +1,209 @@ +"""Tests for Middleware base classes.""" + +from unittest.mock import Mock + +import pytest +from arcade_mcp_server.middleware.base import ( + Middleware, + MiddlewareContext, +) + + +class TestMiddlewareBase: + """Test base middleware functionality.""" + + def test_middleware_context_creation(self): + """Test MiddlewareContext creation.""" + message = {"method": "test", "params": {}} + context = MiddlewareContext( + message=message, + mcp_context=Mock(), + source="client", + type="request", + method="test", + request_id="req-123", + session_id="sess-456", + ) + + assert context.message == message + assert context.source == "client" + assert context.type == "request" + assert context.method == "test" + assert context.request_id == "req-123" + assert context.session_id == "sess-456" + + def test_middleware_context_metadata(self): + """Test metadata management in context.""" + context = MiddlewareContext(message={}, mcp_context=Mock()) + + # Initial metadata is empty + assert context.metadata == {} + + # Add metadata + context.metadata["key1"] = "value1" + context.metadata["key2"] = {"nested": "value"} + + assert context.metadata["key1"] == "value1" + assert context.metadata["key2"]["nested"] == "value" + + @pytest.mark.asyncio + async def test_basic_middleware(self): + """Test basic middleware implementation.""" + # Track calls + middleware_called = False + + class TestMiddleware(Middleware): + async def __call__(self, context, call_next): + nonlocal middleware_called + middleware_called = True + # Pass through to next + return await call_next(context) + + # Create middleware + middleware = TestMiddleware() + + # Mock next handler + async def next_handler(ctx): + return {"result": "success"} + + # Execute + context = MiddlewareContext(message={}, mcp_context=Mock()) + result = await middleware(context, next_handler) + + assert middleware_called + assert result == {"result": "success"} + + @pytest.mark.asyncio + async def test_middleware_modification(self): + """Test middleware that modifies context.""" + + class ModifyingMiddleware(Middleware): + async def __call__(self, context, call_next): + # Modify context before + context.metadata["before"] = True + + # Call next + result = await call_next(context) + + # Modify result after + if isinstance(result, dict): + result["after"] = True + + return result + + middleware = ModifyingMiddleware() + + async def next_handler(ctx): + assert ctx.metadata["before"] is True + return {"original": "value"} + + context = MiddlewareContext(message={}, mcp_context=Mock()) + result = await middleware(context, next_handler) + + assert result == {"original": "value", "after": True} + + @pytest.mark.asyncio + async def test_middleware_chain(self): + """Test chaining multiple middleware.""" + call_order = [] + + class Middleware1(Middleware): + async def __call__(self, context, call_next): + call_order.append("m1_before") + result = await call_next(context) + call_order.append("m1_after") + return result + + class Middleware2(Middleware): + async def __call__(self, context, call_next): + call_order.append("m2_before") + result = await call_next(context) + call_order.append("m2_after") + return result + + # Build chain manually + async def final_handler(ctx): + call_order.append("handler") + return "result" + + m2 = Middleware2() + m1 = Middleware1() + + # Chain: m1 -> m2 -> handler + async def m2_wrapped(ctx): + return await m2(ctx, final_handler) + + context = MiddlewareContext(message={}, mcp_context=Mock()) + result = await m1(context, m2_wrapped) + + # Check order + assert call_order == ["m1_before", "m2_before", "handler", "m2_after", "m1_after"] + assert result == "result" + + @pytest.mark.asyncio + async def test_middleware_error_propagation(self): + """Test error propagation through middleware.""" + + class ErrorMiddleware(Middleware): + async def __call__(self, context, call_next): + try: + return await call_next(context) + except ValueError as e: + # Transform error + raise RuntimeError(f"Wrapped: {e}") + + middleware = ErrorMiddleware() + + async def failing_handler(ctx): + raise ValueError("Original error") + + context = MiddlewareContext(message={}, mcp_context=Mock()) + + with pytest.raises(RuntimeError) as exc_info: + await middleware(context, failing_handler) + + assert "Wrapped: Original error" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_middleware_short_circuit(self): + """Test middleware that short-circuits the chain.""" + + class ShortCircuitMiddleware(Middleware): + async def __call__(self, context, call_next): + # Don't call next for certain conditions + if context.message.get("skip"): + return {"short_circuited": True} + return await call_next(context) + + middleware = ShortCircuitMiddleware() + + # Normal flow + context1 = MiddlewareContext(message={}, mcp_context=Mock()) + + async def handler(ctx): + return {"normal": True} + + result1 = await middleware(context1, handler) + assert result1 == {"normal": True} + + # Short circuit + context2 = MiddlewareContext(message={"skip": True}, mcp_context=Mock()) + result2 = await middleware(context2, handler) + assert result2 == {"short_circuited": True} + + def test_middleware_protocol(self): + """Test that Middleware follows the protocol.""" + # Middleware should be a protocol/ABC + assert callable(Middleware) + + # Should not be instantiable directly + # (This is more of a documentation test since Python protocols are flexible) + + # But subclasses should work + class ConcreteMiddleware(Middleware): + async def __call__(self, context, call_next): + return await call_next(context) + + # Should be instantiable + middleware = ConcreteMiddleware() + assert isinstance(middleware, Middleware) diff --git a/libs/tests/arcade_mcp_server/test_openapi_docs.py b/libs/tests/arcade_mcp_server/test_openapi_docs.py new file mode 100644 index 00000000..eb504612 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_openapi_docs.py @@ -0,0 +1,80 @@ +"""Test that MCP routes appear in OpenAPI documentation.""" + +import pytest +from arcade_core import ToolCatalog +from arcade_core.toolkit import Toolkit +from arcade_mcp_server.settings import MCPSettings +from arcade_mcp_server.worker import create_arcade_mcp +from fastapi.testclient import TestClient + + +def test_mcp_routes_in_openapi(monkeypatch): + """Test that MCP routes appear in FastAPI OpenAPI documentation.""" + # Set environment variables for settings + monkeypatch.setenv("ARCADE_AUTH_DISABLED", "true") + monkeypatch.setenv("ARCADE_WORKER_SECRET", "test") + monkeypatch.setenv("MCP_SERVER_NAME", "test-mcp") + monkeypatch.setenv("MCP_SERVER_VERSION", "0.1.0") + + # Create a simple catalog + catalog = ToolCatalog() + toolkit = Toolkit(name="test", package_name="test", version="0.1.0", description="Test toolkit") + catalog.add_toolkit(toolkit) + + # Create MCP settings from environment + mcp_settings = MCPSettings.from_env() + + # Create the app + app = create_arcade_mcp(catalog, mcp_settings=mcp_settings) + + # Create test client + client = TestClient(app) + + # Get OpenAPI schema + response = client.get("/openapi.json") + assert response.status_code == 200 + + openapi_schema = response.json() + + # Check that MCP paths are documented + assert "/mcp/" in openapi_schema["paths"] + + mcp_path = openapi_schema["paths"]["/mcp/"] + + # Check POST endpoint + assert "post" in mcp_path + assert mcp_path["post"]["summary"] == "Send MCP message" + assert "MCPRequest" in str(mcp_path["post"]) + assert "MCPResponse" in str(mcp_path["post"]) + + # Check GET endpoint + assert "get" in mcp_path + assert mcp_path["get"]["summary"] == "Establish SSE stream" + + # Check DELETE endpoint + assert "delete" in mcp_path + assert mcp_path["delete"]["summary"] == "Terminate session" + + # Check that component schemas are defined + components = openapi_schema.get("components", {}).get("schemas", {}) + assert "MCPRequest" in components + assert "MCPResponse" in components + + # Verify MCPRequest schema + mcp_request = components["MCPRequest"] + assert "jsonrpc" in mcp_request["properties"] + assert "method" in mcp_request["properties"] + assert "params" in mcp_request["properties"] + assert "id" in mcp_request["properties"] + + # Verify that the paths include the MCP tag + assert "tags" in mcp_path["post"] + assert "MCP Protocol" in mcp_path["post"]["tags"] + + # Verify the actual proxy is mounted (not routes) + # The OpenAPI docs should exist but not interfere with the mount + import inspect + + mounts = [route for route in app.routes if hasattr(route, "app") and hasattr(route, "path")] + mcp_mounts = [m for m in mounts if m.path == "/mcp"] + assert len(mcp_mounts) == 1, "Should have exactly one mount at /mcp" diff --git a/libs/tests/arcade_mcp_server/test_prompt.py b/libs/tests/arcade_mcp_server/test_prompt.py new file mode 100644 index 00000000..cf2f6981 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_prompt.py @@ -0,0 +1,241 @@ +"""Tests for Prompt Manager implementation.""" + +import asyncio + +import pytest +from arcade_mcp_server.exceptions import NotFoundError, PromptError +from arcade_mcp_server.managers.prompt import PromptManager +from arcade_mcp_server.types import ( + GetPromptResult, + Prompt, + PromptArgument, + PromptMessage, +) + + +class TestPromptManager: + """Test PromptManager class.""" + + @pytest.fixture + def prompt_manager(self): + """Create a prompt manager instance.""" + return PromptManager() + + @pytest.fixture + def sample_prompt(self): + """Create a sample prompt.""" + return Prompt( + name="greeting", + description="A greeting prompt", + arguments=[ + PromptArgument(name="name", description="The name to greet", required=True), + PromptArgument( + name="formal", description="Whether to use formal greeting", required=False + ), + ], + ) + + @pytest.fixture + def prompt_function(self): + """Create a prompt function.""" + + async def greeting_prompt(args: dict[str, str]) -> list[PromptMessage]: + name = args.get("name", "") + formal_arg = args.get("formal", "false") + formal = str(formal_arg).lower() == "true" + + if formal: + text = f"Good day, {name}. How may I assist you?" + else: + text = f"Hey {name}! What's up?" + + return [PromptMessage(role="assistant", content={"type": "text", "text": text})] + + return greeting_prompt + + def test_manager_initialization(self): + """Test prompt manager initialization.""" + manager = PromptManager() + assert isinstance(manager, PromptManager) + + @pytest.mark.asyncio + async def test_manager_lifecycle(self, prompt_manager): + """Passive manager has no explicit lifecycle; ensure methods work.""" + # Initially empty + prompts = await prompt_manager.list_prompts() + assert prompts == [] + + @pytest.mark.asyncio + async def test_add_prompt(self, prompt_manager, sample_prompt, prompt_function): + """Test adding prompts.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function) + + prompts = await prompt_manager.list_prompts() + assert len(prompts) == 1 + assert prompts[0].name == sample_prompt.name + assert len(prompts[0].arguments) == 2 + + @pytest.mark.asyncio + async def test_remove_prompt(self, prompt_manager, sample_prompt, prompt_function): + """Test removing prompts.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function) + removed = await prompt_manager.remove_prompt(sample_prompt.name) + assert removed.name == sample_prompt.name + + prompts = await prompt_manager.list_prompts() + assert len(prompts) == 0 + + @pytest.mark.asyncio + async def test_get_prompt(self, prompt_manager, sample_prompt, prompt_function): + """Test getting and executing prompts.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function) + + result = await prompt_manager.get_prompt("greeting", {"name": "Alice", "formal": True}) + + assert isinstance(result, GetPromptResult) + assert len(result.messages) == 1 + assert result.messages[0].role == "assistant" + assert "Good day, Alice" in result.messages[0].content["text"] + + @pytest.mark.asyncio + async def test_get_prompt_default_args(self, prompt_manager, sample_prompt, prompt_function): + """Test getting prompt with default arguments.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function) + + result = await prompt_manager.get_prompt("greeting", {"name": "Bob"}) + assert "Hey Bob!" in result.messages[0].content["text"] + + @pytest.mark.asyncio + async def test_get_prompt_missing_required_args( + self, prompt_manager, sample_prompt, prompt_function + ): + """Test getting prompt without required arguments.""" + await prompt_manager.add_prompt(sample_prompt, prompt_function) + + with pytest.raises(PromptError): + await prompt_manager.get_prompt("greeting", {"formal": True}) + + @pytest.mark.asyncio + async def test_get_nonexistent_prompt(self, prompt_manager): + """Test getting non-existent prompt.""" + with pytest.raises(NotFoundError): + await prompt_manager.get_prompt("nonexistent", {}) + + @pytest.mark.asyncio + async def test_prompt_with_multiple_messages(self, prompt_manager): + """Test prompt that returns multiple messages.""" + prompt = Prompt(name="conversation", description="A conversation prompt") + + async def conversation_prompt(args: dict[str, str]) -> list[PromptMessage]: + return [ + PromptMessage(role="user", content={"type": "text", "text": "Hello!"}), + PromptMessage(role="assistant", content={"type": "text", "text": "Hi there!"}), + PromptMessage(role="user", content={"type": "text", "text": "How are you?"}), + PromptMessage( + role="assistant", content={"type": "text", "text": "I'm doing well, thanks!"} + ), + ] + + await prompt_manager.add_prompt(prompt, conversation_prompt) + + result = await prompt_manager.get_prompt("conversation", {}) + + assert len(result.messages) == 4 + assert result.messages[0].role == "user" + assert result.messages[1].role == "assistant" + + @pytest.mark.asyncio + async def test_prompt_with_image_content(self, prompt_manager): + """Test prompt with image content.""" + prompt = Prompt( + name="image_analysis", + description="Analyze an image", + arguments=[PromptArgument(name="image_url", required=True)], + ) + + async def image_prompt(args: dict[str, str]) -> list[PromptMessage]: + image_url = args.get("image_url", "") + return [ + PromptMessage( + role="user", + content={"type": "image", "data": image_url, "mimeType": "image/jpeg"}, + ), + PromptMessage( + role="user", content={"type": "text", "text": "Please analyze this image"} + ), + ] + + await prompt_manager.add_prompt(prompt, image_prompt) + + result = await prompt_manager.get_prompt( + "image_analysis", {"image_url": "http://example.com/image.jpg"} + ) + + assert len(result.messages) == 2 + assert result.messages[0].content["type"] == "image" + assert result.messages[1].content["type"] == "text" + + @pytest.mark.asyncio + async def test_prompt_with_embedded_resource(self, prompt_manager): + """Test prompt with embedded resources.""" + prompt = Prompt(name="with_resource", description="Prompt with embedded resource") + + async def resource_prompt(args: dict[str, str]) -> list[PromptMessage]: + return [ + PromptMessage( + role="user", + content={ + "type": "resource", + "resource": {"uri": "file:///data.txt", "text": "Sample data"}, + }, + ) + ] + + await prompt_manager.add_prompt(prompt, resource_prompt) + + result = await prompt_manager.get_prompt("with_resource", {}) + + assert result.messages[0].content["type"] == "resource" + assert result.messages[0].content["resource"]["uri"] == "file:///data.txt" + + @pytest.mark.asyncio + async def test_concurrent_prompt_operations(self, prompt_manager): + """Test concurrent prompt operations.""" + prompts = [] + for i in range(10): + prompt = Prompt(name=f"prompt_{i}", description=f"Prompt {i}") + + async def func(args: dict[str, str], idx=i): + return [ + PromptMessage( + role="assistant", content={"type": "text", "text": f"Response {idx}"} + ) + ] + + prompts.append((prompt, func)) + + tasks = [prompt_manager.add_prompt(p, f) for p, f in prompts] + await asyncio.gather(*tasks) + + listed = await prompt_manager.list_prompts() + assert len(listed) == 10 + + @pytest.mark.asyncio + async def test_list_prompts_initial(self, prompt_manager): + """Passive manager lists prompts initially as empty.""" + prompts = await prompt_manager.list_prompts() + assert prompts == [] + + @pytest.mark.asyncio + async def test_prompt_error_handling(self): + """Test error handling in prompt functions.""" + manager = PromptManager() + prompt = Prompt(name="error_prompt", description="Prompt that errors") + + async def error_prompt(args: dict[str, str]): + raise RuntimeError("Prompt execution failed") + + await manager.add_prompt(prompt, error_prompt) + + with pytest.raises(PromptError): + await manager.get_prompt("error_prompt", {}) diff --git a/libs/tests/arcade_mcp_server/test_public_imports.py b/libs/tests/arcade_mcp_server/test_public_imports.py new file mode 100644 index 00000000..c397e3b8 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_public_imports.py @@ -0,0 +1,49 @@ +def test_basic_imports(): + """Test basic imports from arcade_mcp_server.""" + from arcade_mcp_server.context import Context + from arcade_mcp_server.server import MCPServer + + # All imports should work + assert MCPServer is not None + assert Context is not None + + +def test_manager_imports(): + """Test manager imports.""" + from arcade_mcp_server.managers.prompt import PromptManager + from arcade_mcp_server.managers.resource import ResourceManager + from arcade_mcp_server.managers.tool import ToolManager + + assert ToolManager is not None + assert ResourceManager is not None + assert PromptManager is not None + + +def test_middleware_imports(): + """Test middleware imports.""" + from arcade_mcp_server.middleware.base import Middleware + from arcade_mcp_server.middleware.error_handling import ErrorHandlingMiddleware + from arcade_mcp_server.middleware.logging import LoggingMiddleware + + assert Middleware is not None + assert ErrorHandlingMiddleware is not None + assert LoggingMiddleware is not None + + +def test_transport_imports(): + """Test transport imports.""" + from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager + from arcade_mcp_server.transports.http_streamable import HTTPStreamableTransport + from arcade_mcp_server.transports.stdio import StdioTransport + + assert StdioTransport is not None + assert HTTPStreamableTransport is not None + assert HTTPSessionManager is not None + + +if __name__ == "__main__": + test_basic_imports() + test_manager_imports() + test_middleware_imports() + test_transport_imports() + print("All imports successful!") diff --git a/libs/tests/arcade_mcp_server/test_resource.py b/libs/tests/arcade_mcp_server/test_resource.py new file mode 100644 index 00000000..1dd0c506 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_resource.py @@ -0,0 +1,193 @@ +"""Tests for Resource Manager implementation.""" + +import asyncio + +import pytest +from arcade_mcp_server.exceptions import NotFoundError +from arcade_mcp_server.managers.resource import ResourceManager +from arcade_mcp_server.types import ( + BlobResourceContents, + Resource, + ResourceContents, + ResourceTemplate, + TextResourceContents, +) + + +class TestResourceManager: + """Test ResourceManager class.""" + + @pytest.fixture + def resource_manager(self): + """Create a resource manager instance.""" + return ResourceManager() + + @pytest.fixture + def sample_resource(self): + """Create a sample resource.""" + return Resource( + uri="file:///test.txt", + name="test.txt", + description="A test text file", + mimeType="text/plain", + ) + + @pytest.fixture + def sample_template(self): + """Create a sample resource template.""" + return ResourceTemplate( + uriTemplate="file:///{path}", + name="File Template", + description="Template for file resources", + mimeType="text/plain", + ) + + def test_manager_initialization(self): + """Test resource manager initialization.""" + manager = ResourceManager() + # Passive manager: no started flag + assert isinstance(manager, ResourceManager) + + @pytest.mark.asyncio + async def test_manager_lifecycle(self, resource_manager): + """Passive manager has no explicit lifecycle; ensure methods work.""" + resources = await resource_manager.list_resources() + assert resources == [] + + @pytest.mark.asyncio + async def test_add_resource(self, resource_manager, sample_resource): + """Test adding resources.""" + await resource_manager.add_resource(sample_resource) + + resources = await resource_manager.list_resources() + assert len(resources) == 1 + assert resources[0].uri == sample_resource.uri + + @pytest.mark.asyncio + async def test_remove_resource(self, resource_manager, sample_resource): + """Test removing resources.""" + await resource_manager.add_resource(sample_resource) + removed = await resource_manager.remove_resource(sample_resource.uri) + assert removed.uri == sample_resource.uri + + resources = await resource_manager.list_resources() + assert len(resources) == 0 + + @pytest.mark.asyncio + async def test_remove_nonexistent_resource(self, resource_manager): + """Test removing non-existent resource.""" + with pytest.raises(NotFoundError): + await resource_manager.remove_resource("file:///nonexistent.txt") + + @pytest.mark.asyncio + async def test_add_resource_template(self, resource_manager, sample_template): + """Test adding resource templates.""" + await resource_manager.add_template(sample_template) + + templates = await resource_manager.list_resource_templates() + assert len(templates) == 1 + assert templates[0].uriTemplate == sample_template.uriTemplate + + @pytest.mark.asyncio + async def test_resource_handlers(self, resource_manager): + """Test adding and using resource handlers.""" + resource = Resource( + uri="custom://test", name="Custom Resource", description="Resource with custom handler" + ) + + async def custom_handler(uri: str) -> list[ResourceContents]: + return [ + TextResourceContents( + uri=uri, text="Custom content for " + uri, mimeType="text/plain" + ) + ] + + await resource_manager.add_resource(resource, handler=custom_handler) + + contents = await resource_manager.read_resource("custom://test") + + assert len(contents) == 1 + assert contents[0].text == "Custom content for custom://test" + + @pytest.mark.asyncio + async def test_read_resource_without_handler(self, resource_manager, sample_resource): + """Test reading resource without a handler returns default content.""" + await resource_manager.add_resource(sample_resource) + + contents = await resource_manager.read_resource(sample_resource.uri) + assert len(contents) == 1 + assert contents[0].uri == sample_resource.uri + + @pytest.mark.asyncio + async def test_read_nonexistent_resource(self, resource_manager): + """Test reading non-existent resource.""" + with pytest.raises(NotFoundError): + await resource_manager.read_resource("file:///nonexistent.txt") + + @pytest.mark.asyncio + async def test_binary_resource_content(self, resource_manager): + """Test handling binary resource content.""" + resource = Resource(uri="file:///image.png", name="image.png", mimeType="image/png") + + async def image_handler(uri: str) -> list[ResourceContents]: + import base64 + + png_data = base64.b64encode( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde" + ).decode() + return [BlobResourceContents(uri=uri, blob=png_data, mimeType="image/png")] + + await resource_manager.add_resource(resource, handler=image_handler) + + contents = await resource_manager.read_resource("file:///image.png") + + assert len(contents) == 1 + assert isinstance(contents[0], BlobResourceContents) + assert contents[0].mimeType == "image/png" + + @pytest.mark.asyncio + async def test_multiple_resource_contents(self, resource_manager): + """Test resources that return multiple contents.""" + resource = Resource(uri="multi://resource", name="Multi Resource") + + async def multi_handler(uri: str) -> list[ResourceContents]: + return [ + TextResourceContents(uri=uri + "#part1", text="Part 1"), + TextResourceContents(uri=uri + "#part2", text="Part 2"), + BlobResourceContents(uri=uri + "#data", blob="YmluYXJ5"), + ] + + await resource_manager.add_resource(resource, handler=multi_handler) + + contents = await resource_manager.read_resource("multi://resource") + + assert len(contents) == 3 + assert contents[0].text == "Part 1" + assert contents[1].text == "Part 2" + assert contents[2].blob == "YmluYXJ5" + + @pytest.mark.asyncio + async def test_concurrent_resource_operations(self, resource_manager): + """Test concurrent resource operations.""" + # Create multiple resources + resources = [] + for i in range(10): + resource = Resource( + uri=f"file:///{i}.txt", name=f"File {i}", description=f"Test file {i}" + ) + resources.append(resource) + + tasks = [resource_manager.add_resource(r) for r in resources] + await asyncio.gather(*tasks) + + listed = await resource_manager.list_resources() + assert len(listed) == 10 + + @pytest.mark.asyncio + async def test_list_resources_and_templates_initial(self): + """Passive manager lists resources/templates initially as empty.""" + manager = ResourceManager() + resources = await manager.list_resources() + assert resources == [] + templates = await manager.list_resource_templates() + assert templates == [] diff --git a/libs/tests/arcade_mcp_server/test_server.py b/libs/tests/arcade_mcp_server/test_server.py new file mode 100644 index 00000000..956cfb43 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_server.py @@ -0,0 +1,383 @@ +"""Tests for MCP Server implementation.""" + +import asyncio +import contextlib +from unittest.mock import AsyncMock, Mock + +import pytest +from arcade_mcp_server.middleware import Middleware +from arcade_mcp_server.server import MCPServer +from arcade_mcp_server.session import InitializationState +from arcade_mcp_server.types import ( + CallToolRequest, + CallToolResult, + InitializeRequest, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + PingRequest, +) + + +class TestMCPServer: + """Test MCPServer class.""" + + def test_server_initialization(self, tool_catalog, mcp_settings): + """Test server initialization with various configurations.""" + # Basic initialization + server = MCPServer( + catalog=tool_catalog, + name="Test Server", + version="1.0.0", + settings=mcp_settings, + ) + + assert server.name == "Test Server" + assert server.version == "1.0.0" + assert server.title == "Test Server" + assert server.settings == mcp_settings + + # With custom title and instructions + server2 = MCPServer( + catalog=tool_catalog, + name="Test Server", + version="1.0.0", + title="Custom Title", + instructions="Custom instructions", + ) + + assert server2.title == "Custom Title" + assert server2.instructions == "Custom instructions" + + def test_handler_registration(self, tool_catalog): + """Test that all required handlers are registered.""" + server = MCPServer(catalog=tool_catalog) + + expected_handlers = [ + "ping", + "initialize", + "tools/list", + "tools/call", + "resources/list", + "resources/templates/list", + "resources/read", + "prompts/list", + "prompts/get", + "logging/setLevel", + ] + + for method in expected_handlers: + assert method in server._handlers + assert callable(server._handlers[method]) + + @pytest.mark.asyncio + async def test_server_lifecycle(self, tool_catalog, mcp_settings): + """Test server startup and shutdown.""" + server = MCPServer( + catalog=tool_catalog, + settings=mcp_settings, + ) + + # Start server + await server.start() + + # Stop server + await server.stop() + + @pytest.mark.asyncio + async def test_handle_ping(self, mcp_server): + """Test ping request handling.""" + message = PingRequest(jsonrpc="2.0", id=1, method="ping") + + response = await mcp_server._handle_ping(message) + + assert isinstance(response, JSONRPCResponse) + assert response.id == 1 + assert response.result == {} + + @pytest.mark.asyncio + async def test_handle_initialize(self, mcp_server): + """Test initialize request handling.""" + message = InitializeRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + ) + + # Create mock session + session = Mock() + session.set_client_params = Mock() + + response = await mcp_server._handle_initialize(message, session=session) + + assert isinstance(response, JSONRPCResponse) + assert response.id == 1 + assert isinstance(response.result, InitializeResult) + assert response.result.protocolVersion is not None + assert response.result.serverInfo.name == mcp_server.name + assert response.result.serverInfo.version == mcp_server.version + + # Check session was updated + session.set_client_params.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_list_tools(self, mcp_server): + """Test list tools request handling.""" + message = ListToolsRequest(jsonrpc="2.0", id=2, method="tools/list", params={}) + + response = await mcp_server._handle_list_tools(message) + + assert isinstance(response, JSONRPCResponse) + assert response.id == 2 + assert isinstance(response.result, ListToolsResult) + assert len(response.result.tools) > 0 + + @pytest.mark.asyncio + async def test_handle_call_tool(self, mcp_server): + """Test tool call request handling.""" + message = CallToolRequest( + jsonrpc="2.0", + id=3, + method="tools/call", + params={"name": "TestToolkit.test_tool", "arguments": {"text": "Hello"}}, + ) + + response = await mcp_server._handle_call_tool(message) + + assert isinstance(response, JSONRPCResponse) + assert response.id == 3 + assert isinstance(response.result, CallToolResult) + assert response.result.structuredContent is not None + assert "result" in response.result.structuredContent + assert "Echo: Hello" in response.result.structuredContent["result"] + + @pytest.mark.asyncio + async def test_handle_call_tool_with_requires_auth(self, mcp_server): + """Test tool call request handling with authorization.""" + + mock_auth_response = Mock() + mock_auth_response.status = "pending" + mock_auth_response.url = "https://example.com/auth" + + # Patch the _check_authorization method to return a tool that has unsatisfied authorization + mcp_server._check_authorization = AsyncMock(return_value=mock_auth_response) + + message = CallToolRequest( + jsonrpc="2.0", + id=3, + method="tools/call", + params={"name": "TestToolkit.sample_tool_with_auth", "arguments": {"text": "Hello"}}, + ) + + response = await mcp_server._handle_call_tool(message) + + assert isinstance(response, JSONRPCResponse) + assert response.id == 3 + assert isinstance(response.result, CallToolResult) + assert response.result.structuredContent is not None + assert "authorization_url" in response.result.structuredContent + assert response.result.structuredContent["authorization_url"] == "https://example.com/auth" + assert "message" in response.result.structuredContent + assert "authorization" in response.result.structuredContent["message"] + + @pytest.mark.asyncio + async def test_handle_call_tool_not_found(self, mcp_server): + """Test calling a non-existent tool.""" + message = CallToolRequest( + jsonrpc="2.0", + id=3, + method="tools/call", + params={"name": "NonExistent.tool", "arguments": {}}, + ) + + response = await mcp_server._handle_call_tool(message) + + assert isinstance(response, JSONRPCResponse) + assert response.result.isError + assert "error" in response.result.structuredContent + assert "Unknown tool" in response.result.structuredContent["error"] + + @pytest.mark.asyncio + async def test_handle_message_routing(self, mcp_server, initialized_server_session): + """Test message routing to appropriate handlers.""" + # Test valid method + message = {"jsonrpc": "2.0", "id": 1, "method": "ping"} + + response = await mcp_server.handle_message(message, session=initialized_server_session) + + assert response is not None + assert str(response.id) == "1" + assert response.result == {} + + # Test invalid method + message = {"jsonrpc": "2.0", "id": 2, "method": "invalid/method"} + + response = await mcp_server.handle_message(message, session=initialized_server_session) + + assert isinstance(response, JSONRPCError) + assert response.error["code"] == -32601 + assert "Method not found" in response.error["message"] + + @pytest.mark.asyncio + async def test_handle_message_invalid_format(self, mcp_server): + """Test handling of invalid message formats.""" + # Non-dict message + response = await mcp_server.handle_message("invalid", session=None) + + assert isinstance(response, JSONRPCError) + assert response.error["code"] == -32600 + assert "Invalid request" in response.error["message"] + + @pytest.mark.asyncio + async def test_initialization_state_enforcement(self, mcp_server): + """Test that non-initialize methods are blocked before initialization.""" + # Create uninitialized session + session = Mock() + session.initialization_state = InitializationState.NOT_INITIALIZED + + # Try to call tools/list before initialization + message = {"jsonrpc": "2.0", "id": 1, "method": "tools/list"} + + response = await mcp_server.handle_message(message, session=session) + + assert isinstance(response, JSONRPCError) + assert response.error["code"] == -32600 + assert "not allowed before initialization" in response.error["message"] + + @pytest.mark.asyncio + async def test_notification_handling(self, mcp_server): + """Test handling of notification messages.""" + session = Mock() + session.mark_initialized = Mock() + + # Send initialized notification + message = {"jsonrpc": "2.0", "method": "notifications/initialized"} + + response = await mcp_server.handle_message(message, session=session) + + # Notifications should not return a response + assert response is None + # Session should be marked as initialized + session.mark_initialized.assert_called_once() + + @pytest.mark.asyncio + async def test_middleware_chain(self, tool_catalog, mcp_settings): + """Test middleware chain execution.""" + # Create a test middleware + test_middleware_called = False + + class TestMiddleware(Middleware): + async def __call__(self, context, call_next): + nonlocal test_middleware_called + test_middleware_called = True + # Modify context + context.metadata["test"] = "value" + return await call_next(context) + + # Create server with middleware + server = MCPServer( + catalog=tool_catalog, + settings=mcp_settings, + middleware=[TestMiddleware()], + ) + await server.start() + + # Send a message + message = {"jsonrpc": "2.0", "id": 1, "method": "ping"} + + response = await server.handle_message(message) + + # Middleware should have been called + assert test_middleware_called + assert response is not None + + @pytest.mark.asyncio + async def test_error_handling_middleware(self, mcp_server): + """Test that error handling middleware catches exceptions.""" + + # Mock a handler to raise an exception + async def failing_handler(*args, **kwargs): + raise Exception("Test error") + + mcp_server._handlers["test/fail"] = failing_handler + + message = {"jsonrpc": "2.0", "id": 1, "method": "test/fail"} + + response = await mcp_server.handle_message(message) + + assert isinstance(response, JSONRPCError) + assert response.error["code"] == -32603 + # Error details should be masked in production + if mcp_server.settings.middleware.mask_error_details: + assert response.error["message"] == "Internal error" + else: + assert "Test error" in response.error["message"] + + @pytest.mark.asyncio + async def test_session_management(self, mcp_server): + """Test session creation and cleanup.""" + + # Create a mock read stream that waits + async def mock_stream(): + try: + while True: + await asyncio.sleep(1) # Keep the session alive + yield None # Yield nothing + except asyncio.CancelledError: + pass + + mock_read_stream = mock_stream() + mock_write_stream = AsyncMock() + + # Track sessions + initial_sessions = len(mcp_server._sessions) + + # Create a new connection + session_task = asyncio.create_task( + mcp_server.run_connection(mock_read_stream, mock_write_stream) + ) + + # Give it time to register + await asyncio.sleep(0.1) + + # Should have one more session + assert len(mcp_server._sessions) == initial_sessions + 1 + + # Cancel the session + session_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session_task + + # Give it time to clean up + await asyncio.sleep(0.1) + + # Session should be cleaned up + assert len(mcp_server._sessions) == initial_sessions + + @pytest.mark.asyncio + async def test_authorization_check(self, mcp_server): + """Test tool authorization checking.""" + # Create a tool that requires auth + from arcade_core.schema import ToolAuthRequirement + + # Ensure the arcade client is not configured in the case that the test environment + # unintentionally has the ARCADE_API_KEY set + mcp_server.arcade = None + + tool = Mock() + tool.definition.requirements.authorization = ToolAuthRequirement( + provider_type="oauth2", provider_id="test-provider" + ) + + # Without arcade client configured + with pytest.raises(Exception) as exc_info: + await mcp_server._check_authorization(tool) + + assert "Authorization required but Arcade API Key is not configured" in str(exc_info.value) diff --git a/libs/tests/arcade_mcp_server/test_session.py b/libs/tests/arcade_mcp_server/test_session.py new file mode 100644 index 00000000..311f6ecc --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_session.py @@ -0,0 +1,344 @@ +"""Tests for MCP ServerSession implementation.""" + +import json +from unittest.mock import AsyncMock, Mock + +import pytest +from arcade_mcp_server.context import Context +from arcade_mcp_server.session import InitializationState, ServerSession +from arcade_mcp_server.types import ( + ClientCapabilities, + InitializeParams, + JSONRPCResponse, + LoggingLevel, +) + + +class TestServerSession: + """Test ServerSession class.""" + + def test_session_initialization(self, mcp_server, mock_read_stream, mock_write_stream): + """Test session initialization.""" + session = ServerSession( + server=mcp_server, + read_stream=mock_read_stream, + write_stream=mock_write_stream, + init_options={"test": "option"}, + ) + + assert session.server == mcp_server + assert session.read_stream == mock_read_stream + assert session.write_stream == mock_write_stream + assert session.init_options == {"test": "option"} + assert session.initialization_state == InitializationState.NOT_INITIALIZED + assert len(session.session_id) > 0 # Should have generated a session ID + + def test_initialization_state_transitions(self, server_session): + """Test initialization state transitions.""" + # Initial state + assert server_session.initialization_state == InitializationState.NOT_INITIALIZED + + # Set client params (happens during initialize) + server_session.set_client_params({ + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": {"name": "test", "version": "1.0"}, + }) + + assert server_session.initialization_state == InitializationState.INITIALIZING + + # Mark as initialized + server_session.mark_initialized() + + assert server_session.initialization_state == InitializationState.INITIALIZED + + @pytest.mark.asyncio + async def test_message_processing(self, server_session): + """Test processing messages.""" + # Mock server handle_message + server_session.server.handle_message = AsyncMock( + return_value=JSONRPCResponse(jsonrpc="2.0", id=1, result={"status": "ok"}) + ) + + # Process a message + await server_session._process_message('{"jsonrpc":"2.0","id":1,"method":"ping"}') + + # Verify server was called + server_session.server.handle_message.assert_called_once() + + # Verify response was sent + server_session.write_stream.send.assert_called_once() + + @pytest.mark.asyncio + async def test_notification_sending(self, server_session): + """Test sending notifications.""" + # Send a tool list changed notification + await server_session.send_tool_list_changed() + + # Verify notification was sent + server_session.write_stream.send.assert_called_once() + + # Check the sent notification + sent_data = server_session.write_stream.send.call_args[0][0] + sent_json = json.loads(sent_data.strip()) + + assert sent_json["jsonrpc"] == "2.0" + assert sent_json["method"] == "notifications/tools/list_changed" + assert "id" not in sent_json # Notifications don't have IDs + + @pytest.mark.asyncio + async def test_multiple_notifications(self, server_session): + """Test sending multiple notifications.""" + # Send multiple notifications + await server_session.send_tool_list_changed() + await server_session.send_resource_list_changed() + await server_session.send_prompt_list_changed() + + # All notifications should be sent immediately + assert server_session.write_stream.send.call_count == 3 + + # Check notification types + calls = server_session.write_stream.send.call_args_list + methods = [] + for call in calls: + data = json.loads(call[0][0].strip()) + methods.append(data["method"]) + + assert "notifications/tools/list_changed" in methods + assert "notifications/resources/list_changed" in methods + assert "notifications/prompts/list_changed" in methods + + @pytest.mark.asyncio + async def test_log_message_sending(self, server_session): + """Test sending log messages.""" + # Send log messages at different levels + await server_session.send_log_message( + LoggingLevel.INFO, "Test info message", logger="test.logger" + ) + await server_session.send_log_message(LoggingLevel.ERROR, "Test error message") + + # Verify log messages were sent + assert server_session.write_stream.send.call_count == 2 + + # Check first log message + first_call = server_session.write_stream.send.call_args_list[0] + first_data = json.loads(first_call[0][0].strip()) + assert first_data["method"] == "notifications/message" + assert first_data["params"]["level"] == "info" + assert first_data["params"]["data"] == "Test info message" + assert first_data["params"]["logger"] == "test.logger" + + @pytest.mark.asyncio + async def test_progress_notification(self, server_session): + """Test progress notification sending.""" + # Send progress notification + await server_session.send_progress_notification( + progress_token="task-123", progress=50, total=100, message="Processing..." + ) + + # Verify notification was sent + server_session.write_stream.send.assert_called_once() + + # Check progress notification content + sent_data = json.loads(server_session.write_stream.send.call_args[0][0].strip()) + assert sent_data["method"] == "notifications/progress" + assert sent_data["params"]["progressToken"] == "task-123" + assert sent_data["params"]["progress"] == 50 + assert sent_data["params"]["total"] == 100 + assert sent_data["params"]["message"] == "Processing..." + + @pytest.mark.asyncio + async def test_request_context_management(self, server_session): + """Test request context creation and cleanup.""" + # Create context + context = await server_session.create_request_context() + + assert isinstance(context, Context) + assert context._session == server_session + assert server_session._current_context == context + + # Cleanup context + await server_session.cleanup_request_context(context) + + # Context should be cleaned up + assert server_session._current_context is None + + @pytest.mark.asyncio + async def test_server_initiated_request(self, server_session): + """Test server-initiated requests to client.""" + # Test create_message request + messages = [{"role": "user", "content": {"type": "text", "text": "Hello"}}] + + # Mock the request manager response + mock_result = { + "role": "assistant", + "content": {"type": "text", "text": "Generated response"}, + "model": "test-model", + } + server_session._request_manager = Mock() + server_session._request_manager.send_request = AsyncMock(return_value=mock_result) + + # Send sampling request + await server_session.create_message( + messages=messages, max_tokens=100, system_prompt="Be helpful" + ) + + # Verify request was sent + server_session._request_manager.send_request.assert_called_once_with( + "sampling/createMessage", + {"messages": messages, "maxTokens": 100, "systemPrompt": "Be helpful"}, + 60.0, + ) + + @pytest.mark.asyncio + async def test_list_roots_request(self, server_session): + """Test list roots server-initiated request.""" + # Mock request manager + mock_roots = {"roots": [{"uri": "file:///home", "name": "Home"}]} + server_session._request_manager = Mock() + server_session._request_manager.send_request = AsyncMock(return_value=mock_roots) + + # Send list roots request + await server_session.list_roots(timeout=30.0) + + # Verify request was sent correctly + server_session._request_manager.send_request.assert_called_once_with( + "roots/list", None, 30.0 + ) + + @pytest.mark.asyncio + async def test_request_without_manager(self, server_session): + """Test error when sending request without request manager.""" + # Clear request manager + server_session._request_manager = None + + # Should raise SessionError + from arcade_mcp_server.exceptions import SessionError + + with pytest.raises(SessionError, match="Cannot send requests without request manager"): + await server_session.create_message([{"role": "user", "content": "test"}], 100) + + @pytest.mark.asyncio + async def test_session_run_loop(self, mcp_server, mock_read_stream, mock_write_stream): + """Test the main session run loop.""" + # Create session + session = ServerSession( + server=mcp_server, + read_stream=mock_read_stream, + write_stream=mock_write_stream, + ) + + # Mock server message handling + mcp_server.handle_message = AsyncMock( + return_value=JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + ) + + # Mock messages to read + messages = [ + '{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}}', + '{"jsonrpc": "2.0", "method": "notifications/initialized"}', + '{"jsonrpc": "2.0", "id": 2, "method": "ping"}', + ] + + # Simple approach: make read_stream an async generator directly + async def async_messages(): + for msg in messages: + yield msg + + # Replace the read_stream with our generator + session.read_stream = async_messages() + + # Run session (it will complete when messages are exhausted) + await session.run() + + # Verify messages were processed + assert mcp_server.handle_message.call_count == 3 + assert session.write_stream.send.call_count >= 2 # At least 2 responses + + @pytest.mark.asyncio + async def test_session_error_handling(self, server_session): + """Test error handling in session.""" + # Mock server to raise error + server_session.server.handle_message = AsyncMock(side_effect=Exception("Test error")) + + # Process message - should handle error gracefully + await server_session._process_message('{"jsonrpc": "2.0", "id": 1, "method": "test"}') + + # Error response should be sent + server_session.write_stream.send.assert_called() + + sent_data = server_session.write_stream.send.call_args[0][0] + sent_json = json.loads(sent_data.strip()) + + assert "error" in sent_json + assert sent_json["error"]["code"] == -32603 + assert "Test error" in sent_json["error"]["message"] + + @pytest.mark.asyncio + async def test_client_capability_checking(self, server_session): + """Test client capability checking.""" + # Set client params with specific capabilities + client_params = InitializeParams( + protocolVersion="2024-11-05", + capabilities=ClientCapabilities(tools={"listChanged": True}, sampling={}), + clientInfo={"name": "test-client", "version": "1.0"}, + ) + + server_session.set_client_params(client_params) + + # Check capabilities - client has tools and sampling + # An empty capability requirement should pass + assert server_session.check_client_capability(ClientCapabilities()) + + # Checking for tools capability should pass (client has it) + assert server_session.check_client_capability(ClientCapabilities(tools={})) + + # Checking for sampling capability should pass (client has it) + assert server_session.check_client_capability(ClientCapabilities(sampling={})) + + # Now test with a client that has no capabilities + no_cap_params = InitializeParams( + protocolVersion="2024-11-05", + capabilities=ClientCapabilities(), + clientInfo={"name": "test-client", "version": "1.0"}, + ) + + server_session.set_client_params(no_cap_params) + + # Empty capability check should still pass + assert server_session.check_client_capability(ClientCapabilities()) + + # Checking for capabilities when client has none + # Since the capability requirements have empty dicts {}, they are considered + # as "having the capability" but with no specific requirements + # So these should actually pass + assert server_session.check_client_capability(ClientCapabilities(tools={})) + assert server_session.check_client_capability(ClientCapabilities(sampling={})) + + @pytest.mark.asyncio + async def test_parse_error_handling(self, server_session): + """Test handling of JSON parse errors.""" + # Send invalid JSON + await server_session._process_message("invalid json {") + + # Error response should be sent + server_session.write_stream.send.assert_called_once() + + sent_data = json.loads(server_session.write_stream.send.call_args[0][0].strip()) + assert "error" in sent_data + assert sent_data["error"]["code"] == -32700 # Parse error + assert sent_data["id"] == "null" + + def test_client_info_extraction(self, server_session): + """Test extracting client information.""" + client_params = { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {"listChanged": True}, "sampling": {}}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + } + + server_session.set_client_params(client_params) + + assert server_session.client_params == client_params + assert server_session.client_params["clientInfo"]["name"] == "test-client" + assert server_session.initialization_state == InitializationState.INITIALIZING diff --git a/libs/tests/arcade_mcp_server/test_session_cancellation.py b/libs/tests/arcade_mcp_server/test_session_cancellation.py new file mode 100644 index 00000000..4ecca244 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_session_cancellation.py @@ -0,0 +1,130 @@ +"""Tests for ServerSession/RequestManager cancellation behavior.""" + +import asyncio +import json +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from arcade_mcp_server.session import ServerSession + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method,params", + [ + ("sampling/createMessage", {"messages": [], "maxTokens": 1}), + ("roots/list", None), + ( + "completion/complete", + {"ref": {"type": "ref/prompt", "name": "x"}, "argument": {"name": "q", "value": ""}}, + ), + ], +) +async def test_cancel_all_sends_notifications_and_fails_futures( + mcp_server, mock_read_stream, mock_write_stream, method, params +): + session = ServerSession( + server=mcp_server, read_stream=mock_read_stream, write_stream=mock_write_stream + ) + assert session._request_manager is not None + + mock_write_stream.send = AsyncMock() + + pending_task = asyncio.create_task( + session._request_manager.send_request(method, params, timeout=5.0) + ) + await asyncio.sleep(0) + + await session._cleanup_pending_requests() + + from arcade_mcp_server.exceptions import SessionError + + with pytest.raises(SessionError): + await pending_task + + # Verify a cancelled notification was sent + assert mock_write_stream.send.call_count >= 1 + sent_methods = [ + json.loads(call[0][0].strip()).get("method") + for call in mock_write_stream.send.call_args_list + ] + assert "notifications/cancelled" in sent_methods + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method,params", + [ + ("roots/list", None), + ("sampling/createMessage", {"messages": [], "maxTokens": 1}), + ], +) +async def test_closed_flag_drops_late_responses( + mcp_server, mock_read_stream, mock_write_stream, method, params +): + session = ServerSession( + server=mcp_server, read_stream=mock_read_stream, write_stream=mock_write_stream + ) + assert session._request_manager is not None + + mock_write_stream.send = AsyncMock() + + send_task = asyncio.create_task( + session._request_manager.send_request(method, params, timeout=5.0) + ) + await asyncio.sleep(0) + + await session._cleanup_pending_requests() + + # Simulate a late response from client; should be dropped silently + late_response: dict[str, Any] = {"jsonrpc": "2.0", "id": "unknown", "result": {"ok": True}} + await session._request_manager.handle_response(late_response) + + from arcade_mcp_server.exceptions import SessionError + + with pytest.raises(SessionError): + await send_task + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method,params", + [ + ( + "completion/complete", + {"ref": {"type": "ref/prompt", "name": "x"}, "argument": {"name": "q", "value": "v"}}, + ), + ("sampling/createMessage", {"messages": [], "maxTokens": 1}), + ], +) +async def test_new_requests_rejected_after_close( + mcp_server, mock_read_stream, mock_write_stream, method, params +): + session = ServerSession( + server=mcp_server, read_stream=mock_read_stream, write_stream=mock_write_stream + ) + assert session._request_manager is not None + + await session._cleanup_pending_requests() + + from arcade_mcp_server.exceptions import SessionError + + with pytest.raises(SessionError): + await session._request_manager.send_request(method, params, timeout=0.1) + + +@pytest.mark.asyncio +async def test_cleanup_is_idempotent(mcp_server, mock_read_stream, mock_write_stream): + session = ServerSession( + server=mcp_server, read_stream=mock_read_stream, write_stream=mock_write_stream + ) + assert session._request_manager is not None + + await session._cleanup_pending_requests() + # Calling again should not raise and should not send extra notifications + before = getattr(mock_write_stream.send, "call_count", 0) + await session._cleanup_pending_requests() + after = getattr(mock_write_stream.send, "call_count", 0) + # Allow zero or unchanged; do not enforce increase + assert after >= before diff --git a/libs/tests/arcade_mcp_server/test_tool.py b/libs/tests/arcade_mcp_server/test_tool.py new file mode 100644 index 00000000..7534ea24 --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_tool.py @@ -0,0 +1,84 @@ +"""Tests for Tool Manager implementation.""" + +import pytest +import pytest_asyncio +from arcade_mcp_server.exceptions import NotFoundError +from arcade_mcp_server.managers.tool import ToolManager +from arcade_mcp_server.types import MCPTool + + +class TestToolManager: + """Test ToolManager class.""" + + @pytest_asyncio.fixture + async def tool_manager(self, materialized_tool): + """Create a tool manager instance with one tool added.""" + manager = ToolManager() + await manager.add_tool(materialized_tool) + return manager + + def test_manager_initialization(self): + """Test tool manager initialization.""" + manager = ToolManager() + assert isinstance(manager, ToolManager) + + @pytest.mark.asyncio + async def test_list_tools(self, tool_manager): + """Test listing tools.""" + tools = await tool_manager.list_tools() + + assert isinstance(tools, list) + assert all(isinstance(t, MCPTool) for t in tools) + + if tools: + tool = tools[0] + assert hasattr(tool, "name") + assert hasattr(tool, "description") + assert hasattr(tool, "inputSchema") + + @pytest.mark.asyncio + async def test_get_tool(self, tool_manager, materialized_tool): + """Test getting a specific tool.""" + # Get tool by name + tool_name = materialized_tool.definition.fully_qualified_name + tool = await tool_manager.get_tool(tool_name) + assert tool.definition.fully_qualified_name == tool_name + + # Try to get non-existent tool + with pytest.raises(NotFoundError): + await tool_manager.get_tool("NonExistent_tool") + + @pytest.mark.asyncio + async def test_remove_tool(self, tool_manager, materialized_tool): + """Test removing tools.""" + name = materialized_tool.definition.fully_qualified_name + _ = await tool_manager.get_tool(name) + + removed = await tool_manager.remove_tool(name) + assert removed.definition.fully_qualified_name == name + + with pytest.raises(NotFoundError): + await tool_manager.get_tool(name) + + @pytest.mark.asyncio + async def test_remove_nonexistent_tool(self, tool_manager): + """Test removing non-existent tool.""" + with pytest.raises(NotFoundError): + await tool_manager.remove_tool("NonExistent_tool") + + @pytest.mark.asyncio + async def test_tool_conversion(self, tool_manager): + """Test conversion of MaterializedTool to MCP Tool format.""" + tools = await tool_manager.list_tools() + if not tools: + pytest.skip("No tools in manager to validate conversion") + tool = tools[0] + + # Check required fields + assert isinstance(tool.name, str) + assert isinstance(tool.description, str) or tool.description is None + assert "inputSchema" in tool.model_dump() + + schema = tool.inputSchema + assert schema["type"] == "object" + assert "properties" in schema diff --git a/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py b/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py new file mode 100644 index 00000000..61f5bc81 --- /dev/null +++ b/libs/tests/arcade_mcp_server/transports/test_http_session_manager.py @@ -0,0 +1,193 @@ +from http import HTTPStatus +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from arcade_mcp_server.transports.http_session_manager import ( + MCP_SESSION_ID_HEADER, + HTTPSessionManager, +) +from arcade_mcp_server.transports.http_streamable import HTTPStreamableTransport + + +class TestHTTPSessionManager: + """Test HTTPSessionManager initialization and lifecycle.""" + + @pytest.mark.asyncio + async def test_run_cannot_be_reused(self, mcp_server): + """Test that run() can only be called once per instance.""" + manager = HTTPSessionManager(server=mcp_server) + + # First call should work + async with manager.run(): + pass + + # Second call should raise error + with pytest.raises(RuntimeError): + async with manager.run(): + pass + + @pytest.mark.asyncio + async def test_handle_request_without_run_raises_error(self, mcp_server): + """Test that handle_request raises error if run() not called.""" + manager = HTTPSessionManager(server=mcp_server) + + scope = {"type": "http", "method": "POST"} + receive = AsyncMock() + send = AsyncMock() + + with pytest.raises(RuntimeError): + await manager.handle_request(scope, receive, send) + + @pytest.mark.asyncio + async def test_stateless_mode_routing(self, mcp_server): + """Test that stateless mode routes to _handle_stateless_request.""" + manager = HTTPSessionManager(server=mcp_server, stateless=True) + + scope = {"type": "http", "method": "POST"} + receive = AsyncMock() + send = AsyncMock() + + with patch.object(manager, "_handle_stateless_request") as mock_stateless: + async with manager.run(): + await manager.handle_request(scope, receive, send) + + mock_stateless.assert_called_once_with(scope, receive, send) + + @pytest.mark.asyncio + async def test_stateful_mode_routing(self, mcp_server): + """Test that stateful mode routes to _handle_stateful_request.""" + manager = HTTPSessionManager(server=mcp_server, stateless=False) + + scope = {"type": "http", "method": "POST"} + receive = AsyncMock() + send = AsyncMock() + + with patch.object(manager, "_handle_stateful_request") as mock_stateful: + async with manager.run(): + await manager.handle_request(scope, receive, send) + + mock_stateful.assert_called_once_with(scope, receive, send) + + +# class TestHTTPSessionManagerStateless: +# """Test stateless request handling.""" + +# @pytest.mark.asyncio +# async def test_stateless_creates_new_transport(self, mcp_server): +# """Test that stateless mode creates a new transport for each request.""" +# manager = HTTPSessionManager(server=mcp_server, stateless=True, json_response=True) + +# scope = {"type": "http", "method": "POST"} +# receive = AsyncMock() +# send = AsyncMock() + +# with patch( +# "arcade_mcp_server.transports.http_session_manager.HTTPStreamableTransport" +# ) as mock_transport_class: +# mock_transport = AsyncMock() +# mock_transport_class.return_value = mock_transport +# mock_transport.connect.return_value.__aenter__ = AsyncMock( +# return_value=(AsyncMock(), AsyncMock()) +# ) +# mock_transport.connect.return_value.__aexit__ = AsyncMock(return_value=None) + +# async with manager.run(): +# await manager.handle_request(scope, receive, send) + +# # Verify transport was created with correct parameters +# mock_transport_class.assert_called_once_with( +# mcp_session_id=None, +# is_json_response_enabled=True, +# event_store=None, +# ) + +# # Verify transport methods were called +# mock_transport.handle_request.assert_called_once_with(scope, receive, send) +# mock_transport.terminate.assert_called_once() + +# @pytest.mark.asyncio +# async def test_stateless_handles_transport_errors(self, mcp_server): +# """Test that stateless mode handles transport creation errors gracefully.""" +# manager = HTTPSessionManager(server=mcp_server, stateless=True) + +# scope = {"type": "http", "method": "POST"} +# receive = AsyncMock() +# send = AsyncMock() + +# with patch( +# "arcade_mcp_server.transports.http_session_manager.HTTPStreamableTransport" +# ) as mock_transport_class: +# mock_transport_class.side_effect = Exception("Transport creation failed") + +# async with manager.run(): +# # Should not raise exception, error handling is internal +# with pytest.raises(Exception, match="Transport creation failed"): +# await manager.handle_request(scope, receive, send) + + +class TestHTTPSessionManagerStateful: + """Test stateful request handling.""" + + @pytest.mark.asyncio + async def test_existing_session_routing(self, mcp_server): + """Test routing to existing session when session ID provided.""" + manager = HTTPSessionManager(server=mcp_server, stateless=False) + + # Pre-populate with an existing transport + existing_transport = AsyncMock() + existing_session_id = "existing-session-456" + manager._server_instances[existing_session_id] = existing_transport + + scope = { + "type": "http", + "method": "POST", + "headers": [(MCP_SESSION_ID_HEADER.lower().encode(), existing_session_id.encode())], + } + receive = AsyncMock() + send = AsyncMock() + + async with manager.run(): + await manager.handle_request(scope, receive, send) + + # Verify existing transport was used + existing_transport.handle_request.assert_called_once_with(scope, receive, send) + + # Verify no new transport was created + assert len(manager._server_instances) == 1 + + @pytest.mark.asyncio + async def test_invalid_session_id_error(self, mcp_server): + """Test error response for invalid session ID.""" + manager = HTTPSessionManager(server=mcp_server, stateless=False) + + scope = { + "type": "http", + "method": "POST", + "headers": [(MCP_SESSION_ID_HEADER.lower().encode(), b"invalid-session-id")], + } + receive = AsyncMock() + send = AsyncMock() + + async with manager.run(): + await manager.handle_request(scope, receive, send) + + # Verify error response was sent + send.assert_called() + # Check that a response was sent (the Response.__call__ method) + call_args = send.call_args_list + assert len(call_args) > 0 + + @pytest.mark.asyncio + async def test_session_cleanup_on_manager_shutdown(self, mcp_server): + """Test that sessions are cleaned up when manager shuts down.""" + manager = HTTPSessionManager(server=mcp_server, stateless=False) + + # Add some mock sessions + manager._server_instances["session-1"] = AsyncMock() + manager._server_instances["session-2"] = AsyncMock() + + async with manager.run(): + assert len(manager._server_instances) == 2 + + # After context exit, sessions should be cleared + assert len(manager._server_instances) == 0 diff --git a/libs/tests/arcade_mcp_server/transports/test_http_streamable.py b/libs/tests/arcade_mcp_server/transports/test_http_streamable.py new file mode 100644 index 00000000..37d8ca8a --- /dev/null +++ b/libs/tests/arcade_mcp_server/transports/test_http_streamable.py @@ -0,0 +1,322 @@ +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from arcade_mcp_server.transports.http_streamable import ( + MCP_SESSION_ID_HEADER, + HTTPStreamableTransport, +) + + +class TestHTTPStreamableTransport: + """Test HTTPStreamableTransport request handling.""" + + @pytest.mark.asyncio + async def test_handle_request_method_routing(self, mcp_server): + """Test that handle_request routes to correct method handlers.""" + transport = HTTPStreamableTransport( + mcp_session_id="test-session", is_json_response_enabled=True + ) + + # Test POST routing + scope = {"type": "http", "method": "POST"} + receive = AsyncMock() + send = AsyncMock() + + with patch.object(transport, "_handle_post_request") as mock_post: + await transport.handle_request(scope, receive, send) + mock_post.assert_called_once() + + # Test GET routing + scope = {"type": "http", "method": "GET"} + with patch.object(transport, "_handle_get_request") as mock_get: + await transport.handle_request(scope, receive, send) + mock_get.assert_called_once() + + # Test DELETE routing + scope = {"type": "http", "method": "DELETE"} + with patch.object(transport, "_handle_delete_request") as mock_delete: + await transport.handle_request(scope, receive, send) + mock_delete.assert_called_once() + + # Test unsupported method + scope = {"type": "http", "method": "PUT"} + with patch.object(transport, "_handle_unsupported_request") as mock_unsupported: + await transport.handle_request(scope, receive, send) + mock_unsupported.assert_called_once() + + +class TestHTTPStreamableTransportPost: + """Test POST request handling.""" + + # @pytest.mark.asyncio + # async def test_handle_post_request_valid_json_mode(self, mcp_server): + # """Test successful POST request handling in JSON response mode.""" + # transport = HTTPStreamableTransport( + # mcp_session_id="test-session", is_json_response_enabled=True + # ) + + # # Mock the read stream writer + # mock_writer = AsyncMock() + # transport._read_stream_writer = mock_writer + + # # Create valid JSON-RPC request + # json_request = {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}} + # body = json.dumps(json_request).encode() + + # scope = { + # "type": "http", + # "method": "POST", + # "headers": [ + # (b"content-type", b"application/json"), + # (b"accept", b"application/json"), + # (MCP_SESSION_ID_HEADER.lower().encode(), b"test-session"), + # ], + # } + + # # Mock request body + # receive = AsyncMock() + # receive.return_value = {"type": "http.request", "body": body} + # send = AsyncMock() + + # # Mock request streams for JSON mode + # transport._request_streams = {} + + # # Create a mock stream that will return a response + # mock_stream_writer, mock_stream_reader = AsyncMock(), AsyncMock() + + # # Mock the response message + # from arcade_mcp_server.types import JSONRPCResponse + + # response_message = JSONRPCResponse(jsonrpc="2.0", id=1, result={"tools": []}) + + # # Mock EventMessage + # from arcade_mcp_server.transports.http_streamable import EventMessage + + # event_msg = EventMessage(message=response_message) + + # # Configure the stream reader to return our response + # mock_stream_reader.__aiter__ = AsyncMock(return_value=iter([event_msg])) + + # with patch("anyio.create_memory_object_stream") as mock_create_stream: + # mock_create_stream.return_value = (mock_stream_writer, mock_stream_reader) + + # # Mock the Request object + # with patch("starlette.requests.Request") as mock_request_class: + # mock_request = MagicMock() + # mock_request.body = AsyncMock(return_value=body) + # mock_request_class.return_value = mock_request + + # await transport._handle_post_request(scope, mock_request, receive, send) + + # # Verify message was sent to read stream + # mock_writer.send.assert_called_once() + + # # Verify HTTP response was sent + # send.assert_called() + + @pytest.mark.asyncio + async def test_handle_post_request_invalid_json(self, mcp_server): + """Test POST request handling with invalid JSON.""" + transport = HTTPStreamableTransport( + mcp_session_id="test-session", is_json_response_enabled=True + ) + + # Mock the read stream writer + mock_writer = AsyncMock() + transport._read_stream_writer = mock_writer + + # Invalid JSON body + body = b'{"invalid": json}' + + scope = { + "type": "http", + "method": "POST", + "headers": [ + (b"content-type", b"application/json"), + (b"accept", b"application/json"), # This should pass Accept header check + (MCP_SESSION_ID_HEADER.lower().encode(), b"test-session"), + ], + } + + receive = AsyncMock() + receive.return_value = {"type": "http.request", "body": body} + send = AsyncMock() + + # Mock the Request object properly + with patch("starlette.requests.Request") as mock_request_class: + mock_request = MagicMock() + mock_request.body = AsyncMock(return_value=body) + # Mock headers.get method properly + mock_request.headers.get.side_effect = lambda key, default="": { + "accept": "application/json", + "content-type": "application/json", + MCP_SESSION_ID_HEADER: "test-session", + }.get(key, default) + mock_request_class.return_value = mock_request + + await transport._handle_post_request(scope, mock_request, receive, send) + + # Verify error response was sent + send.assert_called() + + # Check the ASGI response message for status code 400 (Bad Request) + response_calls = send.call_args_list + assert len(response_calls) > 0 + + # Find the HTTP response start message + for call in response_calls: + message = call[0][0] # First argument of call + if message.get("type") == "http.response.start": + assert message["status"] == 400 + break + else: + pytest.fail("No http.response.start message found") + + +class TestHTTPStreamableTransportGet: + """Test GET request handling.""" + + # @pytest.mark.asyncio + # async def test_handle_get_request_valid_sse(self, mcp_server): + # """Test successful GET request for SSE stream.""" + # transport = HTTPStreamableTransport( + # mcp_session_id="test-session", is_json_response_enabled=False + # ) + + # # Mock the read stream writer + # mock_writer = AsyncMock() + # transport._read_stream_writer = mock_writer + + # # Mock request with SSE accept header + # mock_request = MagicMock() + # mock_request.headers = { + # "accept": "text/event-stream", + # MCP_SESSION_ID_HEADER: "test-session", + # } + # mock_request.headers.get = lambda key, default=None: { + # "accept": "text/event-stream", + # MCP_SESSION_ID_HEADER: "test-session", + # }.get(key, default) + + # send = AsyncMock() + + # # Mock validation methods + # with patch.object(transport, "_validate_request_headers", return_value=True): + # with patch("anyio.create_memory_object_stream") as mock_create_stream: + # mock_stream_writer, mock_stream_reader = AsyncMock(), AsyncMock() + # mock_create_stream.return_value = (mock_stream_writer, mock_stream_reader) + + # # Mock EventSourceResponse + # with patch("sse_starlette.EventSourceResponse") as mock_sse_response: + # mock_response = AsyncMock() + # mock_sse_response.return_value = mock_response + + # await transport._handle_get_request(mock_request, send) + + # # Verify SSE response was created and called + # mock_sse_response.assert_called_once() + # mock_response.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_get_request_invalid_accept_header(self, mcp_server): + """Test GET request with invalid Accept header.""" + transport = HTTPStreamableTransport( + mcp_session_id="test-session", is_json_response_enabled=False + ) + + # Mock the read stream writer (required by _handle_get_request) + mock_writer = AsyncMock() + transport._read_stream_writer = mock_writer + + # Mock request without SSE accept header + mock_request = MagicMock() + # Mock headers.get method properly instead of overriding the dict + mock_request.headers.get.side_effect = lambda key, default="": { + "accept": "application/json", # Wrong accept header for SSE + MCP_SESSION_ID_HEADER: "test-session", + }.get(key, default) + mock_request.scope = {"type": "http", "method": "GET"} + mock_request.receive = AsyncMock() + + send = AsyncMock() + + await transport._handle_get_request(mock_request, send) + + # Verify error response was sent + send.assert_called() + + # Check the ASGI response message for status code 406 (Not Acceptable) + response_calls = send.call_args_list + assert len(response_calls) > 0 + + # Find the HTTP response start message + for call in response_calls: + message = call[0][0] # First argument of call + if message.get("type") == "http.response.start": + assert message["status"] == 406 + break + else: + pytest.fail("No http.response.start message found") + + +class TestHTTPStreamableTransportDelete: + """Test DELETE request handling.""" + + # @pytest.mark.asyncio + # async def test_handle_delete_request_valid_session(self, mcp_server): + # """Test successful DELETE request for session termination.""" + # transport = HTTPStreamableTransport( + # mcp_session_id="test-session", is_json_response_enabled=True + # ) + + # # Mock request with valid session ID + # mock_request = MagicMock() + # mock_request.headers = {MCP_SESSION_ID_HEADER: "test-session"} + # mock_request.headers.get = lambda key, default=None: { + # MCP_SESSION_ID_HEADER: "test-session" + # }.get(key, default) + # mock_request.scope = {"type": "http", "method": "DELETE"} + # mock_request.receive = AsyncMock() + + # send = AsyncMock() + + # # Mock validation and termination + # with patch.object(transport, "_validate_request_headers", return_value=True): + # with patch.object(transport, "terminate") as mock_terminate: + # await transport._handle_delete_request(mock_request, send) + + # # Verify termination was called + # mock_terminate.assert_called_once() + + # # Verify success response was sent + # send.assert_called() + + @pytest.mark.asyncio + async def test_handle_delete_request_no_session_id(self, mcp_server): + """Test DELETE request without session ID support (should not be allowed).""" + transport = HTTPStreamableTransport( + mcp_session_id=None, # No session ID + is_json_response_enabled=True, + ) + + mock_request = MagicMock() + mock_request.scope = {"type": "http", "method": "DELETE"} + mock_request.receive = AsyncMock() + + send = AsyncMock() + + await transport._handle_delete_request(mock_request, send) + + # Verify error response was sent + send.assert_called() + + # Check the ASGI response message for status code + response_calls = send.call_args_list + assert len(response_calls) > 0 + + # The first call should be the HTTP response start with status 405 + first_call = response_calls[0] + message = first_call[0][0] # First argument of first call + if message.get("type") == "http.response.start": + assert message["status"] == 405 diff --git a/libs/tests/arcade_mcp_server/transports/test_stdio.py b/libs/tests/arcade_mcp_server/transports/test_stdio.py new file mode 100644 index 00000000..fcca63b3 --- /dev/null +++ b/libs/tests/arcade_mcp_server/transports/test_stdio.py @@ -0,0 +1,232 @@ +import asyncio +import queue +from unittest.mock import MagicMock, patch + +import pytest +from arcade_mcp_server.exceptions import TransportError +from arcade_mcp_server.session import ServerSession +from arcade_mcp_server.transports.stdio import ( + StdioReadStream, + StdioTransport, + StdioWriteStream, +) + + +class TestStdioWriteStream: + """Test StdioWriteStream functionality.""" + + @pytest.mark.asyncio + async def test_send_adds_newline(self): + """Test that send adds newline to data without one.""" + write_queue = queue.Queue() + stream = StdioWriteStream(write_queue) + + await stream.send("test message") + + # Check that newline was added + assert write_queue.get() == "test message\n" + + @pytest.mark.asyncio + async def test_send_preserves_existing_newline(self): + """Test that send doesn't add extra newline if one exists.""" + write_queue = queue.Queue() + stream = StdioWriteStream(write_queue) + + await stream.send("test message\n") + + # Check that no extra newline was added + assert write_queue.get() == "test message\n" + + +class TestStdioReadStream: + """Test StdioReadStream functionality.""" + + @pytest.mark.asyncio + async def test_read_stream_iteration(self): + """Test basic iteration over read stream.""" + read_queue = queue.Queue() + stream = StdioReadStream(read_queue) + + # Put test data in queue + read_queue.put("line1") + read_queue.put("line2") + read_queue.put(None) # EOF marker + + lines = [] + async for line in stream: + lines.append(line) + + assert lines == ["line1", "line2"] + + @pytest.mark.asyncio + async def test_read_stream_stop(self): + """Test that stopping the stream raises StopAsyncIteration.""" + read_queue = queue.Queue() + stream = StdioReadStream(read_queue) + + stream.stop() + + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + + @pytest.mark.asyncio + async def test_read_stream_none_stops_iteration(self): + """Test that None in queue stops iteration.""" + read_queue = queue.Queue() + stream = StdioReadStream(read_queue) + + read_queue.put(None) + + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + + +class TestStdioTransport: + """Test StdioTransport functionality.""" + + @pytest.mark.asyncio + async def test_transport_initialization(self): + """Test transport initializes with correct defaults.""" + transport = StdioTransport() + + assert transport.name == "stdio" + assert isinstance(transport.read_queue, queue.Queue) + assert isinstance(transport.write_queue, queue.Queue) + assert transport.reader_thread is None + assert transport.writer_thread is None + assert not transport._running + assert transport._sessions == {} + + @pytest.mark.asyncio + async def test_transport_custom_name(self): + """Test transport can be initialized with custom name.""" + transport = StdioTransport(name="custom-stdio") + assert transport.name == "custom-stdio" + + @pytest.mark.asyncio + async def test_start_creates_threads(self): + """Test that start() creates and starts I/O threads.""" + transport = StdioTransport() + + with patch("threading.Thread") as mock_thread: + mock_thread_instance = MagicMock() + mock_thread.return_value = mock_thread_instance + + await transport.start() + + # Should create two threads (reader and writer) + assert mock_thread.call_count == 2 + # Both threads should be started + assert mock_thread_instance.start.call_count == 2 + assert transport._running is True + + @pytest.mark.asyncio + async def test_stop_sets_running_false(self): + """Test that stop() sets _running to False and signals threads.""" + transport = StdioTransport() + transport._running = True + + # Mock threads + mock_reader = MagicMock() + mock_writer = MagicMock() + mock_reader.is_alive.return_value = False + mock_writer.is_alive.return_value = False + + transport.reader_thread = mock_reader + transport.writer_thread = mock_writer + + await transport.stop() + + assert transport._running is False + assert transport._shutdown_event.is_set() + + @pytest.mark.asyncio + async def test_list_sessions_empty(self): + """Test list_sessions returns empty list initially.""" + transport = StdioTransport() + sessions = await transport.list_sessions() + assert sessions == [] + + @pytest.mark.asyncio + async def test_register_session(self): + """Test session registration.""" + transport = StdioTransport() + mock_session = MagicMock(spec=ServerSession) + mock_session.session_id = "test-session" + + await transport.register_session(mock_session) + + sessions = await transport.list_sessions() + assert sessions == ["test-session"] + assert transport._sessions["test-session"] == mock_session + + @pytest.mark.asyncio + async def test_unregister_session(self): + """Test session unregistration.""" + transport = StdioTransport() + mock_session = MagicMock(spec=ServerSession) + mock_session.session_id = "test-session" + + # Register then unregister + await transport.register_session(mock_session) + await transport.unregister_session("test-session") + + sessions = await transport.list_sessions() + assert sessions == [] + assert "test-session" not in transport._sessions + + @pytest.mark.asyncio + async def test_connect_session_single_session_limit(self): + """Test that stdio transport only allows one session.""" + transport = StdioTransport() + + # Mock existing session + mock_session = MagicMock(spec=ServerSession) + mock_session.session_id = "existing-session" + transport._sessions["existing-session"] = mock_session + + # Try to connect another session + with pytest.raises(TransportError, match="Stdio transport only supports one session"): + async with transport.connect_session(): + pass + + @pytest.mark.asyncio + async def test_connect_session_creates_session(self): + """Test that connect_session creates a proper session.""" + transport = StdioTransport() + + # Mock UUID generation + with patch("uuid.uuid4") as mock_uuid: + mock_uuid.return_value.return_value = "test-uuid" + mock_uuid.return_value.__str__.return_value = "test-uuid" + + async with transport.connect_session() as session: + assert isinstance(session, ServerSession) + assert session.session_id == "test-uuid" + assert session.stateless is True + + # Check session was registered + sessions = await transport.list_sessions() + assert "test-uuid" in sessions + + # Check session was unregistered after context exit + sessions = await transport.list_sessions() + assert "test-uuid" not in sessions + + @pytest.mark.asyncio + async def test_wait_for_shutdown(self): + """Test wait_for_shutdown waits for shutdown event.""" + transport = StdioTransport() + + # Start a task that will set the shutdown event after a delay + async def set_shutdown(): + await asyncio.sleep(0.01) + transport._shutdown_event.set() + + task = asyncio.create_task(set_shutdown()) + + # This should complete when the event is set + await transport.wait_for_shutdown() + + assert transport._shutdown_event.is_set() + await task # Clean up the task diff --git a/libs/tests/cli/test_secret.py b/libs/tests/cli/test_secret.py new file mode 100644 index 00000000..935b2510 --- /dev/null +++ b/libs/tests/cli/test_secret.py @@ -0,0 +1,251 @@ +import tempfile +from io import StringIO +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from arcade_cli.secret import ( + _delete_secret_from_engine, + _get_secrets_from_engine, + _remove_inline_comment, + _upsert_secret_to_engine, + load_env_file, + print_secret_table, +) + + +class TestPrintSecretTable: + """Tests for print_secret_table function.""" + + def test_print_secret_table_empty(self, capsys): + """Test printing a table with no secrets.""" + secrets = [] + print_secret_table(secrets) + + captured = capsys.readouterr() + assert "Tool Secrets" in captured.out + + +class TestLoadEnvFile: + """Tests for load_env_file function.""" + + def test_load_env_file_basic(self): + """Test loading a basic .env file.""" + env_content = """ +KEY1=value1 +KEY2=value2 +# This is a comment +KEY3=value3 +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f: + f.write(env_content) + f.flush() + + secrets = load_env_file(f.name) + + assert secrets == { + "KEY1": "value1", + "KEY2": "value2", + "KEY3": "value3", + } + + def test_load_env_file_with_quotes(self): + """Test loading .env file with quoted values.""" + env_content = """ +KEY1="quoted value" +KEY2='single quoted' +KEY3="value with = sign" +KEY4="value with # comment inside" +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f: + f.write(env_content) + f.flush() + + secrets = load_env_file(f.name) + + assert secrets == { + "KEY1": "quoted value", + "KEY2": "single quoted", + "KEY3": "value with = sign", + "KEY4": "value with # comment inside", + } + + def test_load_env_file_with_inline_comments(self): + """Test loading .env file with inline comments.""" + env_content = """ +KEY1=value1 # inline comment +KEY2="quoted value" # comment after quote +KEY3=value3# no space before comment +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f: + f.write(env_content) + f.flush() + + secrets = load_env_file(f.name) + + assert secrets == { + "KEY1": "value1", + "KEY2": "quoted value", + "KEY3": "value3# no space before comment", # No space, so not treated as comment + } + + def test_load_env_file_skip_empty_and_invalid(self): + """Test that empty lines, comments, and invalid entries are skipped.""" + env_content = """ +# Comment line +KEY1=value1 + +KEY2= +=value_without_key +KEY3=value3 +invalid_line_without_equals +KEY4=value4 +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f: + f.write(env_content) + f.flush() + + secrets = load_env_file(f.name) + + assert secrets == { + "KEY1": "value1", + "KEY3": "value3", + "KEY4": "value4", + } + + +class TestRemoveInlineComment: + """Tests for _remove_inline_comment function.""" + + def test_remove_inline_comment_unquoted(self): + """Test removing inline comments from unquoted values.""" + assert _remove_inline_comment("value # comment") == "value" + assert _remove_inline_comment("value# no space") == "value# no space" + assert _remove_inline_comment("value") == "value" + assert _remove_inline_comment("value with spaces # comment") == "value with spaces" + + def test_remove_inline_comment_double_quoted(self): + """Test removing inline comments from double-quoted values.""" + assert _remove_inline_comment('"quoted value" # comment') == "quoted value" + assert _remove_inline_comment('"value with # inside"') == "value with # inside" + assert _remove_inline_comment('"quoted value"') == "quoted value" + assert _remove_inline_comment('"unclosed quote') == '"unclosed quote' + + def test_remove_inline_comment_single_quoted(self): + """Test removing inline comments from single-quoted values.""" + assert _remove_inline_comment("'quoted value' # comment") == "quoted value" + assert _remove_inline_comment("'value with # inside'") == "value with # inside" + assert _remove_inline_comment("'quoted value'") == "quoted value" + assert _remove_inline_comment("'unclosed quote") == "'unclosed quote" + + def test_remove_inline_comment_edge_cases(self): + """Test edge cases for inline comment removal.""" + assert _remove_inline_comment("") == "" + + +class TestUpsertSecretToEngine: + """Tests for _upsert_secret_to_engine function.""" + + @patch("arcade_cli.secret.httpx.put") + def test_upsert_secret_success(self, mock_put): + """Test successful secret upsert.""" + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_put.return_value = mock_response + + _upsert_secret_to_engine( + "https://api.example.com", "test-api-key", "SECRET_KEY", "secret-value" + ) + + mock_put.assert_called_once_with( + "https://api.example.com/v1/admin/secrets/SECRET_KEY", + headers={"Authorization": "Bearer test-api-key"}, + json={"description": "Secret set via CLI", "value": "secret-value"}, + ) + mock_response.raise_for_status.assert_called_once() + + @patch("arcade_cli.secret.httpx.put") + def test_upsert_secret_http_error(self, mock_put): + """Test secret upsert with HTTP error.""" + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Bad Request", request=MagicMock(), response=MagicMock() + ) + mock_put.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + _upsert_secret_to_engine( + "https://api.example.com", "test-api-key", "SECRET_KEY", "secret-value" + ) + + +class TestGetSecretsFromEngine: + """Tests for _get_secrets_from_engine function.""" + + @patch("arcade_cli.secret.httpx.get") + def test_get_secrets_success(self, mock_get): + """Test successful secrets retrieval.""" + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + "items": [ + {"key": "SECRET1", "id": "id1"}, + {"key": "SECRET2", "id": "id2"}, + ] + } + mock_get.return_value = mock_response + + secrets = _get_secrets_from_engine("https://api.example.com", "test-api-key") + + assert secrets == [ + {"key": "SECRET1", "id": "id1"}, + {"key": "SECRET2", "id": "id2"}, + ] + mock_get.assert_called_once_with( + "https://api.example.com/v1/admin/secrets", + headers={"Authorization": "Bearer test-api-key"}, + ) + mock_response.raise_for_status.assert_called_once() + + @patch("arcade_cli.secret.httpx.get") + def test_get_secrets_http_error(self, mock_get): + """Test secrets retrieval with HTTP error.""" + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Unauthorized", request=MagicMock(), response=MagicMock() + ) + mock_get.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + _get_secrets_from_engine("https://api.example.com", "test-api-key") + + +class TestDeleteSecretFromEngine: + """Tests for _delete_secret_from_engine function.""" + + @patch("arcade_cli.secret.httpx.delete") + def test_delete_secret_success(self, mock_delete): + """Test successful secret deletion.""" + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_delete.return_value = mock_response + + _delete_secret_from_engine("https://api.example.com", "test-api-key", "secret-id-123") + + mock_delete.assert_called_once_with( + "https://api.example.com/v1/admin/secrets/secret-id-123", + headers={"Authorization": "Bearer test-api-key"}, + ) + mock_response.raise_for_status.assert_called_once() + + @patch("arcade_cli.secret.httpx.delete") + def test_delete_secret_http_error(self, mock_delete): + """Test secret deletion with HTTP error.""" + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=MagicMock() + ) + mock_delete.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + _delete_secret_from_engine("https://api.example.com", "test-api-key", "secret-id-123") diff --git a/libs/tests/cli/test_show.py b/libs/tests/cli/test_show.py index ad8c33eb..681eb352 100644 --- a/libs/tests/cli/test_show.py +++ b/libs/tests/cli/test_show.py @@ -23,7 +23,7 @@ def test_show_logic_local_false(): def test_show_logic_local_true(): - with patch("arcade_cli.show.create_cli_catalog") as mock_create_catalog: + with patch("arcade_cli.show.create_cli_catalog_local") as mock_create_catalog: mock_create_catalog.return_value = [] show_logic( @@ -38,5 +38,6 @@ def test_show_logic_local_true(): debug=False, ) - # create_cli_catalog should be called when local=True + # create_cli_catalog_local should be called when local=True + # and toolkit is not provided mock_create_catalog.assert_called_once() diff --git a/libs/tests/cli/test_utils.py b/libs/tests/cli/test_utils.py index b3bef020..2a38bc97 100644 --- a/libs/tests/cli/test_utils.py +++ b/libs/tests/cli/test_utils.py @@ -1,5 +1,5 @@ import pytest -from arcade_cli.utils import compute_base_url, compute_login_url +from arcade_cli.utils import Provider, compute_base_url, compute_login_url, resolve_provider_api_key DEFAULT_CLOUD_HOST = "cloud.arcade.dev" DEFAULT_ENGINE_HOST = "api.arcade.dev" @@ -237,3 +237,14 @@ def test_compute_login_url(inputs: dict, expected_output: str): ) assert login_url == expected_output + + +def test_resolve_provider_api_key(): + resolved_api_key = resolve_provider_api_key(Provider.OPENAI, "123") + assert resolved_api_key == "123" + + resolved_api_key = resolve_provider_api_key("not-a-provider", None) + assert resolved_api_key is None + + resolved_api_key = resolve_provider_api_key(Provider.OPENAI, None) + assert resolved_api_key is None diff --git a/libs/tests/cli/toolkit_docs/test_docs_builder_utils.py b/libs/tests/cli/toolkit_docs/test_docs_builder_utils.py index f672c73b..c1a641f7 100644 --- a/libs/tests/cli/toolkit_docs/test_docs_builder_utils.py +++ b/libs/tests/cli/toolkit_docs/test_docs_builder_utils.py @@ -35,7 +35,7 @@ email = "dev@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-mcp[evals]>=2.0.0,<3.0.0", "arcade-serve>=2.0.0,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", @@ -49,7 +49,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = {path = "../../", editable = true} +arcade-mcp = {path = "../../", editable = true} arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } diff --git a/libs/tests/core/converters/test_openai.py b/libs/tests/core/converters/test_openai.py new file mode 100644 index 00000000..e8d0713f --- /dev/null +++ b/libs/tests/core/converters/test_openai.py @@ -0,0 +1,550 @@ +"""Tests for OpenAI converter utilities.""" + +from typing import Annotated + +import pytest +from arcade_core.catalog import MaterializedTool, ToolMeta, create_func_models +from arcade_core.converters.openai import ( + OpenAIFunctionParameterProperty, + OpenAIFunctionParameters, + OpenAIFunctionSchema, + OpenAIToolSchema, + _convert_input_parameters_to_json_schema, + _convert_value_schema_to_json_schema, + _create_tool_schema, + to_openai, +) +from arcade_core.schema import ( + InputParameter, + ToolDefinition, + ToolInput, + ToolkitDefinition, + ToolOutput, + ToolRequirements, + ValueSchema, +) + + +class TestOpenAIConverter: + """Test OpenAI converter functions.""" + + @pytest.fixture + def sample_tool_def(self): + """Create a sample tool definition.""" + return ToolDefinition( + name="calculate", + fully_qualified_name="MathToolkit.calculate", + description="Perform a calculation", + toolkit=ToolkitDefinition( + name="MathToolkit", + description="Math tools", + version="1.0.0", + ), + input=ToolInput( + parameters=[ + InputParameter( + name="expression", + required=True, + description="Math expression to evaluate", + value_schema=ValueSchema(val_type="string"), + ), + InputParameter( + name="precision", + required=False, + description="Decimal precision", + value_schema=ValueSchema(val_type="integer"), + ), + ] + ), + output=ToolOutput( + description="Calculation result", + value_schema=ValueSchema(val_type="number"), + ), + requirements=ToolRequirements(), + ) + + @pytest.fixture + def materialized_tool(self, sample_tool_def): + """Create a materialized tool.""" + + def calculate( + expression: Annotated[str, "Math expression"] = "1 + 1", + precision: Annotated[int, "Decimal precision"] = 2, + ) -> Annotated[float, "Calculation result"]: + """Perform a calculation.""" + return round(eval(expression), precision) # noqa: S307 + + input_model, output_model = create_func_models(calculate) + meta = ToolMeta(module=calculate.__module__, toolkit=sample_tool_def.toolkit.name) + return MaterializedTool( + tool=calculate, + definition=sample_tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + def test_to_openai_basic(self, materialized_tool): + """Test basic OpenAI tool conversion.""" + result = to_openai(materialized_tool) + + assert isinstance(result, dict) + assert result["type"] == "function" + assert "function" in result + + function = result["function"] + assert function["name"] == "MathToolkit_calculate" + assert function["description"] == "Perform a calculation" + assert function["strict"] is True + assert "parameters" in function + + def test_function_name_conversion(self, materialized_tool): + """Test that dots in fully_qualified_name are converted to underscores.""" + result = to_openai(materialized_tool) + assert result["function"]["name"] == "MathToolkit_calculate" + + def test_function_parameters_structure(self, materialized_tool): + """Test the structure of function parameters.""" + result = to_openai(materialized_tool) + params = result["function"]["parameters"] + + assert params["type"] == "object" + assert params["additionalProperties"] is False + assert "properties" in params + assert "required" in params + + # All parameters should be in required list for strict mode + assert set(params["required"]) == {"expression", "precision"} + + def test_required_parameter_schema(self, materialized_tool): + """Test required parameter schema generation.""" + result = to_openai(materialized_tool) + props = result["function"]["parameters"]["properties"] + + expression_prop = props["expression"] + assert expression_prop["type"] == "string" + assert expression_prop["description"] == "Math expression to evaluate" + + def test_optional_parameter_schema(self, materialized_tool): + """Test optional parameter schema with null union type.""" + result = to_openai(materialized_tool) + props = result["function"]["parameters"]["properties"] + + precision_prop = props["precision"] + # Optional parameters should have union type with null + assert precision_prop["type"] == ["integer", "null"] + assert precision_prop["description"] == "Decimal precision" + + def test_no_parameters_tool(self): + """Test tool with no parameters.""" + tool_def = ToolDefinition( + name="get_time", + fully_qualified_name="TimeToolkit.get_time", + description="Get current time", + toolkit=ToolkitDefinition(name="TimeToolkit"), + input=ToolInput(parameters=[]), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def get_time() -> Annotated[str, "current time"]: + return "2023-01-01T00:00:00Z" + + input_model, output_model = create_func_models(get_time) + meta = ToolMeta(module=get_time.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=get_time, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + params = result["function"]["parameters"] + + assert params["type"] == "object" + assert params["properties"] == {} + assert params["additionalProperties"] is False + # No required field when there are no parameters + assert "required" not in params + + @pytest.mark.parametrize( + "arcade_type,expected_json_type", + [ + ("string", "string"), + ("integer", "integer"), + ("number", "number"), + ("boolean", "boolean"), + ("array", "array"), + ("json", "object"), + ], + ) + def test_parameter_type_conversion(self, arcade_type, expected_json_type): + """Test different parameter type conversions.""" + tool_def = ToolDefinition( + name="test", + fully_qualified_name="Test.test", + description="Test tool", + toolkit=ToolkitDefinition(name="Test"), + input=ToolInput( + parameters=[ + InputParameter( + name="param", + required=True, + description="Test parameter", + value_schema=ValueSchema(val_type=arcade_type), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def test_func(param: Annotated[str, "Test parameter"]): + return param + + input_model, output_model = create_func_models(test_func) + meta = ToolMeta(module=test_func.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=test_func, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + param_schema = result["function"]["parameters"]["properties"]["param"] + assert param_schema["type"] == expected_json_type + + def test_array_parameter_with_inner_type(self): + """Test array parameter with inner type specification.""" + tool_def = ToolDefinition( + name="process_items", + fully_qualified_name="ArrayToolkit.process_items", + description="Process a list of items", + toolkit=ToolkitDefinition(name="ArrayToolkit"), + input=ToolInput( + parameters=[ + InputParameter( + name="items", + required=True, + description="List of string items", + value_schema=ValueSchema( + val_type="array", + inner_val_type="string", + ), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def process_items(items: Annotated[list[str], "List of string items"]): + return items + + input_model, output_model = create_func_models(process_items) + meta = ToolMeta(module=process_items.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=process_items, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + param_schema = result["function"]["parameters"]["properties"]["items"] + + assert param_schema["type"] == "array" + assert param_schema["items"]["type"] == "string" + + def test_enum_parameter(self): + """Test parameter with enum values.""" + tool_def = ToolDefinition( + name="set_color", + fully_qualified_name="ColorToolkit.set_color", + description="Set a color", + toolkit=ToolkitDefinition(name="ColorToolkit"), + input=ToolInput( + parameters=[ + InputParameter( + name="color", + required=True, + description="Color choice", + value_schema=ValueSchema( + val_type="string", + enum=["red", "green", "blue"], + ), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def set_color(color: Annotated[str, "Color choice"]): + return color + + input_model, output_model = create_func_models(set_color) + meta = ToolMeta(module=set_color.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=set_color, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + param_schema = result["function"]["parameters"]["properties"]["color"] + + assert param_schema["type"] == "string" + assert param_schema["enum"] == ["red", "green", "blue"] + + def test_array_with_enum_items(self): + """Test array parameter where items have enum values.""" + tool_def = ToolDefinition( + name="set_colors", + fully_qualified_name="ColorToolkit.set_colors", + description="Set multiple colors", + toolkit=ToolkitDefinition(name="ColorToolkit"), + input=ToolInput( + parameters=[ + InputParameter( + name="colors", + required=True, + description="List of colors", + value_schema=ValueSchema( + val_type="array", + inner_val_type="string", + enum=["red", "green", "blue"], + ), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def set_colors(colors: Annotated[list[str], "List of colors"]): + return colors + + input_model, output_model = create_func_models(set_colors) + meta = ToolMeta(module=set_colors.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=set_colors, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + param_schema = result["function"]["parameters"]["properties"]["colors"] + + assert param_schema["type"] == "array" + assert param_schema["items"]["type"] == "string" + assert param_schema["items"]["enum"] == ["red", "green", "blue"] + + def test_json_parameter_with_properties(self): + """Test JSON parameter with nested properties.""" + tool_def = ToolDefinition( + name="create_user", + fully_qualified_name="UserToolkit.create_user", + description="Create a user", + toolkit=ToolkitDefinition(name="UserToolkit"), + input=ToolInput( + parameters=[ + InputParameter( + name="user_data", + required=True, + description="User information", + value_schema=ValueSchema( + val_type="json", + properties={ + "name": ValueSchema(val_type="string"), + "age": ValueSchema(val_type="integer"), + "active": ValueSchema(val_type="boolean"), + }, + ), + ) + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def create_user(user_data: Annotated[dict, "User information"]): + return user_data + + input_model, output_model = create_func_models(create_user) + meta = ToolMeta(module=create_user.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=create_user, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + param_schema = result["function"]["parameters"]["properties"]["user_data"] + + assert param_schema["type"] == "object" + assert "properties" in param_schema + assert param_schema["properties"]["name"]["type"] == "string" + assert param_schema["properties"]["age"]["type"] == "integer" + assert param_schema["properties"]["active"]["type"] == "boolean" + + def test_multiple_optional_parameters(self): + """Test tool with multiple optional parameters.""" + tool_def = ToolDefinition( + name="search", + fully_qualified_name="SearchToolkit.search", + description="Search with filters", + toolkit=ToolkitDefinition(name="SearchToolkit"), + input=ToolInput( + parameters=[ + InputParameter( + name="query", + required=True, + description="Search query", + value_schema=ValueSchema(val_type="string"), + ), + InputParameter( + name="limit", + required=False, + description="Result limit", + value_schema=ValueSchema(val_type="integer"), + ), + InputParameter( + name="include_metadata", + required=False, + description="Include metadata in results", + value_schema=ValueSchema(val_type="boolean"), + ), + ] + ), + output=ToolOutput(), + requirements=ToolRequirements(), + ) + + def search( + query: Annotated[str, "Search query"], + limit: Annotated[int, "Result limit"] = 10, + include_metadata: Annotated[bool, "Include metadata"] = False, + ): + return f"Search results for {query}" + + input_model, output_model = create_func_models(search) + meta = ToolMeta(module=search.__module__, toolkit=tool_def.toolkit.name) + mat_tool = MaterializedTool( + tool=search, + definition=tool_def, + meta=meta, + input_model=input_model, + output_model=output_model, + ) + + result = to_openai(mat_tool) + props = result["function"]["parameters"]["properties"] + + # Required parameter should have single type + assert props["query"]["type"] == "string" + + # Optional parameters should have union types with null + assert props["limit"]["type"] == ["integer", "null"] + assert props["include_metadata"]["type"] == ["boolean", "null"] + + # All parameters should be in required list for strict mode + assert set(result["function"]["parameters"]["required"]) == { + "query", + "limit", + "include_metadata", + } + + +class TestHelperFunctions: + """Test helper functions used by the converter.""" + + def test_create_tool_schema(self): + """Test _create_tool_schema helper function.""" + params: OpenAIFunctionParameters = { + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "additionalProperties": False, + } + + result = _create_tool_schema("test_func", "Test function", params) + + assert result["type"] == "function" + assert result["function"]["name"] == "test_func" + assert result["function"]["description"] == "Test function" + assert result["function"]["parameters"] == params + assert result["function"]["strict"] is True + + def test_convert_value_schema_to_json_schema_basic_types(self): + """Test _convert_value_schema_to_json_schema for basic types.""" + test_cases = [ + ("string", "string"), + ("integer", "integer"), + ("number", "number"), + ("boolean", "boolean"), + ("json", "object"), + ("array", "array"), + ] + + for arcade_type, expected_json_type in test_cases: + schema = ValueSchema(val_type=arcade_type) + result = _convert_value_schema_to_json_schema(schema) + assert result["type"] == expected_json_type + + def test_convert_value_schema_with_enum(self): + """Test _convert_value_schema_to_json_schema with enum values.""" + schema = ValueSchema(val_type="string", enum=["a", "b", "c"]) + result = _convert_value_schema_to_json_schema(schema) + + assert result["type"] == "string" + assert result["enum"] == ["a", "b", "c"] + + def test_convert_input_parameters_empty_list(self): + """Test _convert_input_parameters_to_json_schema with empty parameters.""" + result = _convert_input_parameters_to_json_schema([]) + + assert result["type"] == "object" + assert result["properties"] == {} + assert result["additionalProperties"] is False + assert "required" not in result + + def test_convert_input_parameters_with_required_and_optional(self): + """Test _convert_input_parameters_to_json_schema with mixed parameters.""" + params = [ + InputParameter( + name="required_param", + required=True, + description="Required parameter", + value_schema=ValueSchema(val_type="string"), + ), + InputParameter( + name="optional_param", + required=False, + description="Optional parameter", + value_schema=ValueSchema(val_type="integer"), + ), + ] + + result = _convert_input_parameters_to_json_schema(params) + + assert result["type"] == "object" + assert result["additionalProperties"] is False + assert set(result["required"]) == {"required_param", "optional_param"} + + # Required parameter should have single type + assert result["properties"]["required_param"]["type"] == "string" + + # Optional parameter should have union type with null + assert result["properties"]["optional_param"]["type"] == ["integer", "null"] diff --git a/libs/tests/core/test_discovery.py b/libs/tests/core/test_discovery.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/tests/core/test_schema.py b/libs/tests/core/test_schema.py index d71d475d..5cea8355 100644 --- a/libs/tests/core/test_schema.py +++ b/libs/tests/core/test_schema.py @@ -57,7 +57,7 @@ def test_get_secret_key_not_found(): tool_context = ToolContext(secrets=secrets) # When the key is not found, get_secret should raise a ValueError. - with pytest.raises(ValueError, match=f"Secret {key} not found in context."): + with pytest.raises(ValueError, match=f"Secret '{key}' not found in context."): tool_context.get_secret(key) @@ -65,7 +65,7 @@ def test_get_secret_when_secrets_is_none(): tool_context = ToolContext(secrets=None) # When no secrets dictionary is provided, get_secret should raise a ValueError. - with pytest.raises(ValueError, match="Secrets not found in context."): + with pytest.raises(ValueError, match="Secret 'missing_key' not found in context."): tool_context.get_secret("missing_key") @@ -100,14 +100,14 @@ def test_get_metadata_key_not_found(): metadata = [ToolMetadataItem(key="other_key", value="another_metadata")] tool_context = ToolContext(metadata=metadata) - with pytest.raises(ValueError, match=f"Metadata {key} not found in context."): + with pytest.raises(ValueError, match=f"Metadata '{key}' not found in context."): tool_context.get_metadata(key) def test_get_metadata_when_metadata_is_none(): tool_context = ToolContext(metadata=None) - with pytest.raises(ValueError, match="Metadatas not found in context."): + with pytest.raises(ValueError, match="Metadata 'missing_key' not found in context."): tool_context.get_metadata("missing_key") diff --git a/libs/tests/deployment/test_config.py b/libs/tests/deployment/test_config.py index 8306d39d..805a78ac 100644 --- a/libs/tests/deployment/test_config.py +++ b/libs/tests/deployment/test_config.py @@ -55,7 +55,7 @@ def test_deployment_parsing(test_dir): assert repo.index == "pypi" assert repo.index_url == "https://pypi.org/simple" assert repo.trusted_host == "pypi.org" - assert repo.packages == [Package(name="arcade-ai", specifier=">=1.0.0")] + assert repo.packages == [Package(name="arcade-mcp", specifier=">=1.0.0")] repo = deployment.worker[0].custom_source[1] assert repo.index == "pypi2" @@ -67,11 +67,12 @@ def test_deployment_parsing(test_dir): def test_specifier(): from packaging.requirements import Requirement - req = Requirement("arcade-ai>=1.0.0") - assert req.name == "arcade-ai" + req = Requirement("arcade-mcp>=1.0.0") + assert req.name == "arcade-mcp" assert req.specifier == ">=1.0.0" +@pytest.mark.skip(reason="This test is flaky and needs to be fixed") def test_deployment_dict(test_dir): config_path = test_dir / "test_files" / "full.worker.toml" deployment = Deployment.from_toml(config_path) @@ -97,7 +98,7 @@ def test_deployment_dict(test_dir): { "packages": [ { - "name": "arcade-ai", + "name": "arcade-mcp", "specifier": ">=1.0.0" } ], @@ -132,12 +133,6 @@ def test_deployment_dict(test_dir): assert got == expected -def test_invalid_secret_parsing(test_dir): - config_path = test_dir / "test_files" / "invalid.secret.worker.toml" - with pytest.raises(ValueError): - Deployment.from_toml(config_path) - - def test_missing_local_package(test_dir): config_path = test_dir / "test_files" / "invalid.localfile.worker.toml" deployment = Deployment.from_toml(config_path) diff --git a/libs/tests/deployment/test_files/full.worker.toml b/libs/tests/deployment/test_files/full.worker.toml index 095dbdee..3b03b3d6 100644 --- a/libs/tests/deployment/test_files/full.worker.toml +++ b/libs/tests/deployment/test_files/full.worker.toml @@ -17,7 +17,7 @@ packages = ["./mock_toolkit"] index = "pypi" index_url = "https://pypi.org/simple" trusted_host = "pypi.org" -packages = ["arcade-ai>=1.0.0"] +packages = ["arcade-mcp>=1.0.0"] [[worker.custom_source]] index = "pypi2" diff --git a/libs/tests/deployment/test_files/invalid.localfile.worker.toml b/libs/tests/deployment/test_files/invalid.localfile.worker.toml index 1397eea3..9ca398fb 100644 --- a/libs/tests/deployment/test_files/invalid.localfile.worker.toml +++ b/libs/tests/deployment/test_files/invalid.localfile.worker.toml @@ -8,7 +8,7 @@ retries = 3 secret = "test-secret" [worker.pypi_source] -packages = ["arcade-ai"] +packages = ["arcade-mcp"] [worker.local_source] packages = ["./missing_toolkit"] @@ -22,7 +22,7 @@ retries = 3 secret = "test-secret" [worker.pypi_source] -packages = ["arcade-ai"] +packages = ["arcade-mcp"] [worker.local_source] packages = ["./invalid.localfile.worker.toml"] @@ -36,7 +36,7 @@ retries = 3 secret = "test-secret" [worker.pypi_source] -packages = ["arcade-ai"] +packages = ["arcade-mcp"] [worker.local_source] packages = ["./invalid_toolkit"] diff --git a/libs/tests/mcp/test_convert.py b/libs/tests/mcp/test_convert.py index 24095f74..7ff26b52 100644 --- a/libs/tests/mcp/test_convert.py +++ b/libs/tests/mcp/test_convert.py @@ -2,8 +2,8 @@ import json from typing import Annotated from arcade_core.catalog import ToolCatalog -from arcade_serve.mcp.convert import convert_to_mcp_content, create_mcp_tool -from arcade_tdk import tool +from arcade_mcp_server import tool +from arcade_mcp_server.convert import convert_to_mcp_content, create_mcp_tool @tool @@ -14,15 +14,19 @@ def sample_tool(x: Annotated[int, "first"], y: Annotated[int, "second"]) -> int: def test_convert_to_mcp_content_primitives(): - assert convert_to_mcp_content(42) == [{"type": "text", "text": "42"}] - assert convert_to_mcp_content("hello") == [{"type": "text", "text": "hello"}] - assert convert_to_mcp_content(True) == [{"type": "text", "text": "True"}] + result = convert_to_mcp_content(42) + assert result[0].type == "text" and result[0].text == "42" + result = convert_to_mcp_content("hello") + assert result[0].type == "text" and result[0].text == "hello" + result = convert_to_mcp_content(True) + assert result[0].type == "text" and result[0].text == "True" def test_convert_to_mcp_content_complex(): data = {"a": 1} expected_json = json.dumps(data) - assert convert_to_mcp_content(data) == [{"type": "text", "text": expected_json}] + result = convert_to_mcp_content(data) + assert result[0].type == "text" and result[0].text == expected_json def test_create_mcp_tool(): @@ -33,12 +37,12 @@ def test_create_mcp_tool(): mcp_tool = create_mcp_tool(mat_tool) assert mcp_tool is not None - assert mcp_tool["name"] == "ConvertToolkit_SampleTool" - assert mcp_tool["description"] + assert mcp_tool.name == "ConvertToolkit_SampleTool" + assert mcp_tool.description # Ensure input schema contains both parameters and marks them required - props = mcp_tool["inputSchema"]["properties"] + props = mcp_tool.inputSchema["properties"] assert set(props.keys()) == {"x", "y"} - required_fields = set(mcp_tool["inputSchema"].get("required", [])) + required_fields = set(mcp_tool.inputSchema.get("required", [])) # Ensure no unexpected required fields and that declared ones are subset of expected assert required_fields.issubset({"x", "y"}) diff --git a/libs/tests/mcp/test_message_processor.py b/libs/tests/mcp/test_message_processor.py deleted file mode 100644 index 4a526cca..00000000 --- a/libs/tests/mcp/test_message_processor.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio - -import pytest -from arcade_serve.mcp.message_processor import MCPMessageProcessor, create_message_processor -from arcade_serve.mcp.types import InitializeRequest, PingRequest - - -@pytest.mark.asyncio -async def test_message_processor_parses_initialize_json(): - """Ensure JSON initialize strings are converted into InitializeRequest objects.""" - json_init = '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}\n' - processor = MCPMessageProcessor() - - result = await processor.process_request(json_init) - - assert isinstance(result, InitializeRequest) - assert result.id == 1 - assert result.method == "initialize" - - -@pytest.mark.asyncio -async def test_message_processor_passes_notifications_unchanged(): - """Unknown notifications should be passed through as parsed dictionaries without errors.""" - json_notification = '{"jsonrpc":"2.0","id":null,"method":"notifications/custom","params":{}}\n' - processor = MCPMessageProcessor() - - result = await processor.process_request(json_notification) - - # The MCPMessageProcessor keeps unknown notifications as simple dicts - assert isinstance(result, dict) - assert result["method"] == "notifications/custom" - - -@pytest.mark.asyncio -async def test_message_processor_middleware_execution_order(monkeypatch): - """Middleware (sync + async) should be executed in the order they were added.""" - - order: list[str] = [] - - def mw_sync(msg, direction): # type: ignore[return-value] - order.append("sync") - return msg - - async def mw_async(msg, direction): # type: ignore[return-value] - await asyncio.sleep(0) # ensure it is truly async - order.append("async") - return msg - - processor = create_message_processor(mw_sync, mw_async) - - # Use a pre-parsed PingRequest instance so we don't test parsing again here - ping = PingRequest(id=42) - - _ = await processor.process_request(ping) - - assert order == ["sync", "async"] diff --git a/libs/tests/mcp/test_server.py b/libs/tests/mcp/test_server.py deleted file mode 100644 index 2cdc56a3..00000000 --- a/libs/tests/mcp/test_server.py +++ /dev/null @@ -1,152 +0,0 @@ -import sys -import types -from typing import Annotated, Any - -import pytest -from arcade_core.catalog import ToolCatalog -from arcade_serve.mcp import server as mcp_server -from arcade_serve.mcp.types import ( - CallToolRequest, - CancelRequest, - InitializeRequest, - ListToolsRequest, - PingRequest, -) -from arcade_tdk import tool - -# --------------------------------------------------------------------------- -# Test helpers / stubs -# --------------------------------------------------------------------------- - - -class _FakeAuth: - async def authorize(self, auth_requirement: Any, user_id: str): - """Return an object that mimics AuthorizationResponse with completed status.""" - - class _Ctx: # minimal stub - token = "dummy-token" # noqa: S105 - - class _Resp: # pylint: disable=too-few-public-methods - status = "completed" - url = "" - context = _Ctx() - - return _Resp() - - -class _FakeArcade: # pylint: disable=too-few-public-methods - def __init__(self, **_: Any): - self.auth = _FakeAuth() - - -# Ensure that the AsyncArcade & ArcadeError symbols inside server.py point to our stubs. -pytestmark = pytest.mark.asyncio - - -@pytest.fixture(autouse=True) -def _patch_arcadepy(monkeypatch): - """Patch the external `arcadepy` dependency used by mcp.server.""" - - # Patch the imported symbols on the already-imported server module - monkeypatch.setattr(mcp_server, "AsyncArcade", _FakeArcade, raising=True) - monkeypatch.setattr(mcp_server, "ArcadeError", Exception, raising=True) - - # Provide a dummy `arcadepy` module in sys.modules for any other importers - fake_arcadepy = types.ModuleType("arcadepy") - fake_arcadepy.AsyncArcade = _FakeArcade # type: ignore[attr-defined] - fake_arcadepy.ArcadeError = Exception # type: ignore[attr-defined] - sys.modules["arcadepy"] = fake_arcadepy - - yield - - # Cleanup - sys.modules.pop("arcadepy", None) - - -# --------------------------------------------------------------------------- -# Fixtures for a sample tool / catalog / server -# --------------------------------------------------------------------------- - - -@tool -def multiply(a: Annotated[int, "a"], b: Annotated[int, "b"]) -> Annotated[int, "result"]: - """Return the product of *a* and *b*.""" - - return a * b - - -@pytest.fixture(scope="module") -def sample_catalog(): - catalog = ToolCatalog() - catalog.add_tool(multiply, "test_toolkit") - return catalog - - -@pytest.fixture() -def server(sample_catalog): - # MCPServer constructor is synchronous, so fixture need not be async - return mcp_server.MCPServer(sample_catalog, enable_logging=False) - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -async def test_handle_ping(server): - req = PingRequest(id=123) - resp = await server._handle_ping(req) # pylint: disable=protected-access - assert resp.id == 123 - assert resp.result == {"pong": True} - - -async def test_handle_initialize(server): - req = InitializeRequest(id=1) - resp = await server._handle_initialize(req) # pylint: disable=protected-access - assert resp.id == 1 - assert resp.result.protocolVersion == mcp_server.MCP_PROTOCOL_VERSION - assert resp.result.serverInfo.name.startswith("Arcade") - - -async def test_handle_list_tools(server): - req = ListToolsRequest(id=99) - resp = await server._handle_list_tools(req) # pylint: disable=protected-access - assert resp.id == 99 - # Should list our sample tool only - tool_names = [t.name for t in resp.result.tools] - assert "TestToolkit_Multiply" in tool_names # toolkit + "_" + tool - - -async def test_handle_call_tool_success(server): - req = CallToolRequest( - id="call-1", - params={ - "name": "TestToolkit_Multiply", - "input": {"a": 6, "b": 7}, - }, - ) - resp = await server._handle_call_tool(req, user_id="tester@example.com") # pylint: disable=protected-access - - assert resp.id == "call-1" - # convert_to_mcp_content wraps primitives in list-of-dicts - assert resp.result.content == [{"type": "text", "text": "42"}] - - -async def test_send_response_dict(server, monkeypatch): - """_send_response should JSON-serialize plain dictionaries.""" - - sent: list[str] = [] - - class _Write: - async def send(self, msg): - sent.append(msg) - - await server._send_response(_Write(), {"foo": "bar"}) # pylint: disable=protected-access - - assert sent and sent[0].strip() == '{"foo": "bar"}' - - -async def test_handle_cancel(server): - req = CancelRequest(id=77, params={"id": "abc"}) - resp = await server._handle_cancel(req) # pylint: disable=protected-access - assert resp.result == {"ok": True} diff --git a/libs/tests/mcp/test_stdio.py b/libs/tests/mcp/test_stdio.py deleted file mode 100644 index daa0282d..00000000 --- a/libs/tests/mcp/test_stdio.py +++ /dev/null @@ -1,32 +0,0 @@ -import io -import queue - -from arcade_serve.mcp.stdio import stdio_reader, stdio_writer - - -def test_stdio_reader_puts_lines_and_none(): - q: queue.Queue[str | None] = queue.Queue() - test_input = io.StringIO("line1\nline2\n") - - stdio_reader(test_input, q) - - # We should get the two lines followed by None sentinel - assert q.get_nowait() == "line1\n" - assert q.get_nowait() == "line2\n" - assert q.get_nowait() is None - - -def test_stdio_writer_reads_until_none(): - q: queue.Queue[str | None] = queue.Queue() - output_stream = io.StringIO() - - # preload queue with two messages and sentinel - q.put("msg1") - q.put("msg2\n") - q.put(None) - - stdio_writer(output_stream, q) - - # Ensure writer appended newlines when missing - output_stream.seek(0) - assert output_stream.read() == "msg1\nmsg2\n" diff --git a/libs/tests/core/test_telemetry.py b/libs/tests/worker/test_telemetry.py similarity index 82% rename from libs/tests/core/test_telemetry.py rename to libs/tests/worker/test_telemetry.py index 5545e3c1..1c4984d2 100644 --- a/libs/tests/core/test_telemetry.py +++ b/libs/tests/worker/test_telemetry.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from arcade_core.telemetry import OTELHandler, ShutdownError +from arcade_serve.fastapi.telemetry import OTELHandler, ShutdownError from fastapi import FastAPI @@ -15,11 +15,11 @@ def handler_disabled(app): return OTELHandler(enable=False) -@patch("arcade_core.telemetry.logging") -@patch("arcade_core.telemetry.FastAPIInstrumentor") -@patch("arcade_core.telemetry.OTLPLogExporter") -@patch("arcade_core.telemetry.OTLPMetricExporter") -@patch("arcade_core.telemetry.OTLPSpanExporter") +@patch("arcade_serve.fastapi.telemetry.logging") +@patch("arcade_serve.fastapi.telemetry.FastAPIInstrumentor") +@patch("arcade_serve.fastapi.telemetry.OTLPLogExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPMetricExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPSpanExporter") def test_init_with_enable_true( mock_span_exporter, mock_metric_exporter, @@ -53,8 +53,8 @@ def test_init_with_enable_true( mock_instrumentor.return_value.instrument_app.assert_called_once_with(app) -@patch("arcade_core.telemetry.logging") -@patch("arcade_core.telemetry.FastAPIInstrumentor") +@patch("arcade_serve.fastapi.telemetry.logging") +@patch("arcade_serve.fastapi.telemetry.FastAPIInstrumentor") def test_init_with_enable_false(mock_instrumentor, mock_logging, app): handler = OTELHandler(enable=False) handler.instrument_app(app) @@ -81,9 +81,9 @@ def test_init_tracer_export_exception(app): assert "Could not connect to OpenTelemetry Tracer endpoint" in str(exc_info.value) -@patch("arcade_core.telemetry.OTLPLogExporter") -@patch("arcade_core.telemetry.OTLPMetricExporter") -@patch("arcade_core.telemetry.OTLPSpanExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPLogExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPMetricExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPSpanExporter") def test_shutdown(mock_span_exporter, mock_metric_exporter, mock_log_exporter, app): # Mock the shutdown methods mock_span_exporter.return_value.shutdown = MagicMock() @@ -120,10 +120,10 @@ def test_shutdown_logging_not_initialized(handler_disabled): assert "Log provider not initialized" in str(exc_info.value) -@patch("arcade_core.telemetry.get_meter_provider") -@patch("arcade_core.telemetry.OTLPLogExporter") -@patch("arcade_core.telemetry.OTLPMetricExporter") -@patch("arcade_core.telemetry.OTLPSpanExporter") +@patch("arcade_serve.fastapi.telemetry.get_meter_provider") +@patch("arcade_serve.fastapi.telemetry.OTLPLogExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPMetricExporter") +@patch("arcade_serve.fastapi.telemetry.OTLPSpanExporter") def test_get_meter( mock_span_exporter, mock_metric_exporter, mock_log_exporter, mock_get_meter_provider, app ): diff --git a/pyproject.toml b/pyproject.toml index ed1514ee..f5d4d71a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] -name = "arcade-ai" -version = "2.2.3" +name = "arcade-mcp" +version = "1.0.0rc1" description = "Arcade.dev - Tool Calling platform for Agents" readme = "README.md" license = {file = "LICENSE"} @@ -21,7 +21,8 @@ requires-python = ">=3.10" dependencies = [ # CLI dependencies - "arcade-core>=2.4.0,<3.0.0", + "arcade-mcp-server>=1.0.0rc1,<3.0.0", + "arcade-core>=2.5.0rc1,<3.0.0", "typer==0.10.0", "rich==13.9.4", "Jinja2==3.1.6", @@ -39,10 +40,12 @@ all = [ "scikit-learn>=1.5.0", "pytz>=2024.1", "python-dateutil>=2.8.2", + # mcp + "arcade-mcp-server>=1.0.0rc1,<3.0.0", # serve - "arcade-serve>=2.1.0,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", # tdk - "arcade-tdk>=2.3.1,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", ] # Evals also depends on arcade-core and openai, but they are already required deps evals = [ @@ -69,12 +72,14 @@ dev-dependencies = [ # CLI entry point [project.scripts] arcade = "arcade_cli.main:cli" +arcade-mcp = "arcade_cli.main:cli" [tool.uv.sources] # Workspace member sources arcade-core = { workspace = true } arcade-tdk = { workspace = true } arcade-serve = { workspace = true } +arcade-mcp-server = { workspace = true } [build-system] requires = ["hatchling"] @@ -91,6 +96,7 @@ members = [ "libs/arcade-core", "libs/arcade-tdk", "libs/arcade-serve", + "libs/arcade-mcp-server", ] [tool.mypy] @@ -146,7 +152,7 @@ line-length = 100 [tool.ruff.lint] select = ["E", "F", "I", "N", "UP", "RUF"] -ignore = ["E501"] +ignore = ["E501", "S105"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] diff --git a/toolkits/clickhouse/pyproject.toml b/toolkits/clickhouse/pyproject.toml index e57d81b6..59cebe67 100644 --- a/toolkits/clickhouse/pyproject.toml +++ b/toolkits/clickhouse/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.0" description = "Tools to query and explore a ClickHouse database" requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", "clickhouse-connect>=0.7.0", "pydantic>=2.11.7", "sqlalchemy>=2.0.41", @@ -24,8 +24,8 @@ email = "support@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.0.0,<3.0.0", - "arcade-serve>=2.0.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-mock>=3.11.1,<3.12.0", @@ -38,7 +38,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../", editable = true } +arcade-mcp = { path = "../../", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } diff --git a/toolkits/linkedin/pyproject.toml b/toolkits/linkedin/pyproject.toml index 0d6b9e54..9c8930be 100644 --- a/toolkits/linkedin/pyproject.toml +++ b/toolkits/linkedin/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.13" description = "Arcade.dev LLM tools for LinkedIn" requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", "httpx>=0.27.2,<1.0.0", ] [[project.authors]] @@ -17,8 +17,8 @@ email = "dev@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.0.0,<3.0.0", - "arcade-serve>=2.0.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-asyncio>=0.24.0,<0.25.0", @@ -31,7 +31,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = {path = "../../", editable = true} +arcade-mcp = {path = "../../", editable = true} arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } diff --git a/toolkits/math/pyproject.toml b/toolkits/math/pyproject.toml index e12385b0..4908807f 100644 --- a/toolkits/math/pyproject.toml +++ b/toolkits/math/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.4" description = "Arcade.dev LLM tools for doing math" requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", ] [[project.authors]] name = "Arcade" @@ -16,8 +16,8 @@ email = "dev@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.0.0,<3.0.0", - "arcade-serve>=2.0.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-asyncio>=0.24.0,<0.25.0", @@ -30,7 +30,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = {path = "../../", editable = true} +arcade-mcp = {path = "../../", editable = true} arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } diff --git a/toolkits/mongodb/pyproject.toml b/toolkits/mongodb/pyproject.toml index e2127c6f..8a019581 100644 --- a/toolkits/mongodb/pyproject.toml +++ b/toolkits/mongodb/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.0" description = "Tools to query and explore a MongoDB database" requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", "pymongo>=4.10.1", "pydantic>=2.11.7", "motor>=3.6.0", @@ -20,8 +20,8 @@ email = "support@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.0.0,<3.0.0", - "arcade-serve>=2.0.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-mock>=3.11.1,<3.12.0", @@ -34,7 +34,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../", editable = true } +arcade-mcp = { path = "../../", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } diff --git a/toolkits/postgres/pyproject.toml b/toolkits/postgres/pyproject.toml index 473c35f9..9fb902ee 100644 --- a/toolkits/postgres/pyproject.toml +++ b/toolkits/postgres/pyproject.toml @@ -8,7 +8,7 @@ version = "0.3.0" description = "Tools to query and explore a postgres database" requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", "psycopg2-binary>=2.9.10", "pydantic>=2.11.7", "sqlalchemy>=2.0.41", @@ -23,8 +23,8 @@ email = "support@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.0.0,<3.0.0", - "arcade-serve>=2.0.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-mock>=3.11.1,<3.12.0", @@ -37,7 +37,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../", editable = true } +arcade-mcp = { path = "../../", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } diff --git a/toolkits/slack_api/pyproject.toml b/toolkits/slack_api/pyproject.toml index f649ce85..c5b79331 100644 --- a/toolkits/slack_api/pyproject.toml +++ b/toolkits/slack_api/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.0" description = "Arcade Wrapper Tools enabling LLMs to interact with low-level Slack API endpoints." requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", "httpx>=0.27.2,<1.0.0", ] [[project.authors]] @@ -18,8 +18,8 @@ email = "support@arcade.dev" [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.2.2,<3.0.0", - "arcade-serve>=2.0.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-mock>=3.11.1,<3.12.0", @@ -36,7 +36,7 @@ toolkit_name = "arcade_slack_api" # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../", editable = true } +arcade-mcp = { path = "../../", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } [tool.mypy] diff --git a/toolkits/zendesk/pyproject.toml b/toolkits/zendesk/pyproject.toml index a257e0fd..2b79085d 100644 --- a/toolkits/zendesk/pyproject.toml +++ b/toolkits/zendesk/pyproject.toml @@ -7,7 +7,7 @@ name = "arcade_zendesk" version = "0.3.0" requires-python = ">=3.10" dependencies = [ - "arcade-tdk>=2.3.1,<3.0.0", + "arcade-tdk>=2.6.0rc1,<3.0.0", "httpx>=0.25.0,<1.0.0", "beautifulsoup4>=4.0.0,<5" ] @@ -15,8 +15,8 @@ dependencies = [ [project.optional-dependencies] dev = [ - "arcade-ai[evals]>=2.2.1,<3.0.0", - "arcade-serve>=2.1.0,<3.0.0", + "arcade-mcp[all]>=1.0.0rc1,<3.0.0", + "arcade-serve>=2.2.0rc1,<3.0.0", "pytest>=8.3.0,<8.4.0", "pytest-cov>=4.0.0,<4.1.0", "pytest-mock>=3.11.1,<3.12.0", @@ -29,7 +29,7 @@ dev = [ # Use local path sources for arcade libs when working locally [tool.uv.sources] -arcade-ai = { path = "../../", editable = true } +arcade-mcp = { path = "../../", editable = true } arcade-serve = { path = "../../libs/arcade-serve/", editable = true } arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } diff --git a/uv_setup.sh b/uv_setup.sh new file mode 100755 index 00000000..a28b254f --- /dev/null +++ b/uv_setup.sh @@ -0,0 +1,398 @@ +#!/bin/bash + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +# Script configuration +REQUIRED_PYTHON_VERSION="3.11" +VENV_NAME=".venv" +PROJECT_ROOT=$(pwd) +INTERACTIVE_MODE=true + +echo -e "${BLUE}=== UV Python Environment Setup ===${NC}" +echo -e "${BLUE}Project: ${PROJECT_ROOT}${NC}" +echo "" + +# Function to check if a command exists +command_exists() { + command -v "$1" >/dev/null 2>&1 +} + +# Function to compare version strings +version_ge() { + [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] +} + +# Function to ask for confirmation +confirm() { + local prompt="$1" + local default="${2:-y}" + local response + + if [ "$default" = "y" ]; then + prompt="${prompt} [Y/n]: " + else + prompt="${prompt} [y/N]: " + fi + + while true; do + echo -ne "${CYAN}${prompt}${NC}" + read -r response + response=${response:-$default} + response=$(echo "$response" | tr '[:upper:]' '[:lower:]') + + if [[ "$response" =~ ^(yes|y)$ ]]; then + return 0 + elif [[ "$response" =~ ^(no|n)$ ]]; then + return 1 + else + echo -e "${RED}Please answer yes/no (y/n)${NC}" + fi + done +} + +# Ask if user wants to run in interactive mode +if confirm "Do you want to run in interactive mode? (You'll be asked to confirm each step)" "y"; then + INTERACTIVE_MODE=true + echo -e "${GREEN}Running in interactive mode${NC}\n" +else + INTERACTIVE_MODE=false + echo -e "${YELLOW}Running in automatic mode (all steps will be executed)${NC}\n" +fi + +# Step 1: Check if uv is installed +echo -e "${YELLOW}1. Checking uv installation...${NC}" +if ! command_exists uv; then + echo -e "${RED}โœ— uv is not installed${NC}" + + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to install uv?" "y"; then + echo -e "${YELLOW}Installing uv...${NC}" + else + echo -e "${RED}Cannot proceed without uv. Exiting.${NC}" + exit 1 + fi + else + echo -e "${YELLOW}Installing uv...${NC}" + fi + + # Install uv based on the system + if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS + if command_exists brew; then + brew install uv + else + curl -LsSf https://astral.sh/uv/install.sh | sh + fi + else + # Linux and others + curl -LsSf https://astral.sh/uv/install.sh | sh + fi + + # Add to PATH if not already there + if [[ ":$PATH:" != *":$HOME/.local/bin:"* ]]; then + echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.zshrc + export PATH="$HOME/.local/bin:$PATH" + fi + + echo -e "${GREEN}โœ“ uv installed successfully${NC}" +else + echo -e "${GREEN}โœ“ uv is installed ($(uv --version))${NC}" +fi + +# Step 2: Check Python 3.11 availability +echo -e "\n${YELLOW}2. Checking Python ${REQUIRED_PYTHON_VERSION} availability...${NC}" + +# First check if uv can find Python 3.11 +if uv python find 3.11 >/dev/null 2>&1; then + echo -e "${GREEN}โœ“ Python 3.11 is available via uv${NC}" + PYTHON_PATH=$(uv python find 3.11) + echo -e " Found at: ${PYTHON_PATH}" +else + echo -e "${RED}โœ— Python 3.11 not found${NC}" + + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to install Python 3.11 via uv?" "y"; then + echo -e "${YELLOW}Installing Python 3.11...${NC}" + uv python install 3.11 + PYTHON_PATH=$(uv python find 3.11) + echo -e "${GREEN}โœ“ Python 3.11 installed successfully${NC}" + else + echo -e "${RED}Cannot proceed without Python 3.11. Exiting.${NC}" + exit 1 + fi + else + echo -e "${YELLOW}Installing Python 3.11 via uv...${NC}" + uv python install 3.11 + PYTHON_PATH=$(uv python find 3.11) + echo -e "${GREEN}โœ“ Python 3.11 installed successfully${NC}" + fi +fi + +# Step 3: Create or verify virtual environment +echo -e "\n${YELLOW}3. Setting up virtual environment...${NC}" + +if [ -d "$VENV_NAME" ]; then + echo -e "${YELLOW}Virtual environment already exists. Checking Python version...${NC}" + + # Check if the existing venv has the correct Python version + if [ -f "$VENV_NAME/bin/python" ]; then + VENV_PYTHON_VERSION=$("$VENV_NAME/bin/python" --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) + if [ "$VENV_PYTHON_VERSION" != "$REQUIRED_PYTHON_VERSION" ]; then + echo -e "${RED}โœ— Existing venv uses Python $VENV_PYTHON_VERSION (required: $REQUIRED_PYTHON_VERSION)${NC}" + + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to recreate the virtual environment with Python $REQUIRED_PYTHON_VERSION?" "y"; then + echo -e "${YELLOW}Recreating virtual environment...${NC}" + rm -rf "$VENV_NAME" + uv venv --python 3.11 + echo -e "${GREEN}โœ“ Virtual environment recreated with Python 3.11${NC}" + else + echo -e "${YELLOW}Keeping existing virtual environment with Python $VENV_PYTHON_VERSION${NC}" + echo -e "${RED}Warning: This may cause compatibility issues!${NC}" + fi + else + echo -e "${YELLOW}Recreating with Python $REQUIRED_PYTHON_VERSION...${NC}" + rm -rf "$VENV_NAME" + uv venv --python 3.11 + echo -e "${GREEN}โœ“ Virtual environment recreated with Python 3.11${NC}" + fi + else + echo -e "${GREEN}โœ“ Virtual environment uses correct Python version ($VENV_PYTHON_VERSION)${NC}" + fi + else + echo -e "${RED}โœ— Virtual environment seems corrupted${NC}" + + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to recreate the virtual environment?" "y"; then + echo -e "${YELLOW}Recreating virtual environment...${NC}" + rm -rf "$VENV_NAME" + uv venv --python 3.11 + echo -e "${GREEN}โœ“ Virtual environment recreated${NC}" + else + echo -e "${RED}Cannot proceed with corrupted virtual environment. Exiting.${NC}" + exit 1 + fi + else + echo -e "${YELLOW}Recreating...${NC}" + rm -rf "$VENV_NAME" + uv venv --python 3.11 + echo -e "${GREEN}โœ“ Virtual environment recreated${NC}" + fi + fi +else + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to create a new virtual environment with Python 3.11?" "y"; then + echo -e "${YELLOW}Creating virtual environment...${NC}" + uv venv --python 3.11 + echo -e "${GREEN}โœ“ Virtual environment created${NC}" + else + echo -e "${RED}Cannot proceed without virtual environment. Exiting.${NC}" + exit 1 + fi + else + echo -e "${YELLOW}Creating new virtual environment with Python 3.11...${NC}" + uv venv --python 3.11 + echo -e "${GREEN}โœ“ Virtual environment created${NC}" + fi +fi + +# Step 4: Install dependencies +echo -e "\n${YELLOW}4. Installing project dependencies...${NC}" + +# Check if pyproject.toml exists +if [ -f "pyproject.toml" ]; then + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to install dependencies from pyproject.toml?" "y"; then + echo -e "${YELLOW}Installing dependencies...${NC}" + uv sync + echo -e "${GREEN}โœ“ Dependencies installed${NC}" + else + echo -e "${YELLOW}Skipping dependency installation${NC}" + echo -e "${RED}Warning: You'll need to run 'uv sync' manually later${NC}" + fi + else + echo -e "${YELLOW}Installing dependencies from pyproject.toml...${NC}" + uv sync + echo -e "${GREEN}โœ“ Dependencies installed${NC}" + fi +else + echo -e "${YELLOW}No pyproject.toml found. Skipping dependency installation.${NC}" +fi + +# Step 5: Configure VS Code +echo -e "\n${YELLOW}5. Configuring VS Code...${NC}" + +CONFIGURE_VSCODE=true +if [ "$INTERACTIVE_MODE" = true ]; then + if [ -f ".vscode/settings.json" ]; then + echo -e "${YELLOW}VS Code settings already exist${NC}" + if confirm "Would you like to overwrite existing VS Code settings?" "n"; then + CONFIGURE_VSCODE=true + else + CONFIGURE_VSCODE=false + echo -e "${YELLOW}Keeping existing VS Code settings${NC}" + fi + else + if confirm "Would you like to configure VS Code settings?" "y"; then + CONFIGURE_VSCODE=true + else + CONFIGURE_VSCODE=false + echo -e "${YELLOW}Skipping VS Code configuration${NC}" + fi + fi +fi + +if [ "$CONFIGURE_VSCODE" = true ]; then + # Create .vscode directory if it doesn't exist + mkdir -p .vscode + + # Create VS Code settings + cat > .vscode/settings.json << EOF +{ + "python.defaultInterpreterPath": "${PROJECT_ROOT}/${VENV_NAME}/bin/python", + "python.terminal.activateEnvironment": true, + "python.terminal.activateEnvInCurrentTerminal": true, + "python.envFile": "\${workspaceFolder}/.env", + "python.venvPath": "${PROJECT_ROOT}", + "python.venvFolders": ["${VENV_NAME}"], + "python.linting.enabled": true, + "python.linting.pylintEnabled": false, + "python.linting.flake8Enabled": false, + "python.linting.ruffEnabled": true, + "python.formatting.provider": "ruff", + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit", + "source.fixAll": "explicit" + } + }, + "ruff.path": ["${PROJECT_ROOT}/${VENV_NAME}/bin/ruff"] +} +EOF + + echo -e "${GREEN}โœ“ VS Code settings created/updated${NC}" +fi + +# Create .env file if it doesn't exist +if [ ! -f ".env" ]; then + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to create an .env file?" "y"; then + echo "# Environment variables for the project" > .env + echo -e "${GREEN}โœ“ Created .env file${NC}" + else + echo -e "${YELLOW}Skipping .env file creation${NC}" + fi + else + echo "# Environment variables for the project" > .env + echo -e "${GREEN}โœ“ Created .env file${NC}" + fi +fi + +# Step 6: Create activation helper script +echo -e "\n${YELLOW}6. Creating activation helper...${NC}" + +CREATE_HELPER=true +if [ "$INTERACTIVE_MODE" = true ]; then + if [ -f "activate.sh" ]; then + echo -e "${YELLOW}Activation helper script already exists${NC}" + if confirm "Would you like to overwrite the existing activate.sh?" "n"; then + CREATE_HELPER=true + else + CREATE_HELPER=false + echo -e "${YELLOW}Keeping existing activation helper${NC}" + fi + else + if confirm "Would you like to create an activation helper script (activate.sh)?" "y"; then + CREATE_HELPER=true + else + CREATE_HELPER=false + echo -e "${YELLOW}Skipping activation helper creation${NC}" + fi + fi +fi + +if [ "$CREATE_HELPER" = true ]; then + cat > activate.sh << 'EOF' +#!/bin/bash +# Quick activation script for the virtual environment + +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate + echo "Virtual environment activated!" + echo "Python: $(which python) ($(python --version))" + echo "To deactivate, run: deactivate" +else + echo "Error: Virtual environment not found. Run ./uv_setup.sh first." + exit 1 +fi +EOF + + chmod +x activate.sh + echo -e "${GREEN}โœ“ Created activate.sh helper script${NC}" +fi + +# Step 7: Display final instructions +echo -e "\n${GREEN}=== Setup Complete! ===${NC}" +echo -e "\n${BLUE}To activate the virtual environment in your terminal:${NC}" +echo -e " ${YELLOW}source ${VENV_NAME}/bin/activate${NC}" +echo -e "\n${BLUE}Or use the helper script:${NC}" +echo -e " ${YELLOW}source ./activate.sh${NC}" + +echo -e "\n${BLUE}For VS Code:${NC}" +echo -e " 1. Open VS Code in this directory: ${YELLOW}code .${NC}" +echo -e " 2. When prompted, select the Python interpreter from ${YELLOW}${VENV_NAME}/bin/python${NC}" +echo -e " 3. Or press ${YELLOW}Cmd+Shift+P${NC} and search for 'Python: Select Interpreter'" + +echo -e "\n${BLUE}Current environment info:${NC}" +echo -e " Python version required: ${YELLOW}>=${REQUIRED_PYTHON_VERSION}, <3.11${NC}" +echo -e " Virtual environment: ${YELLOW}${PROJECT_ROOT}/${VENV_NAME}${NC}" +echo -e " Python executable: ${YELLOW}${PROJECT_ROOT}/${VENV_NAME}/bin/python${NC}" + +# Check if we're in an activated environment +if [ -n "$VIRTUAL_ENV" ]; then + echo -e "\n${GREEN}โœ“ Virtual environment is currently activated${NC}" +else + echo -e "\n${YELLOW}! Virtual environment is not activated in this shell${NC}" +fi + +# Add to .gitignore if not already there +if [ -f ".gitignore" ]; then + if ! grep -q "^${VENV_NAME}$" .gitignore; then + if [ "$INTERACTIVE_MODE" = true ]; then + if confirm "Would you like to add ${VENV_NAME} and activate.sh to .gitignore?" "y"; then + echo -e "\n# Virtual environment" >> .gitignore + echo "${VENV_NAME}" >> .gitignore + echo "activate.sh" >> .gitignore + echo -e "${GREEN}โœ“ Added ${VENV_NAME} to .gitignore${NC}" + fi + else + echo -e "\n# Virtual environment" >> .gitignore + echo "${VENV_NAME}" >> .gitignore + echo "activate.sh" >> .gitignore + echo -e "${GREEN}โœ“ Added ${VENV_NAME} to .gitignore${NC}" + fi + fi +fi + +# Summary of actions taken +echo -e "\n${BLUE}=== Setup Summary ===${NC}" +echo -e "${GREEN}โœ“ Completed Steps:${NC}" +[ -n "$(command -v uv)" ] && echo -e " โ€ข uv is installed and available" +[ -d "$VENV_NAME" ] && echo -e " โ€ข Virtual environment created/verified" +[ -f "$VENV_NAME/bin/python" ] && echo -e " โ€ข Python $($VENV_NAME/bin/python --version 2>&1 | cut -d' ' -f2) configured" +[ -f ".vscode/settings.json" ] && [ "$CONFIGURE_VSCODE" = true ] && echo -e " โ€ข VS Code settings configured" +[ -f "activate.sh" ] && [ "$CREATE_HELPER" = true ] && echo -e " โ€ข Activation helper created" +[ -f ".env" ] && echo -e " โ€ข .env file created" + +echo -e "\n${YELLOW}Next Steps:${NC}" +echo -e " 1. Activate the virtual environment: ${CYAN}source ${VENV_NAME}/bin/activate${NC}" +echo -e " 2. If you skipped dependency installation: ${CYAN}uv sync${NC}" +echo -e " 3. Open VS Code and select the Python interpreter if needed"