Worker Stability (#688)
This PR does three things: 1. Executes synchronous tool calls in thread pool allowing for up to 4 + # of CPUs executions in parallel. 2. Makes force quitting via double SIGINT/SIGTERM possible and via single SIGINT/SIGTERM + graceful shutdown timeout expiry possible, even if there are active connections. 3. Sets `timeout_graceful_shutdown` to `ARCADE_UVICORN_TIMEOUT_GRACEFUL_SHUTDOWN` env var if set, else defaults to 15. 4. Disable the worker health check span to reduce noise Tradeoffs: Since this PR introduces executing synchronous tools via `await asyncio.to_thread(func, **func_args)`, this means that there is no way for the thread to be killed until it finishes. The ramifications of this is that the force quitting logic that is also implemented in this PR has to be very harsh `os._exit(1)` just in case there is a sync tool actively executing. This means that `MCPApp` teardown logic will not execute when force quitting is required. Although this was already the case because we weren't previously able to force quit! This tradeoff is justified for now since "parallel" tool executions will relieve us of many worker timeouts that we are seeing in prod. Future work: Minimize/eliminate the need for `os._exit(1)` such that `MCPApp` teardown logic will always execute, even when force quitting. The solution will likely be moving away from `await asyncio.to_thread(func, **func_args)` (while maintaining "parallelism" and then utilize the `TaskTrackerMiddleware` introduced in this PR to cancel all of the active HTTP requests. Resolves PLT-713
This commit is contained in:
parent
b4720c2988
commit
5602578b2f
13 changed files with 480 additions and 40 deletions
|
|
@ -59,7 +59,7 @@ class ToolExecutor:
|
|||
if asyncio.iscoroutinefunction(func):
|
||||
results = await func(**func_args)
|
||||
else:
|
||||
results = func(**func_args)
|
||||
results = await asyncio.to_thread(func, **func_args)
|
||||
|
||||
# serialize the output model
|
||||
output = await ToolExecutor._serialize_output(output_model, results)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-core"
|
||||
version = "3.3.4"
|
||||
version = "3.3.5"
|
||||
description = "Arcade Core - Core library for Arcade platform"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from pathlib import Path
|
|||
from types import ModuleType
|
||||
from typing import Any, Callable, Literal, ParamSpec, TypeVar, cast
|
||||
|
||||
import uvicorn
|
||||
from arcade_core.catalog import MaterializedTool, ToolCatalog, ToolDefinitionError
|
||||
from arcade_tdk.auth import ToolAuthorization
|
||||
from arcade_tdk.error_adapters import ErrorAdapter
|
||||
|
|
@ -29,7 +28,7 @@ from arcade_mcp_server.server import MCPServer
|
|||
from arcade_mcp_server.settings import MCPSettings, ServerSettings
|
||||
from arcade_mcp_server.types import Prompt, PromptMessage, Resource
|
||||
from arcade_mcp_server.usage import ServerTracker
|
||||
from arcade_mcp_server.worker import create_arcade_mcp
|
||||
from arcade_mcp_server.worker import create_arcade_mcp, serve_with_force_quit
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
|
@ -410,13 +409,14 @@ class MCPApp:
|
|||
port=port,
|
||||
tool_count=len(self._catalog),
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
reload=False, # MCPApp handles its own reload via parent/child process pattern
|
||||
lifespan="on",
|
||||
|
||||
asyncio.run(
|
||||
serve_with_force_quit(
|
||||
app=app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -7,16 +7,19 @@ MCP Server endpoints over HTTP/SSE. MCP is always enabled in this integrated mod
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from types import FrameType
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from arcade_core.catalog import ToolCatalog
|
||||
from arcade_serve.fastapi import FastAPIWorker, TaskTrackerMiddleware
|
||||
from arcade_serve.fastapi.telemetry import OTELHandler
|
||||
from arcade_serve.fastapi.worker import FastAPIWorker
|
||||
from fastapi import FastAPI
|
||||
from loguru import logger
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
|
|
@ -26,6 +29,44 @@ from arcade_mcp_server.settings import MCPSettings
|
|||
from arcade_mcp_server.transports.http_session_manager import HTTPSessionManager
|
||||
|
||||
|
||||
class CustomUvicornServer(uvicorn.Server):
|
||||
"""Uvicorn server with force quit support on double SIGINT/SIGTERM."""
|
||||
|
||||
def __init__(self, config: uvicorn.Config, task_tracker: TaskTrackerMiddleware):
|
||||
super().__init__(config)
|
||||
self.task_tracker = task_tracker
|
||||
self._signal_count = 0
|
||||
|
||||
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
|
||||
"""
|
||||
Handle termination signals with force quit on second signal.
|
||||
|
||||
First signal (SIGINT/SIGTERM): Graceful shutdown
|
||||
Second signal: Force quit with os._exit(1)
|
||||
"""
|
||||
self._signal_count += 1
|
||||
|
||||
if self._signal_count == 1:
|
||||
logger.info("Shutting down gracefully. Press Ctrl+C again to force quit.")
|
||||
self.should_exit = True
|
||||
else:
|
||||
logger.warning("Force quit triggered - exiting immediately")
|
||||
os._exit(1)
|
||||
|
||||
async def _wait_tasks_to_complete(self) -> None:
|
||||
try:
|
||||
# Let Uvicorn's normal wait process run
|
||||
await super()._wait_tasks_to_complete()
|
||||
except asyncio.CancelledError:
|
||||
# If we're cancelled (graceful shutdown time expired), then
|
||||
# we need to cancel the active HTTP request tasks that we are tracking
|
||||
logger.warning("Force quit triggered - cancelling all active requests")
|
||||
cancelled = self.task_tracker.cancel_all_tasks()
|
||||
logger.info(f"Cancelled {cancelled} active request(s)")
|
||||
self.force_exit = True
|
||||
os._exit(1)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_lifespan(
|
||||
catalog: ToolCatalog,
|
||||
|
|
@ -125,6 +166,17 @@ def create_arcade_mcp(
|
|||
lifespan=lifespan,
|
||||
)
|
||||
otel_handler.instrument_app(app)
|
||||
|
||||
task_tracker = TaskTrackerMiddleware(app)
|
||||
app.state.task_tracker = task_tracker
|
||||
|
||||
# Since this middleware tracks all HTTP requests, it must be added first
|
||||
@app.middleware("http")
|
||||
async def track_tasks_middleware(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
return await task_tracker.dispatch(request, call_next)
|
||||
|
||||
app.add_middleware(AddTrailingSlashToPathMiddleware)
|
||||
|
||||
# Worker endpoints
|
||||
|
|
@ -267,6 +319,32 @@ def create_arcade_mcp_factory() -> FastAPI:
|
|||
)
|
||||
|
||||
|
||||
async def serve_with_force_quit(
|
||||
app: FastAPI,
|
||||
host: str,
|
||||
port: int,
|
||||
log_level: str,
|
||||
) -> None:
|
||||
"""Serve the FastAPI app with force quit capability."""
|
||||
timeout_graceful_shutdown = int(
|
||||
os.environ.get("ARCADE_UVICORN_TIMEOUT_GRACEFUL_SHUTDOWN", "15")
|
||||
)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
lifespan="on",
|
||||
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
||||
)
|
||||
|
||||
task_tracker = app.state.task_tracker
|
||||
server = CustomUvicornServer(config, task_tracker)
|
||||
|
||||
await server.serve()
|
||||
|
||||
|
||||
def run_arcade_mcp(
|
||||
catalog: ToolCatalog,
|
||||
host: str = "127.0.0.1",
|
||||
|
|
@ -286,11 +364,14 @@ def run_arcade_mcp(
|
|||
This is used for module execution (`arcade mcp` and `python -m arcade_mcp_server`) only.
|
||||
MCPApp has its own reload mechanism.
|
||||
"""
|
||||
import os
|
||||
|
||||
log_level = "debug" if debug else "info"
|
||||
|
||||
if reload:
|
||||
# TODO: This reload path uses uvicorn.run(), which bypasses serve_with_force_quit().
|
||||
# This means that the server will not be able to force quit when there are active
|
||||
# tool executions or active connections with MCP clients. For this reason, prefer
|
||||
# to use MCPApp.run() for reload mode.
|
||||
|
||||
# Set env vars for the app factory to read later
|
||||
os.environ["ARCADE_MCP_DEBUG"] = str(debug)
|
||||
os.environ["ARCADE_MCP_OTEL_ENABLE"] = str(otel_enable)
|
||||
|
|
@ -327,11 +408,11 @@ def run_arcade_mcp(
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
reload=reload,
|
||||
lifespan="on",
|
||||
asyncio.run(
|
||||
serve_with_force_quit(
|
||||
app=app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||
|
||||
[project]
|
||||
name = "arcade-mcp-server"
|
||||
version = "1.7.3"
|
||||
version = "1.8.0"
|
||||
description = "Model Context Protocol (MCP) server framework for Arcade.dev"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Arcade.dev" }]
|
||||
|
|
|
|||
|
|
@ -96,6 +96,4 @@ class HealthCheckComponent(WorkerComponent):
|
|||
"""
|
||||
Handle the request to check the health of the worker.
|
||||
"""
|
||||
tracer = trace.get_tracer(__name__)
|
||||
with tracer.start_as_current_span("HealthCheck"):
|
||||
return self.worker.health_check()
|
||||
return self.worker.health_check()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .task_tracker import TaskTrackerMiddleware
|
||||
from .worker import FastAPIWorker
|
||||
|
||||
__all__ = ["FastAPIWorker"]
|
||||
__all__ = ["FastAPIWorker", "TaskTrackerMiddleware"]
|
||||
|
|
|
|||
57
libs/arcade-serve/arcade_serve/fastapi/task_tracker.py
Normal file
57
libs/arcade-serve/arcade_serve/fastapi/task_tracker.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import asyncio
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
|
||||
class TaskTrackerMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that tracks active HTTP request tasks for force quit functionality."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
self._active_tasks: set[asyncio.Task] = set()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Track the current task while handling the request."""
|
||||
task = asyncio.current_task()
|
||||
|
||||
with self._lock:
|
||||
if task:
|
||||
self._active_tasks.add(task)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
with self._lock:
|
||||
if task:
|
||||
self._active_tasks.discard(task)
|
||||
|
||||
def cancel_all_tasks(self) -> int:
|
||||
"""
|
||||
Cancel all tracked (active) HTTP request tasks.
|
||||
|
||||
This method must be called from within the asyncio event loop's thread
|
||||
(not from background thread) because it calls task.cancel()
|
||||
|
||||
Returns:
|
||||
int: Number of tasks successfully cancelled.
|
||||
"""
|
||||
# Make a copy to avoid mutation during iteration
|
||||
with self._lock:
|
||||
tasks_to_cancel = list(self._active_tasks)
|
||||
|
||||
cancelled_count = 0
|
||||
for task in tasks_to_cancel:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_count += 1
|
||||
|
||||
return cancelled_count
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-serve"
|
||||
version = "3.0.0"
|
||||
version = "3.1.0"
|
||||
description = "Arcade Serve - Serving infrastructure for Arcade tools and workers"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import Context, tool
|
||||
|
||||
|
||||
@tool
|
||||
async def slow_async_tool(
|
||||
context: Context, delay_seconds: Annotated[float, "Delay in seconds"]
|
||||
) -> str:
|
||||
"""A tool that takes time to execute (async)."""
|
||||
await asyncio.sleep(delay_seconds)
|
||||
return f"Completed async task after {delay_seconds}s"
|
||||
|
||||
|
||||
@tool
|
||||
def slow_sync_tool(context: Context, delay_seconds: Annotated[float, "Delay in seconds"]) -> str:
|
||||
"""A tool that takes time to execute (sync)."""
|
||||
time.sleep(delay_seconds)
|
||||
return f"Completed sync task after {delay_seconds}s"
|
||||
|
|
@ -322,7 +322,7 @@ async def test_stdio_e2e():
|
|||
assert "result" in list_tools_response
|
||||
assert "tools" in list_tools_response["result"]
|
||||
tools = list_tools_response["result"]["tools"]
|
||||
assert len(tools) == 7
|
||||
assert len(tools) == 9
|
||||
|
||||
# 5. Call logging_tool
|
||||
logging_id = client.send_request(
|
||||
|
|
@ -558,7 +558,7 @@ async def test_http_e2e():
|
|||
assert "result" in list_tools_data
|
||||
assert "tools" in list_tools_data["result"]
|
||||
tools = list_tools_data["result"]["tools"]
|
||||
assert len(tools) == 7
|
||||
assert len(tools) == 9
|
||||
|
||||
# 5. Call logging_tool
|
||||
logging_request = build_jsonrpc_request(
|
||||
|
|
@ -628,3 +628,287 @@ async def test_http_e2e():
|
|||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_mcp_concurrent_tool_execution():
|
||||
"""Test that multiple tools can execute concurrently via the /mcp route."""
|
||||
process, port = start_mcp_server("http")
|
||||
assert port is not None
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
try:
|
||||
wait_for_http_server_ready(port, timeout=10)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client:
|
||||
# Initialize the connection with the server
|
||||
init_request = build_jsonrpc_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": "2025-06-18",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.0"},
|
||||
},
|
||||
request_id=1,
|
||||
)
|
||||
|
||||
init_response = await client.post("/mcp", json=init_request)
|
||||
assert init_response.status_code == 200
|
||||
session_id = init_response.headers.get("mcp-session-id")
|
||||
assert session_id is not None
|
||||
|
||||
client.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
init_notif = build_jsonrpc_request(
|
||||
"notifications/initialized", params=None, request_id=None
|
||||
)
|
||||
await client.post("/mcp", json=init_notif)
|
||||
|
||||
# Call the tool three times concurrently. Each tool call takes 1 second to execute.
|
||||
# Since the server should be able to execute the tools in parallel, the total time should be around 1 second
|
||||
delay_seconds = 1.0
|
||||
|
||||
tool_requests = [
|
||||
build_jsonrpc_request(
|
||||
"tools/call",
|
||||
{
|
||||
"name": "Server_SlowAsyncTool",
|
||||
"arguments": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
request_id=10,
|
||||
),
|
||||
build_jsonrpc_request(
|
||||
"tools/call",
|
||||
{
|
||||
"name": "Server_SlowSyncTool",
|
||||
"arguments": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
request_id=11,
|
||||
),
|
||||
build_jsonrpc_request(
|
||||
"tools/call",
|
||||
{
|
||||
"name": "Server_SlowSyncTool",
|
||||
"arguments": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
request_id=12,
|
||||
),
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*[
|
||||
client.post("/mcp", json=req) for req in tool_requests
|
||||
])
|
||||
total_time = time.time() - start_time
|
||||
|
||||
assert all(r.status_code == 200 for r in responses), "All requests should succeed"
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
data = response.json()
|
||||
assert data["jsonrpc"] == "2.0"
|
||||
assert data["id"] == idx + 10
|
||||
assert "result" in data
|
||||
assert "error" not in data
|
||||
assert f"after {delay_seconds}s" in data["result"]["content"][0]["text"]
|
||||
|
||||
# If parallel, should take ~1s, not ~3s
|
||||
max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead
|
||||
assert total_time < max_expected_time
|
||||
|
||||
finally:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_worker_concurrent_tool_execution():
|
||||
"""Test that multiple tools can execute concurrently via the /worker/tools/invoke route."""
|
||||
process, port = start_mcp_server("http")
|
||||
assert port is not None
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
try:
|
||||
wait_for_http_server_ready(port, timeout=10)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client:
|
||||
# Call the tool three times concurrently. Each tool call takes 1 second to execute.
|
||||
# Since the server should be able to execute the tools in parallel, the total time should be around 1 second
|
||||
delay_seconds = 1.0
|
||||
|
||||
tool_requests = [
|
||||
{
|
||||
"execution_id": "worker_exec_0",
|
||||
"tool": {
|
||||
"toolkit": "Server",
|
||||
"name": "SlowAsyncTool",
|
||||
},
|
||||
"inputs": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
{
|
||||
"execution_id": "worker_exec_1",
|
||||
"tool": {
|
||||
"toolkit": "Server",
|
||||
"name": "SlowSyncTool",
|
||||
},
|
||||
"inputs": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
{
|
||||
"execution_id": "worker_exec_2",
|
||||
"tool": {
|
||||
"toolkit": "Server",
|
||||
"name": "SlowSyncTool",
|
||||
},
|
||||
"inputs": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*[
|
||||
client.post("/worker/tools/invoke", json=req) for req in tool_requests
|
||||
])
|
||||
total_time = time.time() - start_time
|
||||
|
||||
assert all(r.status_code == 200 for r in responses), "All requests should succeed"
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["execution_id"] == f"worker_exec_{idx}"
|
||||
assert data["output"]["value"] is not None
|
||||
assert f"after {delay_seconds}s" in data["output"]["value"]
|
||||
|
||||
# If parallel, should take ~1s, not ~3s
|
||||
max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead
|
||||
assert total_time < max_expected_time
|
||||
|
||||
finally:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_mixed_route_concurrent_execution():
|
||||
"""Test concurrent tool execution across both MCP and Worker routes simultaneously."""
|
||||
process, port = start_mcp_server("http")
|
||||
assert port is not None
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
try:
|
||||
wait_for_http_server_ready(port, timeout=10)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(base_url=base_url, timeout=30.0, headers=headers) as client:
|
||||
# First, set up the client-server connection for the /mcp route
|
||||
init_request = build_jsonrpc_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": "2025-06-18",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.0"},
|
||||
},
|
||||
request_id=1,
|
||||
)
|
||||
init_response = await client.post("/mcp", json=init_request)
|
||||
session_id = init_response.headers.get("mcp-session-id")
|
||||
|
||||
mcp_headers = {**headers, "Mcp-Session-Id": session_id}
|
||||
|
||||
delay_seconds = 1.0
|
||||
|
||||
await client.post(
|
||||
"/mcp",
|
||||
json=build_jsonrpc_request("notifications/initialized", None, None),
|
||||
headers=mcp_headers,
|
||||
)
|
||||
|
||||
# Prepare the tool calls for both routes
|
||||
mcp_requests = [
|
||||
build_jsonrpc_request(
|
||||
"tools/call",
|
||||
{
|
||||
"name": "Server_SlowAsyncTool",
|
||||
"arguments": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
request_id=10,
|
||||
),
|
||||
build_jsonrpc_request(
|
||||
"tools/call",
|
||||
{
|
||||
"name": "Server_SlowSyncTool",
|
||||
"arguments": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
request_id=11,
|
||||
),
|
||||
]
|
||||
|
||||
worker_requests = [
|
||||
{
|
||||
"execution_id": "worker_exec_0",
|
||||
"tool": {
|
||||
"toolkit": "Server",
|
||||
"name": "SlowAsyncTool",
|
||||
},
|
||||
"inputs": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
{
|
||||
"execution_id": "worker_exec_1",
|
||||
"tool": {
|
||||
"toolkit": "Server",
|
||||
"name": "SlowSyncTool",
|
||||
},
|
||||
"inputs": {"delay_seconds": delay_seconds},
|
||||
},
|
||||
]
|
||||
|
||||
# Execute
|
||||
start_time = time.time()
|
||||
mcp_responses, worker_responses = await asyncio.gather(
|
||||
asyncio.gather(*[
|
||||
client.post("/mcp", json=req, headers=mcp_headers) for req in mcp_requests
|
||||
]),
|
||||
asyncio.gather(*[
|
||||
client.post("/worker/tools/invoke", json=req) for req in worker_requests
|
||||
]),
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
assert all(r.status_code == 200 for r in mcp_responses)
|
||||
assert all(r.status_code == 200 for r in worker_responses)
|
||||
|
||||
# Called the tools four times concurrently (2 MCP + 2 Worker). Each tool call takes 1 second to execute.
|
||||
# Since the server should be able to execute the tools in parallel, the total time should be around 1 second
|
||||
max_expected_time = delay_seconds + 0.5 # Allow 0.5s overhead for mixed routes
|
||||
assert total_time < max_expected_time
|
||||
|
||||
finally:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ class TestMCPApp:
|
|||
"""Test _create_and_run_server method with mocked dependencies."""
|
||||
with (
|
||||
patch("arcade_mcp_server.mcp_app.create_arcade_mcp") as mock_create,
|
||||
patch("arcade_mcp_server.mcp_app.uvicorn") as mock_uvicorn,
|
||||
patch("arcade_mcp_server.mcp_app.serve_with_force_quit") as mock_serve,
|
||||
):
|
||||
mock_fastapi_app = Mock()
|
||||
mock_create.return_value = mock_fastapi_app
|
||||
|
|
@ -343,19 +343,17 @@ class TestMCPApp:
|
|||
mcp_settings=mcp_app._mcp_settings,
|
||||
debug=False,
|
||||
)
|
||||
mock_uvicorn.run.assert_called_once_with(
|
||||
mock_fastapi_app,
|
||||
mock_serve.assert_called_once_with(
|
||||
app=mock_fastapi_app,
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
log_level="info",
|
||||
reload=False,
|
||||
lifespan="on",
|
||||
)
|
||||
|
||||
# Test with DEBUG log level
|
||||
with (
|
||||
patch("arcade_mcp_server.mcp_app.create_arcade_mcp") as mock_create,
|
||||
patch("arcade_mcp_server.mcp_app.uvicorn") as mock_uvicorn,
|
||||
patch("arcade_mcp_server.mcp_app.serve_with_force_quit") as mock_serve,
|
||||
):
|
||||
mock_fastapi_app = Mock()
|
||||
mock_create.return_value = mock_fastapi_app
|
||||
|
|
@ -368,13 +366,11 @@ class TestMCPApp:
|
|||
mcp_settings=mcp_app._mcp_settings,
|
||||
debug=True,
|
||||
)
|
||||
mock_uvicorn.run.assert_called_once_with(
|
||||
mock_fastapi_app,
|
||||
mock_serve.assert_called_once_with(
|
||||
app=mock_fastapi_app,
|
||||
host="192.168.1.1",
|
||||
port=9000,
|
||||
log_level="debug",
|
||||
reload=False,
|
||||
lifespan="on",
|
||||
)
|
||||
|
||||
def test_run_with_reload_spawns_child_process(self, mcp_app: MCPApp):
|
||||
|
|
|
|||
|
|
@ -136,6 +136,8 @@ omit = [
|
|||
"*/test_*",
|
||||
"*/__pycache__/*",
|
||||
]
|
||||
parallel = true
|
||||
patch = ["subprocess"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
|
|
|
|||
Loading…
Reference in a new issue