[TOO-145]GQL Error Adaptor (#692)
# 🎫 [TOO-145](https://linear.app/arcadedev/issue/TOO-145/adding-gql-error-interceptor-in-tdk) - add GraphQL-specific error adapter, expose it, and slot it ahead of HTTP fallback - map gql query/transport errors (incl. rate-limit handoff) without logging sensitive data - cover adapter + tool chain behavior with new unit tests (pathload, dedupe, edge cases) closes TOO-145 --------- Co-authored-by: Francisco Liberal <francisco@arcade.dev>
This commit is contained in:
parent
6d8d71b9ac
commit
2389d89471
7 changed files with 542 additions and 1 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from arcade_tdk.error_adapters.base import ErrorAdapter
|
||||
from arcade_tdk.providers.google import GoogleErrorAdapter
|
||||
from arcade_tdk.providers.graphql import GraphQLErrorAdapter
|
||||
from arcade_tdk.providers.http import HTTPErrorAdapter
|
||||
from arcade_tdk.providers.microsoft import MicrosoftGraphErrorAdapter
|
||||
from arcade_tdk.providers.slack import SlackErrorAdapter
|
||||
|
|
@ -7,6 +8,7 @@ from arcade_tdk.providers.slack import SlackErrorAdapter
|
|||
__all__ = [
|
||||
"ErrorAdapter",
|
||||
"GoogleErrorAdapter",
|
||||
"GraphQLErrorAdapter",
|
||||
"HTTPErrorAdapter",
|
||||
"MicrosoftGraphErrorAdapter",
|
||||
"SlackErrorAdapter",
|
||||
|
|
|
|||
3
libs/arcade-tdk/arcade_tdk/providers/graphql/__init__.py
Normal file
3
libs/arcade-tdk/arcade_tdk/providers/graphql/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .error_adapter import GraphQLErrorAdapter
|
||||
|
||||
__all__ = ["GraphQLErrorAdapter"]
|
||||
186
libs/arcade-tdk/arcade_tdk/providers/graphql/error_adapter.py
Normal file
186
libs/arcade-tdk/arcade_tdk/providers/graphql/error_adapter.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
import importlib
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from arcade_core.errors import ToolRuntimeError, UpstreamError
|
||||
|
||||
from arcade_tdk.providers.http.error_adapter import BaseHTTPErrorMapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Standard Apollo/GraphQL error codes mapped to HTTP status codes
|
||||
_GQL_CODE_TO_STATUS = {
|
||||
"UNAUTHENTICATED": 401,
|
||||
"NOT_AUTHENTICATED": 401,
|
||||
"FORBIDDEN": 403,
|
||||
"ACCESS_DENIED": 403,
|
||||
"NOT_FOUND": 404,
|
||||
"BAD_USER_INPUT": 400,
|
||||
"GRAPHQL_VALIDATION_FAILED": 400,
|
||||
"GRAPHQL_PARSE_FAILED": 400,
|
||||
"INTERNAL_SERVER_ERROR": 500,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_gql_transport_errors() -> (
|
||||
tuple[type[Any], type[Any], type[Any], type[Any], type[Any]] | None
|
||||
):
|
||||
"""Import gql transport exceptions lazily and cache the result."""
|
||||
try:
|
||||
module = importlib.import_module("gql.transport.exceptions")
|
||||
except ImportError:
|
||||
logger.debug("gql not installed; GraphQL adapter disabled")
|
||||
return None
|
||||
else:
|
||||
return (
|
||||
module.TransportError,
|
||||
module.TransportQueryError,
|
||||
module.TransportServerError,
|
||||
module.TransportConnectionFailed,
|
||||
module.TransportProtocolError,
|
||||
)
|
||||
|
||||
|
||||
def _extract_error_message(message: Any) -> str:
|
||||
"""Return the error message or a fallback."""
|
||||
if not message:
|
||||
return "Unknown GraphQL error"
|
||||
try:
|
||||
return str(message) or "Unknown GraphQL error"
|
||||
except Exception:
|
||||
return "Unknown GraphQL error"
|
||||
|
||||
|
||||
class GraphQLErrorAdapter(BaseHTTPErrorMapper):
|
||||
"""Error adapter for GraphQL clients (specifically 'gql' library)."""
|
||||
|
||||
slug = "_graphql"
|
||||
|
||||
def from_exception(self, exc: Exception) -> ToolRuntimeError | None:
|
||||
"""Translate a gql exception into a ToolRuntimeError."""
|
||||
gql_types = _load_gql_transport_errors()
|
||||
if not gql_types:
|
||||
return None
|
||||
|
||||
(
|
||||
TransportError,
|
||||
TransportQueryError,
|
||||
TransportServerError,
|
||||
TransportConnectionFailed,
|
||||
TransportProtocolError,
|
||||
) = gql_types
|
||||
|
||||
# GraphQL errors in response (HTTP 200 with errors array)
|
||||
if isinstance(exc, TransportQueryError):
|
||||
return self._handle_query_error(exc)
|
||||
|
||||
# HTTP-level errors (4xx, 5xx) - these can have rate limit headers
|
||||
if isinstance(exc, TransportServerError):
|
||||
return self._handle_transport_error(exc)
|
||||
|
||||
# Network/protocol errors - simple 502
|
||||
if isinstance(exc, (TransportConnectionFailed, TransportProtocolError)):
|
||||
return UpstreamError(
|
||||
message=f"Upstream GraphQL error: {type(exc).__name__}",
|
||||
status_code=HTTPStatus.BAD_GATEWAY.value,
|
||||
developer_message=str(exc),
|
||||
extra={"service": self.slug, "error_type": type(exc).__name__},
|
||||
)
|
||||
|
||||
# Catch-all for unknown TransportError subclasses
|
||||
if isinstance(exc, TransportError):
|
||||
return self._handle_transport_error(exc)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_query_error(self, exc: Any) -> UpstreamError:
|
||||
"""Handle TransportQueryError (GraphQL errors in response body)."""
|
||||
errors_list = exc.errors or []
|
||||
logger.debug("GraphQL query errors: %s", errors_list)
|
||||
|
||||
messages = [_extract_error_message(e.get("message")) for e in errors_list]
|
||||
joined = "; ".join(messages) if messages else "Unknown GraphQL error"
|
||||
|
||||
# Extract error codes and map to HTTP status
|
||||
codes: list[str] = []
|
||||
status = HTTPStatus.UNPROCESSABLE_ENTITY.value
|
||||
|
||||
for e in errors_list:
|
||||
ext = e.get("extensions") if isinstance(e, dict) else None
|
||||
code = ext.get("code") if isinstance(ext, dict) else None
|
||||
if isinstance(code, str):
|
||||
codes.append(code)
|
||||
mapped = _GQL_CODE_TO_STATUS.get(code)
|
||||
if mapped and mapped > status:
|
||||
status = mapped
|
||||
|
||||
unique_codes = sorted(set(codes))
|
||||
|
||||
return UpstreamError(
|
||||
message=f"Upstream GraphQL error: {joined}",
|
||||
status_code=status,
|
||||
developer_message=f"GraphQL error codes: {', '.join(unique_codes)}"
|
||||
if unique_codes
|
||||
else "GraphQL error",
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "TransportQueryError",
|
||||
"gql_error_codes": unique_codes,
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_transport_error(self, exc: Any) -> UpstreamError:
|
||||
"""Handle TransportServerError and other transport errors."""
|
||||
status = getattr(exc, "code", None)
|
||||
if not isinstance(status, int):
|
||||
status = HTTPStatus.INTERNAL_SERVER_ERROR.value
|
||||
|
||||
# Extract headers for rate limit detection (check exc and __cause__)
|
||||
headers = self._get_headers(exc) or self._get_headers(exc.__cause__)
|
||||
|
||||
# Extract URL from __cause__ (aiohttp/httpx/requests store it there)
|
||||
url, method = self._get_request_info(exc.__cause__)
|
||||
|
||||
return self._map_status_to_error(
|
||||
status=status,
|
||||
headers=headers or {},
|
||||
msg=f"Upstream GraphQL error: {_extract_error_message(str(exc))}",
|
||||
request_url=url,
|
||||
request_method=method,
|
||||
)
|
||||
|
||||
def _get_headers(self, obj: Any) -> dict[str, str] | None:
|
||||
"""Extract headers from an object if available."""
|
||||
if obj and hasattr(obj, "response") and hasattr(obj.response, "headers"):
|
||||
return {k.lower(): v for k, v in obj.response.headers.items()}
|
||||
return None
|
||||
|
||||
def _get_request_info(self, cause: Any) -> tuple[str | None, str | None]:
|
||||
"""Extract URL and method from the __cause__ exception."""
|
||||
if not cause:
|
||||
return None, None
|
||||
|
||||
# aiohttp: request_info.url
|
||||
if hasattr(cause, "request_info"):
|
||||
ri = cause.request_info
|
||||
url = getattr(ri, "url", None) or getattr(ri, "real_url", None)
|
||||
return (str(url), getattr(ri, "method", None)) if url else (None, None)
|
||||
|
||||
# httpx/requests: response.request.url
|
||||
if hasattr(cause, "response") and hasattr(cause.response, "request"):
|
||||
req = cause.response.request
|
||||
url = getattr(req, "url", None)
|
||||
return (str(url), getattr(req, "method", None)) if url else (None, None)
|
||||
|
||||
return None, None
|
||||
|
||||
def _build_extra_metadata(
|
||||
self, request_url: str | None = None, request_method: str | None = None
|
||||
) -> dict[str, str]:
|
||||
"""Override to use GraphQL service slug."""
|
||||
extra = super()._build_extra_metadata(request_url, request_method)
|
||||
extra["service"] = self.slug
|
||||
return extra
|
||||
|
|
@ -10,6 +10,7 @@ from arcade_tdk.errors import (
|
|||
FatalToolError,
|
||||
ToolRuntimeError,
|
||||
)
|
||||
from arcade_tdk.providers.graphql import GraphQLErrorAdapter
|
||||
from arcade_tdk.providers.http import HTTPErrorAdapter
|
||||
from arcade_tdk.utils import snake_to_pascal_case
|
||||
|
||||
|
|
@ -52,6 +53,9 @@ def _build_adapter_chain(
|
|||
if auth_adapter := get_adapter_for_auth_provider(auth_provider):
|
||||
adapter_chain.append(auth_adapter)
|
||||
|
||||
# Always add GraphQL adapter (it will no-op if gql is not installed)
|
||||
adapter_chain.append(GraphQLErrorAdapter())
|
||||
|
||||
# Always add HTTP adapter as the final adapter fallback
|
||||
adapter_chain.append(HTTPErrorAdapter())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-tdk"
|
||||
version = "3.0.1"
|
||||
version = "3.1.0"
|
||||
description = "Arcade TDK - Toolkit Development Kit for building Arcade tools"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
|
|||
283
libs/tests/sdk/test_graphql_adapter.py
Normal file
283
libs/tests/sdk/test_graphql_adapter.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from arcade_core.errors import UpstreamError, UpstreamRateLimitError
|
||||
|
||||
LIBS_DIR = Path(__file__).resolve().parents[2]
|
||||
TDK_SRC = LIBS_DIR / "arcade-tdk"
|
||||
if str(TDK_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(TDK_SRC))
|
||||
|
||||
import arcade_tdk.providers.graphql.error_adapter as gql_adapter # noqa: E402
|
||||
|
||||
# --- Dummy exception classes for testing ---
|
||||
|
||||
|
||||
class DummyTransportError(Exception):
|
||||
def __init__(self, message: str, code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
class DummyTransportQueryError(Exception):
|
||||
def __init__(self, errors: list[dict[str, Any]] | None = None) -> None:
|
||||
super().__init__("query error")
|
||||
self.errors = errors
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
def __init__(self, headers: dict[str, str] | None = None) -> None:
|
||||
self.headers = headers or {}
|
||||
|
||||
|
||||
class DummyTransportServerError(Exception):
|
||||
def __init__(
|
||||
self, message: str, code: int | None = None, headers: dict[str, str] | None = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
if headers is not None:
|
||||
self.response = DummyResponse(headers)
|
||||
|
||||
|
||||
class DummyTransportConnectionFailed(DummyTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class DummyTransportProtocolError(DummyTransportError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cache() -> Iterator[None]:
|
||||
"""Clear cached gql import state between tests."""
|
||||
gql_adapter._load_gql_transport_errors.cache_clear()
|
||||
yield
|
||||
gql_adapter._load_gql_transport_errors.cache_clear()
|
||||
|
||||
|
||||
def _patch_loader() -> Any:
|
||||
"""Patch the loader to return our dummy classes."""
|
||||
return patch.object(
|
||||
gql_adapter,
|
||||
"_load_gql_transport_errors",
|
||||
return_value=(
|
||||
DummyTransportError,
|
||||
DummyTransportQueryError,
|
||||
DummyTransportServerError,
|
||||
DummyTransportConnectionFailed,
|
||||
DummyTransportProtocolError,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestGraphQLErrorAdapter:
|
||||
# --- Import/caching tests ---
|
||||
|
||||
def test_skips_when_gql_not_installed(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Should return None and cache the import failure."""
|
||||
call_count = {"n": 0}
|
||||
|
||||
def fake_import(name: str) -> None:
|
||||
call_count["n"] += 1
|
||||
raise ImportError("no gql")
|
||||
|
||||
monkeypatch.setattr(gql_adapter.importlib, "import_module", fake_import)
|
||||
adapter = gql_adapter.GraphQLErrorAdapter()
|
||||
|
||||
assert adapter.from_exception(Exception("x")) is None
|
||||
assert adapter.from_exception(Exception("y")) is None
|
||||
assert call_count["n"] == 1 # Only tried once
|
||||
|
||||
def test_ignores_non_gql_exceptions(self) -> None:
|
||||
"""Non-gql exceptions should return None."""
|
||||
with _patch_loader():
|
||||
adapter = gql_adapter.GraphQLErrorAdapter()
|
||||
assert adapter.from_exception(RuntimeError("not gql")) is None
|
||||
|
||||
# --- TransportQueryError tests ---
|
||||
|
||||
def test_query_error_extracts_messages_and_codes(self) -> None:
|
||||
"""Should extract messages and map error codes to status."""
|
||||
errors = [
|
||||
{"message": "Not authorized", "extensions": {"code": "FORBIDDEN"}},
|
||||
{"message": "Server error", "extensions": {"code": "INTERNAL_SERVER_ERROR"}},
|
||||
]
|
||||
exc = DummyTransportQueryError(errors=errors)
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR # Highest mapped status
|
||||
assert "Not authorized" in result.message
|
||||
assert "Server error" in result.message
|
||||
assert result.extra["gql_error_codes"] == ["FORBIDDEN", "INTERNAL_SERVER_ERROR"]
|
||||
|
||||
def test_query_error_defaults_when_empty(self) -> None:
|
||||
"""Should handle empty/missing errors gracefully."""
|
||||
exc = DummyTransportQueryError(errors=None)
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
|
||||
assert "Unknown GraphQL error" in result.message
|
||||
|
||||
def test_query_error_deduplicates_codes(self) -> None:
|
||||
"""Duplicate error codes should be deduplicated."""
|
||||
errors = [
|
||||
{"message": "A", "extensions": {"code": "FORBIDDEN"}},
|
||||
{"message": "B", "extensions": {"code": "FORBIDDEN"}},
|
||||
]
|
||||
exc = DummyTransportQueryError(errors=errors)
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert result.extra["gql_error_codes"] == ["FORBIDDEN"]
|
||||
|
||||
# --- TransportServerError tests ---
|
||||
|
||||
def test_server_error_detects_rate_limit(self) -> None:
|
||||
"""Should detect rate limits from status + headers."""
|
||||
exc = DummyTransportServerError(
|
||||
message="Too many requests",
|
||||
code=429,
|
||||
headers={"retry-after": "5"},
|
||||
)
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamRateLimitError)
|
||||
assert result.retry_after_ms == 5000
|
||||
|
||||
def test_server_error_defaults_to_500(self) -> None:
|
||||
"""Should default to 500 when no status code."""
|
||||
exc = DummyTransportServerError("Server error", code=None)
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
def test_server_error_extracts_headers_from_cause(self) -> None:
|
||||
"""Should extract headers from __cause__ if not on exception."""
|
||||
exc = DummyTransportServerError("Error", code=429)
|
||||
# No headers on exc, but on __cause__
|
||||
cause = Exception("inner")
|
||||
cause.response = DummyResponse({"retry-after": "10"}) # type: ignore
|
||||
exc.__cause__ = cause
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamRateLimitError)
|
||||
assert result.retry_after_ms == 10000
|
||||
|
||||
def test_server_error_extracts_url_from_cause_aiohttp(self) -> None:
|
||||
"""Should extract URL from __cause__ (aiohttp pattern)."""
|
||||
exc = DummyTransportServerError("Error", code=500)
|
||||
|
||||
# aiohttp style: request_info.url
|
||||
class FakeRequestInfo:
|
||||
url = "https://api.github.com/graphql"
|
||||
method = "POST"
|
||||
|
||||
cause = Exception("inner")
|
||||
cause.request_info = FakeRequestInfo() # type: ignore
|
||||
exc.__cause__ = cause
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.extra is not None
|
||||
assert result.extra.get("endpoint") == "https://api.github.com/graphql"
|
||||
assert result.extra.get("http_method") == "POST"
|
||||
|
||||
def test_server_error_extracts_url_from_cause_httpx(self) -> None:
|
||||
"""Should extract URL from __cause__ (httpx/requests pattern)."""
|
||||
exc = DummyTransportServerError("Error", code=500)
|
||||
|
||||
# httpx style: response.request.url
|
||||
class FakeRequest:
|
||||
url = "https://api.stripe.com/graphql"
|
||||
method = "POST"
|
||||
|
||||
class FakeResponse:
|
||||
request = FakeRequest()
|
||||
|
||||
cause = Exception("inner")
|
||||
cause.response = FakeResponse() # type: ignore
|
||||
exc.__cause__ = cause
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.extra is not None
|
||||
assert result.extra.get("endpoint") == "https://api.stripe.com/graphql"
|
||||
assert result.extra.get("http_method") == "POST"
|
||||
|
||||
# --- Connection/Protocol error tests ---
|
||||
|
||||
def test_connection_failed_returns_502(self) -> None:
|
||||
"""Connection failures should map to 502."""
|
||||
exc = DummyTransportConnectionFailed("Connection refused")
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == HTTPStatus.BAD_GATEWAY
|
||||
assert result.extra["error_type"] == "DummyTransportConnectionFailed"
|
||||
|
||||
def test_protocol_error_returns_502(self) -> None:
|
||||
"""Protocol errors should map to 502."""
|
||||
exc = DummyTransportProtocolError("Invalid response")
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == HTTPStatus.BAD_GATEWAY
|
||||
assert result.extra["error_type"] == "DummyTransportProtocolError"
|
||||
|
||||
# --- Generic TransportError catch-all ---
|
||||
|
||||
def test_generic_transport_error_handled(self) -> None:
|
||||
"""Unknown TransportError subclasses should be caught."""
|
||||
exc = DummyTransportError("Unknown error", code=503)
|
||||
|
||||
with _patch_loader():
|
||||
result = gql_adapter.GraphQLErrorAdapter().from_exception(exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 503
|
||||
|
||||
# --- Edge cases ---
|
||||
|
||||
def test_extract_message_handles_bad_str(self) -> None:
|
||||
"""Should handle objects that fail str()."""
|
||||
|
||||
class BadStr:
|
||||
def __str__(self) -> str:
|
||||
raise ValueError("nope")
|
||||
|
||||
assert gql_adapter._extract_error_message(BadStr()) == "Unknown GraphQL error"
|
||||
|
||||
def test_extract_message_handles_empty(self) -> None:
|
||||
"""Should handle empty/None messages."""
|
||||
assert gql_adapter._extract_error_message(None) == "Unknown GraphQL error"
|
||||
assert gql_adapter._extract_error_message("") == "Unknown GraphQL error"
|
||||
63
libs/tests/sdk/test_graphql_tooling.py
Normal file
63
libs/tests/sdk/test_graphql_tooling.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from arcade_core.errors import ToolRuntimeError
|
||||
|
||||
LIBS_DIR = Path(__file__).resolve().parents[2]
|
||||
TDK_SRC = LIBS_DIR / "arcade-tdk"
|
||||
if str(TDK_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(TDK_SRC))
|
||||
|
||||
import arcade_tdk.error_adapters as error_adapters # noqa: E402
|
||||
import arcade_tdk.providers.graphql as graphql_provider # noqa: E402
|
||||
from arcade_tdk.providers.graphql import GraphQLErrorAdapter # noqa: E402
|
||||
from arcade_tdk.providers.http import HTTPErrorAdapter # noqa: E402
|
||||
|
||||
tool_module = importlib.import_module("arcade_tdk.tool")
|
||||
|
||||
|
||||
class DummyAdapter:
|
||||
slug = "_dummy"
|
||||
|
||||
def from_exception(self, exc: Exception) -> ToolRuntimeError | None:
|
||||
return None # pragma: no cover - trivial
|
||||
|
||||
|
||||
def test_graphql_adapter_is_exported() -> None:
|
||||
assert "GraphQLErrorAdapter" in graphql_provider.__all__
|
||||
assert graphql_provider.GraphQLErrorAdapter is GraphQLErrorAdapter
|
||||
assert "GraphQLErrorAdapter" in error_adapters.__all__
|
||||
|
||||
|
||||
def test_adapter_chain_appends_graphql_before_http(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: None)
|
||||
chain = tool_module._build_adapter_chain([], auth_provider=None)
|
||||
|
||||
assert isinstance(chain[-2], GraphQLErrorAdapter)
|
||||
assert isinstance(chain[-1], HTTPErrorAdapter)
|
||||
|
||||
|
||||
def test_adapter_chain_deduplicates_graphql_and_http(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: None)
|
||||
chain = tool_module._build_adapter_chain(
|
||||
[GraphQLErrorAdapter(), HTTPErrorAdapter()], auth_provider=None
|
||||
)
|
||||
|
||||
types = [type(adapter) for adapter in chain]
|
||||
assert types.count(GraphQLErrorAdapter) == 1
|
||||
assert types.count(HTTPErrorAdapter) == 1
|
||||
|
||||
|
||||
def test_adapter_chain_includes_auth_adapter_before_graphql(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dummy_auth = DummyAdapter()
|
||||
monkeypatch.setattr(tool_module, "get_adapter_for_auth_provider", lambda auth: dummy_auth)
|
||||
chain = tool_module._build_adapter_chain([], auth_provider=object())
|
||||
|
||||
assert chain[0] is dummy_auth
|
||||
assert isinstance(chain[1], GraphQLErrorAdapter)
|
||||
Loading…
Reference in a new issue