feat: Add TelemetryPassbackMiddleware for serverExecutionTelemetry capability (#797)
**Implements**: [SEP-2448: server execution telemetry] (https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2448) **Description:** **The Observability Gap (The Problem)** MCP clients propagate trace context to servers, but server-side execution remains a black box. The client sees a single tools/call or resources/read span; everything the server does (auth checks, policy evaluation, API calls, sub-tool invocations) is invisible. In cross-organization deployments, clients and servers use separate observability backends with no shared collector access, making traditional span export useless. <img width="1015" height="450" alt="Screenshot 2026-03-23 at 3 43 21 PM" src="https://github.com/user-attachments/assets/58c817b5-fee6-46a3-9877-d523a25368ad" /> **Server Execution Telemetry (The Solution)** Servers advertise serverExecutionTelemetry and return a curated slice of their execution spans directly in _meta.otel of the response. Clients ingest these verbatim OTLP spans into their own collector, stitching server-side execution into their distributed trace; no shared infrastructure required. The black box becomes transparent. <img width="945" height="574" alt="Screenshot 2026-03-23 at 3 43 44 PM" src="https://github.com/user-attachments/assets/38d97c94-aa73-4e62-9b4e-3264600e5ed0" /> . **Summary:** Implement MCP serverExecutionTelemetry capability that enables cross-organization distributed tracing by returning server-side OpenTelemetry spans to clients inline via _meta.otel.traces. Server-side (middleware): - TelemetryPassbackMiddleware intercepts tools/call and resources/read - ContextVarSpanCollector isolates span collection per-request via ContextVar - Propagates traceparent from client request for distributed trace stitching - Serializes collected spans to verbatim OTLP JSON (resourceSpans format), directly POSTable to /v1/traces - Top-level span filtering by default; full span tree via detailed opt-in - Middleware advertises capabilities via get_capabilities() on the Middleware base class - Provisional API: FutureWarning emitted until SEP-2448 is ratified Client-side (reference agent): - LangChain ReAct agent connects to MCP server via streamable_http_client with OAuth 2.1 - Detects serverExecutionTelemetry capability at initialization - Dynamically wraps discovered MCP tools with traceparent propagation and _meta.otel span request - Ingests returned server spans into Jaeger (OTLP JSON) and Galileo (OTLP protobuf) - Two-act demo: --no-passback (black box) vs default (full server-side visibility) Dependencies: - opentelemetry-api and opentelemetry-sdk added to arcade-mcp-server Bump arcade-mcp-server version to 1.18.0.
This commit is contained in:
parent
9bbdbe2b46
commit
78c8e6fb99
14 changed files with 1999 additions and 6 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -179,3 +179,7 @@ cython_debug/
|
|||
|
||||
# Docs
|
||||
libs/arcade-mcp-server/site/*
|
||||
|
||||
# exclude OAuth tokens
|
||||
.oauth_tokens.json
|
||||
.oauth_client.json
|
||||
|
|
|
|||
12
examples/mcp_servers/telemetry_passback/.env.example
Normal file
12
examples/mcp_servers/telemetry_passback/.env.example
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
OPENAI_API_KEY=sk-...
|
||||
|
||||
# Arcade auth (required — get your key at https://www.arcade.dev)
|
||||
ARCADE_API_KEY=
|
||||
ARCADE_USER_ID=
|
||||
ARCADE_API_URL=https://api.arcade.dev
|
||||
|
||||
# Optional: Galileo integration
|
||||
GALILEO_API_KEY=
|
||||
GALILEO_PROJECT=
|
||||
GALILEO_LOG_STREAM=default
|
||||
GALILEO_OTLP_ENDPOINT=https://app.galileo.ai/api/galileo/otel/traces
|
||||
177
examples/mcp_servers/telemetry_passback/README.md
Normal file
177
examples/mcp_servers/telemetry_passback/README.md
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# SEP-2448: MCP server execution telemetry — Reference Implementation
|
||||
|
||||
End-to-end reference implementation of **SEP-2448 `serverExecutionTelemetry`** — cross-organization distributed tracing via MCP.
|
||||
|
||||
## Overview
|
||||
|
||||
This example demonstrates how an MCP server can **pass back OpenTelemetry spans** to the calling client, enabling full distributed tracing across organizational boundaries. Without this capability, the server side of an MCP tool call is a black box — you can see *that* it was called, but not *what happened inside*.
|
||||
|
||||
The example includes three components:
|
||||
|
||||
1. **Server** (`server.py`) — An Arcade MCP server with Gmail tools that uses `TelemetryPassbackMiddleware` to collect and return spans. This shows how a **vendor adopts** the SEP.
|
||||
2. **Agent** (`agent.py`) — A LangChain ReAct agent that requests span passback, receives server spans, and ingests them into Jaeger/Galileo. This shows how a **consumer uses** the SEP.
|
||||
3. **Jaeger** (`docker-compose.yml`) — Local trace collector and UI for visualizing the stitched traces.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- [uv](https://docs.astral.sh/uv/) package manager
|
||||
- Docker (for Jaeger)
|
||||
- An [Arcade](https://www.arcade.dev) account ([quickstart](https://docs.arcade.dev/en/get-started/quickstarts/mcp-server-quickstart))
|
||||
- An OpenAI API key (for the LangChain agent)
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
cd examples/mcp_servers/telemetry_passback
|
||||
|
||||
# Copy env file and add your keys
|
||||
cp .env.example .env
|
||||
# Edit .env: set OPENAI_API_KEY, ARCADE_API_KEY, ARCADE_USER_ID
|
||||
|
||||
# Install dependencies
|
||||
uv sync
|
||||
|
||||
# Start Jaeger
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The server and agent run as **separate processes**. Start the server first, then run the agent in another terminal.
|
||||
|
||||
### Start the Server
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
uv run python src/telemetry_passback/server.py
|
||||
```
|
||||
|
||||
The server listens at `http://127.0.0.1:8000/mcp` with OAuth 2.1 resource server auth via Arcade.
|
||||
|
||||
### Run the Agent
|
||||
|
||||
In a separate terminal. On first run, the MCP SDK will open your browser for OAuth authorization (one-time).
|
||||
|
||||
#### Act 1 — "The Black Box" (no passback)
|
||||
|
||||
```bash
|
||||
uv run python src/telemetry_passback/agent.py --no-passback "List my 3 most recent emails"
|
||||
```
|
||||
|
||||
Open Jaeger at [http://localhost:16686](http://localhost:16686): you see agent LLM reasoning spans + one opaque `mcp.call_tool` CLIENT span. The tool call took ~3 seconds but there's no way to tell why. Is it the LLM? The network? Auth? The Gmail API? Everything inside the server is invisible.
|
||||
|
||||
#### Act 2 — "The Revelation" (with passback)
|
||||
|
||||
```bash
|
||||
uv run python src/telemetry_passback/agent.py --detailed "List my 3 most recent emails"
|
||||
```
|
||||
|
||||
Same call, but now the span tree reveals the server's internal structure:
|
||||
|
||||
```
|
||||
mcp-gmail-agent
|
||||
├── LangChain agent reasoning
|
||||
├── ChatOpenAI (LLM decides to call tool)
|
||||
├── mcp.call_tool list_emails (CLIENT)
|
||||
│ └── tools/call list_emails (SERVER) ← FROM SPAN PASSBACK
|
||||
│ ├── auth.validate (50ms)
|
||||
│ ├── gmail.list_messages (400ms)
|
||||
│ │ └── GET messages (HTTP)
|
||||
│ ├── gmail.fetch_details (1.6s) ← bottleneck!
|
||||
│ │ ├── GET messages/abc (HTTP, 520ms)
|
||||
│ │ ├── GET messages/def (HTTP, 510ms)
|
||||
│ │ └── GET messages/ghi (HTTP, 530ms)
|
||||
│ └── format_response (5ms)
|
||||
└── ChatOpenAI (LLM — final answer)
|
||||
```
|
||||
|
||||
Now the consumer can see exactly what's happening: auth is fast, listing is fine, but **detail fetching is sequential** — three HTTP calls in a waterfall. Armed with this information, the consumer can:
|
||||
|
||||
- **File an informed bug report** to the server vendor: "your `list_emails` has an N+1 in detail fetching — each email triggers a sequential HTTP call"
|
||||
- **Adjust their usage**: request fewer emails, use a query filter to reduce N
|
||||
- **Make an informed vendor choice**: compare span trees across MCP server providers
|
||||
|
||||
This is the core value of the SEP — **the consumer doesn't need access to the server's code or deployment to understand its performance characteristics**.
|
||||
|
||||
### Granularity Control
|
||||
|
||||
The `--detailed` flag demonstrates the SEP's span filtering. Without it, the server returns only top-level phase spans (auth, list, fetch, format). With `--detailed`, the full tree including HTTP child spans is returned. This lets the server vendor control how much internal detail is exposed.
|
||||
|
||||
```bash
|
||||
# Top-level phases only (default)
|
||||
uv run python src/telemetry_passback/agent.py "List my 3 most recent emails"
|
||||
|
||||
# Full span tree including HTTP child spans
|
||||
uv run python src/telemetry_passback/agent.py --detailed "List my 3 most recent emails"
|
||||
```
|
||||
|
||||
### CLI Options
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `query` | `"List my 5 most recent emails"` | The question to ask the agent |
|
||||
| `--detailed` | `false` | Request full span tree |
|
||||
| `--no-passback` | `false` | Disable span passback (Act 1 — server is a black box) |
|
||||
| `--server-url` | `http://127.0.0.1:8000/mcp` | MCP server URL |
|
||||
|
||||
## Expected Results in Jaeger
|
||||
|
||||
Open [http://localhost:16686](http://localhost:16686) and search for service **`mcp-gmail-agent`**.
|
||||
|
||||
| Mode | What you see |
|
||||
|------|-------------|
|
||||
| `--no-passback` | Only agent-side spans: LLM calls + opaque `mcp.call_tool`. Server is a black box. |
|
||||
| Default | Server phase spans stitched into the same trace: `auth.validate`, `gmail.list_messages`, `gmail.fetch_details`, `format_response`. |
|
||||
| `--detailed` | Full span tree: phase spans plus HTTP child spans under each phase, revealing the sequential N+1 pattern in `gmail.fetch_details`. |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────┐ HTTP (streamable) ┌──────────────────────────┐
|
||||
│ agent.py │ ───────────────────────>│ server.py │
|
||||
│ (LangChain ReAct) │ :8000/mcp │ (Arcade MCP Server) │
|
||||
│ │ │ │
|
||||
│ OAuth 2.1 via MCP SDK │ traceparent in _meta │ OAuth 2.1 (Arcade) │
|
||||
│ OTel → Jaeger/Galileo │ ───────────────────────>│ OTel (internal only) │
|
||||
│ │ spans back in _meta │ TelemetryPassback MW │
|
||||
│ │ <───────────────────────│ │
|
||||
└─────────────────────────┘ └──────────────────────────┘
|
||||
│ │
|
||||
└──────────── Stitched trace in Jaeger ───────────────┘
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
**Server side** (`server.py`):
|
||||
1. Validates Bearer tokens via `ArcadeResourceServerAuth` (OAuth 2.1, RFC 9728 discovery)
|
||||
2. `TelemetryPassbackMiddleware` intercepts `tools/call` requests
|
||||
3. Reads `_meta.traceparent` and `_meta.otel.traces.{request, detailed}`
|
||||
4. Creates a SERVER span under the client's trace (via traceparent propagation)
|
||||
5. Tool function creates logical-phase spans with `gen_ai.*` semantic conventions
|
||||
6. httpx auto-instrumentation creates HTTP child spans for Gmail API calls
|
||||
7. Middleware serializes to OTLP JSON and attaches to `response._meta.otel.traces`
|
||||
|
||||
**Client side** (`agent.py`):
|
||||
1. MCP SDK handles OAuth 2.1 automatically (discovers auth server on 401, PKCE flow, token caching)
|
||||
2. Connects to the server via streamable HTTP, detects `serverExecutionTelemetry` capability
|
||||
3. For each tool call, creates a CLIENT span and injects `traceparent` in `_meta`
|
||||
4. Sends `_meta.otel.traces.request: true` to opt into span passback
|
||||
5. Receives server spans in response `_meta.otel.traces.resourceSpans`
|
||||
6. POSTs OTLP JSON to Jaeger for trace stitching
|
||||
7. Optionally exports to Galileo (protobuf) if `GALILEO_API_KEY` is set
|
||||
|
||||
## Configuration
|
||||
|
||||
Copy `.env.example` to `.env`:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `OPENAI_API_KEY` | (required) | OpenAI API key for the LangChain agent |
|
||||
| `ARCADE_API_KEY` | (required) | Arcade API key |
|
||||
| `ARCADE_USER_ID` | (required) | Your Arcade account email |
|
||||
| `ARCADE_API_URL` | `https://api.arcade.dev` | Arcade API endpoint |
|
||||
| `GALILEO_API_KEY` | (optional) | Enables export to Galileo alongside Jaeger |
|
||||
| `GALILEO_PROJECT` | (optional) | Galileo project name |
|
||||
| `GALILEO_LOG_STREAM` | `default` | Galileo log stream |
|
||||
| `GALILEO_OTLP_ENDPOINT` | `https://app.galileo.ai/api/galileo/otel/traces` | Galileo OTLP endpoint |
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
services:
|
||||
jaeger:
|
||||
image: jaegertracing/all-in-one:latest
|
||||
ports:
|
||||
- "16686:16686" # Jaeger UI
|
||||
- "4317:4317" # OTLP gRPC
|
||||
- "4318:4318" # OTLP HTTP
|
||||
42
examples/mcp_servers/telemetry_passback/pyproject.toml
Normal file
42
examples/mcp_servers/telemetry_passback/pyproject.toml
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
[project]
|
||||
name = "telemetry-passback"
|
||||
version = "0.1.0"
|
||||
description = "SEP-0000 reference implementation: cross-org distributed tracing via MCP serverExecutionTelemetry"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.18.0,<2.0.0",
|
||||
"mcp>=1.26.0",
|
||||
"langchain>=1.2.13",
|
||||
"langchain-community>=0.4.1",
|
||||
"langchain-openai>=1.1.12",
|
||||
"langgraph>=1.1.3",
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-instrumentation-langchain",
|
||||
"opentelemetry-instrumentation-httpx",
|
||||
"httpx>=0.28.1",
|
||||
"python-dotenv>=1.2.2",
|
||||
"protobuf",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
"mypy>=1.0.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/telemetry_passback"]
|
||||
|
||||
# Use local arcade-mcp-server (TelemetryPassbackMiddleware is not yet published)
|
||||
[tool.uv.sources]
|
||||
arcade-mcp-server = { path = "../../../libs/arcade-mcp-server/", editable = true }
|
||||
arcade-serve = { path = "../../../libs/arcade-serve/", editable = true }
|
||||
|
|
@ -0,0 +1,581 @@
|
|||
"""LangChain ReAct agent with SEP-2448: MCP server execution telemetry.
|
||||
|
||||
Consumer-side reference implementation of cross-org distributed tracing:
|
||||
|
||||
* Connects to an Arcade MCP Gmail server via streamable HTTP.
|
||||
* Authenticates using MCP OAuth 2.1 (automatic via the MCP SDK).
|
||||
* Detects the ``serverExecutionTelemetry`` capability.
|
||||
* Dynamically discovers tools and wraps them with span passback.
|
||||
* Passes ``traceparent`` + requests span passback via ``_meta.otel``.
|
||||
* Receives server spans and ingests them into Jaeger / Galileo.
|
||||
* Handles Google OAuth authorization flow for Gmail (one-time consent).
|
||||
|
||||
Two-act demo:
|
||||
Act 1 (--no-passback): Opaque tool call -- server is a black box
|
||||
Act 2 (default): Rich span tree via passback reveals server internals
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from mcp import ClientSession
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
OTLPSpanExporter as OTLPHTTPSpanExporter,
|
||||
)
|
||||
from opentelemetry.instrumentation.langchain import LangchainInstrumentor
|
||||
from opentelemetry.instrumentation.langchain.callback_handler import (
|
||||
TraceloopCallbackHandler,
|
||||
)
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace import SpanKind, set_span_in_context
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from pydantic import Field, create_model
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
load_dotenv(_PROJECT_ROOT / ".env")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
JAEGER_GRPC = "localhost:4317"
|
||||
JAEGER_HTTP = "http://localhost:4318/v1/traces"
|
||||
JAEGER_UI = "http://localhost:16686"
|
||||
|
||||
_passback_stats: dict[str, Any] = {"span_count": 0, "truncated": False, "dropped": 0}
|
||||
|
||||
|
||||
# --------------------------------------
|
||||
# MCP OAuth 2.1 (handled by the MCP SDK
|
||||
# --------------------------------------
|
||||
|
||||
_OAUTH_TOKEN_FILE = _PROJECT_ROOT / ".oauth_tokens.json"
|
||||
_OAUTH_CLIENT_FILE = _PROJECT_ROOT / ".oauth_client.json"
|
||||
_CALLBACK_PORT = 9905
|
||||
|
||||
|
||||
class FileTokenStorage(TokenStorage):
|
||||
"""Persist OAuth tokens and client registration to disk between runs."""
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
if _OAUTH_TOKEN_FILE.exists():
|
||||
return OAuthToken.model_validate_json(_OAUTH_TOKEN_FILE.read_text())
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
_OAUTH_TOKEN_FILE.write_text(tokens.model_dump_json())
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
if _OAUTH_CLIENT_FILE.exists():
|
||||
return OAuthClientInformationFull.model_validate_json(_OAUTH_CLIENT_FILE.read_text())
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
_OAUTH_CLIENT_FILE.write_text(client_info.model_dump_json())
|
||||
|
||||
|
||||
async def _handle_redirect(authorization_url: str) -> None:
|
||||
"""Open the browser for OAuth consent."""
|
||||
print("\n Opening browser for authorization...")
|
||||
print(f" URL: {authorization_url}\n")
|
||||
webbrowser.open(authorization_url)
|
||||
|
||||
|
||||
async def _handle_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server, wait for the OAuth redirect, extract the code."""
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[tuple[str, str | None]] = loop.create_future()
|
||||
|
||||
class _Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
code = qs.get("code", [None])[0]
|
||||
state = qs.get("state", [None])[0]
|
||||
if code:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
b"<h2>Authorization successful!</h2><p>You can close this tab.</p>"
|
||||
)
|
||||
loop.call_soon_threadsafe(future.set_result, (code, state))
|
||||
else:
|
||||
error = qs.get("error", ["unknown"])[0]
|
||||
self.send_response(400)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(f"<h2>Authorization failed: {error}</h2>".encode())
|
||||
loop.call_soon_threadsafe(
|
||||
future.set_exception, RuntimeError(f"OAuth error: {error}")
|
||||
)
|
||||
|
||||
def log_message(self, fmt: str, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
server = HTTPServer(("127.0.0.1", _CALLBACK_PORT), _Handler)
|
||||
|
||||
def _serve() -> None:
|
||||
server.handle_request()
|
||||
server.server_close()
|
||||
|
||||
await loop.run_in_executor(None, _serve)
|
||||
return await future
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Span ingestion helpers
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _count_spans(resource_spans: list[dict[str, Any]]) -> int:
|
||||
return sum(len(ss.get("spans", [])) for rs in resource_spans for ss in rs.get("scopeSpans", []))
|
||||
|
||||
|
||||
def _hex_ids_to_base64(resource_spans: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert hex trace/span IDs to base64 for protobuf ``ParseDict``."""
|
||||
import base64
|
||||
import copy
|
||||
|
||||
_ID_FIELDS = ("traceId", "spanId", "parentSpanId")
|
||||
converted = copy.deepcopy(resource_spans)
|
||||
for rs in converted:
|
||||
for ss in rs.get("scopeSpans", []):
|
||||
for span in ss.get("spans", []):
|
||||
for fld in _ID_FIELDS:
|
||||
if span.get(fld):
|
||||
span[fld] = base64.b64encode(bytes.fromhex(span[fld])).decode()
|
||||
return converted
|
||||
|
||||
|
||||
def ingest_spans_json(
|
||||
otlp_json: dict[str, Any],
|
||||
endpoint: str = JAEGER_HTTP,
|
||||
headers: dict[str, str] | None = None,
|
||||
label: str = "collector",
|
||||
) -> None:
|
||||
"""POST OTLP JSON spans to an OTLP HTTP endpoint (e.g. Jaeger)."""
|
||||
count = _count_spans(otlp_json.get("resourceSpans", []))
|
||||
hdrs = {"Content-Type": "application/json", **(headers or {})}
|
||||
try:
|
||||
httpx.post(endpoint, json=otlp_json, headers=hdrs).raise_for_status()
|
||||
log.info("Ingested %d server span(s) into %s", count, label)
|
||||
except httpx.ConnectError:
|
||||
log.exception("Could not connect to %s at %s", label, endpoint)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
log.exception(
|
||||
"%s returned %d: %s", label, exc.response.status_code, exc.response.text[:200]
|
||||
)
|
||||
|
||||
|
||||
def ingest_spans_protobuf(
|
||||
resource_spans: list[dict[str, Any]],
|
||||
endpoint: str,
|
||||
headers: dict[str, str],
|
||||
label: str = "Galileo",
|
||||
) -> None:
|
||||
"""POST OTLP protobuf spans (required by Galileo)."""
|
||||
try:
|
||||
from google.protobuf.json_format import ParseDict
|
||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
||||
ExportTraceServiceRequest,
|
||||
)
|
||||
except ImportError:
|
||||
log.warning("Skipping %s protobuf export (missing protobuf deps)", label)
|
||||
return
|
||||
if not resource_spans:
|
||||
return
|
||||
b64_spans = _hex_ids_to_base64(resource_spans)
|
||||
body = ParseDict({"resourceSpans": b64_spans}, ExportTraceServiceRequest()).SerializeToString()
|
||||
try:
|
||||
resp = httpx.post(
|
||||
endpoint, content=body, headers={"Content-Type": "application/x-protobuf", **headers}
|
||||
)
|
||||
log.info(
|
||||
"Exported %d server span(s) to %s (HTTP %d)",
|
||||
_count_spans(resource_spans),
|
||||
label,
|
||||
resp.status_code,
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
log.exception("Could not connect to %s at %s", label, endpoint)
|
||||
|
||||
|
||||
# ----------------
|
||||
# Galileo config
|
||||
# ----------------
|
||||
|
||||
|
||||
def _galileo_config() -> tuple[str, dict[str, str]] | None:
|
||||
"""Return ``(endpoint, headers)`` for Galileo, or ``None`` if not configured."""
|
||||
api_key = os.environ.get("GALILEO_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
return (
|
||||
os.environ.get("GALILEO_OTLP_ENDPOINT", "https://api.galileo.ai/otel/v1/traces"),
|
||||
{
|
||||
"Galileo-API-Key": api_key,
|
||||
"project": os.environ.get("GALILEO_PROJECT", "mcp-cross-org-observability"),
|
||||
"logstream": os.environ.get("GALILEO_LOG_STREAM", "default"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------
|
||||
# Tracing setup
|
||||
# ---------------
|
||||
|
||||
|
||||
def setup_tracing() -> TracerProvider:
|
||||
"""Jaeger (gRPC) + optional Galileo (OTLP HTTP) + LangChain auto-instrumentation."""
|
||||
resource = Resource.create({"service.name": "mcp-gmail-agent"})
|
||||
provider = TracerProvider(resource=resource)
|
||||
|
||||
provider.add_span_processor(
|
||||
BatchSpanProcessor(OTLPSpanExporter(endpoint=JAEGER_GRPC, insecure=True))
|
||||
)
|
||||
|
||||
gc = _galileo_config()
|
||||
if gc:
|
||||
provider.add_span_processor(
|
||||
BatchSpanProcessor(OTLPHTTPSpanExporter(endpoint=gc[0], headers=gc[1]))
|
||||
)
|
||||
log.info("Galileo tracing enabled -> %s", gc[0])
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
LangchainInstrumentor().instrument()
|
||||
return provider
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="LangChain agent with SEP-0000 span passback")
|
||||
p.add_argument("query", nargs="?", default="List my 5 most recent emails")
|
||||
p.add_argument("--detailed", action="store_true", help="Request full span tree")
|
||||
p.add_argument(
|
||||
"--no-passback",
|
||||
action="store_true",
|
||||
help="Disable span passback (before SEP-0000 -- server is a black box)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--server-url",
|
||||
default="http://127.0.0.1:8000/mcp",
|
||||
help="MCP server URL (default: http://127.0.0.1:8000/mcp)",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instrumentor span look-up
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_instrumentor_span(tool_name: str) -> trace.Span | None:
|
||||
"""Locate the LangChain instrumentor's span for the active tool invocation."""
|
||||
try:
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
|
||||
mgr = BaseCallbackManager(handlers=[])
|
||||
for handler in mgr.inheritable_handlers:
|
||||
if isinstance(handler, TraceloopCallbackHandler):
|
||||
for run_id in reversed(list(handler.spans)):
|
||||
holder = handler.spans[run_id]
|
||||
span = getattr(holder, "span", None)
|
||||
if span and tool_name in getattr(span, "name", ""):
|
||||
return span
|
||||
except Exception:
|
||||
log.debug("Could not look up instrumentor span for %s", tool_name, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Passback span processing
|
||||
# --------------------------
|
||||
|
||||
|
||||
def _extract_trace_id(resource_spans: list[dict[str, Any]]) -> str | None:
|
||||
"""Return the first traceId found in resource_spans, or None."""
|
||||
for rs in resource_spans:
|
||||
for ss in rs.get("scopeSpans", []):
|
||||
for sp in ss.get("spans", []):
|
||||
tid = sp.get("traceId")
|
||||
if tid:
|
||||
return tid
|
||||
return None
|
||||
|
||||
|
||||
def _process_passback_spans(meta: dict[str, Any] | None) -> None:
|
||||
"""Extract server spans from response ``_meta`` and ingest into collectors."""
|
||||
if not meta:
|
||||
return
|
||||
otel_data = meta.get("otel") if isinstance(meta, dict) else getattr(meta, "otel", None)
|
||||
if not otel_data:
|
||||
return
|
||||
traces = otel_data.get("traces", {})
|
||||
resource_spans = traces.get("resourceSpans")
|
||||
if not resource_spans:
|
||||
return
|
||||
|
||||
span_count = _count_spans(resource_spans)
|
||||
truncated = traces.get("truncated", False)
|
||||
dropped = traces.get("droppedSpanCount", 0)
|
||||
|
||||
_passback_stats["span_count"] += span_count
|
||||
_passback_stats["truncated"] = _passback_stats["truncated"] or truncated
|
||||
_passback_stats["dropped"] += dropped
|
||||
if not _passback_stats.get("trace_id"):
|
||||
_passback_stats["trace_id"] = _extract_trace_id(resource_spans) or ""
|
||||
|
||||
print(f" Server-side spans: {span_count} received and ingested")
|
||||
if truncated:
|
||||
print(f" ({dropped} additional spans available with --detailed)")
|
||||
|
||||
ingest_spans_json({"resourceSpans": resource_spans}, endpoint=JAEGER_HTTP, label="Jaeger")
|
||||
|
||||
gc = _galileo_config()
|
||||
if gc:
|
||||
ingest_spans_protobuf(resource_spans, endpoint=gc[0], headers=gc[1])
|
||||
|
||||
|
||||
# ----------------------------------------------------
|
||||
# OAuth helpers (Gmail tool authorization via Arcade)
|
||||
# ----------------------------------------------------
|
||||
|
||||
|
||||
def _extract_auth_url(result: Any) -> str | None:
|
||||
"""Check if a tool result contains an authorization URL (OAuth required)."""
|
||||
for item in result.content:
|
||||
text = getattr(item, "text", None)
|
||||
if text and "authorization_url" in text:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
return data.get("authorization_url")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
structured = getattr(result, "structuredContent", None)
|
||||
if isinstance(structured, dict):
|
||||
return structured.get("authorization_url")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dynamic MCP tool wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_mcp_tools(
|
||||
session: ClientSession,
|
||||
mcp_tools: list[Any],
|
||||
tracer: trace.Tracer,
|
||||
propagator: TraceContextTextMapPropagator,
|
||||
detailed: bool,
|
||||
passback: bool = True,
|
||||
) -> list[StructuredTool]:
|
||||
"""Create LangChain tools from MCP tool definitions, with optional span passback."""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
for mcp_tool in mcp_tools:
|
||||
name = mcp_tool.name
|
||||
desc = mcp_tool.description or f"MCP tool: {name}"
|
||||
schema = mcp_tool.inputSchema
|
||||
props = schema.get("properties", {})
|
||||
required = set(schema.get("required", []))
|
||||
|
||||
def _make_fn(t_name: str) -> Any:
|
||||
async def _call(**kwargs: str) -> str:
|
||||
parent_ctx = None
|
||||
inst_span = _find_instrumentor_span(t_name)
|
||||
if inst_span:
|
||||
parent_ctx = set_span_in_context(inst_span)
|
||||
|
||||
with tracer.start_as_current_span(
|
||||
f"mcp.call_tool {t_name}",
|
||||
context=parent_ctx,
|
||||
kind=SpanKind.CLIENT,
|
||||
attributes={
|
||||
"mcp.tool": t_name,
|
||||
"mcp.server": "mcp-gmail-server",
|
||||
"gen_ai.operation.name": "execute_tool",
|
||||
"gen_ai.tool.name": f"mcp.call_tool {t_name}",
|
||||
"gen_ai.tool.call.arguments": json.dumps(kwargs),
|
||||
},
|
||||
) as span:
|
||||
carrier: dict[str, str] = {}
|
||||
propagator.inject(carrier)
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"traceparent": carrier.get("traceparent", ""),
|
||||
}
|
||||
if passback:
|
||||
meta["otel"] = {"traces": {"request": True, "detailed": detailed}}
|
||||
|
||||
result = await session.call_tool(t_name, arguments=kwargs, meta=meta)
|
||||
|
||||
# Handle Gmail OAuth if needed (one-time consent)
|
||||
auth_url = _extract_auth_url(result)
|
||||
if auth_url:
|
||||
span.set_attribute("mcp.auth.required", True)
|
||||
print(f"\n Authorization required. Open this URL:\n\n {auth_url}\n")
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
input,
|
||||
" Press Enter after authorizing...",
|
||||
)
|
||||
result = await session.call_tool(t_name, arguments=kwargs, meta=meta)
|
||||
|
||||
text = result.content[0].text if result.content else ""
|
||||
span.set_attribute("gen_ai.tool.call.result", text[:500])
|
||||
|
||||
if passback:
|
||||
_process_passback_spans(result.meta)
|
||||
else:
|
||||
print(" Server-side spans: NONE (passback not requested)")
|
||||
|
||||
return text
|
||||
|
||||
return _call
|
||||
|
||||
fields = {}
|
||||
for pname, pinfo in props.items():
|
||||
pdesc = pinfo.get("description", "")
|
||||
if pname in required:
|
||||
fields[pname] = (str, Field(description=pdesc))
|
||||
else:
|
||||
fields[pname] = (str, Field(default="", description=pdesc))
|
||||
|
||||
tools.append(
|
||||
StructuredTool(
|
||||
name=name,
|
||||
description=desc,
|
||||
coroutine=_make_fn(name),
|
||||
args_schema=create_model(f"{name}Args", **fields),
|
||||
)
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO, format=" %(message)s")
|
||||
|
||||
args = parse_args()
|
||||
provider = setup_tracing()
|
||||
tracer = trace.get_tracer("mcp-gmail-agent")
|
||||
propagator = TraceContextTextMapPropagator()
|
||||
passback = not args.no_passback
|
||||
|
||||
if args.no_passback:
|
||||
mode = "Act 1: The Black Box (no passback)"
|
||||
elif args.detailed:
|
||||
mode = "Act 2: The Revelation (detailed span tree)"
|
||||
else:
|
||||
mode = "Act 2: The Revelation (span passback)"
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {mode}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
server_url = args.server_url
|
||||
|
||||
# MCP SDK handles OAuth 2.1 automatically:
|
||||
# On 401 → discovers auth server via RFC 9728 → PKCE flow → caches tokens
|
||||
oauth_auth = OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=OAuthClientMetadata(
|
||||
client_name="mcp-gmail-agent",
|
||||
redirect_uris=["http://127.0.0.1:9905/callback"],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
token_endpoint_auth_method="none", # noqa: S106 - OAuth 2.1 public client (no secret)
|
||||
),
|
||||
storage=FileTokenStorage(),
|
||||
redirect_handler=_handle_redirect,
|
||||
callback_handler=_handle_callback,
|
||||
)
|
||||
http_client = httpx.AsyncClient(auth=oauth_auth)
|
||||
print(f" Connecting to MCP server at {server_url} ...")
|
||||
|
||||
async with (
|
||||
streamable_http_client(url=server_url, http_client=http_client) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
init = await session.initialize()
|
||||
|
||||
telemetry_cap = getattr(init.capabilities, "serverExecutionTelemetry", None)
|
||||
print(f" Server: {init.serverInfo.name} v{init.serverInfo.version}")
|
||||
print(f" serverExecutionTelemetry: {telemetry_cap is not None}")
|
||||
if telemetry_cap:
|
||||
print(f" Capability: {telemetry_cap}")
|
||||
|
||||
discovered = await session.list_tools()
|
||||
print(f" Tools: {[t.name for t in discovered.tools]}\n")
|
||||
|
||||
lc_tools = build_mcp_tools(
|
||||
session,
|
||||
discovered.tools,
|
||||
tracer,
|
||||
propagator,
|
||||
detailed=args.detailed,
|
||||
passback=passback,
|
||||
)
|
||||
|
||||
agent = create_agent(ChatOpenAI(model="gpt-4o-mini"), lc_tools)
|
||||
|
||||
print(f" Query: {args.query}")
|
||||
if passback:
|
||||
print(f" Detailed: {args.detailed}")
|
||||
print()
|
||||
|
||||
result = await agent.ainvoke({"messages": [("user", args.query)]})
|
||||
|
||||
print(f"\n Agent: {result['messages'][-1].content}\n")
|
||||
|
||||
provider.force_flush()
|
||||
provider.shutdown()
|
||||
|
||||
trace_id = _passback_stats.get("trace_id", "")
|
||||
if trace_id:
|
||||
print(f" Jaeger UI: {JAEGER_UI}/trace/{trace_id}")
|
||||
else:
|
||||
print(f" Jaeger UI: {JAEGER_UI} (search for service mcp-gmail-agent)")
|
||||
print(" Service: mcp-gmail-agent")
|
||||
if args.no_passback:
|
||||
print(" Expected: only agent-side spans (server is a black box)")
|
||||
elif args.detailed:
|
||||
print(
|
||||
" Expected: full span tree -- auth.validate, gmail.list_messages,"
|
||||
" gmail.fetch_details (with HTTP child spans), format_response"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" Expected: server phases visible -- auth.validate,"
|
||||
" gmail.list_messages, gmail.fetch_details, format_response"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,260 @@
|
|||
"""Gmail MCP server with SEP-2448 MCP server execution telemetry.
|
||||
|
||||
Vendor-side reference implementation of cross-org distributed tracing:
|
||||
|
||||
* Advertises ``serverExecutionTelemetry`` via ``TelemetryPassbackMiddleware``.
|
||||
* Gmail tools (list_emails, send_email) using Google OAuth via Arcade.
|
||||
* Rich server-side instrumentation creates logical-phase spans
|
||||
(auth, API calls, formatting) that the middleware automatically collects.
|
||||
* ``HTTPXClientInstrumentor`` auto-instruments Gmail API HTTP calls as child spans.
|
||||
* Does NOT export spans externally — simulates a vendor with its own backend.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from email.mime.text import MIMEText
|
||||
from pathlib import Path
|
||||
from typing import Annotated, cast
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(Path(__file__).resolve().parent.parent.parent / ".env")
|
||||
|
||||
from arcade_mcp_server import Context, MCPApp # noqa: E402
|
||||
from arcade_mcp_server.auth import Google # noqa: E402
|
||||
from arcade_mcp_server.mcp_app import TransportType # noqa: E402
|
||||
from arcade_mcp_server.middleware.telemetry import TelemetryPassbackMiddleware # noqa: E402
|
||||
from arcade_mcp_server.resource_server import ( # noqa: E402
|
||||
AuthorizationServerEntry,
|
||||
ResourceServerAuth,
|
||||
)
|
||||
from arcade_mcp_server.resource_server.base import ResourceOwner # noqa: E402
|
||||
from opentelemetry import trace # noqa: E402
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor # noqa: E402
|
||||
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# OpenTelemetry — server-internal only (no external exporter)
|
||||
# ------------------------------------------------------------
|
||||
|
||||
provider = TracerProvider()
|
||||
telemetry_mw = TelemetryPassbackMiddleware(
|
||||
service_name="mcp-gmail-server",
|
||||
tracer_provider=provider,
|
||||
)
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
tracer = trace.get_tracer("mcp-gmail-server")
|
||||
_log = logging.getLogger("mcp-gmail-server")
|
||||
|
||||
|
||||
async def _async_request_hook(span, request):
|
||||
"""Capture request method + URL with gen_ai semantic conventions."""
|
||||
method = request.method.decode() if isinstance(request.method, bytes) else str(request.method)
|
||||
url = str(request.url)
|
||||
endpoint = url.split("?")[0].split("/")[-1] or "/"
|
||||
span.update_name(f"{method} {endpoint}")
|
||||
span.set_attribute("gen_ai.system", "mcp")
|
||||
span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
span.set_attribute("gen_ai.tool.name", f"{method} {endpoint}")
|
||||
span.set_attribute("gen_ai.tool.call.arguments", json.dumps({"method": method, "url": url}))
|
||||
|
||||
|
||||
async def _async_response_hook(span, request, response):
|
||||
"""Capture response status with gen_ai semantic conventions."""
|
||||
method = request.method.decode() if isinstance(request.method, bytes) else str(request.method)
|
||||
url = str(request.url)
|
||||
span.set_attribute(
|
||||
"gen_ai.tool.call.result",
|
||||
json.dumps({
|
||||
"status": response.status_code,
|
||||
"url": url,
|
||||
"method": method,
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
HTTPXClientInstrumentor().instrument(
|
||||
async_request_hook=_async_request_hook,
|
||||
async_response_hook=_async_response_hook,
|
||||
)
|
||||
|
||||
# -----------------------
|
||||
# Arcade MCP application
|
||||
# -----------------------
|
||||
|
||||
GMAIL_API = "https://gmail.googleapis.com/gmail/v1/users/me"
|
||||
GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly"
|
||||
GMAIL_SEND_SCOPE = "https://www.googleapis.com/auth/gmail.send"
|
||||
|
||||
# --------------------------------------------
|
||||
# Resource server auth (OAuth 2.1 via Arcade)
|
||||
# --------------------------------------------
|
||||
|
||||
CANONICAL_URL = "http://127.0.0.1:8000/mcp"
|
||||
|
||||
|
||||
class ArcadeResourceServerAuth(ResourceServerAuth):
|
||||
"""ResourceServerAuth that uses the ``email`` claim as user_id.
|
||||
|
||||
Arcade's tool authorization identifies users by email, but the default
|
||||
``JWKSTokenValidator`` uses the ``sub`` claim (a UUID). This override
|
||||
swaps user_id to the email so Arcade can match the authorized user.
|
||||
"""
|
||||
|
||||
async def validate_token(self, token: str) -> ResourceOwner:
|
||||
owner = await super().validate_token(token)
|
||||
email = owner.claims.get("email")
|
||||
if email:
|
||||
owner.user_id = email
|
||||
return owner
|
||||
|
||||
|
||||
resource_server_auth = ArcadeResourceServerAuth(
|
||||
canonical_url=CANONICAL_URL,
|
||||
authorization_servers=[
|
||||
AuthorizationServerEntry(
|
||||
authorization_server_url="https://cloud.arcade.dev/oauth2",
|
||||
issuer="https://cloud.arcade.dev/oauth2",
|
||||
jwks_uri="https://cloud.arcade.dev/.well-known/jwks/oauth2",
|
||||
algorithm="Ed25519",
|
||||
expected_audiences=[CANONICAL_URL],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
app = MCPApp(
|
||||
name="mcp_gmail_server",
|
||||
version="0.1.0",
|
||||
instructions=(
|
||||
"Gmail server with cross-org observability " "via SEP-0000 serverExecutionTelemetry."
|
||||
),
|
||||
auth=resource_server_auth,
|
||||
middleware=[telemetry_mw],
|
||||
)
|
||||
|
||||
# ------
|
||||
# Tools
|
||||
# ------
|
||||
|
||||
|
||||
@app.tool(requires_auth=Google(scopes=[GMAIL_READONLY_SCOPE]))
|
||||
async def list_emails(
|
||||
context: Context,
|
||||
max_results: Annotated[int, "Maximum number of emails to return"] = 5,
|
||||
query: Annotated[str, "Gmail search query (e.g. 'is:unread')"] = "",
|
||||
) -> Annotated[dict, "Dict with 'emails' key containing list of recent emails"]:
|
||||
"""List recent emails from the user's Gmail inbox."""
|
||||
token = context.get_auth_token_or_empty()
|
||||
|
||||
with tracer.start_as_current_span("auth.validate") as auth_span:
|
||||
auth_span.set_attribute("gen_ai.system", "mcp")
|
||||
auth_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
auth_span.set_attribute("gen_ai.tool.name", "auth.validate")
|
||||
auth_span.set_attribute("auth.method", "oauth2_bearer")
|
||||
|
||||
params: dict = {"maxResults": min(max(max_results, 1), 20)}
|
||||
if query:
|
||||
params["q"] = query
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
with tracer.start_as_current_span("gmail.list_messages") as list_span:
|
||||
list_span.set_attribute("gen_ai.system", "mcp")
|
||||
list_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
list_span.set_attribute("gen_ai.tool.name", "gmail.list_messages")
|
||||
list_span.set_attribute("gmail.max_results", params["maxResults"])
|
||||
list_resp = await client.get(
|
||||
f"{GMAIL_API}/messages",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params=params,
|
||||
)
|
||||
list_resp.raise_for_status()
|
||||
messages = list_resp.json().get("messages", [])
|
||||
list_span.set_attribute("gmail.message_count", len(messages))
|
||||
|
||||
with tracer.start_as_current_span("gmail.fetch_details") as details_span:
|
||||
details_span.set_attribute("gen_ai.system", "mcp")
|
||||
details_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
details_span.set_attribute("gen_ai.tool.name", "gmail.fetch_details")
|
||||
details_span.set_attribute("gmail.fetch_count", len(messages))
|
||||
|
||||
results = []
|
||||
for msg_ref in messages:
|
||||
detail_resp = await client.get(
|
||||
f"{GMAIL_API}/messages/{msg_ref['id']}",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params={
|
||||
"format": "metadata",
|
||||
"metadataHeaders": ["Subject", "From"],
|
||||
},
|
||||
)
|
||||
detail_resp.raise_for_status()
|
||||
msg = detail_resp.json()
|
||||
hdrs = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
results.append({
|
||||
"id": msg["id"],
|
||||
"subject": hdrs.get("Subject", "(no subject)"),
|
||||
"from": hdrs.get("From", "(unknown)"),
|
||||
"snippet": msg.get("snippet", ""),
|
||||
})
|
||||
|
||||
with tracer.start_as_current_span("format_response") as fmt_span:
|
||||
fmt_span.set_attribute("gen_ai.system", "mcp")
|
||||
fmt_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
fmt_span.set_attribute("gen_ai.tool.name", "format_response")
|
||||
fmt_span.set_attribute("email.count", len(results))
|
||||
|
||||
return {"emails": results}
|
||||
|
||||
|
||||
@app.tool(requires_auth=Google(scopes=[GMAIL_SEND_SCOPE]))
|
||||
async def send_email(
|
||||
context: Context,
|
||||
to: Annotated[str, "Recipient email address"],
|
||||
subject: Annotated[str, "Email subject line"],
|
||||
body: Annotated[str, "Email body text"],
|
||||
) -> Annotated[dict, "Send confirmation with message ID"]:
|
||||
"""Send an email via the user's Gmail account."""
|
||||
token = context.get_auth_token_or_empty()
|
||||
|
||||
with tracer.start_as_current_span("auth.validate") as auth_span:
|
||||
auth_span.set_attribute("gen_ai.system", "mcp")
|
||||
auth_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
auth_span.set_attribute("gen_ai.tool.name", "auth.validate")
|
||||
auth_span.set_attribute("auth.method", "oauth2_bearer")
|
||||
|
||||
mime = MIMEText(body)
|
||||
mime["to"] = to
|
||||
mime["subject"] = subject
|
||||
raw = base64.urlsafe_b64encode(mime.as_bytes()).decode()
|
||||
|
||||
with tracer.start_as_current_span("gmail.send_message") as send_span:
|
||||
send_span.set_attribute("gen_ai.system", "mcp")
|
||||
send_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
send_span.set_attribute("gen_ai.tool.name", "gmail.send_message")
|
||||
send_span.set_attribute("gmail.recipient", to)
|
||||
send_span.set_attribute("gmail.subject", subject)
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{GMAIL_API}/messages/send",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={"raw": raw},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
with tracer.start_as_current_span("format_response") as fmt_span:
|
||||
fmt_span.set_attribute("gen_ai.system", "mcp")
|
||||
fmt_span.set_attribute("gen_ai.operation.name", "execute_tool")
|
||||
fmt_span.set_attribute("gen_ai.tool.name", "format_response")
|
||||
fmt_span.set_attribute("email.message_id", result.get("id", ""))
|
||||
|
||||
return {"message_id": result.get("id", ""), "status": "sent"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "http"
|
||||
app.run(transport=cast(TransportType, transport), host="127.0.0.1", port=8000)
|
||||
|
|
@ -7,6 +7,7 @@ from arcade_mcp_server.middleware.base import (
|
|||
)
|
||||
from arcade_mcp_server.middleware.error_handling import ErrorHandlingMiddleware
|
||||
from arcade_mcp_server.middleware.logging import LoggingMiddleware
|
||||
from arcade_mcp_server.middleware.telemetry import TelemetryPassbackMiddleware
|
||||
|
||||
__all__ = [
|
||||
"CallNext",
|
||||
|
|
@ -14,4 +15,5 @@ __all__ = [
|
|||
"LoggingMiddleware",
|
||||
"Middleware",
|
||||
"MiddlewareContext",
|
||||
"TelemetryPassbackMiddleware",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -74,6 +74,13 @@ class Middleware:
|
|||
function to invoke the next handler in the chain.
|
||||
"""
|
||||
|
||||
def get_capabilities(self) -> dict[str, Any]:
|
||||
"""Return extra server capabilities to advertise.
|
||||
|
||||
Override in subclasses to declare capabilities (e.g. serverExecutionTelemetry).
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: MiddlewareContext[T],
|
||||
|
|
|
|||
361
libs/arcade-mcp-server/arcade_mcp_server/middleware/telemetry.py
Normal file
361
libs/arcade-mcp-server/arcade_mcp_server/middleware/telemetry.py
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
"""Telemetry passback middleware for MCP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import context as otel_context
|
||||
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, TracerProvider
|
||||
from opentelemetry.trace import SpanKind, StatusCode
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from arcade_mcp_server.middleware.base import CallNext, Middleware, MiddlewareContext
|
||||
from arcade_mcp_server.types import JSONRPCResponse
|
||||
|
||||
logger = logging.getLogger("arcade.mcp.telemetry")
|
||||
|
||||
|
||||
# Per-request span collection via ContextVar
|
||||
_request_spans: ContextVar[list[ReadableSpan] | None] = ContextVar("_request_spans", default=None)
|
||||
|
||||
|
||||
class ContextVarSpanCollector(SpanProcessor):
|
||||
"""Collect ended spans into a per-request ContextVar bucket."""
|
||||
|
||||
def on_start(self, span: Any, parent_context: Any = None) -> None:
|
||||
pass
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
bucket = _request_spans.get()
|
||||
if bucket is not None:
|
||||
bucket.append(span)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: int | None = None) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def start_collecting() -> list[ReadableSpan]:
|
||||
"""Begin collecting spans for the current async context."""
|
||||
bucket: list[ReadableSpan] = []
|
||||
_request_spans.set(bucket)
|
||||
return bucket
|
||||
|
||||
|
||||
def stop_collecting() -> list[ReadableSpan]:
|
||||
"""Stop collecting and return the collected spans."""
|
||||
bucket = _request_spans.get() or []
|
||||
_request_spans.set(None)
|
||||
return bucket
|
||||
|
||||
|
||||
# Span filtering
|
||||
def filter_top_level_spans(spans: list[ReadableSpan]) -> list[ReadableSpan]:
|
||||
"""Keep only the root span and its direct children (depth <= 1)."""
|
||||
span_ids = {s.context.span_id for s in spans if s.context}
|
||||
root = None
|
||||
for s in spans:
|
||||
if s.parent and s.parent.span_id not in span_ids:
|
||||
root = s
|
||||
break
|
||||
if root is None:
|
||||
return spans
|
||||
root_id = root.context.span_id
|
||||
return [
|
||||
s
|
||||
for s in spans
|
||||
if s.context.span_id == root_id or (s.parent and s.parent.span_id == root_id)
|
||||
]
|
||||
|
||||
|
||||
# OTLP JSON serialization
|
||||
|
||||
_SPAN_KIND_MAP = {
|
||||
SpanKind.INTERNAL: 1,
|
||||
SpanKind.SERVER: 2,
|
||||
SpanKind.CLIENT: 3,
|
||||
SpanKind.PRODUCER: 4,
|
||||
SpanKind.CONSUMER: 5,
|
||||
}
|
||||
_STATUS_CODE_MAP = {StatusCode.UNSET: 0, StatusCode.OK: 1, StatusCode.ERROR: 2}
|
||||
|
||||
|
||||
def _ns(ns: int | None) -> str:
|
||||
return str(ns) if ns is not None else "0"
|
||||
|
||||
|
||||
def _attrs_to_kv(attrs: dict[str, Any] | None) -> list[dict[str, Any]]:
|
||||
if not attrs:
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
for key, val in attrs.items():
|
||||
if isinstance(val, bool):
|
||||
out.append({"key": key, "value": {"boolValue": val}})
|
||||
elif isinstance(val, int):
|
||||
out.append({"key": key, "value": {"intValue": str(val)}})
|
||||
elif isinstance(val, float):
|
||||
out.append({"key": key, "value": {"doubleValue": val}})
|
||||
elif isinstance(val, str):
|
||||
out.append({"key": key, "value": {"stringValue": val}})
|
||||
elif isinstance(val, (list, tuple)):
|
||||
arr: list[dict[str, Any]] = []
|
||||
for v in val:
|
||||
if isinstance(v, str):
|
||||
arr.append({"stringValue": v})
|
||||
elif isinstance(v, bool):
|
||||
arr.append({"boolValue": v})
|
||||
elif isinstance(v, int):
|
||||
arr.append({"intValue": str(v)})
|
||||
elif isinstance(v, float):
|
||||
arr.append({"doubleValue": v})
|
||||
out.append({"key": key, "value": {"arrayValue": {"values": arr}}})
|
||||
else:
|
||||
out.append({"key": key, "value": {"stringValue": str(val)}})
|
||||
return out
|
||||
|
||||
|
||||
def spans_to_otlp_json(
|
||||
spans: list[ReadableSpan],
|
||||
service_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert ReadableSpan objects to OTLP JSON (ExportTraceServiceRequest)."""
|
||||
otlp_spans: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ctx = span.context
|
||||
if ctx is None:
|
||||
continue
|
||||
rec: dict[str, Any] = {
|
||||
"traceId": format(ctx.trace_id, "032x"),
|
||||
"spanId": format(ctx.span_id, "016x"),
|
||||
"name": span.name,
|
||||
"kind": _SPAN_KIND_MAP.get(span.kind, 0),
|
||||
"startTimeUnixNano": _ns(span.start_time),
|
||||
"endTimeUnixNano": _ns(span.end_time),
|
||||
"attributes": _attrs_to_kv(dict(span.attributes) if span.attributes else None),
|
||||
"status": {"code": _STATUS_CODE_MAP.get(span.status.status_code, 0)},
|
||||
}
|
||||
if span.parent and span.parent.span_id:
|
||||
rec["parentSpanId"] = format(span.parent.span_id, "016x")
|
||||
if span.events:
|
||||
rec["events"] = [
|
||||
{
|
||||
"timeUnixNano": _ns(ev.timestamp),
|
||||
"name": ev.name,
|
||||
"attributes": _attrs_to_kv(dict(ev.attributes) if ev.attributes else None),
|
||||
}
|
||||
for ev in span.events
|
||||
]
|
||||
otlp_spans.append(rec)
|
||||
|
||||
return {
|
||||
"resourceSpans": [
|
||||
{
|
||||
"resource": {
|
||||
"attributes": [{"key": "service.name", "value": {"stringValue": service_name}}],
|
||||
},
|
||||
"scopeSpans": [{"scope": {"name": "mcp-telemetry-passback"}, "spans": otlp_spans}],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class TelemetryPassbackMiddleware(Middleware):
|
||||
"""MCP middleware implementing SEP serverExecutionTelemetry.
|
||||
|
||||
.. warning:: **Provisional API** — This middleware implements a draft SEP
|
||||
(SEP-2448: server execution telemetry) that has not yet been finalized in
|
||||
the MCP specification. The capability schema, response payload shape,
|
||||
and constructor signature may change in a future release once the SEP is
|
||||
ratified. Pin your ``arcade-mcp-server`` version if you need stability.
|
||||
|
||||
Intercepts tools/call and resources/read to:
|
||||
|
||||
1. Propagate traceparent from the client for distributed tracing.
|
||||
2. Collect server-side spans during request execution (per-request
|
||||
isolation via ContextVar).
|
||||
3. Return spans in the response _meta.otel.traces.resourceSpans.
|
||||
"""
|
||||
|
||||
def __init__(self, service_name: str, tracer_provider: TracerProvider) -> None:
|
||||
warnings.warn(
|
||||
"TelemetryPassbackMiddleware implements a draft version of SEP "
|
||||
"(SEP-2448: server execution telemetry) that is not yet finalized. "
|
||||
"The API may change in a future release.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._service_name = service_name
|
||||
self._tracer = tracer_provider.get_tracer(service_name)
|
||||
self._propagator = TraceContextTextMapPropagator()
|
||||
self._collector = ContextVarSpanCollector()
|
||||
tracer_provider.add_span_processor(self._collector)
|
||||
|
||||
def get_capabilities(self) -> dict[str, Any]:
|
||||
"""Return the serverExecutionTelemetry capability dict."""
|
||||
return {
|
||||
"serverExecutionTelemetry": {
|
||||
"version": "2026-03-01",
|
||||
"signals": {"traces": {"supported": True}},
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_otel_meta(self, context: MiddlewareContext[Any]) -> dict[str, Any]:
|
||||
"""Extract traceparent and otel request flags from _meta."""
|
||||
mcp_ctx = context.mcp_context
|
||||
if mcp_ctx is None:
|
||||
return {}
|
||||
session = getattr(mcp_ctx, "_session", None)
|
||||
if session is None:
|
||||
return {}
|
||||
meta = getattr(session, "_request_meta", None)
|
||||
if meta is None:
|
||||
return {}
|
||||
|
||||
otel_meta = getattr(meta, "otel", None)
|
||||
traces = otel_meta.get("traces", {}) if isinstance(otel_meta, dict) else {}
|
||||
|
||||
return {
|
||||
"traceparent": getattr(meta, "traceparent", None),
|
||||
"request": traces.get("request", False) if isinstance(traces, dict) else False,
|
||||
"detailed": traces.get("detailed", False) if isinstance(traces, dict) else False,
|
||||
}
|
||||
|
||||
def _parent_context(self, traceparent: str | None) -> Any:
|
||||
if traceparent:
|
||||
return self._propagator.extract(carrier={"traceparent": traceparent})
|
||||
return otel_context.get_current()
|
||||
|
||||
def _attach_spans(self, response: Any, collected: list[ReadableSpan], detailed: bool) -> Any:
|
||||
if not detailed:
|
||||
filtered = filter_top_level_spans(collected)
|
||||
dropped = len(collected) - len(filtered)
|
||||
else:
|
||||
filtered = collected
|
||||
dropped = 0
|
||||
|
||||
payload = {
|
||||
"traces": {
|
||||
"resourceSpans": spans_to_otlp_json(filtered, self._service_name).get(
|
||||
"resourceSpans", []
|
||||
),
|
||||
"truncated": dropped > 0,
|
||||
"droppedSpanCount": dropped,
|
||||
}
|
||||
}
|
||||
|
||||
if isinstance(response, JSONRPCResponse) and response.result is not None:
|
||||
result = response.result
|
||||
if not isinstance(result, dict) and hasattr(result, "meta"):
|
||||
meta = result.meta or {}
|
||||
meta["otel"] = payload
|
||||
result.meta = meta
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response: Any) -> str | None:
|
||||
"""Extract text content from a JSONRPCResponse."""
|
||||
if not isinstance(response, JSONRPCResponse) or response.result is None:
|
||||
return None
|
||||
result = response.result
|
||||
for attr in ("content", "contents"):
|
||||
items = getattr(result, attr, None)
|
||||
if items and isinstance(items, list):
|
||||
parts = [p for c in items if (p := getattr(c, "text", None))]
|
||||
if parts:
|
||||
return "\n".join(str(p) for p in parts)
|
||||
return None
|
||||
|
||||
async def _handle_with_telemetry(
|
||||
self,
|
||||
context: MiddlewareContext[Any],
|
||||
call_next: CallNext[Any, Any],
|
||||
span_name: str,
|
||||
span_attributes: dict[str, Any] | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
) -> Any:
|
||||
"""Common telemetry logic for tools/call and resources/read."""
|
||||
otel = self._extract_otel_meta(context)
|
||||
if not otel.get("request", False):
|
||||
return await call_next(context)
|
||||
|
||||
start_collecting()
|
||||
|
||||
try:
|
||||
with self._tracer.start_as_current_span(
|
||||
span_name,
|
||||
context=self._parent_context(otel.get("traceparent")),
|
||||
kind=SpanKind.SERVER,
|
||||
attributes=span_attributes or {},
|
||||
) as span:
|
||||
if tool_arguments:
|
||||
span.set_attribute(
|
||||
"gen_ai.input.messages",
|
||||
json.dumps([{"role": "user", "content": tool_arguments}]),
|
||||
)
|
||||
response = await call_next(context)
|
||||
output = self._extract_response_text(response)
|
||||
if output:
|
||||
span.set_attribute("gen_ai.tool.call.result", output[:500])
|
||||
span.set_attribute(
|
||||
"gen_ai.output.messages",
|
||||
json.dumps([{"role": "assistant", "content": output}]),
|
||||
)
|
||||
finally:
|
||||
collected = stop_collecting()
|
||||
|
||||
return self._attach_spans(response, collected, otel.get("detailed", False))
|
||||
|
||||
async def on_call_tool(
|
||||
self,
|
||||
context: MiddlewareContext[Any],
|
||||
call_next: CallNext[Any, Any],
|
||||
) -> Any:
|
||||
"""Intercept tools/call to collect and return server spans."""
|
||||
msg = context.message
|
||||
params = msg.get("params", {}) if isinstance(msg, dict) else {}
|
||||
name = params.get("name", "")
|
||||
arguments = params.get("arguments", {})
|
||||
return await self._handle_with_telemetry(
|
||||
context,
|
||||
call_next,
|
||||
span_name=f"tools/call {name}",
|
||||
span_attributes={
|
||||
"mcp.method": "tools/call",
|
||||
"mcp.tool": name,
|
||||
"gen_ai.system": "mcp",
|
||||
"gen_ai.operation.name": "execute_tool",
|
||||
"gen_ai.tool.name": name,
|
||||
"gen_ai.tool.call.arguments": json.dumps(arguments) if arguments else "",
|
||||
},
|
||||
tool_arguments=json.dumps(arguments) if arguments else None,
|
||||
)
|
||||
|
||||
async def on_read_resource(
|
||||
self,
|
||||
context: MiddlewareContext[Any],
|
||||
call_next: CallNext[Any, Any],
|
||||
) -> Any:
|
||||
"""Intercept resources/read to collect and return server spans."""
|
||||
msg = context.message
|
||||
params = msg.get("params", {}) if isinstance(msg, dict) else {}
|
||||
uri = params.get("uri", "")
|
||||
return await self._handle_with_telemetry(
|
||||
context,
|
||||
call_next,
|
||||
span_name=f"resources/read {uri}",
|
||||
span_attributes={
|
||||
"mcp.method": "resources/read",
|
||||
"mcp.resource.uri": uri,
|
||||
"gen_ai.system": "mcp",
|
||||
"gen_ai.operation.name": "execute_tool",
|
||||
"gen_ai.tool.name": f"resources/read {uri}",
|
||||
},
|
||||
tool_arguments=json.dumps({"uri": uri}) if uri else None,
|
||||
)
|
||||
|
|
@ -202,6 +202,7 @@ class MCPServer:
|
|||
|
||||
# Middleware chain
|
||||
self.middleware: list[Middleware] = []
|
||||
self._extra_capabilities: dict[str, Any] = {}
|
||||
self._init_middleware(middleware)
|
||||
|
||||
# Lifespan management
|
||||
|
|
@ -319,6 +320,10 @@ class MCPServer:
|
|||
if custom_middleware:
|
||||
self.middleware.extend(custom_middleware)
|
||||
|
||||
# Collect capabilities from middleware that declare them
|
||||
for mw in self.middleware:
|
||||
self._extra_capabilities.update(mw.get_capabilities())
|
||||
|
||||
def _register_handlers(self) -> dict[str, Callable]:
|
||||
"""Register method handlers."""
|
||||
return {
|
||||
|
|
@ -681,14 +686,18 @@ class MCPServer:
|
|||
if session:
|
||||
session.set_client_params(message.params)
|
||||
|
||||
caps_kwargs: dict[str, Any] = {
|
||||
"tools": {"listChanged": True},
|
||||
"logging": {},
|
||||
"prompts": {"listChanged": True},
|
||||
"resources": {"subscribe": True, "listChanged": True},
|
||||
}
|
||||
if self._extra_capabilities:
|
||||
caps_kwargs.update(self._extra_capabilities)
|
||||
|
||||
result = InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(
|
||||
tools={"listChanged": True},
|
||||
logging={},
|
||||
prompts={"listChanged": True},
|
||||
resources={"subscribe": True, "listChanged": True},
|
||||
),
|
||||
capabilities=ServerCapabilities(**caps_kwargs),
|
||||
serverInfo=Implementation(
|
||||
name=self.name,
|
||||
version=self.version,
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ dependencies = [
|
|||
"pydantic-settings>=2.10.1",
|
||||
"joserfc>=1.5.0",
|
||||
"httpx>=0.27.0,<1.0.0",
|
||||
"opentelemetry-api>=1.20.0",
|
||||
"opentelemetry-sdk>=1.20.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
529
libs/tests/arcade_mcp_server/test_telemetry_middleware.py
Normal file
529
libs/tests/arcade_mcp_server/test_telemetry_middleware.py
Normal file
|
|
@ -0,0 +1,529 @@
|
|||
"""Tests for Telemetry Passback Middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
|
||||
from opentelemetry.trace import SpanKind, StatusCode
|
||||
|
||||
from arcade_mcp_server.middleware.base import MiddlewareContext
|
||||
from arcade_mcp_server.middleware.telemetry import (
|
||||
ContextVarSpanCollector,
|
||||
TelemetryPassbackMiddleware,
|
||||
_attrs_to_kv,
|
||||
_ns,
|
||||
_request_spans,
|
||||
filter_top_level_spans,
|
||||
spans_to_otlp_json,
|
||||
start_collecting,
|
||||
stop_collecting,
|
||||
)
|
||||
from arcade_mcp_server.types import (
|
||||
CallToolResult,
|
||||
JSONRPCResponse,
|
||||
ReadResourceResult,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_span(
|
||||
*,
|
||||
name: str = "test-span",
|
||||
trace_id: int = 0xABCD,
|
||||
span_id: int = 0x1234,
|
||||
parent_span_id: int | None = None,
|
||||
kind: SpanKind = SpanKind.INTERNAL,
|
||||
start_time: int = 1000,
|
||||
end_time: int = 2000,
|
||||
status_code: StatusCode = StatusCode.UNSET,
|
||||
attributes: dict | None = None,
|
||||
events: list | None = None,
|
||||
) -> ReadableSpan:
|
||||
"""Build a lightweight mock ReadableSpan."""
|
||||
span = MagicMock(spec=ReadableSpan)
|
||||
ctx = MagicMock()
|
||||
ctx.trace_id = trace_id
|
||||
ctx.span_id = span_id
|
||||
span.context = ctx
|
||||
span.name = name
|
||||
span.kind = kind
|
||||
span.start_time = start_time
|
||||
span.end_time = end_time
|
||||
|
||||
status = MagicMock()
|
||||
status.status_code = status_code
|
||||
span.status = status
|
||||
|
||||
if parent_span_id is not None:
|
||||
parent = MagicMock()
|
||||
parent.span_id = parent_span_id
|
||||
span.parent = parent
|
||||
else:
|
||||
span.parent = None
|
||||
|
||||
span.attributes = attributes
|
||||
span.events = events or []
|
||||
return span
|
||||
|
||||
|
||||
def _make_tool_response(*, text: str = "hello") -> JSONRPCResponse[CallToolResult]:
|
||||
"""Build a JSONRPCResponse wrapping a CallToolResult."""
|
||||
return JSONRPCResponse(
|
||||
id="1",
|
||||
result=CallToolResult(
|
||||
content=[TextContent(type="text", text=text)],
|
||||
isError=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_resource_response(*, text: str = "hello") -> JSONRPCResponse[ReadResourceResult]:
|
||||
"""Build a JSONRPCResponse wrapping a ReadResourceResult."""
|
||||
return JSONRPCResponse(
|
||||
id="1",
|
||||
result=ReadResourceResult(
|
||||
contents=[TextResourceContents(uri="file:///test", text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# ContextVarSpanCollector
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestContextVarSpanCollector:
|
||||
"""Test ContextVarSpanCollector class."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_contextvar(self):
|
||||
"""Ensure the ContextVar is clean before and after each test."""
|
||||
_request_spans.set(None)
|
||||
yield
|
||||
_request_spans.set(None)
|
||||
|
||||
def test_spans_collected_when_bucket_active(self):
|
||||
collector = ContextVarSpanCollector()
|
||||
bucket = start_collecting()
|
||||
span = _make_span()
|
||||
collector.on_end(span)
|
||||
assert span in bucket
|
||||
|
||||
def test_spans_not_collected_when_bucket_inactive(self):
|
||||
collector = ContextVarSpanCollector()
|
||||
span = _make_span()
|
||||
collector.on_end(span)
|
||||
# No active bucket; span is silently dropped
|
||||
assert _request_spans.get() is None
|
||||
|
||||
def test_on_start_is_noop(self):
|
||||
collector = ContextVarSpanCollector()
|
||||
collector.on_start(MagicMock())
|
||||
|
||||
def test_shutdown_is_noop(self):
|
||||
collector = ContextVarSpanCollector()
|
||||
collector.shutdown()
|
||||
|
||||
def test_force_flush_returns_true(self):
|
||||
collector = ContextVarSpanCollector()
|
||||
assert collector.force_flush() is True
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# start_collecting / stop_collecting
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestCollectingLifecycle:
|
||||
"""Test start_collecting / stop_collecting."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_contextvar(self):
|
||||
_request_spans.set(None)
|
||||
yield
|
||||
_request_spans.set(None)
|
||||
|
||||
def test_lifecycle(self):
|
||||
bucket = start_collecting()
|
||||
assert bucket == []
|
||||
assert _request_spans.get() is bucket
|
||||
span = _make_span()
|
||||
bucket.append(span)
|
||||
result = stop_collecting()
|
||||
assert span in result
|
||||
assert _request_spans.get() is None
|
||||
|
||||
def test_idempotent_stop(self):
|
||||
start_collecting()
|
||||
stop_collecting()
|
||||
result = stop_collecting()
|
||||
assert result == []
|
||||
|
||||
def test_isolation_between_calls(self):
|
||||
bucket1 = start_collecting()
|
||||
bucket1.append(_make_span(name="a"))
|
||||
stop_collecting()
|
||||
|
||||
bucket2 = start_collecting()
|
||||
assert bucket2 == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# filter_top_level_spans
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFilterTopLevelSpans:
|
||||
def test_root_and_children_kept_grandchildren_dropped(self):
|
||||
root = _make_span(name="root", span_id=1, parent_span_id=999)
|
||||
child = _make_span(name="child", span_id=2, parent_span_id=1)
|
||||
grandchild = _make_span(name="grandchild", span_id=3, parent_span_id=2)
|
||||
result = filter_top_level_spans([root, child, grandchild])
|
||||
names = {s.name for s in result}
|
||||
assert names == {"root", "child"}
|
||||
|
||||
def test_empty_list(self):
|
||||
assert filter_top_level_spans([]) == []
|
||||
|
||||
def test_single_span_no_parent(self):
|
||||
span = _make_span(name="only", span_id=1)
|
||||
result = filter_top_level_spans([span])
|
||||
assert result == [span]
|
||||
|
||||
def test_single_span_with_external_parent(self):
|
||||
span = _make_span(name="only", span_id=1, parent_span_id=999)
|
||||
result = filter_top_level_spans([span])
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "only"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _attrs_to_kv
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAttrsToKv:
|
||||
def test_bool_value(self):
|
||||
result = _attrs_to_kv({"flag": True})
|
||||
assert result == [{"key": "flag", "value": {"boolValue": True}}]
|
||||
|
||||
def test_int_value(self):
|
||||
result = _attrs_to_kv({"count": 42})
|
||||
assert result == [{"key": "count", "value": {"intValue": "42"}}]
|
||||
|
||||
def test_float_value(self):
|
||||
result = _attrs_to_kv({"rate": 3.14})
|
||||
assert result == [{"key": "rate", "value": {"doubleValue": 3.14}}]
|
||||
|
||||
def test_str_value(self):
|
||||
result = _attrs_to_kv({"name": "foo"})
|
||||
assert result == [{"key": "name", "value": {"stringValue": "foo"}}]
|
||||
|
||||
def test_list_value(self):
|
||||
result = _attrs_to_kv({"tags": ["a", "b"]})
|
||||
assert result == [
|
||||
{"key": "tags", "value": {"arrayValue": {"values": [{"stringValue": "a"}, {"stringValue": "b"}]}}}
|
||||
]
|
||||
|
||||
def test_fallback_to_string(self):
|
||||
result = _attrs_to_kv({"obj": object()})
|
||||
assert len(result) == 1
|
||||
assert "stringValue" in result[0]["value"]
|
||||
|
||||
def test_empty_attrs(self):
|
||||
assert _attrs_to_kv({}) == []
|
||||
|
||||
def test_none_attrs(self):
|
||||
assert _attrs_to_kv(None) == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# spans_to_otlp_json
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSpansToOtlpJson:
|
||||
def test_full_serialization(self):
|
||||
event = MagicMock()
|
||||
event.timestamp = 1500
|
||||
event.name = "log"
|
||||
event.attributes = {"msg": "hi"}
|
||||
|
||||
span = _make_span(
|
||||
name="op",
|
||||
trace_id=0xABCDEF,
|
||||
span_id=0x123456,
|
||||
parent_span_id=0x111,
|
||||
kind=SpanKind.SERVER,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status_code=StatusCode.OK,
|
||||
attributes={"k": "v"},
|
||||
events=[event],
|
||||
)
|
||||
|
||||
result = spans_to_otlp_json([span], "test-service")
|
||||
assert "resourceSpans" in result
|
||||
rs = result["resourceSpans"][0]
|
||||
assert rs["resource"]["attributes"][0]["value"]["stringValue"] == "test-service"
|
||||
otlp_span = rs["scopeSpans"][0]["spans"][0]
|
||||
assert otlp_span["name"] == "op"
|
||||
assert otlp_span["kind"] == 2 # SERVER
|
||||
assert otlp_span["status"]["code"] == 1 # OK
|
||||
assert "parentSpanId" in otlp_span
|
||||
assert len(otlp_span["events"]) == 1
|
||||
|
||||
def test_span_without_parent(self):
|
||||
span = _make_span(name="root", span_id=1)
|
||||
result = spans_to_otlp_json([span], "svc")
|
||||
otlp_span = result["resourceSpans"][0]["scopeSpans"][0]["spans"][0]
|
||||
assert "parentSpanId" not in otlp_span
|
||||
|
||||
def test_span_with_no_context_skipped(self):
|
||||
span = _make_span()
|
||||
span.context = None
|
||||
result = spans_to_otlp_json([span], "svc")
|
||||
assert result["resourceSpans"][0]["scopeSpans"][0]["spans"] == []
|
||||
|
||||
def test_status_error(self):
|
||||
span = _make_span(status_code=StatusCode.ERROR)
|
||||
result = spans_to_otlp_json([span], "svc")
|
||||
otlp_span = result["resourceSpans"][0]["scopeSpans"][0]["spans"][0]
|
||||
assert otlp_span["status"]["code"] == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _ns
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestNs:
|
||||
def test_zero(self):
|
||||
assert _ns(0) == "0"
|
||||
|
||||
def test_none(self):
|
||||
assert _ns(None) == "0"
|
||||
|
||||
def test_positive(self):
|
||||
assert _ns(123456) == "123456"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TelemetryPassbackMiddleware
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def _make_otel_context(*, request: bool = True, detailed: bool = False, traceparent: str | None = None):
|
||||
"""Build a MiddlewareContext with otel meta configured."""
|
||||
session = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.traceparent = traceparent
|
||||
meta.otel = {"traces": {"request": request, "detailed": detailed}}
|
||||
session._request_meta = meta
|
||||
|
||||
mcp_ctx = MagicMock()
|
||||
mcp_ctx._session = session
|
||||
return mcp_ctx
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
class TestTelemetryPassbackMiddleware:
|
||||
"""Test TelemetryPassbackMiddleware class."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_contextvar(self):
|
||||
_request_spans.set(None)
|
||||
yield
|
||||
_request_spans.set(None)
|
||||
|
||||
@pytest.fixture
|
||||
def tracer_provider(self):
|
||||
return TracerProvider()
|
||||
|
||||
@pytest.fixture
|
||||
def middleware(self, tracer_provider):
|
||||
return TelemetryPassbackMiddleware(
|
||||
service_name="test-svc",
|
||||
tracer_provider=tracer_provider,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tool_call_context(self):
|
||||
"""Context for a tools/call with otel request=True."""
|
||||
return MiddlewareContext(
|
||||
message={"method": "tools/call", "params": {"name": "my_tool", "arguments": {"x": 1}}},
|
||||
mcp_context=_make_otel_context(request=True, detailed=False),
|
||||
method="tools/call",
|
||||
request_id="req-1",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tool_call_context_no_request(self):
|
||||
"""Context for a tools/call with otel request=False."""
|
||||
return MiddlewareContext(
|
||||
message={"method": "tools/call", "params": {"name": "my_tool", "arguments": {}}},
|
||||
mcp_context=_make_otel_context(request=False),
|
||||
method="tools/call",
|
||||
request_id="req-2",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def resource_read_context(self):
|
||||
"""Context for resources/read with otel request=True."""
|
||||
return MiddlewareContext(
|
||||
message={"method": "resources/read", "params": {"uri": "file:///foo.txt"}},
|
||||
mcp_context=_make_otel_context(request=True, detailed=False),
|
||||
method="resources/read",
|
||||
request_id="req-3",
|
||||
)
|
||||
|
||||
# --- get_capabilities ---
|
||||
|
||||
def test_get_capabilities(self, middleware):
|
||||
caps = middleware.get_capabilities()
|
||||
assert "serverExecutionTelemetry" in caps
|
||||
assert caps["serverExecutionTelemetry"]["signals"]["traces"]["supported"] is True
|
||||
|
||||
# --- on_call_tool: request=False passthrough ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_no_request_passthrough(self, middleware, tool_call_context_no_request):
|
||||
sentinel = object()
|
||||
|
||||
async def handler(ctx):
|
||||
return sentinel
|
||||
|
||||
result = await middleware.on_call_tool(tool_call_context_no_request, handler)
|
||||
assert result is sentinel
|
||||
|
||||
# --- on_call_tool: request=True attaches spans ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_with_request_attaches_spans(self, middleware, tool_call_context):
|
||||
resp = _make_tool_response(text="tool output")
|
||||
|
||||
async def handler(ctx):
|
||||
return resp
|
||||
|
||||
result = await middleware.on_call_tool(tool_call_context, handler)
|
||||
meta = result.result.meta
|
||||
assert meta is not None
|
||||
assert "otel" in meta
|
||||
assert "traces" in meta["otel"]
|
||||
assert "resourceSpans" in meta["otel"]["traces"]
|
||||
|
||||
# --- on_read_resource ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_read_resource_attaches_spans(self, middleware, resource_read_context):
|
||||
resp = _make_resource_response(text="resource content")
|
||||
|
||||
async def handler(ctx):
|
||||
return resp
|
||||
|
||||
result = await middleware.on_read_resource(resource_read_context, handler)
|
||||
meta = result.result.meta
|
||||
assert meta is not None
|
||||
assert "otel" in meta
|
||||
assert "traces" in meta["otel"]
|
||||
|
||||
# --- detailed=True vs detailed=False ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detailed_true_returns_all_spans(self, tracer_provider):
|
||||
mw = TelemetryPassbackMiddleware(service_name="svc", tracer_provider=tracer_provider)
|
||||
|
||||
context = MiddlewareContext(
|
||||
message={"method": "tools/call", "params": {"name": "t", "arguments": {}}},
|
||||
mcp_context=_make_otel_context(request=True, detailed=True),
|
||||
method="tools/call",
|
||||
)
|
||||
|
||||
resp = _make_tool_response()
|
||||
|
||||
async def handler(ctx):
|
||||
return resp
|
||||
|
||||
result = await mw.on_call_tool(context, handler)
|
||||
otel = result.result.meta["otel"]["traces"]
|
||||
assert otel["truncated"] is False
|
||||
assert otel["droppedSpanCount"] == 0
|
||||
|
||||
# --- traceparent propagation ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_traceparent_propagation(self, tracer_provider):
|
||||
mw = TelemetryPassbackMiddleware(service_name="svc", tracer_provider=tracer_provider)
|
||||
|
||||
context = MiddlewareContext(
|
||||
message={"method": "tools/call", "params": {"name": "t", "arguments": {}}},
|
||||
mcp_context=_make_otel_context(
|
||||
request=True,
|
||||
detailed=True,
|
||||
traceparent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
|
||||
),
|
||||
method="tools/call",
|
||||
)
|
||||
|
||||
resp = _make_tool_response()
|
||||
|
||||
async def handler(ctx):
|
||||
return resp
|
||||
|
||||
result = await mw.on_call_tool(context, handler)
|
||||
assert result.result.meta is not None
|
||||
assert "otel" in result.result.meta
|
||||
|
||||
# --- error in call_next — spans still cleaned up ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_in_call_next_cleans_up_spans(self, middleware, tool_call_context):
|
||||
async def handler(ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await middleware.on_call_tool(tool_call_context, handler)
|
||||
|
||||
assert _request_spans.get() is None
|
||||
|
||||
# --- _extract_response_text ---
|
||||
|
||||
def test_extract_response_text_with_content(self, middleware):
|
||||
resp = _make_tool_response(text="hello")
|
||||
assert middleware._extract_response_text(resp) == "hello"
|
||||
|
||||
def test_extract_response_text_with_contents(self, middleware):
|
||||
resp = _make_resource_response(text="world")
|
||||
assert middleware._extract_response_text(resp) == "world"
|
||||
|
||||
def test_extract_response_text_no_text(self, middleware):
|
||||
# Empty text is falsy, filtered out by _extract_response_text
|
||||
resp = _make_tool_response(text="")
|
||||
assert middleware._extract_response_text(resp) is None
|
||||
|
||||
def test_extract_response_text_non_jsonrpc(self, middleware):
|
||||
assert middleware._extract_response_text({"not": "a response"}) is None
|
||||
|
||||
def test_extract_response_text_none_result(self, middleware):
|
||||
resp = JSONRPCResponse(id="1", result={})
|
||||
assert middleware._extract_response_text(resp) is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Middleware base class default get_capabilities
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMiddlewareBaseGetCapabilities:
|
||||
def test_default_returns_empty_dict(self):
|
||||
from arcade_mcp_server.middleware.base import Middleware
|
||||
|
||||
mw = Middleware()
|
||||
assert mw.get_capabilities() == {}
|
||||
Loading…
Reference in a new issue