diff --git a/.gitignore b/.gitignore index 1f19c235..79d71009 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,7 @@ cython_debug/ # Docs libs/arcade-mcp-server/site/* + +# exclude OAuth tokens +.oauth_tokens.json +.oauth_client.json diff --git a/examples/mcp_servers/telemetry_passback/.env.example b/examples/mcp_servers/telemetry_passback/.env.example new file mode 100644 index 00000000..5de24524 --- /dev/null +++ b/examples/mcp_servers/telemetry_passback/.env.example @@ -0,0 +1,12 @@ +OPENAI_API_KEY=sk-... + +# Arcade auth (required — get your key at https://www.arcade.dev) +ARCADE_API_KEY= +ARCADE_USER_ID= +ARCADE_API_URL=https://api.arcade.dev + +# Optional: Galileo integration +GALILEO_API_KEY= +GALILEO_PROJECT= +GALILEO_LOG_STREAM=default +GALILEO_OTLP_ENDPOINT=https://app.galileo.ai/api/galileo/otel/traces diff --git a/examples/mcp_servers/telemetry_passback/README.md b/examples/mcp_servers/telemetry_passback/README.md new file mode 100644 index 00000000..5ea1d65a --- /dev/null +++ b/examples/mcp_servers/telemetry_passback/README.md @@ -0,0 +1,177 @@ +# SEP-2448: MCP server execution telemetry — Reference Implementation + +End-to-end reference implementation of **SEP-2448 `serverExecutionTelemetry`** — cross-organization distributed tracing via MCP. + +## Overview + +This example demonstrates how an MCP server can **pass back OpenTelemetry spans** to the calling client, enabling full distributed tracing across organizational boundaries. Without this capability, the server side of an MCP tool call is a black box — you can see *that* it was called, but not *what happened inside*. + +The example includes three components: + +1. **Server** (`server.py`) — An Arcade MCP server with Gmail tools that uses `TelemetryPassbackMiddleware` to collect and return spans. This shows how a **vendor adopts** the SEP. +2. **Agent** (`agent.py`) — A LangChain ReAct agent that requests span passback, receives server spans, and ingests them into Jaeger/Galileo. This shows how a **consumer uses** the SEP. +3. **Jaeger** (`docker-compose.yml`) — Local trace collector and UI for visualizing the stitched traces. + +## Prerequisites + +- Python 3.11+ +- [uv](https://docs.astral.sh/uv/) package manager +- Docker (for Jaeger) +- An [Arcade](https://www.arcade.dev) account ([quickstart](https://docs.arcade.dev/en/get-started/quickstarts/mcp-server-quickstart)) +- An OpenAI API key (for the LangChain agent) + +## Setup + +```bash +cd examples/mcp_servers/telemetry_passback + +# Copy env file and add your keys +cp .env.example .env +# Edit .env: set OPENAI_API_KEY, ARCADE_API_KEY, ARCADE_USER_ID + +# Install dependencies +uv sync + +# Start Jaeger +docker compose up -d +``` + +## Usage + +The server and agent run as **separate processes**. Start the server first, then run the agent in another terminal. + +### Start the Server + +```bash +# Terminal 1 +uv run python src/telemetry_passback/server.py +``` + +The server listens at `http://127.0.0.1:8000/mcp` with OAuth 2.1 resource server auth via Arcade. + +### Run the Agent + +In a separate terminal. On first run, the MCP SDK will open your browser for OAuth authorization (one-time). + +#### Act 1 — "The Black Box" (no passback) + +```bash +uv run python src/telemetry_passback/agent.py --no-passback "List my 3 most recent emails" +``` + +Open Jaeger at [http://localhost:16686](http://localhost:16686): you see agent LLM reasoning spans + one opaque `mcp.call_tool` CLIENT span. The tool call took ~3 seconds but there's no way to tell why. Is it the LLM? The network? Auth? The Gmail API? Everything inside the server is invisible. + +#### Act 2 — "The Revelation" (with passback) + +```bash +uv run python src/telemetry_passback/agent.py --detailed "List my 3 most recent emails" +``` + +Same call, but now the span tree reveals the server's internal structure: + +``` +mcp-gmail-agent +├── LangChain agent reasoning +├── ChatOpenAI (LLM decides to call tool) +├── mcp.call_tool list_emails (CLIENT) +│ └── tools/call list_emails (SERVER) ← FROM SPAN PASSBACK +│ ├── auth.validate (50ms) +│ ├── gmail.list_messages (400ms) +│ │ └── GET messages (HTTP) +│ ├── gmail.fetch_details (1.6s) ← bottleneck! +│ │ ├── GET messages/abc (HTTP, 520ms) +│ │ ├── GET messages/def (HTTP, 510ms) +│ │ └── GET messages/ghi (HTTP, 530ms) +│ └── format_response (5ms) +└── ChatOpenAI (LLM — final answer) +``` + +Now the consumer can see exactly what's happening: auth is fast, listing is fine, but **detail fetching is sequential** — three HTTP calls in a waterfall. Armed with this information, the consumer can: + +- **File an informed bug report** to the server vendor: "your `list_emails` has an N+1 in detail fetching — each email triggers a sequential HTTP call" +- **Adjust their usage**: request fewer emails, use a query filter to reduce N +- **Make an informed vendor choice**: compare span trees across MCP server providers + +This is the core value of the SEP — **the consumer doesn't need access to the server's code or deployment to understand its performance characteristics**. + +### Granularity Control + +The `--detailed` flag demonstrates the SEP's span filtering. Without it, the server returns only top-level phase spans (auth, list, fetch, format). With `--detailed`, the full tree including HTTP child spans is returned. This lets the server vendor control how much internal detail is exposed. + +```bash +# Top-level phases only (default) +uv run python src/telemetry_passback/agent.py "List my 3 most recent emails" + +# Full span tree including HTTP child spans +uv run python src/telemetry_passback/agent.py --detailed "List my 3 most recent emails" +``` + +### CLI Options + +| Flag | Default | Description | +|------|---------|-------------| +| `query` | `"List my 5 most recent emails"` | The question to ask the agent | +| `--detailed` | `false` | Request full span tree | +| `--no-passback` | `false` | Disable span passback (Act 1 — server is a black box) | +| `--server-url` | `http://127.0.0.1:8000/mcp` | MCP server URL | + +## Expected Results in Jaeger + +Open [http://localhost:16686](http://localhost:16686) and search for service **`mcp-gmail-agent`**. + +| Mode | What you see | +|------|-------------| +| `--no-passback` | Only agent-side spans: LLM calls + opaque `mcp.call_tool`. Server is a black box. | +| Default | Server phase spans stitched into the same trace: `auth.validate`, `gmail.list_messages`, `gmail.fetch_details`, `format_response`. | +| `--detailed` | Full span tree: phase spans plus HTTP child spans under each phase, revealing the sequential N+1 pattern in `gmail.fetch_details`. | + +## Architecture + +``` +┌─────────────────────────┐ HTTP (streamable) ┌──────────────────────────┐ +│ agent.py │ ───────────────────────>│ server.py │ +│ (LangChain ReAct) │ :8000/mcp │ (Arcade MCP Server) │ +│ │ │ │ +│ OAuth 2.1 via MCP SDK │ traceparent in _meta │ OAuth 2.1 (Arcade) │ +│ OTel → Jaeger/Galileo │ ───────────────────────>│ OTel (internal only) │ +│ │ spans back in _meta │ TelemetryPassback MW │ +│ │ <───────────────────────│ │ +└─────────────────────────┘ └──────────────────────────┘ + │ │ + └──────────── Stitched trace in Jaeger ───────────────┘ +``` + +### How It Works + +**Server side** (`server.py`): +1. Validates Bearer tokens via `ArcadeResourceServerAuth` (OAuth 2.1, RFC 9728 discovery) +2. `TelemetryPassbackMiddleware` intercepts `tools/call` requests +3. Reads `_meta.traceparent` and `_meta.otel.traces.{request, detailed}` +4. Creates a SERVER span under the client's trace (via traceparent propagation) +5. Tool function creates logical-phase spans with `gen_ai.*` semantic conventions +6. httpx auto-instrumentation creates HTTP child spans for Gmail API calls +7. Middleware serializes to OTLP JSON and attaches to `response._meta.otel.traces` + +**Client side** (`agent.py`): +1. MCP SDK handles OAuth 2.1 automatically (discovers auth server on 401, PKCE flow, token caching) +2. Connects to the server via streamable HTTP, detects `serverExecutionTelemetry` capability +3. For each tool call, creates a CLIENT span and injects `traceparent` in `_meta` +4. Sends `_meta.otel.traces.request: true` to opt into span passback +5. Receives server spans in response `_meta.otel.traces.resourceSpans` +6. POSTs OTLP JSON to Jaeger for trace stitching +7. Optionally exports to Galileo (protobuf) if `GALILEO_API_KEY` is set + +## Configuration + +Copy `.env.example` to `.env`: + +| Variable | Default | Description | +|----------|---------|-------------| +| `OPENAI_API_KEY` | (required) | OpenAI API key for the LangChain agent | +| `ARCADE_API_KEY` | (required) | Arcade API key | +| `ARCADE_USER_ID` | (required) | Your Arcade account email | +| `ARCADE_API_URL` | `https://api.arcade.dev` | Arcade API endpoint | +| `GALILEO_API_KEY` | (optional) | Enables export to Galileo alongside Jaeger | +| `GALILEO_PROJECT` | (optional) | Galileo project name | +| `GALILEO_LOG_STREAM` | `default` | Galileo log stream | +| `GALILEO_OTLP_ENDPOINT` | `https://app.galileo.ai/api/galileo/otel/traces` | Galileo OTLP endpoint | diff --git a/examples/mcp_servers/telemetry_passback/docker-compose.yml b/examples/mcp_servers/telemetry_passback/docker-compose.yml new file mode 100644 index 00000000..5dda6ac1 --- /dev/null +++ b/examples/mcp_servers/telemetry_passback/docker-compose.yml @@ -0,0 +1,7 @@ +services: + jaeger: + image: jaegertracing/all-in-one:latest + ports: + - "16686:16686" # Jaeger UI + - "4317:4317" # OTLP gRPC + - "4318:4318" # OTLP HTTP diff --git a/examples/mcp_servers/telemetry_passback/pyproject.toml b/examples/mcp_servers/telemetry_passback/pyproject.toml new file mode 100644 index 00000000..c9bbb652 --- /dev/null +++ b/examples/mcp_servers/telemetry_passback/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "telemetry-passback" +version = "0.1.0" +description = "SEP-0000 reference implementation: cross-org distributed tracing via MCP serverExecutionTelemetry" +requires-python = ">=3.11" +dependencies = [ + "arcade-mcp-server>=1.18.0,<2.0.0", + "mcp>=1.26.0", + "langchain>=1.2.13", + "langchain-community>=0.4.1", + "langchain-openai>=1.1.12", + "langgraph>=1.1.3", + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-instrumentation-langchain", + "opentelemetry-instrumentation-httpx", + "httpx>=0.28.1", + "python-dotenv>=1.2.2", + "protobuf", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "mypy>=1.0.0", + "ruff>=0.1.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/telemetry_passback"] + +# Use local arcade-mcp-server (TelemetryPassbackMiddleware is not yet published) +[tool.uv.sources] +arcade-mcp-server = { path = "../../../libs/arcade-mcp-server/", editable = true } +arcade-serve = { path = "../../../libs/arcade-serve/", editable = true } diff --git a/examples/mcp_servers/telemetry_passback/src/telemetry_passback/__init__.py b/examples/mcp_servers/telemetry_passback/src/telemetry_passback/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/mcp_servers/telemetry_passback/src/telemetry_passback/agent.py b/examples/mcp_servers/telemetry_passback/src/telemetry_passback/agent.py new file mode 100644 index 00000000..505d9df0 --- /dev/null +++ b/examples/mcp_servers/telemetry_passback/src/telemetry_passback/agent.py @@ -0,0 +1,581 @@ +"""LangChain ReAct agent with SEP-2448: MCP server execution telemetry. + +Consumer-side reference implementation of cross-org distributed tracing: + +* Connects to an Arcade MCP Gmail server via streamable HTTP. +* Authenticates using MCP OAuth 2.1 (automatic via the MCP SDK). +* Detects the ``serverExecutionTelemetry`` capability. +* Dynamically discovers tools and wraps them with span passback. +* Passes ``traceparent`` + requests span passback via ``_meta.otel``. +* Receives server spans and ingests them into Jaeger / Galileo. +* Handles Google OAuth authorization flow for Gmail (one-time consent). + +Two-act demo: + Act 1 (--no-passback): Opaque tool call -- server is a black box + Act 2 (default): Rich span tree via passback reveals server internals +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import os +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from typing import Any +from urllib.parse import parse_qs, urlparse + +import httpx +from dotenv import load_dotenv +from langchain.agents import create_agent +from langchain_core.tools import StructuredTool +from langchain_openai import ChatOpenAI +from mcp import ClientSession +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter as OTLPHTTPSpanExporter, +) +from opentelemetry.instrumentation.langchain import LangchainInstrumentor +from opentelemetry.instrumentation.langchain.callback_handler import ( + TraceloopCallbackHandler, +) +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.trace import SpanKind, set_span_in_context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from pydantic import Field, create_model + +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +load_dotenv(_PROJECT_ROOT / ".env") + +log = logging.getLogger(__name__) + +JAEGER_GRPC = "localhost:4317" +JAEGER_HTTP = "http://localhost:4318/v1/traces" +JAEGER_UI = "http://localhost:16686" + +_passback_stats: dict[str, Any] = {"span_count": 0, "truncated": False, "dropped": 0} + + +# -------------------------------------- +# MCP OAuth 2.1 (handled by the MCP SDK +# -------------------------------------- + +_OAUTH_TOKEN_FILE = _PROJECT_ROOT / ".oauth_tokens.json" +_OAUTH_CLIENT_FILE = _PROJECT_ROOT / ".oauth_client.json" +_CALLBACK_PORT = 9905 + + +class FileTokenStorage(TokenStorage): + """Persist OAuth tokens and client registration to disk between runs.""" + + async def get_tokens(self) -> OAuthToken | None: + if _OAUTH_TOKEN_FILE.exists(): + return OAuthToken.model_validate_json(_OAUTH_TOKEN_FILE.read_text()) + return None + + async def set_tokens(self, tokens: OAuthToken) -> None: + _OAUTH_TOKEN_FILE.write_text(tokens.model_dump_json()) + + async def get_client_info(self) -> OAuthClientInformationFull | None: + if _OAUTH_CLIENT_FILE.exists(): + return OAuthClientInformationFull.model_validate_json(_OAUTH_CLIENT_FILE.read_text()) + return None + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + _OAUTH_CLIENT_FILE.write_text(client_info.model_dump_json()) + + +async def _handle_redirect(authorization_url: str) -> None: + """Open the browser for OAuth consent.""" + print("\n Opening browser for authorization...") + print(f" URL: {authorization_url}\n") + webbrowser.open(authorization_url) + + +async def _handle_callback() -> tuple[str, str | None]: + """Start a local HTTP server, wait for the OAuth redirect, extract the code.""" + loop = asyncio.get_event_loop() + future: asyncio.Future[tuple[str, str | None]] = loop.create_future() + + class _Handler(BaseHTTPRequestHandler): + def do_GET(self) -> None: + qs = parse_qs(urlparse(self.path).query) + code = qs.get("code", [None])[0] + state = qs.get("state", [None])[0] + if code: + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"

