[TOO-145]GQL Error Adaptor (#692)

# 🎫
[TOO-145](https://linear.app/arcadedev/issue/TOO-145/adding-gql-error-interceptor-in-tdk)
- add GraphQL-specific error adapter, expose it, and slot it ahead of
HTTP fallback
- map gql query/transport errors (incl. rate-limit handoff) without
logging sensitive data
- cover adapter + tool chain behavior with new unit tests (pathload,
dedupe, edge cases)



closes TOO-145

---------

Co-authored-by: Francisco Liberal <francisco@arcade.dev>
This commit is contained in:
jottakka 2025-11-25 23:00:26 -03:00 committed by GitHub
parent 6d8d71b9ac
commit 2389d89471
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 542 additions and 1 deletions

View file

@ -1,5 +1,6 @@
from arcade_tdk.error_adapters.base import ErrorAdapter
from arcade_tdk.providers.google import GoogleErrorAdapter
from arcade_tdk.providers.graphql import GraphQLErrorAdapter
from arcade_tdk.providers.http import HTTPErrorAdapter
from arcade_tdk.providers.microsoft import MicrosoftGraphErrorAdapter
from arcade_tdk.providers.slack import SlackErrorAdapter
@ -7,6 +8,7 @@ from arcade_tdk.providers.slack import SlackErrorAdapter
__all__ = [
"ErrorAdapter",
"GoogleErrorAdapter",
"GraphQLErrorAdapter",
"HTTPErrorAdapter",
"MicrosoftGraphErrorAdapter",
"SlackErrorAdapter",

View file

@ -0,0 +1,3 @@
from .error_adapter import GraphQLErrorAdapter
__all__ = ["GraphQLErrorAdapter"]

View file

@ -0,0 +1,186 @@
import importlib
import logging
from functools import lru_cache
from http import HTTPStatus
from typing import Any
from arcade_core.errors import ToolRuntimeError, UpstreamError
from arcade_tdk.providers.http.error_adapter import BaseHTTPErrorMapper
logger = logging.getLogger(__name__)
# Standard Apollo/GraphQL error codes mapped to HTTP status codes
_GQL_CODE_TO_STATUS = {
"UNAUTHENTICATED": 401,
"NOT_AUTHENTICATED": 401,
"FORBIDDEN": 403,
"ACCESS_DENIED": 403,
"NOT_FOUND": 404,
"BAD_USER_INPUT": 400,
"GRAPHQL_VALIDATION_FAILED": 400,
"GRAPHQL_PARSE_FAILED": 400,
"INTERNAL_SERVER_ERROR": 500,
}
@lru_cache(maxsize=1)
def _load_gql_transport_errors() -> (
tuple[type[Any], type[Any], type[Any], type[Any], type[Any]] | None
):
"""Import gql transport exceptions lazily and cache the result."""
try:
module = importlib.import_module("gql.transport.exceptions")
except ImportError:
logger.debug("gql not installed; GraphQL adapter disabled")
return None
else:
return (
module.TransportError,
module.TransportQueryError,
module.TransportServerError,
module.TransportConnectionFailed,
module.TransportProtocolError,
)
def _extract_error_message(message: Any) -> str:
"""Return the error message or a fallback."""
if not message:
return "Unknown GraphQL error"
try:
return str(message) or "Unknown GraphQL error"
except Exception:
return "Unknown GraphQL error"
class GraphQLErrorAdapter(BaseHTTPErrorMapper):
"""Error adapter for GraphQL clients (specifically 'gql' library)."""
slug = "_graphql"
def from_exception(self, exc: Exception) -> ToolRuntimeError | None:
"""Translate a gql exception into a ToolRuntimeError."""
gql_types = _load_gql_transport_errors()
if not gql_types:
return None
(
TransportError,
TransportQueryError,
TransportServerError,
TransportConnectionFailed,
TransportProtocolError,
) = gql_types
# GraphQL errors in response (HTTP 200 with errors array)
if isinstance(exc, TransportQueryError):
return self._handle_query_error(exc)
# HTTP-level errors (4xx, 5xx) - these can have rate limit headers
if isinstance(exc, TransportServerError):
return self._handle_transport_error(exc)
# Network/protocol errors - simple 502
if isinstance(exc, (TransportConnectionFailed, TransportProtocolError)):
return UpstreamError(
message=f"Upstream GraphQL error: {type(exc).__name__}",
status_code=HTTPStatus.BAD_GATEWAY.value,
developer_message=str(exc),
extra={"service": self.slug, "error_type": type(exc).__name__},
)
# Catch-all for unknown TransportError subclasses
if isinstance(exc, TransportError):
return self._handle_transport_error(exc)
return None
def _handle_query_error(self, exc: Any) -> UpstreamError:
"""Handle TransportQueryError (GraphQL errors in response body)."""
errors_list = exc.errors or []
logger.debug("GraphQL query errors: %s", errors_list)
messages = [_extract_error_message(e.get("message")) for e in errors_list]
joined = "; ".join(messages) if messages else "Unknown GraphQL error"
# Extract error codes and map to HTTP status
codes: list[str] = []
status = HTTPStatus.UNPROCESSABLE_ENTITY.value
for e in errors_list:
ext = e.get("extensions") if isinstance(e, dict) else None
code = ext.get("code") if isinstance(ext, dict) else None
if isinstance(code, str):
codes.append(code)
mapped = _GQL_CODE_TO_STATUS.get(code)
if mapped and mapped > status:
status = mapped
unique_codes = sorted(set(codes))
return UpstreamError(
message=f"Upstream GraphQL error: {joined}",
status_code=status,
developer_message=f"GraphQL error codes: {', '.join(unique_codes)}"
if unique_codes
else "GraphQL error",
extra={
"service": self.slug,
"error_type": "TransportQueryError",
"gql_error_codes": unique_codes,
},
)
def _handle_transport_error(self, exc: Any) -> UpstreamError:
"""Handle TransportServerError and other transport errors."""
status = getattr(exc, "code", None)
if not isinstance(status, int):
status = HTTPStatus.INTERNAL_SERVER_ERROR.value
# Extract headers for rate limit detection (check exc and __cause__)
headers = self._get_headers(exc) or self._get_headers(exc.__cause__)
# Extract URL from __cause__ (aiohttp/httpx/requests store it there)
url, method = self._get_request_info(exc.__cause__)
return self._map_status_to_error(
status=status,
headers=headers or {},
msg=f"Upstream GraphQL error: {_extract_error_message(str(exc))}",
request_url=url,
request_method=method,
)
def _get_headers(self, obj: Any) -> dict[str, str] | None:
"""Extract headers from an object if available."""
if obj and hasattr(obj, "response") and hasattr(obj.response, "headers"):
return {k.lower(): v for k, v in obj.response.headers.items()}
return None
def _get_request_info(self, cause: Any) -> tuple[str | None, str | None]:
"""Extract URL and method from the __cause__ exception."""
if not cause:
return None, None
# aiohttp: request_info.url
if hasattr(cause, "request_info"):
ri = cause.request_info
url = getattr(ri, "url", None) or getattr(ri, "real_url", None)
return (str(url), getattr(ri, "method", None)) if url else (None, None)
# httpx/requests: response.request.url
if hasattr(cause, "response") and hasattr(cause.response, "request"):
req = cause.response.request
url = getattr(req, "url", None)
return (str(url), getattr(req, "method", None)) if url else (None, None)
return None, None
def _build_extra_metadata(
self, request_url: str | None = None, request_method: str | None = None
) -> dict[str, str]:
"""Override to use GraphQL service slug."""
extra = super()._build_extra_metadata(request_url, request_method)
extra["service"] = self.slug
return extra

View file

@ -10,6 +10,7 @@ from arcade_tdk.errors import (
FatalToolError,
ToolRuntimeError,
)
from arcade_tdk.providers.graphql import GraphQLErrorAdapter
from arcade_tdk.providers.http import HTTPErrorAdapter
from arcade_tdk.utils import snake_to_pascal_case
@ -52,6 +53,9 @@ def _build_adapter_chain(
if auth_adapter := get_adapter_for_auth_provider(auth_provider):
adapter_chain.append(auth_adapter)
# Always add GraphQL adapter (it will no-op if gql is not installed)
adapter_chain.append(GraphQLErrorAdapter())
# Always add HTTP adapter as the final adapter fallback
adapter_chain.append(HTTPErrorAdapter())

View file

@ -1,6 +1,6 @@
[project]
name = "arcade-tdk"
version = "3.0.1"
version = "3.1.0"
description = "Arcade TDK - Toolkit Development Kit for building Arcade tools"
readme = "README.md"
license = {text = "MIT"}

View file

@ -0,0 +1,283 @@
from __future__ import annotations
import sys
from collections.abc import Iterator
from http import HTTPStatus
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from arcade_core.errors import UpstreamError, UpstreamRateLimitError
LIBS_DIR = Path(__file__).resolve().parents[2]
TDK_SRC = LIBS_DIR / "arcade-tdk"
if str(TDK_SRC) not in sys.path:
sys.path.insert(0, str(TDK_SRC))
import arcade_tdk.providers.graphql.error_adapter as gql_adapter # noqa: E402
# --- Dummy exception classes for testing ---
class DummyTransportError(Exception):
def __init__(self, message: str, code: int | None = None) -> None:
super().__init__(message)
self.code = code
class DummyTransportQueryError(Exception):
def __init__(self, errors: list[dict[str, Any]] | None = None) -> None:
super().__init__("query error")
self.errors = errors
class DummyResponse:
def __init__(self, headers: dict[str, str] | None = None) -> None:
self.headers = headers or {}
class DummyTransportServerError(Exception):
def __init__(
self, message: str, code: int | None = None, headers: dict[str, str] | None = None
):
super().__init__(message)
self.code = code
if headers is not None:
self.response = DummyResponse(headers)
class DummyTransportConnectionFailed(DummyTransportError):
pass
class DummyTransportProtocolError(DummyTransportError):
pass
@pytest.fixture(autouse=True)
def reset_cache() -> Iterator[None]:
"""Clear cached gql import state between tests."""
gql_adapter._load_gql_transport_errors.cache_clear()
yield
gql_adapter._load_gql_transport_errors.cache_clear()
def _patch_loader() -> Any:
"""Patch the loader to return our dummy classes."""
return patch.object(
gql_adapter,
"_load_gql_transport_errors",
return_value=(
DummyTransportError,
DummyTransportQueryError,
DummyTransportServerError,
DummyTransportConnectionFailed,
DummyTransportProtocolError,
),
)
class TestGraphQLErrorAdapter:
# --- Import/caching tests ---
def test_skips_when_gql_not_installed(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Should return None and cache the import failure."""
call_count = {"n": 0}
def fake_import(name: str) -> None:
call_count["n"] += 1
raise ImportError("no gql")
monkeypatch.setattr(gql_adapter.importlib, "import_module", fake_import)
adapter = gql_adapter.GraphQLErrorAdapter()
assert adapter.from_exception(Exception("x")) is None
assert adapter.from_exception(Exception("y")) is None
assert call_count["n"] == 1 # Only tried once
def test_ignores_non_gql_exceptions(self) -> None:
"""Non-gql exceptions should return None."""
with _patch_loader():
adapter = gql_adapter.GraphQLErrorAdapter()
assert adapter.from_exception(RuntimeError("not gql")) is None
# --- TransportQueryError tests ---
def test_query_error_extracts_messages_and_codes(self) -> None:
"""Should extract messages and map error codes to status."""
errors = [
{"message": "Not authorized", "extensions": {"code": "FORBIDDEN"}},
{"message": "Server error", "extensions": {"code": "INTERNAL_SERVER_ERROR"}},
]
exc = DummyTransportQueryError(errors=errors)
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR # Highest mapped status
assert "Not authorized" in result.message
assert "Server error" in result.message
assert result.extra["gql_error_codes"] == ["FORBIDDEN", "INTERNAL_SERVER_ERROR"]
def test_query_error_defaults_when_empty(self) -> None:
"""Should handle empty/missing errors gracefully."""
exc = DummyTransportQueryError(errors=None)
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert "Unknown GraphQL error" in result.message
def test_query_error_deduplicates_codes(self) -> None:
"""Duplicate error codes should be deduplicated."""
errors = [
{"message": "A", "extensions": {"code": "FORBIDDEN"}},
{"message": "B", "extensions": {"code": "FORBIDDEN"}},
]
exc = DummyTransportQueryError(errors=errors)
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert result.extra["gql_error_codes"] == ["FORBIDDEN"]
# --- TransportServerError tests ---
def test_server_error_detects_rate_limit(self) -> None:
"""Should detect rate limits from status + headers."""
exc = DummyTransportServerError(
message="Too many requests",
code=429,
headers={"retry-after": "5"},
)
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamRateLimitError)
assert result.retry_after_ms == 5000
def test_server_error_defaults_to_500(self) -> None:
"""Should default to 500 when no status code."""
exc = DummyTransportServerError("Server error", code=None)
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
def test_server_error_extracts_headers_from_cause(self) -> None:
"""Should extract headers from __cause__ if not on exception."""
exc = DummyTransportServerError("Error", code=429)
# No headers on exc, but on __cause__
cause = Exception("inner")
cause.response = DummyResponse({"retry-after": "10"}) # type: ignore
exc.__cause__ = cause
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamRateLimitError)
assert result.retry_after_ms == 10000
def test_server_error_extracts_url_from_cause_aiohttp(self) -> None:
"""Should extract URL from __cause__ (aiohttp pattern)."""
exc = DummyTransportServerError("Error", code=500)
# aiohttp style: request_info.url
class FakeRequestInfo:
url = "https://api.github.com/graphql"
method = "POST"
cause = Exception("inner")
cause.request_info = FakeRequestInfo() # type: ignore
exc.__cause__ = cause
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.extra is not None
assert result.extra.get("endpoint") == "https://api.github.com/graphql"
assert result.extra.get("http_method") == "POST"
def test_server_error_extracts_url_from_cause_httpx(self) -> None:
"""Should extract URL from __cause__ (httpx/requests pattern)."""
exc = DummyTransportServerError("Error", code=500)
# httpx style: response.request.url
class FakeRequest:
url = "https://api.stripe.com/graphql"
method = "POST"
class FakeResponse:
request = FakeRequest()
cause = Exception("inner")
cause.response = FakeResponse() # type: ignore
exc.__cause__ = cause
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.extra is not None
assert result.extra.get("endpoint") == "https://api.stripe.com/graphql"
assert result.extra.get("http_method") == "POST"
# --- Connection/Protocol error tests ---
def test_connection_failed_returns_502(self) -> None:
"""Connection failures should map to 502."""
exc = DummyTransportConnectionFailed("Connection refused")
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.status_code == HTTPStatus.BAD_GATEWAY
assert result.extra["error_type"] == "DummyTransportConnectionFailed"
def test_protocol_error_returns_502(self) -> None:
"""Protocol errors should map to 502."""
exc = DummyTransportProtocolError("Invalid response")
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.status_code == HTTPStatus.BAD_GATEWAY
assert result.extra["error_type"] == "DummyTransportProtocolError"
# --- Generic TransportError catch-all ---
def test_generic_transport_error_handled(self) -> None:
"""Unknown TransportError subclasses should be caught."""
exc = DummyTransportError("Unknown error", code=503)
with _patch_loader():
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
assert isinstance(result, UpstreamError)
assert result.status_code == 503
# --- Edge cases ---
def test_extract_message_handles_bad_str(self) -> None:
"""Should handle objects that fail str()."""
class BadStr:
def __str__(self) -> str:
raise ValueError("nope")
assert gql_adapter._extract_error_message(BadStr()) == "Unknown GraphQL error"
def test_extract_message_handles_empty(self) -> None:
"""Should handle empty/None messages."""
assert gql_adapter._extract_error_message(None) == "Unknown GraphQL error"
assert gql_adapter._extract_error_message("") == "Unknown GraphQL error"

View file

@ -0,0 +1,63 @@
from __future__ import annotations
import importlib
import sys
from pathlib import Path
import pytest
from arcade_core.errors import ToolRuntimeError
LIBS_DIR = Path(__file__).resolve().parents[2]
TDK_SRC = LIBS_DIR / "arcade-tdk"
if str(TDK_SRC) not in sys.path:
sys.path.insert(0, str(TDK_SRC))
import arcade_tdk.error_adapters as error_adapters # noqa: E402
import arcade_tdk.providers.graphql as graphql_provider # noqa: E402
from arcade_tdk.providers.graphql import GraphQLErrorAdapter # noqa: E402
from arcade_tdk.providers.http import HTTPErrorAdapter # noqa: E402
tool_module = importlib.import_module("arcade_tdk.tool")
class DummyAdapter:
slug = "_dummy"
def from_exception(self, exc: Exception) -> ToolRuntimeError | None:
return None # pragma: no cover - trivial
def test_graphql_adapter_is_exported() -> None:
assert "GraphQLErrorAdapter" in graphql_provider.__all__
assert graphql_provider.GraphQLErrorAdapter is GraphQLErrorAdapter
assert "GraphQLErrorAdapter" in error_adapters.__all__
def test_adapter_chain_appends_graphql_before_http(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: None)
chain = tool_module._build_adapter_chain([], auth_provider=None)
assert isinstance(chain[-2], GraphQLErrorAdapter)
assert isinstance(chain[-1], HTTPErrorAdapter)
def test_adapter_chain_deduplicates_graphql_and_http(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: None)
chain = tool_module._build_adapter_chain(
[GraphQLErrorAdapter(), HTTPErrorAdapter()], auth_provider=None
)
types = [type(adapter) for adapter in chain]
assert types.count(GraphQLErrorAdapter) == 1
assert types.count(HTTPErrorAdapter) == 1
def test_adapter_chain_includes_auth_adapter_before_graphql(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dummy_auth = DummyAdapter()
monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: dummy_auth)
chain = tool_module._build_adapter_chain([], auth_provider=object())
assert chain[0] is dummy_auth
assert isinstance(chain[1], GraphQLErrorAdapter)