diff --git a/libs/arcade-core/arcade_core/executor.py b/libs/arcade-core/arcade_core/executor.py index 3b2cca59..1dfc2e3f 100644 --- a/libs/arcade-core/arcade_core/executor.py +++ b/libs/arcade-core/arcade_core/executor.py @@ -59,7 +59,7 @@ class ToolExecutor: if asyncio.iscoroutinefunction(func): results = await func(**func_args) else: - results = func(**func_args) + results = await asyncio.to_thread(func, **func_args) # serialize the output model output = await ToolExecutor._serialize_output(output_model, results) diff --git a/libs/arcade-core/pyproject.toml b/libs/arcade-core/pyproject.toml index 2ca26c89..229d69aa 100644 --- a/libs/arcade-core/pyproject.toml +++ b/libs/arcade-core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-core" -version = "3.3.4" +version = "3.3.5" description = "Arcade Core - Core library for Arcade platform" readme = "README.md" license = {text = "MIT"} diff --git a/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py b/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py index 12bc18bb..72d035e4 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/mcp_app.py @@ -15,7 +15,6 @@ from pathlib import Path from types import ModuleType from typing import Any, Callable, Literal, ParamSpec, TypeVar, cast -import uvicorn from arcade_core.catalog import MaterializedTool, ToolCatalog, ToolDefinitionError from arcade_tdk.auth import ToolAuthorization from arcade_tdk.error_adapters import ErrorAdapter @@ -29,7 +28,7 @@ from arcade_mcp_server.server import MCPServer from arcade_mcp_server.settings import MCPSettings, ServerSettings from arcade_mcp_server.types import Prompt, PromptMessage, Resource from arcade_mcp_server.usage import ServerTracker -from arcade_mcp_server.worker import create_arcade_mcp +from arcade_mcp_server.worker import create_arcade_mcp, serve_with_force_quit P = ParamSpec("P") T = TypeVar("T") @@ -410,13 +409,14 @@ class MCPApp: port=port, tool_count=len(self._catalog), ) - uvicorn.run( - app, - host=host, - port=port, - log_level=log_level, - reload=False, # MCPApp handles its own reload via parent/child process pattern - lifespan="on", + + asyncio.run( + serve_with_force_quit( + app=app, + host=host, + port=port, + log_level=log_level, + ) ) @staticmethod diff --git a/libs/arcade-mcp-server/arcade_mcp_server/worker.py b/libs/arcade-mcp-server/arcade_mcp_server/worker.py index 8755b1e7..6db9764b 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/worker.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/worker.py @@ -7,16 +7,19 @@ MCP Server endpoints over HTTP/SSE. MCP is always enabled in this integrated mod import asyncio import logging -from collections.abc import AsyncGenerator, AsyncIterator +import os +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable from contextlib import asynccontextmanager +from types import FrameType from typing import Any import uvicorn from arcade_core.catalog import ToolCatalog +from arcade_serve.fastapi import FastAPIWorker, TaskTrackerMiddleware from arcade_serve.fastapi.telemetry import OTELHandler -from arcade_serve.fastapi.worker import FastAPIWorker from fastapi import FastAPI from loguru import logger +from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send @@ -26,6 +29,44 @@ from arcade_mcp_server.settings import MCPSettings from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager +class CustomUvicornServer(uvicorn.Server): + """Uvicorn server with force quit support on double SIGINT/SIGTERM.""" + + def __init__(self, config: uvicorn.Config, task_tracker: TaskTrackerMiddleware): + super().__init__(config) + self.task_tracker = task_tracker + self._signal_count = 0 + + def handle_exit(self, sig: int, frame: FrameType | None) -> None: + """ + Handle termination signals with force quit on second signal. + + First signal (SIGINT/SIGTERM): Graceful shutdown + Second signal: Force quit with os._exit(1) + """ + self._signal_count += 1 + + if self._signal_count == 1: + logger.info("Shutting down gracefully. Press Ctrl+C again to force quit.") + self.should_exit = True + else: + logger.warning("Force quit triggered - exiting immediately") + os._exit(1) + + async def _wait_tasks_to_complete(self) -> None: + try: + # Let Uvicorn's normal wait process run + await super()._wait_tasks_to_complete() + except asyncio.CancelledError: + # If we're cancelled (graceful shutdown time expired), then + # we need to cancel the active HTTP request tasks that we are tracking + logger.warning("Force quit triggered - cancelling all active requests") + cancelled = self.task_tracker.cancel_all_tasks() + logger.info(f"Cancelled {cancelled} active request(s)") + self.force_exit = True + os._exit(1) + + @asynccontextmanager async def create_lifespan( catalog: ToolCatalog, @@ -125,6 +166,17 @@ def create_arcade_mcp( lifespan=lifespan, ) otel_handler.instrument_app(app) + + task_tracker = TaskTrackerMiddleware(app) + app.state.task_tracker = task_tracker + + # Since this middleware tracks all HTTP requests, it must be added first + @app.middleware("http") + async def track_tasks_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + return await task_tracker.dispatch(request, call_next) + app.add_middleware(AddTrailingSlashToPathMiddleware) # Worker endpoints @@ -267,6 +319,32 @@ def create_arcade_mcp_factory() -> FastAPI: ) +async def serve_with_force_quit( + app: FastAPI, + host: str, + port: int, + log_level: str, +) -> None: + """Serve the FastAPI app with force quit capability.""" + timeout_graceful_shutdown = int( + os.environ.get("ARCADE_UVICORN_TIMEOUT_GRACEFUL_SHUTDOWN", "15") + ) + + config = uvicorn.Config( + app=app, + host=host, + port=port, + log_level=log_level, + lifespan="on", + timeout_graceful_shutdown=timeout_graceful_shutdown, + ) + + task_tracker = app.state.task_tracker + server = CustomUvicornServer(config, task_tracker) + + await server.serve() + + def run_arcade_mcp( catalog: ToolCatalog, host: str = "127.0.0.1", @@ -286,11 +364,14 @@ def run_arcade_mcp( This is used for module execution (`arcade mcp` and `python -m arcade_mcp_server`) only. MCPApp has its own reload mechanism. """ - import os - log_level = "debug" if debug else "info" if reload: + # TODO: This reload path uses uvicorn.run(), which bypasses serve_with_force_quit(). + # This means that the server will not be able to force quit when there are active + # tool executions or active connections with MCP clients. For this reason, prefer + # to use MCPApp.run() for reload mode. + # Set env vars for the app factory to read later os.environ["ARCADE_MCP_DEBUG"] = str(debug) os.environ["ARCADE_MCP_OTEL_ENABLE"] = str(otel_enable) @@ -327,11 +408,11 @@ def run_arcade_mcp( **kwargs, ) - uvicorn.run( - app, - host=host, - port=port, - log_level=log_level, - reload=reload, - lifespan="on", + asyncio.run( + serve_with_force_quit( + app=app, + host=host, + port=port, + log_level=log_level, + ) ) diff --git a/libs/arcade-mcp-server/pyproject.toml b/libs/arcade-mcp-server/pyproject.toml index 6e38eaa5..3ef84b5e 100644 --- a/libs/arcade-mcp-server/pyproject.toml +++ b/libs/arcade-mcp-server/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "arcade-mcp-server" -version = "1.7.3" +version = "1.8.0" description = "Model Context Protocol (MCP) server framework for Arcade.dev" readme = "README.md" authors = [{ name = "Arcade.dev" }] diff --git a/libs/arcade-serve/arcade_serve/core/components.py b/libs/arcade-serve/arcade_serve/core/components.py index e3d93072..01d385c1 100644 --- a/libs/arcade-serve/arcade_serve/core/components.py +++ b/libs/arcade-serve/arcade_serve/core/components.py @@ -96,6 +96,4 @@ class HealthCheckComponent(WorkerComponent): """ Handle the request to check the health of the worker. """ - tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span("HealthCheck"): - return self.worker.health_check() + return self.worker.health_check() diff --git a/libs/arcade-serve/arcade_serve/fastapi/__init__.py b/libs/arcade-serve/arcade_serve/fastapi/__init__.py index f3c85996..d141a4b8 100644 --- a/libs/arcade-serve/arcade_serve/fastapi/__init__.py +++ b/libs/arcade-serve/arcade_serve/fastapi/__init__.py @@ -1,3 +1,4 @@ +from .task_tracker import TaskTrackerMiddleware from .worker import FastAPIWorker -__all__ = ["FastAPIWorker"] +__all__ = ["FastAPIWorker", "TaskTrackerMiddleware"] diff --git a/libs/arcade-serve/arcade_serve/fastapi/task_tracker.py b/libs/arcade-serve/arcade_serve/fastapi/task_tracker.py new file mode 100644 index 00000000..6434c524 --- /dev/null +++ b/libs/arcade-serve/arcade_serve/fastapi/task_tracker.py @@ -0,0 +1,57 @@ +import asyncio +import threading +from collections.abc import Awaitable, Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp + + +class TaskTrackerMiddleware(BaseHTTPMiddleware): + """Middleware that tracks active HTTP request tasks for force quit functionality.""" + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + self._active_tasks: set[asyncio.Task] = set() + self._lock = threading.Lock() + + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + """Track the current task while handling the request.""" + task = asyncio.current_task() + + with self._lock: + if task: + self._active_tasks.add(task) + + try: + response = await call_next(request) + return response + finally: + with self._lock: + if task: + self._active_tasks.discard(task) + + def cancel_all_tasks(self) -> int: + """ + Cancel all tracked (active) HTTP request tasks. + + This method must be called from within the asyncio event loop's thread + (not from background thread) because it calls task.cancel() + + Returns: + int: Number of tasks successfully cancelled. + """ + # Make a copy to avoid mutation during iteration + with self._lock: + tasks_to_cancel = list(self._active_tasks) + + cancelled_count = 0 + for task in tasks_to_cancel: + if not task.done(): + task.cancel() + cancelled_count += 1 + + return cancelled_count diff --git a/libs/arcade-serve/pyproject.toml b/libs/arcade-serve/pyproject.toml index 1e1548c7..b3ef31c6 100644 --- a/libs/arcade-serve/pyproject.toml +++ b/libs/arcade-serve/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-serve" -version = "3.0.0" +version = "3.1.0" description = "Arcade Serve - Serving infrastructure for Arcade tools and workers" readme = "README.md" license = {text = "MIT"} diff --git a/libs/tests/arcade_mcp_server/integration/server/src/server/timing_tools.py b/libs/tests/arcade_mcp_server/integration/server/src/server/timing_tools.py new file mode 100644 index 00000000..70d2a094 --- /dev/null +++ b/libs/tests/arcade_mcp_server/integration/server/src/server/timing_tools.py @@ -0,0 +1,21 @@ +import asyncio +import time +from typing import Annotated + +from arcade_mcp_server import Context, tool + + +@tool +async def slow_async_tool( + context: Context, delay_seconds: Annotated[float, "Delay in seconds"] +) -> str: + """A tool that takes time to execute (async).""" + await asyncio.sleep(delay_seconds) + return f"Completed async task after {delay_seconds}s" + + +@tool +def slow_sync_tool(context: Context, delay_seconds: Annotated[float, "Delay in seconds"]) -> str: + """A tool that takes time to execute (sync).""" + time.sleep(delay_seconds) + return f"Completed sync task after {delay_seconds}s" diff --git a/libs/tests/arcade_mcp_server/integration/test_end_to_end.py b/libs/tests/arcade_mcp_server/integration/test_end_to_end.py index befc689a..be716ed2 100644 --- a/libs/tests/arcade_mcp_server/integration/test_end_to_end.py +++ b/libs/tests/arcade_mcp_server/integration/test_end_to_end.py @@ -322,7 +322,7 @@ async def test_stdio_e2e(): assert "result" in list_tools_response assert "tools" in list_tools_response["result"] tools = list_tools_response["result"]["tools"] - assert len(tools) == 7 + assert len(tools) == 9 # 5. Call logging_tool logging_id = client.send_request( @@ -558,7 +558,7 @@ async def test_http_e2e(): assert "result" in list_tools_data assert "tools" in list_tools_data["result"] tools = list_tools_data["result"]["tools"] - assert len(tools) == 7 + assert len(tools) == 9 # 5. Call logging_tool logging_request = build_jsonrpc_request( @@ -628,3 +628,287 @@ async def test_http_e2e(): except subprocess.TimeoutExpired: process.kill() process.wait() + + +@pytest.mark.asyncio +async def test_http_mcp_concurrent_tool_execution(): + """Test that multiple tools can execute concurrently via the /mcp route.""" + process, port = start_mcp_server("http") + assert port is not None + + base_url = f"http://127.0.0.1:{port}" + + try: + wait_for_http_server_ready(port, timeout=10) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client: + # Initialize the connection with the server + init_request = build_jsonrpc_request( + "initialize", + { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + request_id=1, + ) + + init_response = await client.post("/mcp", json=init_request) + assert init_response.status_code == 200 + session_id = init_response.headers.get("mcp-session-id") + assert session_id is not None + + client.headers.update({"Mcp-Session-Id": session_id}) + + init_notif = build_jsonrpc_request( + "notifications/initialized", params=None, request_id=None + ) + await client.post("/mcp", json=init_notif) + + # Call the tool three times concurrently. Each tool call takes 1 second to execute. + # Since the server should be able to execute the tools in parallel, the total time should be around 1 second + delay_seconds = 1.0 + + tool_requests = [ + build_jsonrpc_request( + "tools/call", + { + "name": "Server_SlowAsyncTool", + "arguments": {"delay_seconds": delay_seconds}, + }, + request_id=10, + ), + build_jsonrpc_request( + "tools/call", + { + "name": "Server_SlowSyncTool", + "arguments": {"delay_seconds": delay_seconds}, + }, + request_id=11, + ), + build_jsonrpc_request( + "tools/call", + { + "name": "Server_SlowSyncTool", + "arguments": {"delay_seconds": delay_seconds}, + }, + request_id=12, + ), + ] + + start_time = time.time() + responses = await asyncio.gather(*[ + client.post("/mcp", json=req) for req in tool_requests + ]) + total_time = time.time() - start_time + + assert all(r.status_code == 200 for r in responses), "All requests should succeed" + + for idx, response in enumerate(responses): + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == idx + 10 + assert "result" in data + assert "error" not in data + assert f"after {delay_seconds}s" in data["result"]["content"][0]["text"] + + # If parallel, should take ~1s, not ~3s + max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead + assert total_time < max_expected_time + + finally: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + +@pytest.mark.asyncio +async def test_http_worker_concurrent_tool_execution(): + """Test that multiple tools can execute concurrently via the /worker/tools/invoke route.""" + process, port = start_mcp_server("http") + assert port is not None + + base_url = f"http://127.0.0.1:{port}" + + try: + wait_for_http_server_ready(port, timeout=10) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client: + # Call the tool three times concurrently. Each tool call takes 1 second to execute. + # Since the server should be able to execute the tools in parallel, the total time should be around 1 second + delay_seconds = 1.0 + + tool_requests = [ + { + "execution_id": "worker_exec_0", + "tool": { + "toolkit": "Server", + "name": "SlowAsyncTool", + }, + "inputs": {"delay_seconds": delay_seconds}, + }, + { + "execution_id": "worker_exec_1", + "tool": { + "toolkit": "Server", + "name": "SlowSyncTool", + }, + "inputs": {"delay_seconds": delay_seconds}, + }, + { + "execution_id": "worker_exec_2", + "tool": { + "toolkit": "Server", + "name": "SlowSyncTool", + }, + "inputs": {"delay_seconds": delay_seconds}, + }, + ] + + start_time = time.time() + responses = await asyncio.gather(*[ + client.post("/worker/tools/invoke", json=req) for req in tool_requests + ]) + total_time = time.time() - start_time + + assert all(r.status_code == 200 for r in responses), "All requests should succeed" + + for idx, response in enumerate(responses): + data = response.json() + assert data["success"] is True + assert data["execution_id"] == f"worker_exec_{idx}" + assert data["output"]["value"] is not None + assert f"after {delay_seconds}s" in data["output"]["value"] + + # If parallel, should take ~1s, not ~3s + max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead + assert total_time < max_expected_time + + finally: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + +@pytest.mark.asyncio +async def test_http_mixed_route_concurrent_execution(): + """Test concurrent tool execution across both MCP and Worker routes simultaneously.""" + process, port = start_mcp_server("http") + assert port is not None + + base_url = f"http://127.0.0.1:{port}" + + try: + wait_for_http_server_ready(port, timeout=10) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client: + # First, set up the client-server connection for the /mcp route + init_request = build_jsonrpc_request( + "initialize", + { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + request_id=1, + ) + init_response = await client.post("/mcp", json=init_request) + session_id = init_response.headers.get("mcp-session-id") + + mcp_headers = {**headers, "Mcp-Session-Id": session_id} + + delay_seconds = 1.0 + + await client.post( + "/mcp", + json=build_jsonrpc_request("notifications/initialized", None, None), + headers=mcp_headers, + ) + + # Prepare the tool calls for both routes + mcp_requests = [ + build_jsonrpc_request( + "tools/call", + { + "name": "Server_SlowAsyncTool", + "arguments": {"delay_seconds": delay_seconds}, + }, + request_id=10, + ), + build_jsonrpc_request( + "tools/call", + { + "name": "Server_SlowSyncTool", + "arguments": {"delay_seconds": delay_seconds}, + }, + request_id=11, + ), + ] + + worker_requests = [ + { + "execution_id": "worker_exec_0", + "tool": { + "toolkit": "Server", + "name": "SlowAsyncTool", + }, + "inputs": {"delay_seconds": delay_seconds}, + }, + { + "execution_id": "worker_exec_1", + "tool": { + "toolkit": "Server", + "name": "SlowSyncTool", + }, + "inputs": {"delay_seconds": delay_seconds}, + }, + ] + + # Execute + start_time = time.time() + mcp_responses, worker_responses = await asyncio.gather( + asyncio.gather(*[ + client.post("/mcp", json=req, headers=mcp_headers) for req in mcp_requests + ]), + asyncio.gather(*[ + client.post("/worker/tools/invoke", json=req) for req in worker_requests + ]), + ) + total_time = time.time() - start_time + + assert all(r.status_code == 200 for r in mcp_responses) + assert all(r.status_code == 200 for r in worker_responses) + + # Called the tools four times concurrently (2 MCP + 2 Worker). Each tool call takes 1 second to execute. + # Since the server should be able to execute the tools in parallel, the total time should be around 1 second + max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead for mixed routes + assert total_time < max_expected_time + + finally: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() diff --git a/libs/tests/arcade_mcp_server/test_mcp_app.py b/libs/tests/arcade_mcp_server/test_mcp_app.py index 4192b51e..a81865d0 100644 --- a/libs/tests/arcade_mcp_server/test_mcp_app.py +++ b/libs/tests/arcade_mcp_server/test_mcp_app.py @@ -329,7 +329,7 @@ class TestMCPApp: """Test _create_and_run_server method with mocked dependencies.""" with ( patch("arcade_mcp_server.mcp_app.create_arcade_mcp") as mock_create, - patch("arcade_mcp_server.mcp_app.uvicorn") as mock_uvicorn, + patch("arcade_mcp_server.mcp_app.serve_with_force_quit") as mock_serve, ): mock_fastapi_app = Mock() mock_create.return_value = mock_fastapi_app @@ -343,19 +343,17 @@ class TestMCPApp: mcp_settings=mcp_app._mcp_settings, debug=False, ) - mock_uvicorn.run.assert_called_once_with( - mock_fastapi_app, + mock_serve.assert_called_once_with( + app=mock_fastapi_app, host="127.0.0.1", port=8000, log_level="info", - reload=False, - lifespan="on", ) # Test with DEBUG log level with ( patch("arcade_mcp_server.mcp_app.create_arcade_mcp") as mock_create, - patch("arcade_mcp_server.mcp_app.uvicorn") as mock_uvicorn, + patch("arcade_mcp_server.mcp_app.serve_with_force_quit") as mock_serve, ): mock_fastapi_app = Mock() mock_create.return_value = mock_fastapi_app @@ -368,13 +366,11 @@ class TestMCPApp: mcp_settings=mcp_app._mcp_settings, debug=True, ) - mock_uvicorn.run.assert_called_once_with( - mock_fastapi_app, + mock_serve.assert_called_once_with( + app=mock_fastapi_app, host="192.168.1.1", port=9000, log_level="debug", - reload=False, - lifespan="on", ) def test_run_with_reload_spawns_child_process(self, mcp_app: MCPApp): diff --git a/pyproject.toml b/pyproject.toml index 1afdbf0d..deaa7fdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,8 @@ omit = [ "*/test_*", "*/__pycache__/*", ] +parallel = true +patch = ["subprocess"] [tool.coverage.report] exclude_lines = [