Worker Stability (#688)

This PR does three things:
1. Executes synchronous tool calls in thread pool allowing for up to 4 +
# of CPUs executions in parallel.
2. Makes force quitting via double SIGINT/SIGTERM possible and via
single SIGINT/SIGTERM + graceful shutdown timeout expiry possible, even
if there are active connections.
3. Sets `timeout_graceful_shutdown` to
`ARCADE_UVICORN_TIMEOUT_GRACEFUL_SHUTDOWN` env var if set, else defaults
to 15.
4. Disable the worker health check span to reduce noise

Tradeoffs:
Since this PR introduces executing synchronous tools via `await
asyncio.to_thread(func, **func_args)`, this means that there is no way
for the thread to be killed until it finishes. The ramifications of this
is that the force quitting logic that is also implemented in this PR has
to be very harsh `os._exit(1)` just in case there is a sync tool
actively executing. This means that `MCPApp` teardown logic will not
execute when force quitting is required. Although this was already the
case because we weren't previously able to force quit! This tradeoff is
justified for now since "parallel" tool executions will relieve us of
many worker timeouts that we are seeing in prod.

Future work:
Minimize/eliminate the need for `os._exit(1)` such that `MCPApp`
teardown logic will always execute, even when force quitting. The
solution will likely be moving away from `await asyncio.to_thread(func,
**func_args)` (while maintaining "parallelism" and then utilize the
`TaskTrackerMiddleware` introduced in this PR to cancel all of the
active HTTP requests.

Resolves PLT-713
This commit is contained in:
Eric Gustin 2025-11-20 11:13:41 -08:00 committed by GitHub
parent b4720c2988
commit 5602578b2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 480 additions and 40 deletions

View file

@ -59,7 +59,7 @@ class ToolExecutor:
if asyncio.iscoroutinefunction(func):
results = await func(**func_args)
else:
results = func(**func_args)
results = await asyncio.to_thread(func, **func_args)
# serialize the output model
output = await ToolExecutor._serialize_output(output_model, results)

View file

@ -1,6 +1,6 @@
[project]
name = "arcade-core"
version = "3.3.4"
version = "3.3.5"
description = "Arcade Core - Core library for Arcade platform"
readme = "README.md"
license = {text = "MIT"}

View file

@ -15,7 +15,6 @@ from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Literal, ParamSpec, TypeVar, cast
import uvicorn
from arcade_core.catalog import MaterializedTool, ToolCatalog, ToolDefinitionError
from arcade_tdk.auth import ToolAuthorization
from arcade_tdk.error_adapters import ErrorAdapter
@ -29,7 +28,7 @@ from arcade_mcp_server.server import MCPServer
from arcade_mcp_server.settings import MCPSettings, ServerSettings
from arcade_mcp_server.types import Prompt, PromptMessage, Resource
from arcade_mcp_server.usage import ServerTracker
from arcade_mcp_server.worker import create_arcade_mcp
from arcade_mcp_server.worker import create_arcade_mcp, serve_with_force_quit
P = ParamSpec("P")
T = TypeVar("T")
@ -410,13 +409,14 @@ class MCPApp:
port=port,
tool_count=len(self._catalog),
)
uvicorn.run(
app,
host=host,
port=port,
log_level=log_level,
reload=False, # MCPApp handles its own reload via parent/child process pattern
lifespan="on",
asyncio.run(
serve_with_force_quit(
app=app,
host=host,
port=port,
log_level=log_level,
)
)
@staticmethod

View file

@ -7,16 +7,19 @@ MCP Server endpoints over HTTP/SSE. MCP is always enabled in this integrated mod
import asyncio
import logging
from collections.abc import AsyncGenerator, AsyncIterator
import os
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from types import FrameType
from typing import Any
import uvicorn
from arcade_core.catalog import ToolCatalog
from arcade_serve.fastapi import FastAPIWorker, TaskTrackerMiddleware
from arcade_serve.fastapi.telemetry import OTELHandler
from arcade_serve.fastapi.worker import FastAPIWorker
from fastapi import FastAPI
from loguru import logger
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
@ -26,6 +29,44 @@ from arcade_mcp_server.settings import MCPSettings
from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager
class CustomUvicornServer(uvicorn.Server):
"""Uvicorn server with force quit support on double SIGINT/SIGTERM."""
def __init__(self, config: uvicorn.Config, task_tracker: TaskTrackerMiddleware):
super().__init__(config)
self.task_tracker = task_tracker
self._signal_count = 0
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
"""
Handle termination signals with force quit on second signal.
First signal (SIGINT/SIGTERM): Graceful shutdown
Second signal: Force quit with os._exit(1)
"""
self._signal_count += 1
if self._signal_count == 1:
logger.info("Shutting down gracefully. Press Ctrl+C again to force quit.")
self.should_exit = True
else:
logger.warning("Force quit triggered - exiting immediately")
os._exit(1)
async def _wait_tasks_to_complete(self) -> None:
try:
# Let Uvicorn's normal wait process run
await super()._wait_tasks_to_complete()
except asyncio.CancelledError:
# If we're cancelled (graceful shutdown time expired), then
# we need to cancel the active HTTP request tasks that we are tracking
logger.warning("Force quit triggered - cancelling all active requests")
cancelled = self.task_tracker.cancel_all_tasks()
logger.info(f"Cancelled {cancelled} active request(s)")
self.force_exit = True
os._exit(1)
@asynccontextmanager
async def create_lifespan(
catalog: ToolCatalog,
@ -125,6 +166,17 @@ def create_arcade_mcp(
lifespan=lifespan,
)
otel_handler.instrument_app(app)
task_tracker = TaskTrackerMiddleware(app)
app.state.task_tracker = task_tracker
# Since this middleware tracks all HTTP requests, it must be added first
@app.middleware("http")
async def track_tasks_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await task_tracker.dispatch(request, call_next)
app.add_middleware(AddTrailingSlashToPathMiddleware)
# Worker endpoints
@ -267,6 +319,32 @@ def create_arcade_mcp_factory() -> FastAPI:
)
async def serve_with_force_quit(
app: FastAPI,
host: str,
port: int,
log_level: str,
) -> None:
"""Serve the FastAPI app with force quit capability."""
timeout_graceful_shutdown = int(
os.environ.get("ARCADE_UVICORN_TIMEOUT_GRACEFUL_SHUTDOWN", "15")
)
config = uvicorn.Config(
app=app,
host=host,
port=port,
log_level=log_level,
lifespan="on",
timeout_graceful_shutdown=timeout_graceful_shutdown,
)
task_tracker = app.state.task_tracker
server = CustomUvicornServer(config, task_tracker)
await server.serve()
def run_arcade_mcp(
catalog: ToolCatalog,
host: str = "127.0.0.1",
@ -286,11 +364,14 @@ def run_arcade_mcp(
This is used for module execution (`arcade mcp` and `python -m arcade_mcp_server`) only.
MCPApp has its own reload mechanism.
"""
import os
log_level = "debug" if debug else "info"
if reload:
# TODO: This reload path uses uvicorn.run(), which bypasses serve_with_force_quit().
# This means that the server will not be able to force quit when there are active
# tool executions or active connections with MCP clients. For this reason, prefer
# to use MCPApp.run() for reload mode.
# Set env vars for the app factory to read later
os.environ["ARCADE_MCP_DEBUG"] = str(debug)
os.environ["ARCADE_MCP_OTEL_ENABLE"] = str(otel_enable)
@ -327,11 +408,11 @@ def run_arcade_mcp(
**kwargs,
)
uvicorn.run(
app,
host=host,
port=port,
log_level=log_level,
reload=reload,
lifespan="on",
asyncio.run(
serve_with_force_quit(
app=app,
host=host,
port=port,
log_level=log_level,
)
)

View file

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "arcade-mcp-server"
version = "1.7.3"
version = "1.8.0"
description = "Model Context Protocol (MCP) server framework for Arcade.dev"
readme = "README.md"
authors = [{ name = "Arcade.dev" }]

View file

@ -96,6 +96,4 @@ class HealthCheckComponent(WorkerComponent):
"""
Handle the request to check the health of the worker.
"""
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("HealthCheck"):
return self.worker.health_check()
return self.worker.health_check()

View file

@ -1,3 +1,4 @@
from .task_tracker import TaskTrackerMiddleware
from .worker import FastAPIWorker
__all__ = ["FastAPIWorker"]
__all__ = ["FastAPIWorker", "TaskTrackerMiddleware"]

View file

@ -0,0 +1,57 @@
import asyncio
import threading
from collections.abc import Awaitable, Callable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp
class TaskTrackerMiddleware(BaseHTTPMiddleware):
"""Middleware that tracks active HTTP request tasks for force quit functionality."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
self._active_tasks: set[asyncio.Task] = set()
self._lock = threading.Lock()
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Track the current task while handling the request."""
task = asyncio.current_task()
with self._lock:
if task:
self._active_tasks.add(task)
try:
response = await call_next(request)
return response
finally:
with self._lock:
if task:
self._active_tasks.discard(task)
def cancel_all_tasks(self) -> int:
"""
Cancel all tracked (active) HTTP request tasks.
This method must be called from within the asyncio event loop's thread
(not from background thread) because it calls task.cancel()
Returns:
int: Number of tasks successfully cancelled.
"""
# Make a copy to avoid mutation during iteration
with self._lock:
tasks_to_cancel = list(self._active_tasks)
cancelled_count = 0
for task in tasks_to_cancel:
if not task.done():
task.cancel()
cancelled_count += 1
return cancelled_count

View file

@ -1,6 +1,6 @@
[project]
name = "arcade-serve"
version = "3.0.0"
version = "3.1.0"
description = "Arcade Serve - Serving infrastructure for Arcade tools and workers"
readme = "README.md"
license = {text = "MIT"}

View file

@ -0,0 +1,21 @@
import asyncio
import time
from typing import Annotated
from arcade_mcp_server import Context, tool
@tool
async def slow_async_tool(
context: Context, delay_seconds: Annotated[float, "Delay in seconds"]
) -> str:
"""A tool that takes time to execute (async)."""
await asyncio.sleep(delay_seconds)
return f"Completed async task after {delay_seconds}s"
@tool
def slow_sync_tool(context: Context, delay_seconds: Annotated[float, "Delay in seconds"]) -> str:
"""A tool that takes time to execute (sync)."""
time.sleep(delay_seconds)
return f"Completed sync task after {delay_seconds}s"

View file

@ -322,7 +322,7 @@ async def test_stdio_e2e():
assert "result" in list_tools_response
assert "tools" in list_tools_response["result"]
tools = list_tools_response["result"]["tools"]
assert len(tools) == 7
assert len(tools) == 9
# 5. Call logging_tool
logging_id = client.send_request(
@ -558,7 +558,7 @@ async def test_http_e2e():
assert "result" in list_tools_data
assert "tools" in list_tools_data["result"]
tools = list_tools_data["result"]["tools"]
assert len(tools) == 7
assert len(tools) == 9
# 5. Call logging_tool
logging_request = build_jsonrpc_request(
@ -628,3 +628,287 @@ async def test_http_e2e():
except subprocess.TimeoutExpired:
process.kill()
process.wait()
@pytest.mark.asyncio
async def test_http_mcp_concurrent_tool_execution():
"""Test that multiple tools can execute concurrently via the /mcp route."""
process, port = start_mcp_server("http")
assert port is not None
base_url = f"http://127.0.0.1:{port}"
try:
wait_for_http_server_ready(port, timeout=10)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client:
# Initialize the connection with the server
init_request = build_jsonrpc_request(
"initialize",
{
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0"},
},
request_id=1,
)
init_response = await client.post("/mcp", json=init_request)
assert init_response.status_code == 200
session_id = init_response.headers.get("mcp-session-id")
assert session_id is not None
client.headers.update({"Mcp-Session-Id": session_id})
init_notif = build_jsonrpc_request(
"notifications/initialized", params=None, request_id=None
)
await client.post("/mcp", json=init_notif)
# Call the tool three times concurrently. Each tool call takes 1 second to execute.
# Since the server should be able to execute the tools in parallel, the total time should be around 1 second
delay_seconds = 1.0
tool_requests = [
build_jsonrpc_request(
"tools/call",
{
"name": "Server_SlowAsyncTool",
"arguments": {"delay_seconds": delay_seconds},
},
request_id=10,
),
build_jsonrpc_request(
"tools/call",
{
"name": "Server_SlowSyncTool",
"arguments": {"delay_seconds": delay_seconds},
},
request_id=11,
),
build_jsonrpc_request(
"tools/call",
{
"name": "Server_SlowSyncTool",
"arguments": {"delay_seconds": delay_seconds},
},
request_id=12,
),
]
start_time = time.time()
responses = await asyncio.gather(*[
client.post("/mcp", json=req) for req in tool_requests
])
total_time = time.time() - start_time
assert all(r.status_code == 200 for r in responses), "All requests should succeed"
for idx, response in enumerate(responses):
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == idx + 10
assert "result" in data
assert "error" not in data
assert f"after {delay_seconds}s" in data["result"]["content"][0]["text"]
# If parallel, should take ~1s, not ~3s
max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead
assert total_time < max_expected_time
finally:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
@pytest.mark.asyncio
async def test_http_worker_concurrent_tool_execution():
"""Test that multiple tools can execute concurrently via the /worker/tools/invoke route."""
process, port = start_mcp_server("http")
assert port is not None
base_url = f"http://127.0.0.1:{port}"
try:
wait_for_http_server_ready(port, timeout=10)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client:
# Call the tool three times concurrently. Each tool call takes 1 second to execute.
# Since the server should be able to execute the tools in parallel, the total time should be around 1 second
delay_seconds = 1.0
tool_requests = [
{
"execution_id": "worker_exec_0",
"tool": {
"toolkit": "Server",
"name": "SlowAsyncTool",
},
"inputs": {"delay_seconds": delay_seconds},
},
{
"execution_id": "worker_exec_1",
"tool": {
"toolkit": "Server",
"name": "SlowSyncTool",
},
"inputs": {"delay_seconds": delay_seconds},
},
{
"execution_id": "worker_exec_2",
"tool": {
"toolkit": "Server",
"name": "SlowSyncTool",
},
"inputs": {"delay_seconds": delay_seconds},
},
]
start_time = time.time()
responses = await asyncio.gather(*[
client.post("/worker/tools/invoke", json=req) for req in tool_requests
])
total_time = time.time() - start_time
assert all(r.status_code == 200 for r in responses), "All requests should succeed"
for idx, response in enumerate(responses):
data = response.json()
assert data["success"] is True
assert data["execution_id"] == f"worker_exec_{idx}"
assert data["output"]["value"] is not None
assert f"after {delay_seconds}s" in data["output"]["value"]
# If parallel, should take ~1s, not ~3s
max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead
assert total_time < max_expected_time
finally:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
@pytest.mark.asyncio
async def test_http_mixed_route_concurrent_execution():
"""Test concurrent tool execution across both MCP and Worker routes simultaneously."""
process, port = start_mcp_server("http")
assert port is not None
base_url = f"http://127.0.0.1:{port}"
try:
wait_for_http_server_ready(port, timeout=10)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client:
# First, set up the client-server connection for the /mcp route
init_request = build_jsonrpc_request(
"initialize",
{
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0"},
},
request_id=1,
)
init_response = await client.post("/mcp", json=init_request)
session_id = init_response.headers.get("mcp-session-id")
mcp_headers = {**headers, "Mcp-Session-Id": session_id}
delay_seconds = 1.0
await client.post(
"/mcp",
json=build_jsonrpc_request("notifications/initialized", None, None),
headers=mcp_headers,
)
# Prepare the tool calls for both routes
mcp_requests = [
build_jsonrpc_request(
"tools/call",
{
"name": "Server_SlowAsyncTool",
"arguments": {"delay_seconds": delay_seconds},
},
request_id=10,
),
build_jsonrpc_request(
"tools/call",
{
"name": "Server_SlowSyncTool",
"arguments": {"delay_seconds": delay_seconds},
},
request_id=11,
),
]
worker_requests = [
{
"execution_id": "worker_exec_0",
"tool": {
"toolkit": "Server",
"name": "SlowAsyncTool",
},
"inputs": {"delay_seconds": delay_seconds},
},
{
"execution_id": "worker_exec_1",
"tool": {
"toolkit": "Server",
"name": "SlowSyncTool",
},
"inputs": {"delay_seconds": delay_seconds},
},
]
# Execute
start_time = time.time()
mcp_responses, worker_responses = await asyncio.gather(
asyncio.gather(*[
client.post("/mcp", json=req, headers=mcp_headers) for req in mcp_requests
]),
asyncio.gather(*[
client.post("/worker/tools/invoke", json=req) for req in worker_requests
]),
)
total_time = time.time() - start_time
assert all(r.status_code == 200 for r in mcp_responses)
assert all(r.status_code == 200 for r in worker_responses)
# Called the tools four times concurrently (2 MCP + 2 Worker). Each tool call takes 1 second to execute.
# Since the server should be able to execute the tools in parallel, the total time should be around 1 second
max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead for mixed routes
assert total_time < max_expected_time
finally:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()

View file

@ -329,7 +329,7 @@ class TestMCPApp:
"""Test _create_and_run_server method with mocked dependencies."""
with (
patch("arcade_mcp_server.mcp_app.create_arcade_mcp") as mock_create,
patch("arcade_mcp_server.mcp_app.uvicorn") as mock_uvicorn,
patch("arcade_mcp_server.mcp_app.serve_with_force_quit") as mock_serve,
):
mock_fastapi_app = Mock()
mock_create.return_value = mock_fastapi_app
@ -343,19 +343,17 @@ class TestMCPApp:
mcp_settings=mcp_app._mcp_settings,
debug=False,
)
mock_uvicorn.run.assert_called_once_with(
mock_fastapi_app,
mock_serve.assert_called_once_with(
app=mock_fastapi_app,
host="127.0.0.1",
port=8000,
log_level="info",
reload=False,
lifespan="on",
)
# Test with DEBUG log level
with (
patch("arcade_mcp_server.mcp_app.create_arcade_mcp") as mock_create,
patch("arcade_mcp_server.mcp_app.uvicorn") as mock_uvicorn,
patch("arcade_mcp_server.mcp_app.serve_with_force_quit") as mock_serve,
):
mock_fastapi_app = Mock()
mock_create.return_value = mock_fastapi_app
@ -368,13 +366,11 @@ class TestMCPApp:
mcp_settings=mcp_app._mcp_settings,
debug=True,
)
mock_uvicorn.run.assert_called_once_with(
mock_fastapi_app,
mock_serve.assert_called_once_with(
app=mock_fastapi_app,
host="192.168.1.1",
port=9000,
log_level="debug",
reload=False,
lifespan="on",
)
def test_run_with_reload_spawns_child_process(self, mcp_app: MCPApp):

View file

@ -136,6 +136,8 @@ omit = [
"*/test_*",
"*/__pycache__/*",
]
parallel = true
patch = ["subprocess"]
[tool.coverage.report]
exclude_lines = [