Authorization successful!

You can close this tab.

" + ) + loop.call_soon_threadsafe(future.set_result, (code, state)) + else: + error = qs.get("error", ["unknown"])[0] + self.send_response(400) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write(f"

Authorization failed: {error}

".encode()) + loop.call_soon_threadsafe( + future.set_exception, RuntimeError(f"OAuth error: {error}") + ) + + def log_message(self, fmt: str, *args: Any) -> None: + pass + + server = HTTPServer(("127.0.0.1", _CALLBACK_PORT), _Handler) + + def _serve() -> None: + server.handle_request() + server.server_close() + + await loop.run_in_executor(None, _serve) + return await future + + +# ----------------------- +# Span ingestion helpers +# ----------------------- + + +def _count_spans(resource_spans: list[dict[str, Any]]) -> int: + return sum(len(ss.get("spans", [])) for rs in resource_spans for ss in rs.get("scopeSpans", [])) + + +def _hex_ids_to_base64(resource_spans: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert hex trace/span IDs to base64 for protobuf ``ParseDict``.""" + import base64 + import copy + + _ID_FIELDS = ("traceId", "spanId", "parentSpanId") + converted = copy.deepcopy(resource_spans) + for rs in converted: + for ss in rs.get("scopeSpans", []): + for span in ss.get("spans", []): + for fld in _ID_FIELDS: + if span.get(fld): + span[fld] = base64.b64encode(bytes.fromhex(span[fld])).decode() + return converted + + +def ingest_spans_json( + otlp_json: dict[str, Any], + endpoint: str = JAEGER_HTTP, + headers: dict[str, str] | None = None, + label: str = "collector", +) -> None: + """POST OTLP JSON spans to an OTLP HTTP endpoint (e.g. Jaeger).""" + count = _count_spans(otlp_json.get("resourceSpans", [])) + hdrs = {"Content-Type": "application/json", **(headers or {})} + try: + httpx.post(endpoint, json=otlp_json, headers=hdrs).raise_for_status() + log.info("Ingested %d server span(s) into %s", count, label) + except httpx.ConnectError: + log.exception("Could not connect to %s at %s", label, endpoint) + except httpx.HTTPStatusError as exc: + log.exception( + "%s returned %d: %s", label, exc.response.status_code, exc.response.text[:200] + ) + + +def ingest_spans_protobuf( + resource_spans: list[dict[str, Any]], + endpoint: str, + headers: dict[str, str], + label: str = "Galileo", +) -> None: + """POST OTLP protobuf spans (required by Galileo).""" + try: + from google.protobuf.json_format import ParseDict + from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceRequest, + ) + except ImportError: + log.warning("Skipping %s protobuf export (missing protobuf deps)", label) + return + if not resource_spans: + return + b64_spans = _hex_ids_to_base64(resource_spans) + body = ParseDict({"resourceSpans": b64_spans}, ExportTraceServiceRequest()).SerializeToString() + try: + resp = httpx.post( + endpoint, content=body, headers={"Content-Type": "application/x-protobuf", **headers} + ) + log.info( + "Exported %d server span(s) to %s (HTTP %d)", + _count_spans(resource_spans), + label, + resp.status_code, + ) + except httpx.ConnectError: + log.exception("Could not connect to %s at %s", label, endpoint) + + +# ---------------- +# Galileo config +# ---------------- + + +def _galileo_config() -> tuple[str, dict[str, str]] | None: + """Return ``(endpoint, headers)`` for Galileo, or ``None`` if not configured.""" + api_key = os.environ.get("GALILEO_API_KEY") + if not api_key: + return None + return ( + os.environ.get("GALILEO_OTLP_ENDPOINT", "https://api.galileo.ai/otel/v1/traces"), + { + "Galileo-API-Key": api_key, + "project": os.environ.get("GALILEO_PROJECT", "mcp-cross-org-observability"), + "logstream": os.environ.get("GALILEO_LOG_STREAM", "default"), + }, + ) + + +# --------------- +# Tracing setup +# --------------- + + +def setup_tracing() -> TracerProvider: + """Jaeger (gRPC) + optional Galileo (OTLP HTTP) + LangChain auto-instrumentation.""" + resource = Resource.create({"service.name": "mcp-gmail-agent"}) + provider = TracerProvider(resource=resource) + + provider.add_span_processor( + BatchSpanProcessor(OTLPSpanExporter(endpoint=JAEGER_GRPC, insecure=True)) + ) + + gc = _galileo_config() + if gc: + provider.add_span_processor( + BatchSpanProcessor(OTLPHTTPSpanExporter(endpoint=gc[0], headers=gc[1])) + ) + log.info("Galileo tracing enabled -> %s", gc[0]) + + trace.set_tracer_provider(provider) + LangchainInstrumentor().instrument() + return provider + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="LangChain agent with SEP-0000 span passback") + p.add_argument("query", nargs="?", default="List my 5 most recent emails") + p.add_argument("--detailed", action="store_true", help="Request full span tree") + p.add_argument( + "--no-passback", + action="store_true", + help="Disable span passback (before SEP-0000 -- server is a black box)", + ) + p.add_argument( + "--server-url", + default="http://127.0.0.1:8000/mcp", + help="MCP server URL (default: http://127.0.0.1:8000/mcp)", + ) + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Instrumentor span look-up +# --------------------------------------------------------------------------- + + +def _find_instrumentor_span(tool_name: str) -> trace.Span | None: + """Locate the LangChain instrumentor's span for the active tool invocation.""" + try: + from langchain_core.callbacks import BaseCallbackManager + + mgr = BaseCallbackManager(handlers=[]) + for handler in mgr.inheritable_handlers: + if isinstance(handler, TraceloopCallbackHandler): + for run_id in reversed(list(handler.spans)): + holder = handler.spans[run_id] + span = getattr(holder, "span", None) + if span and tool_name in getattr(span, "name", ""): + return span + except Exception: + log.debug("Could not look up instrumentor span for %s", tool_name, exc_info=True) + return None + + +# -------------------------- +# Passback span processing +# -------------------------- + + +def _extract_trace_id(resource_spans: list[dict[str, Any]]) -> str | None: + """Return the first traceId found in resource_spans, or None.""" + for rs in resource_spans: + for ss in rs.get("scopeSpans", []): + for sp in ss.get("spans", []): + tid = sp.get("traceId") + if tid: + return tid + return None + + +def _process_passback_spans(meta: dict[str, Any] | None) -> None: + """Extract server spans from response ``_meta`` and ingest into collectors.""" + if not meta: + return + otel_data = meta.get("otel") if isinstance(meta, dict) else getattr(meta, "otel", None) + if not otel_data: + return + traces = otel_data.get("traces", {}) + resource_spans = traces.get("resourceSpans") + if not resource_spans: + return + + span_count = _count_spans(resource_spans) + truncated = traces.get("truncated", False) + dropped = traces.get("droppedSpanCount", 0) + + _passback_stats["span_count"] += span_count + _passback_stats["truncated"] = _passback_stats["truncated"] or truncated + _passback_stats["dropped"] += dropped + if not _passback_stats.get("trace_id"): + _passback_stats["trace_id"] = _extract_trace_id(resource_spans) or "" + + print(f" Server-side spans: {span_count} received and ingested") + if truncated: + print(f" ({dropped} additional spans available with --detailed)") + + ingest_spans_json({"resourceSpans": resource_spans}, endpoint=JAEGER_HTTP, label="Jaeger") + + gc = _galileo_config() + if gc: + ingest_spans_protobuf(resource_spans, endpoint=gc[0], headers=gc[1]) + + +# ---------------------------------------------------- +# OAuth helpers (Gmail tool authorization via Arcade) +# ---------------------------------------------------- + + +def _extract_auth_url(result: Any) -> str | None: + """Check if a tool result contains an authorization URL (OAuth required).""" + for item in result.content: + text = getattr(item, "text", None) + if text and "authorization_url" in text: + try: + data = json.loads(text) + return data.get("authorization_url") + except (json.JSONDecodeError, TypeError): + pass + structured = getattr(result, "structuredContent", None) + if isinstance(structured, dict): + return structured.get("authorization_url") + return None + + +# --------------------------------------------------------------------------- +# Dynamic MCP tool wrappers +# --------------------------------------------------------------------------- + + +def build_mcp_tools( + session: ClientSession, + mcp_tools: list[Any], + tracer: trace.Tracer, + propagator: TraceContextTextMapPropagator, + detailed: bool, + passback: bool = True, +) -> list[StructuredTool]: + """Create LangChain tools from MCP tool definitions, with optional span passback.""" + tools: list[StructuredTool] = [] + + for mcp_tool in mcp_tools: + name = mcp_tool.name + desc = mcp_tool.description or f"MCP tool: {name}" + schema = mcp_tool.inputSchema + props = schema.get("properties", {}) + required = set(schema.get("required", [])) + + def _make_fn(t_name: str) -> Any: + async def _call(**kwargs: str) -> str: + parent_ctx = None + inst_span = _find_instrumentor_span(t_name) + if inst_span: + parent_ctx = set_span_in_context(inst_span) + + with tracer.start_as_current_span( + f"mcp.call_tool {t_name}", + context=parent_ctx, + kind=SpanKind.CLIENT, + attributes={ + "mcp.tool": t_name, + "mcp.server": "mcp-gmail-server", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.name": f"mcp.call_tool {t_name}", + "gen_ai.tool.call.arguments": json.dumps(kwargs), + }, + ) as span: + carrier: dict[str, str] = {} + propagator.inject(carrier) + + meta: dict[str, Any] = { + "traceparent": carrier.get("traceparent", ""), + } + if passback: + meta["otel"] = {"traces": {"request": True, "detailed": detailed}} + + result = await session.call_tool(t_name, arguments=kwargs, meta=meta) + + # Handle Gmail OAuth if needed (one-time consent) + auth_url = _extract_auth_url(result) + if auth_url: + span.set_attribute("mcp.auth.required", True) + print(f"\n Authorization required. Open this URL:\n\n {auth_url}\n") + await asyncio.get_event_loop().run_in_executor( + None, + input, + " Press Enter after authorizing...", + ) + result = await session.call_tool(t_name, arguments=kwargs, meta=meta) + + text = result.content[0].text if result.content else "" + span.set_attribute("gen_ai.tool.call.result", text[:500]) + + if passback: + _process_passback_spans(result.meta) + else: + print(" Server-side spans: NONE (passback not requested)") + + return text + + return _call + + fields = {} + for pname, pinfo in props.items(): + pdesc = pinfo.get("description", "") + if pname in required: + fields[pname] = (str, Field(description=pdesc)) + else: + fields[pname] = (str, Field(default="", description=pdesc)) + + tools.append( + StructuredTool( + name=name, + description=desc, + coroutine=_make_fn(name), + args_schema=create_model(f"{name}Args", **fields), + ) + ) + + return tools + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + logging.basicConfig(level=logging.INFO, format=" %(message)s") + + args = parse_args() + provider = setup_tracing() + tracer = trace.get_tracer("mcp-gmail-agent") + propagator = TraceContextTextMapPropagator() + passback = not args.no_passback + + if args.no_passback: + mode = "Act 1: The Black Box (no passback)" + elif args.detailed: + mode = "Act 2: The Revelation (detailed span tree)" + else: + mode = "Act 2: The Revelation (span passback)" + + print(f"\n{'=' * 60}") + print(f" {mode}") + print(f"{'=' * 60}\n") + + server_url = args.server_url + + # MCP SDK handles OAuth 2.1 automatically: + # On 401 → discovers auth server via RFC 9728 → PKCE flow → caches tokens + oauth_auth = OAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="mcp-gmail-agent", + redirect_uris=["http://127.0.0.1:9905/callback"], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", # noqa: S106 - OAuth 2.1 public client (no secret) + ), + storage=FileTokenStorage(), + redirect_handler=_handle_redirect, + callback_handler=_handle_callback, + ) + http_client = httpx.AsyncClient(auth=oauth_auth) + print(f" Connecting to MCP server at {server_url} ...") + + async with ( + streamable_http_client(url=server_url, http_client=http_client) as (read, write, _), + ClientSession(read, write) as session, + ): + init = await session.initialize() + + telemetry_cap = getattr(init.capabilities, "serverExecutionTelemetry", None) + print(f" Server: {init.serverInfo.name} v{init.serverInfo.version}") + print(f" serverExecutionTelemetry: {telemetry_cap is not None}") + if telemetry_cap: + print(f" Capability: {telemetry_cap}") + + discovered = await session.list_tools() + print(f" Tools: {[t.name for t in discovered.tools]}\n") + + lc_tools = build_mcp_tools( + session, + discovered.tools, + tracer, + propagator, + detailed=args.detailed, + passback=passback, + ) + + agent = create_agent(ChatOpenAI(model="gpt-4o-mini"), lc_tools) + + print(f" Query: {args.query}") + if passback: + print(f" Detailed: {args.detailed}") + print() + + result = await agent.ainvoke({"messages": [("user", args.query)]}) + + print(f"\n Agent: {result['messages'][-1].content}\n") + + provider.force_flush() + provider.shutdown() + + trace_id = _passback_stats.get("trace_id", "") + if trace_id: + print(f" Jaeger UI: {JAEGER_UI}/trace/{trace_id}") + else: + print(f" Jaeger UI: {JAEGER_UI} (search for service mcp-gmail-agent)") + print(" Service: mcp-gmail-agent") + if args.no_passback: + print(" Expected: only agent-side spans (server is a black box)") + elif args.detailed: + print( + " Expected: full span tree -- auth.validate, gmail.list_messages," + " gmail.fetch_details (with HTTP child spans), format_response" + ) + else: + print( + " Expected: server phases visible -- auth.validate," + " gmail.list_messages, gmail.fetch_details, format_response" + ) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mcp_servers/telemetry_passback/src/telemetry_passback/server.py b/examples/mcp_servers/telemetry_passback/src/telemetry_passback/server.py new file mode 100644 index 00000000..18dd6a8c --- /dev/null +++ b/examples/mcp_servers/telemetry_passback/src/telemetry_passback/server.py @@ -0,0 +1,260 @@ +"""Gmail MCP server with SEP-2448 MCP server execution telemetry. + +Vendor-side reference implementation of cross-org distributed tracing: + +* Advertises ``serverExecutionTelemetry`` via ``TelemetryPassbackMiddleware``. +* Gmail tools (list_emails, send_email) using Google OAuth via Arcade. +* Rich server-side instrumentation creates logical-phase spans + (auth, API calls, formatting) that the middleware automatically collects. +* ``HTTPXClientInstrumentor`` auto-instruments Gmail API HTTP calls as child spans. +* Does NOT export spans externally — simulates a vendor with its own backend. +""" + +import base64 +import json +import logging +import sys +from email.mime.text import MIMEText +from pathlib import Path +from typing import Annotated, cast + +import httpx +from dotenv import load_dotenv + +load_dotenv(Path(__file__).resolve().parent.parent.parent / ".env") + +from arcade_mcp_server import Context, MCPApp # noqa: E402 +from arcade_mcp_server.auth import Google # noqa: E402 +from arcade_mcp_server.mcp_app import TransportType # noqa: E402 +from arcade_mcp_server.middleware.telemetry import TelemetryPassbackMiddleware # noqa: E402 +from arcade_mcp_server.resource_server import ( # noqa: E402 + AuthorizationServerEntry, + ResourceServerAuth, +) +from arcade_mcp_server.resource_server.base import ResourceOwner # noqa: E402 +from opentelemetry import trace # noqa: E402 +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor # noqa: E402 +from opentelemetry.sdk.trace import TracerProvider # noqa: E402 + +# ------------------------------------------------------------ +# OpenTelemetry — server-internal only (no external exporter) +# ------------------------------------------------------------ + +provider = TracerProvider() +telemetry_mw = TelemetryPassbackMiddleware( + service_name="mcp-gmail-server", + tracer_provider=provider, +) +trace.set_tracer_provider(provider) + +tracer = trace.get_tracer("mcp-gmail-server") +_log = logging.getLogger("mcp-gmail-server") + + +async def _async_request_hook(span, request): + """Capture request method + URL with gen_ai semantic conventions.""" + method = request.method.decode() if isinstance(request.method, bytes) else str(request.method) + url = str(request.url) + endpoint = url.split("?")[0].split("/")[-1] or "/" + span.update_name(f"{method} {endpoint}") + span.set_attribute("gen_ai.system", "mcp") + span.set_attribute("gen_ai.operation.name", "execute_tool") + span.set_attribute("gen_ai.tool.name", f"{method} {endpoint}") + span.set_attribute("gen_ai.tool.call.arguments", json.dumps({"method": method, "url": url})) + + +async def _async_response_hook(span, request, response): + """Capture response status with gen_ai semantic conventions.""" + method = request.method.decode() if isinstance(request.method, bytes) else str(request.method) + url = str(request.url) + span.set_attribute( + "gen_ai.tool.call.result", + json.dumps({ + "status": response.status_code, + "url": url, + "method": method, + }), + ) + + +HTTPXClientInstrumentor().instrument( + async_request_hook=_async_request_hook, + async_response_hook=_async_response_hook, +) + +# ----------------------- +# Arcade MCP application +# ----------------------- + +GMAIL_API = "https://gmail.googleapis.com/gmail/v1/users/me" +GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly" +GMAIL_SEND_SCOPE = "https://www.googleapis.com/auth/gmail.send" + +# -------------------------------------------- +# Resource server auth (OAuth 2.1 via Arcade) +# -------------------------------------------- + +CANONICAL_URL = "http://127.0.0.1:8000/mcp" + + +class ArcadeResourceServerAuth(ResourceServerAuth): + """ResourceServerAuth that uses the ``email`` claim as user_id. + + Arcade's tool authorization identifies users by email, but the default + ``JWKSTokenValidator`` uses the ``sub`` claim (a UUID). This override + swaps user_id to the email so Arcade can match the authorized user. + """ + + async def validate_token(self, token: str) -> ResourceOwner: + owner = await super().validate_token(token) + email = owner.claims.get("email") + if email: + owner.user_id = email + return owner + + +resource_server_auth = ArcadeResourceServerAuth( + canonical_url=CANONICAL_URL, + authorization_servers=[ + AuthorizationServerEntry( + authorization_server_url="https://cloud.arcade.dev/oauth2", + issuer="https://cloud.arcade.dev/oauth2", + jwks_uri="https://cloud.arcade.dev/.well-known/jwks/oauth2", + algorithm="Ed25519", + expected_audiences=[CANONICAL_URL], + ), + ], +) + +app = MCPApp( + name="mcp_gmail_server", + version="0.1.0", + instructions=( + "Gmail server with cross-org observability " "via SEP-0000 serverExecutionTelemetry." + ), + auth=resource_server_auth, + middleware=[telemetry_mw], +) + +# ------ +# Tools +# ------ + + +@app.tool(requires_auth=Google(scopes=[GMAIL_READONLY_SCOPE])) +async def list_emails( + context: Context, + max_results: Annotated[int, "Maximum number of emails to return"] = 5, + query: Annotated[str, "Gmail search query (e.g. 'is:unread')"] = "", +) -> Annotated[dict, "Dict with 'emails' key containing list of recent emails"]: + """List recent emails from the user's Gmail inbox.""" + token = context.get_auth_token_or_empty() + + with tracer.start_as_current_span("auth.validate") as auth_span: + auth_span.set_attribute("gen_ai.system", "mcp") + auth_span.set_attribute("gen_ai.operation.name", "execute_tool") + auth_span.set_attribute("gen_ai.tool.name", "auth.validate") + auth_span.set_attribute("auth.method", "oauth2_bearer") + + params: dict = {"maxResults": min(max(max_results, 1), 20)} + if query: + params["q"] = query + + async with httpx.AsyncClient() as client: + with tracer.start_as_current_span("gmail.list_messages") as list_span: + list_span.set_attribute("gen_ai.system", "mcp") + list_span.set_attribute("gen_ai.operation.name", "execute_tool") + list_span.set_attribute("gen_ai.tool.name", "gmail.list_messages") + list_span.set_attribute("gmail.max_results", params["maxResults"]) + list_resp = await client.get( + f"{GMAIL_API}/messages", + headers={"Authorization": f"Bearer {token}"}, + params=params, + ) + list_resp.raise_for_status() + messages = list_resp.json().get("messages", []) + list_span.set_attribute("gmail.message_count", len(messages)) + + with tracer.start_as_current_span("gmail.fetch_details") as details_span: + details_span.set_attribute("gen_ai.system", "mcp") + details_span.set_attribute("gen_ai.operation.name", "execute_tool") + details_span.set_attribute("gen_ai.tool.name", "gmail.fetch_details") + details_span.set_attribute("gmail.fetch_count", len(messages)) + + results = [] + for msg_ref in messages: + detail_resp = await client.get( + f"{GMAIL_API}/messages/{msg_ref['id']}", + headers={"Authorization": f"Bearer {token}"}, + params={ + "format": "metadata", + "metadataHeaders": ["Subject", "From"], + }, + ) + detail_resp.raise_for_status() + msg = detail_resp.json() + hdrs = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])} + results.append({ + "id": msg["id"], + "subject": hdrs.get("Subject", "(no subject)"), + "from": hdrs.get("From", "(unknown)"), + "snippet": msg.get("snippet", ""), + }) + + with tracer.start_as_current_span("format_response") as fmt_span: + fmt_span.set_attribute("gen_ai.system", "mcp") + fmt_span.set_attribute("gen_ai.operation.name", "execute_tool") + fmt_span.set_attribute("gen_ai.tool.name", "format_response") + fmt_span.set_attribute("email.count", len(results)) + + return {"emails": results} + + +@app.tool(requires_auth=Google(scopes=[GMAIL_SEND_SCOPE])) +async def send_email( + context: Context, + to: Annotated[str, "Recipient email address"], + subject: Annotated[str, "Email subject line"], + body: Annotated[str, "Email body text"], +) -> Annotated[dict, "Send confirmation with message ID"]: + """Send an email via the user's Gmail account.""" + token = context.get_auth_token_or_empty() + + with tracer.start_as_current_span("auth.validate") as auth_span: + auth_span.set_attribute("gen_ai.system", "mcp") + auth_span.set_attribute("gen_ai.operation.name", "execute_tool") + auth_span.set_attribute("gen_ai.tool.name", "auth.validate") + auth_span.set_attribute("auth.method", "oauth2_bearer") + + mime = MIMEText(body) + mime["to"] = to + mime["subject"] = subject + raw = base64.urlsafe_b64encode(mime.as_bytes()).decode() + + with tracer.start_as_current_span("gmail.send_message") as send_span: + send_span.set_attribute("gen_ai.system", "mcp") + send_span.set_attribute("gen_ai.operation.name", "execute_tool") + send_span.set_attribute("gen_ai.tool.name", "gmail.send_message") + send_span.set_attribute("gmail.recipient", to) + send_span.set_attribute("gmail.subject", subject) + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{GMAIL_API}/messages/send", + headers={"Authorization": f"Bearer {token}"}, + json={"raw": raw}, + ) + resp.raise_for_status() + result = resp.json() + + with tracer.start_as_current_span("format_response") as fmt_span: + fmt_span.set_attribute("gen_ai.system", "mcp") + fmt_span.set_attribute("gen_ai.operation.name", "execute_tool") + fmt_span.set_attribute("gen_ai.tool.name", "format_response") + fmt_span.set_attribute("email.message_id", result.get("id", "")) + + return {"message_id": result.get("id", ""), "status": "sent"} + + +if __name__ == "__main__": + transport = sys.argv[1] if len(sys.argv) > 1 else "http" + app.run(transport=cast(TransportType, transport), host="127.0.0.1", port=8000) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py index 858b815a..97774f19 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/__init__.py @@ -7,6 +7,7 @@ from arcade_mcp_server.middleware.base import ( ) from arcade_mcp_server.middleware.error_handling import ErrorHandlingMiddleware from arcade_mcp_server.middleware.logging import LoggingMiddleware +from arcade_mcp_server.middleware.telemetry import TelemetryPassbackMiddleware __all__ = [ "CallNext", @@ -14,4 +15,5 @@ __all__ = [ "LoggingMiddleware", "Middleware", "MiddlewareContext", + "TelemetryPassbackMiddleware", ] diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py index 93765d95..460af2f0 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/base.py @@ -74,6 +74,13 @@ class Middleware: function to invoke the next handler in the chain. """ + def get_capabilities(self) -> dict[str, Any]: + """Return extra server capabilities to advertise. + + Override in subclasses to declare capabilities (e.g. serverExecutionTelemetry). + """ + return {} + async def __call__( self, context: MiddlewareContext[T], diff --git a/libs/arcade-mcp-server/arcade_mcp_server/middleware/telemetry.py b/libs/arcade-mcp-server/arcade_mcp_server/middleware/telemetry.py new file mode 100644 index 00000000..ed2406bc --- /dev/null +++ b/libs/arcade-mcp-server/arcade_mcp_server/middleware/telemetry.py @@ -0,0 +1,361 @@ +"""Telemetry passback middleware for MCP server.""" + +from __future__ import annotations + +import json +import logging +import warnings +from contextvars import ContextVar +from typing import Any + +from opentelemetry import context as otel_context +from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, TracerProvider +from opentelemetry.trace import SpanKind, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +from arcade_mcp_server.middleware.base import CallNext, Middleware, MiddlewareContext +from arcade_mcp_server.types import JSONRPCResponse + +logger = logging.getLogger("arcade.mcp.telemetry") + + +# Per-request span collection via ContextVar +_request_spans: ContextVar[list[ReadableSpan] | None] = ContextVar("_request_spans", default=None) + + +class ContextVarSpanCollector(SpanProcessor): + """Collect ended spans into a per-request ContextVar bucket.""" + + def on_start(self, span: Any, parent_context: Any = None) -> None: + pass + + def on_end(self, span: ReadableSpan) -> None: + bucket = _request_spans.get() + if bucket is not None: + bucket.append(span) + + def shutdown(self) -> None: + pass + + def force_flush(self, timeout_millis: int | None = None) -> bool: + return True + + +def start_collecting() -> list[ReadableSpan]: + """Begin collecting spans for the current async context.""" + bucket: list[ReadableSpan] = [] + _request_spans.set(bucket) + return bucket + + +def stop_collecting() -> list[ReadableSpan]: + """Stop collecting and return the collected spans.""" + bucket = _request_spans.get() or [] + _request_spans.set(None) + return bucket + + +# Span filtering +def filter_top_level_spans(spans: list[ReadableSpan]) -> list[ReadableSpan]: + """Keep only the root span and its direct children (depth <= 1).""" + span_ids = {s.context.span_id for s in spans if s.context} + root = None + for s in spans: + if s.parent and s.parent.span_id not in span_ids: + root = s + break + if root is None: + return spans + root_id = root.context.span_id + return [ + s + for s in spans + if s.context.span_id == root_id or (s.parent and s.parent.span_id == root_id) + ] + + +# OTLP JSON serialization + +_SPAN_KIND_MAP = { + SpanKind.INTERNAL: 1, + SpanKind.SERVER: 2, + SpanKind.CLIENT: 3, + SpanKind.PRODUCER: 4, + SpanKind.CONSUMER: 5, +} +_STATUS_CODE_MAP = {StatusCode.UNSET: 0, StatusCode.OK: 1, StatusCode.ERROR: 2} + + +def _ns(ns: int | None) -> str: + return str(ns) if ns is not None else "0" + + +def _attrs_to_kv(attrs: dict[str, Any] | None) -> list[dict[str, Any]]: + if not attrs: + return [] + out: list[dict[str, Any]] = [] + for key, val in attrs.items(): + if isinstance(val, bool): + out.append({"key": key, "value": {"boolValue": val}}) + elif isinstance(val, int): + out.append({"key": key, "value": {"intValue": str(val)}}) + elif isinstance(val, float): + out.append({"key": key, "value": {"doubleValue": val}}) + elif isinstance(val, str): + out.append({"key": key, "value": {"stringValue": val}}) + elif isinstance(val, (list, tuple)): + arr: list[dict[str, Any]] = [] + for v in val: + if isinstance(v, str): + arr.append({"stringValue": v}) + elif isinstance(v, bool): + arr.append({"boolValue": v}) + elif isinstance(v, int): + arr.append({"intValue": str(v)}) + elif isinstance(v, float): + arr.append({"doubleValue": v}) + out.append({"key": key, "value": {"arrayValue": {"values": arr}}}) + else: + out.append({"key": key, "value": {"stringValue": str(val)}}) + return out + + +def spans_to_otlp_json( + spans: list[ReadableSpan], + service_name: str, +) -> dict[str, Any]: + """Convert ReadableSpan objects to OTLP JSON (ExportTraceServiceRequest).""" + otlp_spans: list[dict[str, Any]] = [] + for span in spans: + ctx = span.context + if ctx is None: + continue + rec: dict[str, Any] = { + "traceId": format(ctx.trace_id, "032x"), + "spanId": format(ctx.span_id, "016x"), + "name": span.name, + "kind": _SPAN_KIND_MAP.get(span.kind, 0), + "startTimeUnixNano": _ns(span.start_time), + "endTimeUnixNano": _ns(span.end_time), + "attributes": _attrs_to_kv(dict(span.attributes) if span.attributes else None), + "status": {"code": _STATUS_CODE_MAP.get(span.status.status_code, 0)}, + } + if span.parent and span.parent.span_id: + rec["parentSpanId"] = format(span.parent.span_id, "016x") + if span.events: + rec["events"] = [ + { + "timeUnixNano": _ns(ev.timestamp), + "name": ev.name, + "attributes": _attrs_to_kv(dict(ev.attributes) if ev.attributes else None), + } + for ev in span.events + ] + otlp_spans.append(rec) + + return { + "resourceSpans": [ + { + "resource": { + "attributes": [{"key": "service.name", "value": {"stringValue": service_name}}], + }, + "scopeSpans": [{"scope": {"name": "mcp-telemetry-passback"}, "spans": otlp_spans}], + } + ] + } + + +class TelemetryPassbackMiddleware(Middleware): + """MCP middleware implementing SEP serverExecutionTelemetry. + + .. warning:: **Provisional API** — This middleware implements a draft SEP + (SEP-2448: server execution telemetry) that has not yet been finalized in + the MCP specification. The capability schema, response payload shape, + and constructor signature may change in a future release once the SEP is + ratified. Pin your ``arcade-mcp-server`` version if you need stability. + + Intercepts tools/call and resources/read to: + + 1. Propagate traceparent from the client for distributed tracing. + 2. Collect server-side spans during request execution (per-request + isolation via ContextVar). + 3. Return spans in the response _meta.otel.traces.resourceSpans. + """ + + def __init__(self, service_name: str, tracer_provider: TracerProvider) -> None: + warnings.warn( + "TelemetryPassbackMiddleware implements a draft version of SEP " + "(SEP-2448: server execution telemetry) that is not yet finalized. " + "The API may change in a future release.", + FutureWarning, + stacklevel=2, + ) + self._service_name = service_name + self._tracer = tracer_provider.get_tracer(service_name) + self._propagator = TraceContextTextMapPropagator() + self._collector = ContextVarSpanCollector() + tracer_provider.add_span_processor(self._collector) + + def get_capabilities(self) -> dict[str, Any]: + """Return the serverExecutionTelemetry capability dict.""" + return { + "serverExecutionTelemetry": { + "version": "2026-03-01", + "signals": {"traces": {"supported": True}}, + } + } + + def _extract_otel_meta(self, context: MiddlewareContext[Any]) -> dict[str, Any]: + """Extract traceparent and otel request flags from _meta.""" + mcp_ctx = context.mcp_context + if mcp_ctx is None: + return {} + session = getattr(mcp_ctx, "_session", None) + if session is None: + return {} + meta = getattr(session, "_request_meta", None) + if meta is None: + return {} + + otel_meta = getattr(meta, "otel", None) + traces = otel_meta.get("traces", {}) if isinstance(otel_meta, dict) else {} + + return { + "traceparent": getattr(meta, "traceparent", None), + "request": traces.get("request", False) if isinstance(traces, dict) else False, + "detailed": traces.get("detailed", False) if isinstance(traces, dict) else False, + } + + def _parent_context(self, traceparent: str | None) -> Any: + if traceparent: + return self._propagator.extract(carrier={"traceparent": traceparent}) + return otel_context.get_current() + + def _attach_spans(self, response: Any, collected: list[ReadableSpan], detailed: bool) -> Any: + if not detailed: + filtered = filter_top_level_spans(collected) + dropped = len(collected) - len(filtered) + else: + filtered = collected + dropped = 0 + + payload = { + "traces": { + "resourceSpans": spans_to_otlp_json(filtered, self._service_name).get( + "resourceSpans", [] + ), + "truncated": dropped > 0, + "droppedSpanCount": dropped, + } + } + + if isinstance(response, JSONRPCResponse) and response.result is not None: + result = response.result + if not isinstance(result, dict) and hasattr(result, "meta"): + meta = result.meta or {} + meta["otel"] = payload + result.meta = meta + + return response + + @staticmethod + def _extract_response_text(response: Any) -> str | None: + """Extract text content from a JSONRPCResponse.""" + if not isinstance(response, JSONRPCResponse) or response.result is None: + return None + result = response.result + for attr in ("content", "contents"): + items = getattr(result, attr, None) + if items and isinstance(items, list): + parts = [p for c in items if (p := getattr(c, "text", None))] + if parts: + return "\n".join(str(p) for p in parts) + return None + + async def _handle_with_telemetry( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + span_name: str, + span_attributes: dict[str, Any] | None = None, + tool_arguments: str | None = None, + ) -> Any: + """Common telemetry logic for tools/call and resources/read.""" + otel = self._extract_otel_meta(context) + if not otel.get("request", False): + return await call_next(context) + + start_collecting() + + try: + with self._tracer.start_as_current_span( + span_name, + context=self._parent_context(otel.get("traceparent")), + kind=SpanKind.SERVER, + attributes=span_attributes or {}, + ) as span: + if tool_arguments: + span.set_attribute( + "gen_ai.input.messages", + json.dumps([{"role": "user", "content": tool_arguments}]), + ) + response = await call_next(context) + output = self._extract_response_text(response) + if output: + span.set_attribute("gen_ai.tool.call.result", output[:500]) + span.set_attribute( + "gen_ai.output.messages", + json.dumps([{"role": "assistant", "content": output}]), + ) + finally: + collected = stop_collecting() + + return self._attach_spans(response, collected, otel.get("detailed", False)) + + async def on_call_tool( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Intercept tools/call to collect and return server spans.""" + msg = context.message + params = msg.get("params", {}) if isinstance(msg, dict) else {} + name = params.get("name", "") + arguments = params.get("arguments", {}) + return await self._handle_with_telemetry( + context, + call_next, + span_name=f"tools/call {name}", + span_attributes={ + "mcp.method": "tools/call", + "mcp.tool": name, + "gen_ai.system": "mcp", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.name": name, + "gen_ai.tool.call.arguments": json.dumps(arguments) if arguments else "", + }, + tool_arguments=json.dumps(arguments) if arguments else None, + ) + + async def on_read_resource( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Intercept resources/read to collect and return server spans.""" + msg = context.message + params = msg.get("params", {}) if isinstance(msg, dict) else {} + uri = params.get("uri", "") + return await self._handle_with_telemetry( + context, + call_next, + span_name=f"resources/read {uri}", + span_attributes={ + "mcp.method": "resources/read", + "mcp.resource.uri": uri, + "gen_ai.system": "mcp", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.name": f"resources/read {uri}", + }, + tool_arguments=json.dumps({"uri": uri}) if uri else None, + ) diff --git a/libs/arcade-mcp-server/arcade_mcp_server/server.py b/libs/arcade-mcp-server/arcade_mcp_server/server.py index 7055467c..68407a0c 100644 --- a/libs/arcade-mcp-server/arcade_mcp_server/server.py +++ b/libs/arcade-mcp-server/arcade_mcp_server/server.py @@ -202,6 +202,7 @@ class MCPServer: # Middleware chain self.middleware: list[Middleware] = [] + self._extra_capabilities: dict[str, Any] = {} self._init_middleware(middleware) # Lifespan management @@ -319,6 +320,10 @@ class MCPServer: if custom_middleware: self.middleware.extend(custom_middleware) + # Collect capabilities from middleware that declare them + for mw in self.middleware: + self._extra_capabilities.update(mw.get_capabilities()) + def _register_handlers(self) -> dict[str, Callable]: """Register method handlers.""" return { @@ -681,14 +686,18 @@ class MCPServer: if session: session.set_client_params(message.params) + caps_kwargs: dict[str, Any] = { + "tools": {"listChanged": True}, + "logging": {}, + "prompts": {"listChanged": True}, + "resources": {"subscribe": True, "listChanged": True}, + } + if self._extra_capabilities: + caps_kwargs.update(self._extra_capabilities) + result = InitializeResult( protocolVersion=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities( - tools={"listChanged": True}, - logging={}, - prompts={"listChanged": True}, - resources={"subscribe": True, "listChanged": True}, - ), + capabilities=ServerCapabilities(**caps_kwargs), serverInfo=Implementation( name=self.name, version=self.version, diff --git a/libs/arcade-mcp-server/pyproject.toml b/libs/arcade-mcp-server/pyproject.toml index 05275f2f..6c53c514 100644 --- a/libs/arcade-mcp-server/pyproject.toml +++ b/libs/arcade-mcp-server/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "pydantic-settings>=2.10.1", "joserfc>=1.5.0", "httpx>=0.27.0,<1.0.0", + "opentelemetry-api>=1.20.0", + "opentelemetry-sdk>=1.20.0", ] [project.optional-dependencies] diff --git a/libs/tests/arcade_mcp_server/test_telemetry_middleware.py b/libs/tests/arcade_mcp_server/test_telemetry_middleware.py new file mode 100644 index 00000000..519aed0a --- /dev/null +++ b/libs/tests/arcade_mcp_server/test_telemetry_middleware.py @@ -0,0 +1,529 @@ +"""Tests for Telemetry Passback Middleware.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.trace import SpanKind, StatusCode + +from arcade_mcp_server.middleware.base import MiddlewareContext +from arcade_mcp_server.middleware.telemetry import ( + ContextVarSpanCollector, + TelemetryPassbackMiddleware, + _attrs_to_kv, + _ns, + _request_spans, + filter_top_level_spans, + spans_to_otlp_json, + start_collecting, + stop_collecting, +) +from arcade_mcp_server.types import ( + CallToolResult, + JSONRPCResponse, + ReadResourceResult, + TextContent, + TextResourceContents, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_span( + *, + name: str = "test-span", + trace_id: int = 0xABCD, + span_id: int = 0x1234, + parent_span_id: int | None = None, + kind: SpanKind = SpanKind.INTERNAL, + start_time: int = 1000, + end_time: int = 2000, + status_code: StatusCode = StatusCode.UNSET, + attributes: dict | None = None, + events: list | None = None, +) -> ReadableSpan: + """Build a lightweight mock ReadableSpan.""" + span = MagicMock(spec=ReadableSpan) + ctx = MagicMock() + ctx.trace_id = trace_id + ctx.span_id = span_id + span.context = ctx + span.name = name + span.kind = kind + span.start_time = start_time + span.end_time = end_time + + status = MagicMock() + status.status_code = status_code + span.status = status + + if parent_span_id is not None: + parent = MagicMock() + parent.span_id = parent_span_id + span.parent = parent + else: + span.parent = None + + span.attributes = attributes + span.events = events or [] + return span + + +def _make_tool_response(*, text: str = "hello") -> JSONRPCResponse[CallToolResult]: + """Build a JSONRPCResponse wrapping a CallToolResult.""" + return JSONRPCResponse( + id="1", + result=CallToolResult( + content=[TextContent(type="text", text=text)], + isError=False, + ), + ) + + +def _make_resource_response(*, text: str = "hello") -> JSONRPCResponse[ReadResourceResult]: + """Build a JSONRPCResponse wrapping a ReadResourceResult.""" + return JSONRPCResponse( + id="1", + result=ReadResourceResult( + contents=[TextResourceContents(uri="file:///test", text=text)], + ), + ) + + +# =========================================================================== +# ContextVarSpanCollector +# =========================================================================== + + +class TestContextVarSpanCollector: + """Test ContextVarSpanCollector class.""" + + @pytest.fixture(autouse=True) + def _reset_contextvar(self): + """Ensure the ContextVar is clean before and after each test.""" + _request_spans.set(None) + yield + _request_spans.set(None) + + def test_spans_collected_when_bucket_active(self): + collector = ContextVarSpanCollector() + bucket = start_collecting() + span = _make_span() + collector.on_end(span) + assert span in bucket + + def test_spans_not_collected_when_bucket_inactive(self): + collector = ContextVarSpanCollector() + span = _make_span() + collector.on_end(span) + # No active bucket; span is silently dropped + assert _request_spans.get() is None + + def test_on_start_is_noop(self): + collector = ContextVarSpanCollector() + collector.on_start(MagicMock()) + + def test_shutdown_is_noop(self): + collector = ContextVarSpanCollector() + collector.shutdown() + + def test_force_flush_returns_true(self): + collector = ContextVarSpanCollector() + assert collector.force_flush() is True + + +# =========================================================================== +# start_collecting / stop_collecting +# =========================================================================== + + +class TestCollectingLifecycle: + """Test start_collecting / stop_collecting.""" + + @pytest.fixture(autouse=True) + def _reset_contextvar(self): + _request_spans.set(None) + yield + _request_spans.set(None) + + def test_lifecycle(self): + bucket = start_collecting() + assert bucket == [] + assert _request_spans.get() is bucket + span = _make_span() + bucket.append(span) + result = stop_collecting() + assert span in result + assert _request_spans.get() is None + + def test_idempotent_stop(self): + start_collecting() + stop_collecting() + result = stop_collecting() + assert result == [] + + def test_isolation_between_calls(self): + bucket1 = start_collecting() + bucket1.append(_make_span(name="a")) + stop_collecting() + + bucket2 = start_collecting() + assert bucket2 == [] + + +# =========================================================================== +# filter_top_level_spans +# =========================================================================== + + +class TestFilterTopLevelSpans: + def test_root_and_children_kept_grandchildren_dropped(self): + root = _make_span(name="root", span_id=1, parent_span_id=999) + child = _make_span(name="child", span_id=2, parent_span_id=1) + grandchild = _make_span(name="grandchild", span_id=3, parent_span_id=2) + result = filter_top_level_spans([root, child, grandchild]) + names = {s.name for s in result} + assert names == {"root", "child"} + + def test_empty_list(self): + assert filter_top_level_spans([]) == [] + + def test_single_span_no_parent(self): + span = _make_span(name="only", span_id=1) + result = filter_top_level_spans([span]) + assert result == [span] + + def test_single_span_with_external_parent(self): + span = _make_span(name="only", span_id=1, parent_span_id=999) + result = filter_top_level_spans([span]) + assert len(result) == 1 + assert result[0].name == "only" + + +# =========================================================================== +# _attrs_to_kv +# =========================================================================== + + +class TestAttrsToKv: + def test_bool_value(self): + result = _attrs_to_kv({"flag": True}) + assert result == [{"key": "flag", "value": {"boolValue": True}}] + + def test_int_value(self): + result = _attrs_to_kv({"count": 42}) + assert result == [{"key": "count", "value": {"intValue": "42"}}] + + def test_float_value(self): + result = _attrs_to_kv({"rate": 3.14}) + assert result == [{"key": "rate", "value": {"doubleValue": 3.14}}] + + def test_str_value(self): + result = _attrs_to_kv({"name": "foo"}) + assert result == [{"key": "name", "value": {"stringValue": "foo"}}] + + def test_list_value(self): + result = _attrs_to_kv({"tags": ["a", "b"]}) + assert result == [ + {"key": "tags", "value": {"arrayValue": {"values": [{"stringValue": "a"}, {"stringValue": "b"}]}}} + ] + + def test_fallback_to_string(self): + result = _attrs_to_kv({"obj": object()}) + assert len(result) == 1 + assert "stringValue" in result[0]["value"] + + def test_empty_attrs(self): + assert _attrs_to_kv({}) == [] + + def test_none_attrs(self): + assert _attrs_to_kv(None) == [] + + +# =========================================================================== +# spans_to_otlp_json +# =========================================================================== + + +class TestSpansToOtlpJson: + def test_full_serialization(self): + event = MagicMock() + event.timestamp = 1500 + event.name = "log" + event.attributes = {"msg": "hi"} + + span = _make_span( + name="op", + trace_id=0xABCDEF, + span_id=0x123456, + parent_span_id=0x111, + kind=SpanKind.SERVER, + start_time=1000, + end_time=2000, + status_code=StatusCode.OK, + attributes={"k": "v"}, + events=[event], + ) + + result = spans_to_otlp_json([span], "test-service") + assert "resourceSpans" in result + rs = result["resourceSpans"][0] + assert rs["resource"]["attributes"][0]["value"]["stringValue"] == "test-service" + otlp_span = rs["scopeSpans"][0]["spans"][0] + assert otlp_span["name"] == "op" + assert otlp_span["kind"] == 2 # SERVER + assert otlp_span["status"]["code"] == 1 # OK + assert "parentSpanId" in otlp_span + assert len(otlp_span["events"]) == 1 + + def test_span_without_parent(self): + span = _make_span(name="root", span_id=1) + result = spans_to_otlp_json([span], "svc") + otlp_span = result["resourceSpans"][0]["scopeSpans"][0]["spans"][0] + assert "parentSpanId" not in otlp_span + + def test_span_with_no_context_skipped(self): + span = _make_span() + span.context = None + result = spans_to_otlp_json([span], "svc") + assert result["resourceSpans"][0]["scopeSpans"][0]["spans"] == [] + + def test_status_error(self): + span = _make_span(status_code=StatusCode.ERROR) + result = spans_to_otlp_json([span], "svc") + otlp_span = result["resourceSpans"][0]["scopeSpans"][0]["spans"][0] + assert otlp_span["status"]["code"] == 2 + + +# =========================================================================== +# _ns +# =========================================================================== + + +class TestNs: + def test_zero(self): + assert _ns(0) == "0" + + def test_none(self): + assert _ns(None) == "0" + + def test_positive(self): + assert _ns(123456) == "123456" + + +# =========================================================================== +# TelemetryPassbackMiddleware +# =========================================================================== + + +def _make_otel_context(*, request: bool = True, detailed: bool = False, traceparent: str | None = None): + """Build a MiddlewareContext with otel meta configured.""" + session = MagicMock() + meta = MagicMock() + meta.traceparent = traceparent + meta.otel = {"traces": {"request": request, "detailed": detailed}} + session._request_meta = meta + + mcp_ctx = MagicMock() + mcp_ctx._session = session + return mcp_ctx + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +class TestTelemetryPassbackMiddleware: + """Test TelemetryPassbackMiddleware class.""" + + @pytest.fixture(autouse=True) + def _reset_contextvar(self): + _request_spans.set(None) + yield + _request_spans.set(None) + + @pytest.fixture + def tracer_provider(self): + return TracerProvider() + + @pytest.fixture + def middleware(self, tracer_provider): + return TelemetryPassbackMiddleware( + service_name="test-svc", + tracer_provider=tracer_provider, + ) + + @pytest.fixture + def tool_call_context(self): + """Context for a tools/call with otel request=True.""" + return MiddlewareContext( + message={"method": "tools/call", "params": {"name": "my_tool", "arguments": {"x": 1}}}, + mcp_context=_make_otel_context(request=True, detailed=False), + method="tools/call", + request_id="req-1", + ) + + @pytest.fixture + def tool_call_context_no_request(self): + """Context for a tools/call with otel request=False.""" + return MiddlewareContext( + message={"method": "tools/call", "params": {"name": "my_tool", "arguments": {}}}, + mcp_context=_make_otel_context(request=False), + method="tools/call", + request_id="req-2", + ) + + @pytest.fixture + def resource_read_context(self): + """Context for resources/read with otel request=True.""" + return MiddlewareContext( + message={"method": "resources/read", "params": {"uri": "file:///foo.txt"}}, + mcp_context=_make_otel_context(request=True, detailed=False), + method="resources/read", + request_id="req-3", + ) + + # --- get_capabilities --- + + def test_get_capabilities(self, middleware): + caps = middleware.get_capabilities() + assert "serverExecutionTelemetry" in caps + assert caps["serverExecutionTelemetry"]["signals"]["traces"]["supported"] is True + + # --- on_call_tool: request=False passthrough --- + + @pytest.mark.asyncio + async def test_on_call_tool_no_request_passthrough(self, middleware, tool_call_context_no_request): + sentinel = object() + + async def handler(ctx): + return sentinel + + result = await middleware.on_call_tool(tool_call_context_no_request, handler) + assert result is sentinel + + # --- on_call_tool: request=True attaches spans --- + + @pytest.mark.asyncio + async def test_on_call_tool_with_request_attaches_spans(self, middleware, tool_call_context): + resp = _make_tool_response(text="tool output") + + async def handler(ctx): + return resp + + result = await middleware.on_call_tool(tool_call_context, handler) + meta = result.result.meta + assert meta is not None + assert "otel" in meta + assert "traces" in meta["otel"] + assert "resourceSpans" in meta["otel"]["traces"] + + # --- on_read_resource --- + + @pytest.mark.asyncio + async def test_on_read_resource_attaches_spans(self, middleware, resource_read_context): + resp = _make_resource_response(text="resource content") + + async def handler(ctx): + return resp + + result = await middleware.on_read_resource(resource_read_context, handler) + meta = result.result.meta + assert meta is not None + assert "otel" in meta + assert "traces" in meta["otel"] + + # --- detailed=True vs detailed=False --- + + @pytest.mark.asyncio + async def test_detailed_true_returns_all_spans(self, tracer_provider): + mw = TelemetryPassbackMiddleware(service_name="svc", tracer_provider=tracer_provider) + + context = MiddlewareContext( + message={"method": "tools/call", "params": {"name": "t", "arguments": {}}}, + mcp_context=_make_otel_context(request=True, detailed=True), + method="tools/call", + ) + + resp = _make_tool_response() + + async def handler(ctx): + return resp + + result = await mw.on_call_tool(context, handler) + otel = result.result.meta["otel"]["traces"] + assert otel["truncated"] is False + assert otel["droppedSpanCount"] == 0 + + # --- traceparent propagation --- + + @pytest.mark.asyncio + async def test_traceparent_propagation(self, tracer_provider): + mw = TelemetryPassbackMiddleware(service_name="svc", tracer_provider=tracer_provider) + + context = MiddlewareContext( + message={"method": "tools/call", "params": {"name": "t", "arguments": {}}}, + mcp_context=_make_otel_context( + request=True, + detailed=True, + traceparent="00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + ), + method="tools/call", + ) + + resp = _make_tool_response() + + async def handler(ctx): + return resp + + result = await mw.on_call_tool(context, handler) + assert result.result.meta is not None + assert "otel" in result.result.meta + + # --- error in call_next — spans still cleaned up --- + + @pytest.mark.asyncio + async def test_error_in_call_next_cleans_up_spans(self, middleware, tool_call_context): + async def handler(ctx): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await middleware.on_call_tool(tool_call_context, handler) + + assert _request_spans.get() is None + + # --- _extract_response_text --- + + def test_extract_response_text_with_content(self, middleware): + resp = _make_tool_response(text="hello") + assert middleware._extract_response_text(resp) == "hello" + + def test_extract_response_text_with_contents(self, middleware): + resp = _make_resource_response(text="world") + assert middleware._extract_response_text(resp) == "world" + + def test_extract_response_text_no_text(self, middleware): + # Empty text is falsy, filtered out by _extract_response_text + resp = _make_tool_response(text="") + assert middleware._extract_response_text(resp) is None + + def test_extract_response_text_non_jsonrpc(self, middleware): + assert middleware._extract_response_text({"not": "a response"}) is None + + def test_extract_response_text_none_result(self, middleware): + resp = JSONRPCResponse(id="1", result={}) + assert middleware._extract_response_text(resp) is None + + +# =========================================================================== +# Middleware base class default get_capabilities +# =========================================================================== + + +class TestMiddlewareBaseGetCapabilities: + def test_default_returns_empty_dict(self): + from arcade_mcp_server.middleware.base import Middleware + + mw = Middleware() + assert mw.get_capabilities() == {}