This PR does three things: 1. Executes synchronous tool calls in thread pool allowing for up to 4 + # of CPUs executions in parallel. 2. Makes force quitting via double SIGINT/SIGTERM possible and via single SIGINT/SIGTERM + graceful shutdown timeout expiry possible, even if there are active connections. 3. Sets `timeout_graceful_shutdown` to `ARCADE_UVICORN_TIMEOUT_GRACEFUL_SHUTDOWN` env var if set, else defaults to 15. 4. Disable the worker health check span to reduce noise Tradeoffs: Since this PR introduces executing synchronous tools via `await asyncio.to_thread(func, **func_args)`, this means that there is no way for the thread to be killed until it finishes. The ramifications of this is that the force quitting logic that is also implemented in this PR has to be very harsh `os._exit(1)` just in case there is a sync tool actively executing. This means that `MCPApp` teardown logic will not execute when force quitting is required. Although this was already the case because we weren't previously able to force quit! This tradeoff is justified for now since "parallel" tool executions will relieve us of many worker timeouts that we are seeing in prod. Future work: Minimize/eliminate the need for `os._exit(1)` such that `MCPApp` teardown logic will always execute, even when force quitting. The solution will likely be moving away from `await asyncio.to_thread(func, **func_args)` (while maintaining "parallelism" and then utilize the `TaskTrackerMiddleware` introduced in this PR to cancel all of the active HTTP requests. Resolves PLT-713
57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
import asyncio
|
|
import threading
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.types import ASGIApp
|
|
|
|
|
|
class TaskTrackerMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that tracks active HTTP request tasks for force quit functionality."""
|
|
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
super().__init__(app)
|
|
self._active_tasks: set[asyncio.Task] = set()
|
|
self._lock = threading.Lock()
|
|
|
|
async def dispatch(
|
|
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
) -> Response:
|
|
"""Track the current task while handling the request."""
|
|
task = asyncio.current_task()
|
|
|
|
with self._lock:
|
|
if task:
|
|
self._active_tasks.add(task)
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
return response
|
|
finally:
|
|
with self._lock:
|
|
if task:
|
|
self._active_tasks.discard(task)
|
|
|
|
def cancel_all_tasks(self) -> int:
|
|
"""
|
|
Cancel all tracked (active) HTTP request tasks.
|
|
|
|
This method must be called from within the asyncio event loop's thread
|
|
(not from background thread) because it calls task.cancel()
|
|
|
|
Returns:
|
|
int: Number of tasks successfully cancelled.
|
|
"""
|
|
# Make a copy to avoid mutation during iteration
|
|
with self._lock:
|
|
tasks_to_cancel = list(self._active_tasks)
|
|
|
|
cancelled_count = 0
|
|
for task in tasks_to_cancel:
|
|
if not task.done():
|
|
task.cancel()
|
|
cancelled_count += 1
|
|
|
|
return cancelled_count
|