From 2389d89471fa6fa3df1b5378200a2b0a1994269e Mon Sep 17 00:00:00 2001 From: jottakka Date: Tue, 25 Nov 2025 23:00:26 -0300 Subject: [PATCH] [TOO-145]GQL Error Adaptor (#692) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # 🎫 [TOO-145](https://linear.app/arcadedev/issue/TOO-145/adding-gql-error-interceptor-in-tdk) - add GraphQL-specific error adapter, expose it, and slot it ahead of HTTP fallback - map gql query/transport errors (incl. rate-limit handoff) without logging sensitive data - cover adapter + tool chain behavior with new unit tests (pathload, dedupe, edge cases) closes TOO-145 --------- Co-authored-by: Francisco Liberal --- .../arcade_tdk/error_adapters/__init__.py | 2 + .../arcade_tdk/providers/graphql/__init__.py | 3 + .../providers/graphql/error_adapter.py | 186 ++++++++++++ libs/arcade-tdk/arcade_tdk/tool.py | 4 + libs/arcade-tdk/pyproject.toml | 2 +- libs/tests/sdk/test_graphql_adapter.py | 283 ++++++++++++++++++ libs/tests/sdk/test_graphql_tooling.py | 63 ++++ 7 files changed, 542 insertions(+), 1 deletion(-) create mode 100644 libs/arcade-tdk/arcade_tdk/providers/graphql/__init__.py create mode 100644 libs/arcade-tdk/arcade_tdk/providers/graphql/error_adapter.py create mode 100644 libs/tests/sdk/test_graphql_adapter.py create mode 100644 libs/tests/sdk/test_graphql_tooling.py diff --git a/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py b/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py index 2da2a646..06714fb2 100644 --- a/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py +++ b/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py @@ -1,5 +1,6 @@ from arcade_tdk.error_adapters.base import ErrorAdapter from arcade_tdk.providers.google import GoogleErrorAdapter +from arcade_tdk.providers.graphql import GraphQLErrorAdapter from arcade_tdk.providers.http import HTTPErrorAdapter from arcade_tdk.providers.microsoft import MicrosoftGraphErrorAdapter from arcade_tdk.providers.slack import SlackErrorAdapter @@ -7,6 +8,7 @@ from arcade_tdk.providers.slack import SlackErrorAdapter __all__ = [ "ErrorAdapter", "GoogleErrorAdapter", + "GraphQLErrorAdapter", "HTTPErrorAdapter", "MicrosoftGraphErrorAdapter", "SlackErrorAdapter", diff --git a/libs/arcade-tdk/arcade_tdk/providers/graphql/__init__.py b/libs/arcade-tdk/arcade_tdk/providers/graphql/__init__.py new file mode 100644 index 00000000..aef5e2a4 --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/providers/graphql/__init__.py @@ -0,0 +1,3 @@ +from .error_adapter import GraphQLErrorAdapter + +__all__ = ["GraphQLErrorAdapter"] diff --git a/libs/arcade-tdk/arcade_tdk/providers/graphql/error_adapter.py b/libs/arcade-tdk/arcade_tdk/providers/graphql/error_adapter.py new file mode 100644 index 00000000..808cc3ae --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/providers/graphql/error_adapter.py @@ -0,0 +1,186 @@ +import importlib +import logging +from functools import lru_cache +from http import HTTPStatus +from typing import Any + +from arcade_core.errors import ToolRuntimeError, UpstreamError + +from arcade_tdk.providers.http.error_adapter import BaseHTTPErrorMapper + +logger = logging.getLogger(__name__) + +# Standard Apollo/GraphQL error codes mapped to HTTP status codes +_GQL_CODE_TO_STATUS = { + "UNAUTHENTICATED": 401, + "NOT_AUTHENTICATED": 401, + "FORBIDDEN": 403, + "ACCESS_DENIED": 403, + "NOT_FOUND": 404, + "BAD_USER_INPUT": 400, + "GRAPHQL_VALIDATION_FAILED": 400, + "GRAPHQL_PARSE_FAILED": 400, + "INTERNAL_SERVER_ERROR": 500, +} + + +@lru_cache(maxsize=1) +def _load_gql_transport_errors() -> ( + tuple[type[Any], type[Any], type[Any], type[Any], type[Any]] | None +): + """Import gql transport exceptions lazily and cache the result.""" + try: + module = importlib.import_module("gql.transport.exceptions") + except ImportError: + logger.debug("gql not installed; GraphQL adapter disabled") + return None + else: + return ( + module.TransportError, + module.TransportQueryError, + module.TransportServerError, + module.TransportConnectionFailed, + module.TransportProtocolError, + ) + + +def _extract_error_message(message: Any) -> str: + """Return the error message or a fallback.""" + if not message: + return "Unknown GraphQL error" + try: + return str(message) or "Unknown GraphQL error" + except Exception: + return "Unknown GraphQL error" + + +class GraphQLErrorAdapter(BaseHTTPErrorMapper): + """Error adapter for GraphQL clients (specifically 'gql' library).""" + + slug = "_graphql" + + def from_exception(self, exc: Exception) -> ToolRuntimeError | None: + """Translate a gql exception into a ToolRuntimeError.""" + gql_types = _load_gql_transport_errors() + if not gql_types: + return None + + ( + TransportError, + TransportQueryError, + TransportServerError, + TransportConnectionFailed, + TransportProtocolError, + ) = gql_types + + # GraphQL errors in response (HTTP 200 with errors array) + if isinstance(exc, TransportQueryError): + return self._handle_query_error(exc) + + # HTTP-level errors (4xx, 5xx) - these can have rate limit headers + if isinstance(exc, TransportServerError): + return self._handle_transport_error(exc) + + # Network/protocol errors - simple 502 + if isinstance(exc, (TransportConnectionFailed, TransportProtocolError)): + return UpstreamError( + message=f"Upstream GraphQL error: {type(exc).__name__}", + status_code=HTTPStatus.BAD_GATEWAY.value, + developer_message=str(exc), + extra={"service": self.slug, "error_type": type(exc).__name__}, + ) + + # Catch-all for unknown TransportError subclasses + if isinstance(exc, TransportError): + return self._handle_transport_error(exc) + + return None + + def _handle_query_error(self, exc: Any) -> UpstreamError: + """Handle TransportQueryError (GraphQL errors in response body).""" + errors_list = exc.errors or [] + logger.debug("GraphQL query errors: %s", errors_list) + + messages = [_extract_error_message(e.get("message")) for e in errors_list] + joined = "; ".join(messages) if messages else "Unknown GraphQL error" + + # Extract error codes and map to HTTP status + codes: list[str] = [] + status = HTTPStatus.UNPROCESSABLE_ENTITY.value + + for e in errors_list: + ext = e.get("extensions") if isinstance(e, dict) else None + code = ext.get("code") if isinstance(ext, dict) else None + if isinstance(code, str): + codes.append(code) + mapped = _GQL_CODE_TO_STATUS.get(code) + if mapped and mapped > status: + status = mapped + + unique_codes = sorted(set(codes)) + + return UpstreamError( + message=f"Upstream GraphQL error: {joined}", + status_code=status, + developer_message=f"GraphQL error codes: {', '.join(unique_codes)}" + if unique_codes + else "GraphQL error", + extra={ + "service": self.slug, + "error_type": "TransportQueryError", + "gql_error_codes": unique_codes, + }, + ) + + def _handle_transport_error(self, exc: Any) -> UpstreamError: + """Handle TransportServerError and other transport errors.""" + status = getattr(exc, "code", None) + if not isinstance(status, int): + status = HTTPStatus.INTERNAL_SERVER_ERROR.value + + # Extract headers for rate limit detection (check exc and __cause__) + headers = self._get_headers(exc) or self._get_headers(exc.__cause__) + + # Extract URL from __cause__ (aiohttp/httpx/requests store it there) + url, method = self._get_request_info(exc.__cause__) + + return self._map_status_to_error( + status=status, + headers=headers or {}, + msg=f"Upstream GraphQL error: {_extract_error_message(str(exc))}", + request_url=url, + request_method=method, + ) + + def _get_headers(self, obj: Any) -> dict[str, str] | None: + """Extract headers from an object if available.""" + if obj and hasattr(obj, "response") and hasattr(obj.response, "headers"): + return {k.lower(): v for k, v in obj.response.headers.items()} + return None + + def _get_request_info(self, cause: Any) -> tuple[str | None, str | None]: + """Extract URL and method from the __cause__ exception.""" + if not cause: + return None, None + + # aiohttp: request_info.url + if hasattr(cause, "request_info"): + ri = cause.request_info + url = getattr(ri, "url", None) or getattr(ri, "real_url", None) + return (str(url), getattr(ri, "method", None)) if url else (None, None) + + # httpx/requests: response.request.url + if hasattr(cause, "response") and hasattr(cause.response, "request"): + req = cause.response.request + url = getattr(req, "url", None) + return (str(url), getattr(req, "method", None)) if url else (None, None) + + return None, None + + def _build_extra_metadata( + self, request_url: str | None = None, request_method: str | None = None + ) -> dict[str, str]: + """Override to use GraphQL service slug.""" + extra = super()._build_extra_metadata(request_url, request_method) + extra["service"] = self.slug + return extra diff --git a/libs/arcade-tdk/arcade_tdk/tool.py b/libs/arcade-tdk/arcade_tdk/tool.py index 943d6a81..efa4b89d 100644 --- a/libs/arcade-tdk/arcade_tdk/tool.py +++ b/libs/arcade-tdk/arcade_tdk/tool.py @@ -10,6 +10,7 @@ from arcade_tdk.errors import ( FatalToolError, ToolRuntimeError, ) +from arcade_tdk.providers.graphql import GraphQLErrorAdapter from arcade_tdk.providers.http import HTTPErrorAdapter from arcade_tdk.utils import snake_to_pascal_case @@ -52,6 +53,9 @@ def _build_adapter_chain( if auth_adapter := get_adapter_for_auth_provider(auth_provider): adapter_chain.append(auth_adapter) + # Always add GraphQL adapter (it will no-op if gql is not installed) + adapter_chain.append(GraphQLErrorAdapter()) + # Always add HTTP adapter as the final adapter fallback adapter_chain.append(HTTPErrorAdapter()) diff --git a/libs/arcade-tdk/pyproject.toml b/libs/arcade-tdk/pyproject.toml index 9503654f..db639afc 100644 --- a/libs/arcade-tdk/pyproject.toml +++ b/libs/arcade-tdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-tdk" -version = "3.0.1" +version = "3.1.0" description = "Arcade TDK - Toolkit Development Kit for building Arcade tools" readme = "README.md" license = {text = "MIT"} diff --git a/libs/tests/sdk/test_graphql_adapter.py b/libs/tests/sdk/test_graphql_adapter.py new file mode 100644 index 00000000..de6e44dd --- /dev/null +++ b/libs/tests/sdk/test_graphql_adapter.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterator +from http import HTTPStatus +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +from arcade_core.errors import UpstreamError, UpstreamRateLimitError + +LIBS_DIR = Path(__file__).resolve().parents[2] +TDK_SRC = LIBS_DIR / "arcade-tdk" +if str(TDK_SRC) not in sys.path: + sys.path.insert(0, str(TDK_SRC)) + +import arcade_tdk.providers.graphql.error_adapter as gql_adapter # noqa: E402 + +# --- Dummy exception classes for testing --- + + +class DummyTransportError(Exception): + def __init__(self, message: str, code: int | None = None) -> None: + super().__init__(message) + self.code = code + + +class DummyTransportQueryError(Exception): + def __init__(self, errors: list[dict[str, Any]] | None = None) -> None: + super().__init__("query error") + self.errors = errors + + +class DummyResponse: + def __init__(self, headers: dict[str, str] | None = None) -> None: + self.headers = headers or {} + + +class DummyTransportServerError(Exception): + def __init__( + self, message: str, code: int | None = None, headers: dict[str, str] | None = None + ): + super().__init__(message) + self.code = code + if headers is not None: + self.response = DummyResponse(headers) + + +class DummyTransportConnectionFailed(DummyTransportError): + pass + + +class DummyTransportProtocolError(DummyTransportError): + pass + + +@pytest.fixture(autouse=True) +def reset_cache() -> Iterator[None]: + """Clear cached gql import state between tests.""" + gql_adapter._load_gql_transport_errors.cache_clear() + yield + gql_adapter._load_gql_transport_errors.cache_clear() + + +def _patch_loader() -> Any: + """Patch the loader to return our dummy classes.""" + return patch.object( + gql_adapter, + "_load_gql_transport_errors", + return_value=( + DummyTransportError, + DummyTransportQueryError, + DummyTransportServerError, + DummyTransportConnectionFailed, + DummyTransportProtocolError, + ), + ) + + +class TestGraphQLErrorAdapter: + # --- Import/caching tests --- + + def test_skips_when_gql_not_installed(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Should return None and cache the import failure.""" + call_count = {"n": 0} + + def fake_import(name: str) -> None: + call_count["n"] += 1 + raise ImportError("no gql") + + monkeypatch.setattr(gql_adapter.importlib, "import_module", fake_import) + adapter = gql_adapter.GraphQLErrorAdapter() + + assert adapter.from_exception(Exception("x")) is None + assert adapter.from_exception(Exception("y")) is None + assert call_count["n"] == 1 # Only tried once + + def test_ignores_non_gql_exceptions(self) -> None: + """Non-gql exceptions should return None.""" + with _patch_loader(): + adapter = gql_adapter.GraphQLErrorAdapter() + assert adapter.from_exception(RuntimeError("not gql")) is None + + # --- TransportQueryError tests --- + + def test_query_error_extracts_messages_and_codes(self) -> None: + """Should extract messages and map error codes to status.""" + errors = [ + {"message": "Not authorized", "extensions": {"code": "FORBIDDEN"}}, + {"message": "Server error", "extensions": {"code": "INTERNAL_SERVER_ERROR"}}, + ] + exc = DummyTransportQueryError(errors=errors) + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR # Highest mapped status + assert "Not authorized" in result.message + assert "Server error" in result.message + assert result.extra["gql_error_codes"] == ["FORBIDDEN", "INTERNAL_SERVER_ERROR"] + + def test_query_error_defaults_when_empty(self) -> None: + """Should handle empty/missing errors gracefully.""" + exc = DummyTransportQueryError(errors=None) + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert "Unknown GraphQL error" in result.message + + def test_query_error_deduplicates_codes(self) -> None: + """Duplicate error codes should be deduplicated.""" + errors = [ + {"message": "A", "extensions": {"code": "FORBIDDEN"}}, + {"message": "B", "extensions": {"code": "FORBIDDEN"}}, + ] + exc = DummyTransportQueryError(errors=errors) + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert result.extra["gql_error_codes"] == ["FORBIDDEN"] + + # --- TransportServerError tests --- + + def test_server_error_detects_rate_limit(self) -> None: + """Should detect rate limits from status + headers.""" + exc = DummyTransportServerError( + message="Too many requests", + code=429, + headers={"retry-after": "5"}, + ) + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamRateLimitError) + assert result.retry_after_ms == 5000 + + def test_server_error_defaults_to_500(self) -> None: + """Should default to 500 when no status code.""" + exc = DummyTransportServerError("Server error", code=None) + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + def test_server_error_extracts_headers_from_cause(self) -> None: + """Should extract headers from __cause__ if not on exception.""" + exc = DummyTransportServerError("Error", code=429) + # No headers on exc, but on __cause__ + cause = Exception("inner") + cause.response = DummyResponse({"retry-after": "10"}) # type: ignore + exc.__cause__ = cause + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamRateLimitError) + assert result.retry_after_ms == 10000 + + def test_server_error_extracts_url_from_cause_aiohttp(self) -> None: + """Should extract URL from __cause__ (aiohttp pattern).""" + exc = DummyTransportServerError("Error", code=500) + + # aiohttp style: request_info.url + class FakeRequestInfo: + url = "https://api.github.com/graphql" + method = "POST" + + cause = Exception("inner") + cause.request_info = FakeRequestInfo() # type: ignore + exc.__cause__ = cause + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.extra is not None + assert result.extra.get("endpoint") == "https://api.github.com/graphql" + assert result.extra.get("http_method") == "POST" + + def test_server_error_extracts_url_from_cause_httpx(self) -> None: + """Should extract URL from __cause__ (httpx/requests pattern).""" + exc = DummyTransportServerError("Error", code=500) + + # httpx style: response.request.url + class FakeRequest: + url = "https://api.stripe.com/graphql" + method = "POST" + + class FakeResponse: + request = FakeRequest() + + cause = Exception("inner") + cause.response = FakeResponse() # type: ignore + exc.__cause__ = cause + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.extra is not None + assert result.extra.get("endpoint") == "https://api.stripe.com/graphql" + assert result.extra.get("http_method") == "POST" + + # --- Connection/Protocol error tests --- + + def test_connection_failed_returns_502(self) -> None: + """Connection failures should map to 502.""" + exc = DummyTransportConnectionFailed("Connection refused") + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == HTTPStatus.BAD_GATEWAY + assert result.extra["error_type"] == "DummyTransportConnectionFailed" + + def test_protocol_error_returns_502(self) -> None: + """Protocol errors should map to 502.""" + exc = DummyTransportProtocolError("Invalid response") + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == HTTPStatus.BAD_GATEWAY + assert result.extra["error_type"] == "DummyTransportProtocolError" + + # --- Generic TransportError catch-all --- + + def test_generic_transport_error_handled(self) -> None: + """Unknown TransportError subclasses should be caught.""" + exc = DummyTransportError("Unknown error", code=503) + + with _patch_loader(): + result = gql_adapter.GraphQLErrorAdapter().from_exception(exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == 503 + + # --- Edge cases --- + + def test_extract_message_handles_bad_str(self) -> None: + """Should handle objects that fail str().""" + + class BadStr: + def __str__(self) -> str: + raise ValueError("nope") + + assert gql_adapter._extract_error_message(BadStr()) == "Unknown GraphQL error" + + def test_extract_message_handles_empty(self) -> None: + """Should handle empty/None messages.""" + assert gql_adapter._extract_error_message(None) == "Unknown GraphQL error" + assert gql_adapter._extract_error_message("") == "Unknown GraphQL error" diff --git a/libs/tests/sdk/test_graphql_tooling.py b/libs/tests/sdk/test_graphql_tooling.py new file mode 100644 index 00000000..64dc404c --- /dev/null +++ b/libs/tests/sdk/test_graphql_tooling.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import importlib +import sys +from pathlib import Path + +import pytest +from arcade_core.errors import ToolRuntimeError + +LIBS_DIR = Path(__file__).resolve().parents[2] +TDK_SRC = LIBS_DIR / "arcade-tdk" +if str(TDK_SRC) not in sys.path: + sys.path.insert(0, str(TDK_SRC)) + +import arcade_tdk.error_adapters as error_adapters # noqa: E402 +import arcade_tdk.providers.graphql as graphql_provider # noqa: E402 +from arcade_tdk.providers.graphql import GraphQLErrorAdapter # noqa: E402 +from arcade_tdk.providers.http import HTTPErrorAdapter # noqa: E402 + +tool_module = importlib.import_module("arcade_tdk.tool") + + +class DummyAdapter: + slug = "_dummy" + + def from_exception(self, exc: Exception) -> ToolRuntimeError | None: + return None # pragma: no cover - trivial + + +def test_graphql_adapter_is_exported() -> None: + assert "GraphQLErrorAdapter" in graphql_provider.__all__ + assert graphql_provider.GraphQLErrorAdapter is GraphQLErrorAdapter + assert "GraphQLErrorAdapter" in error_adapters.__all__ + + +def test_adapter_chain_appends_graphql_before_http(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: None) + chain = tool_module._build_adapter_chain([], auth_provider=None) + + assert isinstance(chain[-2], GraphQLErrorAdapter) + assert isinstance(chain[-1], HTTPErrorAdapter) + + +def test_adapter_chain_deduplicates_graphql_and_http(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: None) + chain = tool_module._build_adapter_chain( + [GraphQLErrorAdapter(), HTTPErrorAdapter()], auth_provider=None + ) + + types = [type(adapter) for adapter in chain] + assert types.count(GraphQLErrorAdapter) == 1 + assert types.count(HTTPErrorAdapter) == 1 + + +def test_adapter_chain_includes_auth_adapter_before_graphql( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dummy_auth = DummyAdapter() + monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: dummy_auth) + chain = tool_module._build_adapter_chain([], auth_provider=object()) + + assert chain[0] is dummy_auth + assert isinstance(chain[1], GraphQLErrorAdapter)