From f4558ef3a8a1c827c1f0c5f9cab3f3f82f6410df Mon Sep 17 00:00:00 2001 From: Eric Gustin <34000337+EricGustin@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:45:18 -0700 Subject: [PATCH] Tool Error Handling (#539) # Improvements to Arcade TDK Error Handling I tried my very best to not make any breaking changes in this PR. So, you will notice various "Deprecation" notices throughout. ### Instructions for PR reviewers 1. Pull down this PR's branch 2. Pull down the Engine's tool error handling PR's branch 3. Update your installed arcadepy to have the following: - In `arcadepy/resources/tools/tools.py`, if you want to test out including stacktraces, then you need to update `ToolsResource.execute` to accept a `include_error_stacktrace` argument and also include the "include_error_stacktrace" argument to the POST to the Engine inside of the function's execute method's body. - In `arcadepy/types/execute_tool_response.py` add the following enum ```py class ErrorKind(str, Enum): """Error kind that is comprised of - the who (toolkit, tool, upstream) - the when (load time, definition parsing time, runtime) - the what (bad_definition, bad_input, bad_output, retry, context_required, fatal, etc.)""" TOOLKIT_LOAD_FAILED = "TOOLKIT_LOAD_FAILED" TOOL_DEFINITION_BAD_DEFINITION = "TOOL_DEFINITION_BAD_DEFINITION" TOOL_DEFINITION_BAD_INPUT_SCHEMA = "TOOL_DEFINITION_BAD_INPUT_SCHEMA" TOOL_DEFINITION_BAD_OUTPUT_SCHEMA = "TOOL_DEFINITION_BAD_OUTPUT_SCHEMA" TOOL_RUNTIME_BAD_INPUT_VALUE = "TOOL_RUNTIME_BAD_INPUT_VALUE" TOOL_RUNTIME_BAD_OUTPUT_VALUE = "TOOL_RUNTIME_BAD_OUTPUT_VALUE" TOOL_RUNTIME_RETRY = "TOOL_RUNTIME_RETRY" TOOL_RUNTIME_CONTEXT_REQUIRED = "TOOL_RUNTIME_CONTEXT_REQUIRED" TOOL_RUNTIME_FATAL = "TOOL_RUNTIME_FATAL" UPSTREAM_RUNTIME_BAD_REQUEST = "UPSTREAM_RUNTIME_BAD_REQUEST" UPSTREAM_RUNTIME_AUTH_ERROR = "UPSTREAM_RUNTIME_AUTH_ERROR" UPSTREAM_RUNTIME_NOT_FOUND = "UPSTREAM_RUNTIME_NOT_FOUND" UPSTREAM_RUNTIME_VALIDATION_ERROR = "UPSTREAM_RUNTIME_VALIDATION_ERROR" UPSTREAM_RUNTIME_RATE_LIMIT = "UPSTREAM_RUNTIME_RATE_LIMIT" UPSTREAM_RUNTIME_SERVER_ERROR = "UPSTREAM_RUNTIME_SERVER_ERROR" UPSTREAM_RUNTIME_UNMAPPED = "UPSTREAM_RUNTIME_UNMAPPED" UNKNOWN = "UNKNOWN" ``` - In `arcadepy/types/execute_tool_response.py` add the following fields to OutputError: ```py kind: ErrorKind status_code: Optional[int] = None stacktrace: Optional[str] = None extra: Optional[dict[str, Any]] = None ``` ### Example Client Usage ```py # Example of handling an upstream rate limit error = response.output.error if error and error.kind == ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT: sleep_time = error.retry_after_ms / 1000 time.sleep(sleep_time) # and then execute again ``` ```py # Examples of determining what type of runtime error it is error = response.output.error if error: is_retryable_error = error.kind == ErrorKind.TOOL_RUNTIME_RETRY is_a_bug_in_the_tool = error.kind == ErrorKind.TOOL_RUNTIME_FATAL is_additional_context_required = error.kind == ErrorKind.TOOL_RUNTIME_CONTEXT_REQUIRED ``` ### Example Tool Usage ```py # EXAMPLE 1 letting Arcade handle upstream error handling for you reddit_client.post(params) # Arcade's httpx adapter will handle error handling for you! # ------------------------------------ # EXAMPLE 2 handling upstream bad request yourself, but letting Arcade handle the rest try: reddit_client.post(params) except httpx.HTTPStatusError as e: if e.status_code == 400: raise UpstreamError("My extra custom message) from e raise ``` ```py # EXAMPLE 1 letting Arcade handle it for you risky_element = my_risky_list[42] # Arcade will raise a FatalToolError for you # ------------------------------------ # EXAMPLE 2 handling it yourself for extra flexibility try: risky_element = my_risky_list[42] except IndexError as e: raise FatalToolError("My extra custom message") from e ``` ### Non-runtime Error Message Examples Example ToolkitLoadError Messages: ``` - [TOOLKIT_LOAD_FAILED] ToolkitLoadError when loading toolkit 'sample_tool': Could not import module mock_module. Reason: Mock import error - [TOOLKIT_LOAD_FAILED] ToolkitLoadError when loading toolkit 'test_toolkit': Tool 'ValidTool' in toolkit 'test_toolkit' already exists in the catalog. ``` Example ToolDefinitionError Messages ``` - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_missing_description': Tool 'tool_missing_description' is missing a description - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_invalid_secret_type': Secret keys must be strings (error in tool ToolWithInvalidSecretType). - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_empty_secret': Secrets must have a non-empty key (error in tool ToolWithEmptySecret). - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_invalid_metadata_type': Metadata must be strings (error in tool ToolWithInvalidMetadataType). - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_metadata_requiring_auth_without_auth': Tool ToolWithMetadataRequiringAuthWithoutAuth declares metadata key 'client_id', which requires that the tool has an auth requirement, but no auth requirement was provided. Please specify an auth requirement. - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_empty_metadata': Metadata must have a non-empty key (error in tool ToolWithEmptyMetadata). - [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_unsupported_param_type': Unsupported parameter type: ``` Example ToolInputSchemaError Messages ``` - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_missing_input_parameter_annotation': Parameter 'input_text' is missing a description - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_no_type_annotation': Parameter param has no type annotation. - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_invalid_param_name': Invalid parameter name: '123invalid' is not a valid identifier. Identifiers must start with a letter or underscore, and can only contain letters, digits, or underscores. - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_too_many_annotations': Parameter param: Annotated[str, 'name', 'desc', 'extra'] has too many string annotations. Expected 0, 1, or 2, got 3. - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_required_union_param': Parameter param is a union type. Only optional types are supported. - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_non_callable_default_factory': Default factory for parameter param: Annotated[str, 'Parameter'] = FieldInfo(annotation=NoneType, required=False, default_factory=str) is not callable. - [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_multiple_tool_contexts': Only one ToolContext parameter is supported, but tool tool_with_multiple_tool_contexts has multiple. ``` Example ToolOutputSchemaError Messages ``` - [TOOL_DEFINITION_BAD_OUTPUT_SCHEMA] ToolOutputSchemaError in definition of tool 'tool_missing_return_type_hint': Tool 'ToolMissingReturnTypeHint' must have a return type - [TOOL_DEFINITION_BAD_OUTPUT_SCHEMA] ToolOutputSchemaError in definition of tool 'tool_with_unsupported_output_type': Unsupported output type ''. Only built-in Python types, TypedDicts, Pydantic models, and standard collections are supported as tool output types. ``` ### Runtime Error Message Examples Example Tool Runtime Error Messages ``` - [TOOL_RUNTIME_FATAL] FatalToolError during execution of tool 'get_posts_in_subreddit': list index out of range - [TOOL_RUNTIME_CONTEXT_REQUIRED] ContextRequiredToolError during execution of tool 'get_posts_in_subreddit': Ambiguous username. Please provide a more specific username - [TOOL_RUNTIME_RETRY] RetryableToolError during execution of tool 'get_posts_in_subreddit': Retry with subreddit=learnpython or subreddit=learnprogramming ``` Example Upstream Runtime Error Messages ``` - [UPSTREAM_RUNTIME_RATE_LIMIT] UpstreamRateLimitError during execution of tool 'get_posts_in_subreddit': 429 Client Error: Too Many Requests - [UPSTREAM_RUNTIME_BAD_REQUEST] UpstreamError during execution of tool 'get_posts_in_subreddit': 400 Client Error: Bad request. Missing 'id' parameter. - [UPSTREAM_RUNTIME_BAD_REQUEST] UpstreamError during execution of tool 'search_files': Upstream Google API error: Invalid value '-23'. Values must be within the range: [value: 1\n, value: 1000\n] ``` --- libs/arcade-core/arcade_core/catalog.py | 62 ++- libs/arcade-core/arcade_core/errors.py | 351 ++++++++++-- libs/arcade-core/arcade_core/executor.py | 27 +- libs/arcade-core/arcade_core/output.py | 35 +- libs/arcade-core/arcade_core/schema.py | 27 +- libs/arcade-core/pyproject.toml | 2 +- libs/arcade-serve/arcade_serve/core/base.py | 4 +- libs/arcade-serve/pyproject.toml | 4 +- .../arcade_tdk/error_adapters/__init__.py | 5 + .../arcade_tdk/error_adapters/base.py | 22 + .../arcade_tdk/error_adapters/utils.py | 15 + libs/arcade-tdk/arcade_tdk/errors.py | 22 +- .../arcade_tdk/providers/__init__.py | 0 .../arcade_tdk/providers/google/__init__.py | 3 + .../providers/google/error_adapter.py | 228 ++++++++ .../arcade_tdk/providers/http/__init__.py | 3 + .../providers/http/error_adapter.py | 200 +++++++ libs/arcade-tdk/arcade_tdk/tool.py | 114 +++- libs/arcade-tdk/pyproject.toml | 4 +- libs/tests/core/test_catalog.py | 329 ++++++++++- libs/tests/core/test_executor.py | 191 +++++-- libs/tests/core/test_schema_validation.py | 2 + libs/tests/sdk/test_google_adapter.py | 510 ++++++++++++++++++ libs/tests/sdk/test_httpx_adapter.py | 333 ++++++++++++ .../test_create_tool_definition_errors.py | 4 +- pyproject.toml | 8 +- 26 files changed, 2348 insertions(+), 157 deletions(-) create mode 100644 libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py create mode 100644 libs/arcade-tdk/arcade_tdk/error_adapters/base.py create mode 100644 libs/arcade-tdk/arcade_tdk/error_adapters/utils.py create mode 100644 libs/arcade-tdk/arcade_tdk/providers/__init__.py create mode 100644 libs/arcade-tdk/arcade_tdk/providers/google/__init__.py create mode 100644 libs/arcade-tdk/arcade_tdk/providers/google/error_adapter.py create mode 100644 libs/arcade-tdk/arcade_tdk/providers/http/__init__.py create mode 100644 libs/arcade-tdk/arcade_tdk/providers/http/error_adapter.py create mode 100644 libs/tests/sdk/test_google_adapter.py create mode 100644 libs/tests/sdk/test_httpx_adapter.py diff --git a/libs/arcade-core/arcade_core/catalog.py b/libs/arcade-core/arcade_core/catalog.py index 0893915e..0323da09 100644 --- a/libs/arcade-core/arcade_core/catalog.py +++ b/libs/arcade-core/arcade_core/catalog.py @@ -27,7 +27,12 @@ from pydantic_core import PydanticUndefined from arcade_core.annotations import Inferrable from arcade_core.auth import OAuth2, ToolAuthorization -from arcade_core.errors import ToolDefinitionError +from arcade_core.errors import ( + ToolDefinitionError, + ToolInputSchemaError, + ToolkitLoadError, + ToolOutputSchemaError, +) from arcade_core.schema import ( TOOL_NAME_SEPARATOR, FullyQualifiedName, @@ -224,7 +229,9 @@ class ToolCatalog(BaseModel): fully_qualified_name = definition.get_fully_qualified_name() if fully_qualified_name in self._tools: - raise KeyError(f"Tool '{definition.name}' already exists in the catalog.") + raise ToolkitLoadError( + f"Tool '{definition.name}' in toolkit '{toolkit_name}' already exists in the catalog." + ) if str(fully_qualified_name).lower() in self._disabled_tools: logger.info(f"Tool '{fully_qualified_name!s}' is disabled and will not be cataloged.") @@ -270,20 +277,26 @@ class ToolCatalog(BaseModel): tool_func = getattr(module, tool_name) self.add_tool(tool_func, toolkit, module) + except ToolDefinitionError as e: + raise e.with_context(tool_name) from e + except ToolkitLoadError as e: + raise e.with_context(toolkit.name) from e + except ImportError as e: + raise ToolkitLoadError( + f"Could not import module {module_name}. Reason: {e}" + ).with_context(tool_name) except AttributeError as e: raise ToolDefinitionError( f"Could not import tool {tool_name} in module {module_name}. Reason: {e}" - ) - except ImportError as e: - raise ToolDefinitionError(f"Could not import module {module_name}. Reason: {e}") + ).with_context(tool_name) except TypeError as e: raise ToolDefinitionError( f"Type error encountered while adding tool {tool_name} from {module_name}. Reason: {e}" - ) + ).with_context(tool_name) except Exception as e: raise ToolDefinitionError( f"Error encountered while adding tool {tool_name} from {module_name}. Reason: {e}" - ) + ).with_context(tool_name) def __getitem__(self, name: FullyQualifiedName) -> MaterializedTool: return self.get_tool(name) @@ -392,11 +405,11 @@ class ToolCatalog(BaseModel): # Hard requirement: tools must have descriptions tool_description = getattr(tool, "__tool_description__", None) if not tool_description: - raise ToolDefinitionError(f"Tool {raw_tool_name} is missing a description") + raise ToolDefinitionError(f"Tool '{raw_tool_name}' is missing a description") # If the function returns a value, it must have a type annotation if does_function_return_value(tool) and tool.__annotations__.get("return") is None: - raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation") + raise ToolOutputSchemaError(f"Tool '{raw_tool_name}' must have a return type") auth_requirement = create_auth_requirement(tool) secrets_requirement = create_secrets_requirement(tool) @@ -438,7 +451,7 @@ def create_input_definition(func: Callable) -> ToolInput: for _, param in inspect.signature(func, follow_wrapped=True).parameters.items(): if param.annotation is ToolContext: if tool_context_param_name is not None: - raise ToolDefinitionError( + raise ToolInputSchemaError( f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple." ) @@ -635,7 +648,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: """ annotation = param.annotation if annotation == inspect.Parameter.empty: - raise ToolDefinitionError(f"Parameter {param} has no type annotation.") + raise ToolInputSchemaError(f"Parameter {param} has no type annotation.") # Get the majority of the param info from either the Pydantic Field() or regular inspection if isinstance(param.default, FieldInfo): @@ -654,7 +667,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: elif len(str_annotations) == 2: new_name = str_annotations[0] if not new_name.isidentifier(): - raise ToolDefinitionError( + raise ToolInputSchemaError( f"Invalid parameter name: '{new_name}' is not a valid identifier. " "Identifiers must start with a letter or underscore, " "and can only contain letters, digits, or underscores." @@ -662,7 +675,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: param_info.name = new_name param_info.description = str_annotations[1] else: - raise ToolDefinitionError( + raise ToolInputSchemaError( f"Parameter {param} has too many string annotations. Expected 0, 1, or 2, got {len(str_annotations)}." ) @@ -677,10 +690,10 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: # Final reality check if param_info.description is None: - raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description") + raise ToolInputSchemaError(f"Parameter '{param_info.name}' is missing a description") if wire_type_info.wire_type is None: - raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}") + raise ToolInputSchemaError(f"Unknown parameter type: {param_info.field_type}") return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable) @@ -875,7 +888,7 @@ def extract_python_param_info(param: inspect.Parameter) -> ParamInfo: # Union types are not currently supported # (other than optional, which is handled above) if is_union(field_type): - raise ToolDefinitionError( + raise ToolInputSchemaError( f"Parameter {param.name} is a union type. Only optional types are supported." ) @@ -895,7 +908,7 @@ def extract_pydantic_param_info(param: inspect.Parameter) -> ParamInfo: if callable(param.default.default_factory): default_value = param.default.default_factory() else: - raise ToolDefinitionError(f"Default factory for parameter {param} is not callable.") + raise ToolInputSchemaError(f"Default factory for parameter {param} is not callable.") # If the param is Annotated[], unwrap the annotation to get the "real" type # Otherwise, use the literal type @@ -988,7 +1001,6 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel] input_model = create_model(f"{snake_to_pascal_case(func.__name__)}Input", **input_fields) # type: ignore[call-overload] output_model = determine_output_model(func) - return input_model, output_model @@ -1033,10 +1045,16 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 # If the return annotation has a description, use it if description: - return create_model( - output_model_name, - result=(field_type, Field(description=str(description))), - ) + try: + return create_model( + output_model_name, + result=(field_type, Field(description=str(description))), + ) + except Exception: + raise ToolOutputSchemaError( + f"Unsupported output type '{field_type}'. Only built-in Python types, TypedDicts, " + "Pydantic models, and standard collections are supported as tool output types." + ) # If the return annotation is a Union type origin = return_annotation.__origin__ diff --git a/libs/arcade-core/arcade_core/errors.py b/libs/arcade-core/arcade_core/errors.py index 35ff6afc..a9fed151 100644 --- a/libs/arcade-core/arcade_core/errors.py +++ b/libs/arcade-core/arcade_core/errors.py @@ -1,103 +1,378 @@ import traceback -from typing import Optional +import warnings +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any -class ToolkitError(Exception): +class ErrorKind(str, Enum): + """Error kind that is comprised of + - the who (toolkit, tool, upstream) + - the when (load time, definition parsing time, runtime) + - the what (bad_definition, bad_input, bad_output, retry, context_required, fatal, etc.)""" + + TOOLKIT_LOAD_FAILED = "TOOLKIT_LOAD_FAILED" + TOOL_DEFINITION_BAD_DEFINITION = "TOOL_DEFINITION_BAD_DEFINITION" + TOOL_DEFINITION_BAD_INPUT_SCHEMA = "TOOL_DEFINITION_BAD_INPUT_SCHEMA" + TOOL_DEFINITION_BAD_OUTPUT_SCHEMA = "TOOL_DEFINITION_BAD_OUTPUT_SCHEMA" + TOOL_RUNTIME_BAD_INPUT_VALUE = "TOOL_RUNTIME_BAD_INPUT_VALUE" + TOOL_RUNTIME_BAD_OUTPUT_VALUE = "TOOL_RUNTIME_BAD_OUTPUT_VALUE" + TOOL_RUNTIME_RETRY = "TOOL_RUNTIME_RETRY" + TOOL_RUNTIME_CONTEXT_REQUIRED = "TOOL_RUNTIME_CONTEXT_REQUIRED" + TOOL_RUNTIME_FATAL = "TOOL_RUNTIME_FATAL" + UPSTREAM_RUNTIME_BAD_REQUEST = "UPSTREAM_RUNTIME_BAD_REQUEST" + UPSTREAM_RUNTIME_AUTH_ERROR = "UPSTREAM_RUNTIME_AUTH_ERROR" + UPSTREAM_RUNTIME_NOT_FOUND = "UPSTREAM_RUNTIME_NOT_FOUND" + UPSTREAM_RUNTIME_VALIDATION_ERROR = "UPSTREAM_RUNTIME_VALIDATION_ERROR" + UPSTREAM_RUNTIME_RATE_LIMIT = "UPSTREAM_RUNTIME_RATE_LIMIT" + UPSTREAM_RUNTIME_SERVER_ERROR = "UPSTREAM_RUNTIME_SERVER_ERROR" + UPSTREAM_RUNTIME_UNMAPPED = "UPSTREAM_RUNTIME_UNMAPPED" + UNKNOWN = "UNKNOWN" + + +class ToolkitError(Exception, ABC): """ - Base class for all errors related to toolkits. + Base class for all Arcade errors. + + Note: This class is an abstract class and cannot be instantiated directly. + + These errors are ultimately converted to the ToolCallError schema. + Attributes expected from subclasses: + message : str # user-facing error message + kind : ErrorKind # the error kind + can_retry : bool # whether the operation can be retried + developer_message : str | None # developer-facing error details + status_code : int | None # HTTP status code when relevant + additional_prompt_content : str | None # content for retry prompts + retry_after_ms : int | None # milliseconds to wait before retry + stacktrace : str | None # stacktrace information + extra : dict[str, Any] | None # arbitrary structured metadata """ - pass + def __new__(cls, *args: Any, **kwargs: Any) -> "ToolkitError": + abs_methods = getattr(cls, "__abstractmethods__", None) + if abs_methods: + raise TypeError(f"Can't instantiate abstract class {cls.__name__}") + return super().__new__(cls) + + @abstractmethod + def create_message_prefix(self, name: str) -> str: + pass + + def with_context(self, name: str) -> "ToolkitError": + """ + Add context to the error message. + + Args: + name: The name of the tool or toolkit that caused the error. + + Returns: + The error with the context added to the message. + """ + prefix = self.create_message_prefix(name) + self.message = f"{prefix}{self.message}" # type: ignore[has-type] + if hasattr(self, "developer_message") and self.developer_message: # type: ignore[has-type] + self.developer_message = f"{prefix}{self.developer_message}" # type: ignore[has-type] + + return self + + @property + def is_toolkit_error(self) -> bool: + """Check if this error originated from loading a toolkit.""" + return hasattr(self, "kind") and self.kind.name.startswith("TOOLKIT_") + + @property + def is_tool_error(self) -> bool: + """Check if this error originated from a tool.""" + return hasattr(self, "kind") and self.kind.name.startswith("TOOL_") + + @property + def is_upstream_error(self) -> bool: + """Check if this error originated from an upstream service.""" + return hasattr(self, "kind") and self.kind.name.startswith("UPSTREAM_") + + def __str__(self) -> str: + return self.message class ToolkitLoadError(ToolkitError): """ - Raised when there is an error loading a toolkit. + Raised while importing / loading a toolkit package + (e.g. missing dependency, SyntaxError in module top-level code). """ - pass + kind: ErrorKind = ErrorKind.TOOLKIT_LOAD_FAILED + can_retry: bool = False + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + def create_message_prefix(self, toolkit_name: str) -> str: + return f"[{self.kind.value}] {type(self).__name__} when loading toolkit '{toolkit_name}': " -class ToolError(Exception): +class ToolError(ToolkitError): """ - Base class for all errors related to tools. + Any error related to an Arcade tool. + + Note: This class is an abstract class and cannot be instantiated directly. """ - pass - +# ------ definition-time errors (tool developer's responsibility) ------ class ToolDefinitionError(ToolError): """ - Raised when there is an error in the definition of a tool. + Raised when there is an error in the definition/signature of a tool. """ - pass + kind: ErrorKind = ErrorKind.TOOL_DEFINITION_BAD_DEFINITION + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + def create_message_prefix(self, tool_name: str) -> str: + return f"[{self.kind.value}] {type(self).__name__} in definition of tool '{tool_name}': " + + +class ToolInputSchemaError(ToolDefinitionError): + """Raised when there is an error in the schema of a tool's input parameter.""" + + kind: ErrorKind = ErrorKind.TOOL_DEFINITION_BAD_INPUT_SCHEMA + + +class ToolOutputSchemaError(ToolDefinitionError): + """Raised when there is an error in the schema of a tool's output parameter.""" + + kind: ErrorKind = ErrorKind.TOOL_DEFINITION_BAD_OUTPUT_SCHEMA # ------ runtime errors ------ +class ToolRuntimeError(ToolError, RuntimeError): + """ + Any failure starting from when the tool call begins until the tool call returns. + Note: This class should typically not be instantiated directly, but rather subclassed. + """ + + kind: ErrorKind = ErrorKind.TOOL_RUNTIME_FATAL + can_retry: bool = False + status_code: int | None = None + extra: dict[str, Any] | None = None -class ToolRuntimeError(RuntimeError): def __init__( self, message: str, - developer_message: Optional[str] = None, + developer_message: str | None = None, + *, + extra: dict[str, Any] | None = None, ): super().__init__(message) self.message = message - self.developer_message = developer_message + self.developer_message = developer_message # type: ignore[assignment] + self.extra = extra - def traceback_info(self) -> str | None: - # return the traceback information of the parent exception + def create_message_prefix(self, tool_name: str) -> str: + return f"[{self.kind.value}] {type(self).__name__} during execution of tool '{tool_name}': " + + def stacktrace(self) -> str | None: if self.__cause__: return "\n".join(traceback.format_exception(self.__cause__)) return None + def traceback_info(self) -> str | None: + """DEPRECATED: Use stacktrace() instead. + This method is deprecated and will be removed in a future major version. + """ + return self.stacktrace() + + # wire-format helper + def to_payload(self) -> dict[str, Any]: + return { + "message": self.message, + "developer_message": self.developer_message, + "kind": self.kind, + "can_retry": self.can_retry, + "status_code": self.status_code, + **(self.extra or {}), + } + + +# 1. ------ serialization errors ------ +class ToolSerializationError(ToolRuntimeError): + """ + Raised when there is an error serializing/marshalling the tool call arguments or return value. + + Note: This class is not intended to be instantiated directly, but rather subclassed. + """ + + +class ToolInputError(ToolSerializationError): + """ + Raised when there is an error parsing a tool call argument. + """ + + kind: ErrorKind = ErrorKind.TOOL_RUNTIME_BAD_INPUT_VALUE + status_code: int = 400 + + +class ToolOutputError(ToolSerializationError): + """ + Raised when there is an error serializing a tool call return value. + """ + + kind: ErrorKind = ErrorKind.TOOL_RUNTIME_BAD_OUTPUT_VALUE + status_code: int = 500 + + +# 2. ------ tool-body errors ------ class ToolExecutionError(ToolRuntimeError): """ - Raised when there is an error executing a tool. - """ + DEPRECATED: Raised when there is an error executing a tool. - pass - - -class RetryableToolError(ToolExecutionError): - """ - Raised when a tool error is retryable. + ToolExecutionError is deprecated and will be removed in a future major version. + Use more specific error types instead: + - RetryableToolError for retryable errors + - ContextRequiredToolError for errors requiring user context + - FatalToolError for fatal/unexpected errors + - UpstreamError for upstream service errors + - UpstreamRateLimitError for upstream rate limiting errors """ def __init__( self, message: str, - developer_message: Optional[str] = None, - additional_prompt_content: Optional[str] = None, - retry_after_ms: Optional[int] = None, + developer_message: str | None = None, + *, + extra: dict[str, Any] | None = None, ): - super().__init__(message, developer_message) + if type(self) is ToolExecutionError: + warnings.warn( + "ToolExecutionError is deprecated and will be removed in a future major version. " + "Use more specific error types instead: RetryableToolError, ContextRequiredToolError, " + "FatalToolError, UpstreamError, or UpstreamRateLimitError.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(message, developer_message=developer_message, extra=extra) + + +class RetryableToolError(ToolExecutionError): + """ + Raised when a tool execution error is retryable. + """ + + kind: ErrorKind = ErrorKind.TOOL_RUNTIME_RETRY + can_retry: bool = True + + def __init__( + self, + message: str, + developer_message: str | None = None, + additional_prompt_content: str | None = None, # TODO: Make required in next major version + retry_after_ms: int | None = None, + extra: dict[str, Any] | None = None, + ): + super().__init__(message, developer_message=developer_message, extra=extra) self.additional_prompt_content = additional_prompt_content self.retry_after_ms = retry_after_ms -class ToolSerializationError(ToolRuntimeError): +class ContextRequiredToolError(ToolExecutionError): """ - Raised when there is an error executing a tool. + Raised when the combination of additional content from the tool AND + additional context from the end-user/orchestrator is required before retrying the tool. + + This is typically used when an argument provided to the tool is invalid in some way, + and immediately prompting an LLM to retry the tool call is not desired. """ - pass + kind: ErrorKind = ErrorKind.TOOL_RUNTIME_CONTEXT_REQUIRED + + def __init__( + self, + message: str, + additional_prompt_content: str, + developer_message: str | None = None, + *, + extra: dict[str, Any] | None = None, + ): + super().__init__(message, developer_message=developer_message, extra=extra) + self.additional_prompt_content = additional_prompt_content -class ToolInputError(ToolSerializationError): +class FatalToolError(ToolExecutionError): """ - Raised when there is an error in the input to a tool. + Raised when there is an unexpected or unknown error executing a tool. """ - pass + status_code: int = 500 + + def __init__( + self, + message: str, + developer_message: str | None = None, + *, + extra: dict[str, Any] | None = None, + ): + super().__init__(message, developer_message=developer_message, extra=extra) -class ToolOutputError(ToolSerializationError): +# 3. ------ upstream errors in tool body------ +class UpstreamError(ToolExecutionError): """ - Raised when there is an error in the output of a tool. + Error from an upstream service/API during tool execution. + + This class handles all upstream failures except rate limiting. + The status_code and extra dict provide details about the specific error type. """ - pass + def __init__( + self, + message: str, + developer_message: str | None = None, + *, + status_code: int, + extra: dict[str, Any] | None = None, + ): + super().__init__(message, developer_message=developer_message, extra=extra) + self.status_code = status_code + # Determine retryability based on status code + self.can_retry = status_code >= 500 or status_code == 429 + # Set appropriate error kind based on status + if status_code in (401, 403): + self.kind = ErrorKind.UPSTREAM_RUNTIME_AUTH_ERROR + elif status_code == 404: + self.kind = ErrorKind.UPSTREAM_RUNTIME_NOT_FOUND + elif status_code == 429: + self.kind = ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT + elif status_code >= 500: + self.kind = ErrorKind.UPSTREAM_RUNTIME_SERVER_ERROR + elif 400 <= status_code < 500: + self.kind = ErrorKind.UPSTREAM_RUNTIME_BAD_REQUEST + else: + self.kind = ErrorKind.UPSTREAM_RUNTIME_UNMAPPED + + +class UpstreamRateLimitError(UpstreamError): + """ + Rate limit error from an upstream service. + + Special case of UpstreamError that includes retry_after_ms information. + """ + + kind: ErrorKind = ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT + can_retry: bool = True + + def __init__( + self, + message: str, + retry_after_ms: int, + developer_message: str | None = None, + *, + extra: dict[str, Any] | None = None, + ): + super().__init__(message, status_code=429, developer_message=developer_message, extra=extra) + self.retry_after_ms = retry_after_ms diff --git a/libs/arcade-core/arcade_core/executor.py b/libs/arcade-core/arcade_core/executor.py index a9c67020..3b2cca59 100644 --- a/libs/arcade-core/arcade_core/executor.py +++ b/libs/arcade-core/arcade_core/executor.py @@ -6,11 +6,9 @@ from typing import Any from pydantic import BaseModel, ValidationError from arcade_core.errors import ( - RetryableToolError, ToolInputError, ToolOutputError, ToolRuntimeError, - ToolSerializationError, ) from arcade_core.output import output_factory from arcade_core.schema import ( @@ -69,31 +67,26 @@ class ToolExecutor: # return the output return output_factory.success(data=output, logs=tool_call_logs) - except RetryableToolError as e: - return output_factory.fail_retry( - message=e.message, - developer_message=e.developer_message, - additional_prompt_content=e.additional_prompt_content, - retry_after_ms=e.retry_after_ms, - ) - - except ToolSerializationError as e: - return output_factory.fail(message=e.message, developer_message=e.developer_message) - - # should catch all tool exceptions due to the try/except in the tool decorator except ToolRuntimeError as e: + e.with_context(func.__name__) return output_factory.fail( message=e.message, developer_message=e.developer_message, - traceback_info=e.traceback_info(), + stacktrace=e.stacktrace(), + additional_prompt_content=getattr(e, "additional_prompt_content", None), + retry_after_ms=getattr(e, "retry_after_ms", None), + kind=e.kind, + can_retry=e.can_retry, + status_code=e.status_code, + extra=e.extra, ) # if we get here we're in trouble except Exception as e: return output_factory.fail( - message="Error in execution", + message=f"Error in execution of '{func.__name__}'", developer_message=str(e), - traceback_info=traceback.format_exc(), + stacktrace=traceback.format_exc(), ) @staticmethod diff --git a/libs/arcade-core/arcade_core/output.py b/libs/arcade-core/arcade_core/output.py index bb32e3b4..d5863802 100644 --- a/libs/arcade-core/arcade_core/output.py +++ b/libs/arcade-core/arcade_core/output.py @@ -1,7 +1,8 @@ -from typing import TypeVar +from typing import Any, TypeVar from pydantic import BaseModel +from arcade_core.errors import ErrorKind from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput from arcade_core.utils import coerce_empty_list_to_none @@ -61,15 +62,26 @@ class ToolOutputFactory: *, message: str, developer_message: str | None = None, - traceback_info: str | None = None, + stacktrace: str | None = None, logs: list[ToolCallLog] | None = None, + additional_prompt_content: str | None = None, + retry_after_ms: int | None = None, + kind: ErrorKind = ErrorKind.UNKNOWN, + can_retry: bool = False, + status_code: int | None = None, + extra: dict[str, Any] | None = None, ) -> ToolCallOutput: return ToolCallOutput( error=ToolCallError( message=message, developer_message=developer_message, - can_retry=False, - traceback_info=traceback_info, + can_retry=can_retry, + additional_prompt_content=additional_prompt_content, + retry_after_ms=retry_after_ms, + stacktrace=stacktrace, + kind=kind, + status_code=status_code, + extra=extra, ), logs=coerce_empty_list_to_none(logs), ) @@ -81,9 +93,17 @@ class ToolOutputFactory: developer_message: str | None = None, additional_prompt_content: str | None = None, retry_after_ms: int | None = None, - traceback_info: str | None = None, + stacktrace: str | None = None, logs: list[ToolCallLog] | None = None, + kind: ErrorKind = ErrorKind.TOOL_RUNTIME_RETRY, + status_code: int = 500, + extra: dict[str, Any] | None = None, ) -> ToolCallOutput: + """ + DEPRECATED: Use ToolOutputFactory.fail instead. + This method will be removed in version 3.0.0 + """ + return ToolCallOutput( error=ToolCallError( message=message, @@ -91,7 +111,10 @@ class ToolOutputFactory: can_retry=True, additional_prompt_content=additional_prompt_content, retry_after_ms=retry_after_ms, - traceback_info=traceback_info, + stacktrace=stacktrace, + kind=kind, + status_code=status_code, + extra=extra, ), logs=coerce_empty_list_to_none(logs), ) diff --git a/libs/arcade-core/arcade_core/schema.py b/libs/arcade-core/arcade_core/schema.py index 28708e8c..64e12738 100644 --- a/libs/arcade-core/arcade_core/schema.py +++ b/libs/arcade-core/arcade_core/schema.py @@ -5,6 +5,8 @@ from typing import Any, Literal from pydantic import BaseModel, Field +from arcade_core.errors import ErrorKind + # allow for custom tool name separator TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".") @@ -390,6 +392,8 @@ class ToolCallError(BaseModel): message: str """The user-facing error message.""" + kind: ErrorKind + """The error kind that uniquely identifies the kind of error.""" developer_message: str | None = None """The developer-facing error details.""" can_retry: bool = False @@ -398,8 +402,27 @@ class ToolCallError(BaseModel): """Additional content to be included in the retry prompt.""" retry_after_ms: int | None = None """The number of milliseconds (if any) to wait before retrying the tool call.""" - traceback_info: str | None = None - """The traceback information for the tool call.""" + stacktrace: str | None = None + """The stacktrace information for the tool call.""" + status_code: int | None = None + """The HTTP status code of the error.""" + extra: dict[str, Any] | None = None + """Additional information about the error.""" + + @property + def is_toolkit_error(self) -> bool: + """Check if this error originated from loading a toolkit.""" + return self.kind.name.startswith("TOOLKIT_") + + @property + def is_tool_error(self) -> bool: + """Check if this error originated from a tool.""" + return self.kind.name.startswith("TOOL_") + + @property + def is_upstream_error(self) -> bool: + """Check if this error originated from an upstream service.""" + return self.kind.name.startswith("UPSTREAM_") class ToolCallRequiresAuthorization(BaseModel): diff --git a/libs/arcade-core/pyproject.toml b/libs/arcade-core/pyproject.toml index 055f38cf..2814e854 100644 --- a/libs/arcade-core/pyproject.toml +++ b/libs/arcade-core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-core" -version = "2.3.0" +version = "2.4.0" description = "Arcade Core - Core library for Arcade platform" readme = "README.md" license = {text = "MIT"} diff --git a/libs/arcade-serve/arcade_serve/core/base.py b/libs/arcade-serve/arcade_serve/core/base.py index e205416a..d24248a5 100644 --- a/libs/arcade-serve/arcade_serve/core/base.py +++ b/libs/arcade-serve/arcade_serve/core/base.py @@ -153,8 +153,8 @@ class BaseWorker(Worker): logger.debug( f"{execution_id} | duration: {duration_ms}ms | Tool output: {output.value}" ) - if output.error.traceback_info: - logger.debug(f"{execution_id} | Tool traceback: {output.error.traceback_info}") + if output.error.stacktrace: + logger.debug(f"{execution_id} | Tool traceback: {output.error.stacktrace}") else: logger.info( f"{execution_id} | Tool {tool_fqname} version {tool_request.tool.version} success" diff --git a/libs/arcade-serve/pyproject.toml b/libs/arcade-serve/pyproject.toml index 7da31f18..40d0c0f2 100644 --- a/libs/arcade-serve/pyproject.toml +++ b/libs/arcade-serve/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-serve" -version = "2.0.0" +version = "2.1.0" description = "Arcade Serve - Serving infrastructure for Arcade tools and workers" readme = "README.md" license = {text = "MIT"} @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "arcade-core>=2.0.0,<3.0.0", + "arcade-core>=2.4.0,<3.0.0", "fastapi>=0.115.3", "uvicorn>=0.30.0", "watchfiles>=1.0.5", diff --git a/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py b/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py new file mode 100644 index 00000000..471dc0b9 --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py @@ -0,0 +1,5 @@ +from arcade_tdk.error_adapters.base import ErrorAdapter +from arcade_tdk.providers.google import GoogleErrorAdapter +from arcade_tdk.providers.http import HTTPErrorAdapter + +__all__ = ["ErrorAdapter", "HTTPErrorAdapter", "GoogleErrorAdapter"] diff --git a/libs/arcade-tdk/arcade_tdk/error_adapters/base.py b/libs/arcade-tdk/arcade_tdk/error_adapters/base.py new file mode 100644 index 00000000..541bd254 --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/error_adapters/base.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from arcade_tdk.errors import ToolRuntimeError + + +@runtime_checkable +class ErrorAdapter(Protocol): + """ + Plugin that translates vendor-specific exceptions / responses into + the appropriate Arcade Errors. + """ + + slug: str # for logging & metrics + + def from_exception(self, exc: Exception) -> ToolRuntimeError | None: + """ + Translate an exception raised by an SDK, HTTP client, etc. + into a `ToolRuntimeError` subclass. + """ + ... diff --git a/libs/arcade-tdk/arcade_tdk/error_adapters/utils.py b/libs/arcade-tdk/arcade_tdk/error_adapters/utils.py new file mode 100644 index 00000000..4a7de8fe --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/error_adapters/utils.py @@ -0,0 +1,15 @@ +from arcade_tdk.auth import Google, ToolAuthorization +from arcade_tdk.error_adapters import ErrorAdapter, GoogleErrorAdapter + + +def get_adapter_for_auth_provider(auth_provider: ToolAuthorization | None) -> ErrorAdapter | None: + """ + Get an error adapter from an auth provider. + """ + if not auth_provider: + return None + + if isinstance(auth_provider, Google): + return GoogleErrorAdapter() + + return None diff --git a/libs/arcade-tdk/arcade_tdk/errors.py b/libs/arcade-tdk/arcade_tdk/errors.py index dc3444a0..dcd90a5f 100644 --- a/libs/arcade-tdk/arcade_tdk/errors.py +++ b/libs/arcade-tdk/arcade_tdk/errors.py @@ -1,16 +1,34 @@ -from arcade_core.errors import RetryableToolError, ToolExecutionError, ToolRuntimeError +from arcade_core.errors import ( + ContextRequiredToolError, + ErrorKind, + FatalToolError, + RetryableToolError, + ToolExecutionError, + ToolRuntimeError, + UpstreamError, + UpstreamRateLimitError, +) __all__ = [ + "ErrorKind", + "FatalToolError", "RetryableToolError", "SDKError", "ToolExecutionError", "ToolRuntimeError", + "UpstreamError", + "UpstreamRateLimitError", + "ContextRequiredToolError", "WeightError", ] class SDKError(Exception): - """Base class for all SDK errors.""" + """ + DEPRECATED: Base class for all SDK errors. + + SDKError is deprecated and will be removed in a future major version. + """ class WeightError(SDKError): diff --git a/libs/arcade-tdk/arcade_tdk/providers/__init__.py b/libs/arcade-tdk/arcade_tdk/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/arcade-tdk/arcade_tdk/providers/google/__init__.py b/libs/arcade-tdk/arcade_tdk/providers/google/__init__.py new file mode 100644 index 00000000..b667dced --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/providers/google/__init__.py @@ -0,0 +1,3 @@ +from arcade_tdk.providers.google.error_adapter import GoogleErrorAdapter + +__all__ = ["GoogleErrorAdapter"] diff --git a/libs/arcade-tdk/arcade_tdk/providers/google/error_adapter.py b/libs/arcade-tdk/arcade_tdk/providers/google/error_adapter.py new file mode 100644 index 00000000..6baf4768 --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/providers/google/error_adapter.py @@ -0,0 +1,228 @@ +import logging +from datetime import datetime, timezone +from typing import Any +from urllib.parse import urlparse + +from arcade_core.errors import ( + ToolRuntimeError, + UpstreamError, + UpstreamRateLimitError, +) + +logger = logging.getLogger(__name__) + + +class GoogleErrorAdapter: + """Error adapter for Google's API Python Client library.""" + + slug = "_google_api_client" + + def _sanitize_uri(self, uri: str) -> str: + """Strip query params and fragments from URI for privacy.""" + + parsed = urlparse(uri) + return f"{parsed.scheme}://{parsed.netloc.strip('/')}/{parsed.path.strip('/')}" + + def _parse_retry_after(self, error: Any) -> int: + """ + Extract retry-after from Google API errors. + Returns milliseconds to wait before retry. + Defaults to 1000ms if not found. + + Args: + error: The Google client error to parse + + Returns: + The number of milliseconds to wait before retry + """ + if hasattr(error, "resp") and hasattr(error.resp, "headers"): + headers = error.resp.headers + + retry_after = headers.get("Retry-After", headers.get("retry-after")) + if retry_after: + try: + # If it's a number, it's seconds + if retry_after.isdigit(): + return int(retry_after) * 1000 + # Otherwise try to parse as date + dt = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z") + return int((dt - datetime.now(timezone.utc)).total_seconds() * 1000) + except Exception: + # TODO: Log? + return 1000 + + return 1000 + + def _map_http_error(self, error: Any) -> ToolRuntimeError | None: + """Map Google HttpError to appropriate ToolRuntimeError.""" + status_code = error.status_code + reason = str(error.reason) if error.reason else f"HTTP {status_code} error" + + message = f"Upstream Google API error: {reason}" + + developer_message = None + if error.error_details: + # str error details are added to the message + if isinstance(error.error_details, str): + message = f"{message} - Details: {error.error_details}" + else: + # structured error details are added to the developer message + developer_message = f"Upstream Google API error details: {error.error_details}" + + # Build extra metadata + extra = { + "service": self.slug, + } + + # Try to extract request details if available + if hasattr(error, "uri"): + extra["endpoint"] = self._sanitize_uri(error.uri) + if hasattr(error, "method_"): + extra["http_method"] = error.method_.upper() + + # Special case for rate limiting + if status_code == 429: + return UpstreamRateLimitError( + retry_after_ms=self._parse_retry_after(error), + message=message, + developer_message=developer_message, + extra=extra, + ) + + return UpstreamError( + message=message, + status_code=status_code, + developer_message=developer_message, + extra=extra, + ) + + def _handle_http_errors(self, exc: Exception, errors_module: Any) -> ToolRuntimeError | None: + """Handle HttpError and its subclasses.""" + if isinstance(exc, errors_module.HttpError): + return self._map_http_error(exc) + + if isinstance(exc, errors_module.BatchError): + # BatchError might not have status_code, so handle carefully + if hasattr(exc, "resp") and hasattr(exc.resp, "status"): + exc.status_code = exc.resp.status + return self._map_http_error(exc) + else: + # No status code available, treat as server error + extra = { + "service": "google_api", + "error_type": "BatchError", + } + return UpstreamError( + message=f"Upstream Google API batch operation failed: {exc.reason}", + status_code=500, + extra=extra, + ) + return None + + def _handle_other_errors(self, exc: Exception, errors_module: Any) -> ToolRuntimeError | None: + """Handle non-HTTP Google API errors.""" + if isinstance(exc, errors_module.InvalidJsonError): + return UpstreamError( + message="Upstream Google API returned invalid JSON response", + status_code=502, + developer_message=str(exc), + extra={ + "service": self.slug, + "error_type": "InvalidJsonError", + }, + ) + + if isinstance(exc, errors_module.UnknownApiNameOrVersion): + return UpstreamError( + message="Upstream Google API error: Unknown API name or version", + status_code=404, + developer_message=str(exc), + extra={ + "service": self.slug, + "error_type": "UnknownApiNameOrVersion", + }, + ) + + if isinstance(exc, errors_module.UnacceptableMimeTypeError): + return UpstreamError( + message="Upstream Google API error: Unacceptable MIME type for this operation", + status_code=400, + developer_message=str(exc), + extra={ + "service": self.slug, + "error_type": "UnacceptableMimeTypeError", + }, + ) + + if isinstance(exc, errors_module.MediaUploadSizeError): + return UpstreamError( + message="Upstream Google API error: Media file size exceeds allowed limit", + status_code=400, + developer_message=str(exc), + extra={ + "service": self.slug, + "error_type": "MediaUploadSizeError", + }, + ) + + if isinstance(exc, errors_module.InvalidChunkSizeError): + return UpstreamError( + message="Upstream Google API error: Invalid chunk size specified", + developer_message=str(exc), + status_code=400, + extra={ + "service": self.slug, + "error_type": "InvalidChunkSizeError", + }, + ) + + if isinstance(exc, errors_module.InvalidNotificationError): + return UpstreamError( + message="Upstream Google API error: Invalid notification configuration", + developer_message=str(exc), + status_code=400, + extra={ + "service": self.slug, + "error_type": "InvalidNotificationError", + }, + ) + + return None + + def from_exception(self, exc: Exception) -> ToolRuntimeError | None: + """ + Translate a Google API client exception into a ToolRuntimeError. + """ + # Lazy import the Google API client errors module to avoid import errors for toolkits that don't use googleapiclient + try: + from googleapiclient import errors + except ImportError: + logger.info( + f"'googleapiclient' is not installed in the toolkit's environment, " + f"so the '{self.slug}' adapter was not used to handle the upstream error" + ) + return None + + # Try HTTP errors first + result = self._handle_http_errors(exc, errors) + if result: + return result + + # Then try other error types + result = self._handle_other_errors(exc, errors) + if result: + return result + + # Failsafe for any unhandled Google API client errors that are not mapped above + if hasattr(exc, "__module__") and exc.__module__ == "googleapiclient.errors": + return UpstreamError( + message=f"Upstream Google API error: {exc}", + status_code=500, + extra={ + "service": self.slug, + "error_type": exc.__class__.__name__, + }, + ) + + # Not a Google API client error + return None diff --git a/libs/arcade-tdk/arcade_tdk/providers/http/__init__.py b/libs/arcade-tdk/arcade_tdk/providers/http/__init__.py new file mode 100644 index 00000000..705278ec --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/providers/http/__init__.py @@ -0,0 +1,3 @@ +from arcade_tdk.providers.http.error_adapter import HTTPErrorAdapter + +__all__ = ["HTTPErrorAdapter"] diff --git a/libs/arcade-tdk/arcade_tdk/providers/http/error_adapter.py b/libs/arcade-tdk/arcade_tdk/providers/http/error_adapter.py new file mode 100644 index 00000000..1fa56beb --- /dev/null +++ b/libs/arcade-tdk/arcade_tdk/providers/http/error_adapter.py @@ -0,0 +1,200 @@ +import logging +from datetime import datetime, timezone +from typing import Any +from urllib.parse import urlparse + +from arcade_core.errors import ( + UpstreamError, + UpstreamRateLimitError, +) + +logger = logging.getLogger(__name__) + +RATE_HEADERS = ("retry-after", "x-ratelimit-reset", "x-ratelimit-reset-ms") + + +class BaseHTTPErrorMapper: + """Base class for HTTP error mapping functionality.""" + + def _parse_retry_ms(self, headers: dict[str, str]) -> int: + """ + Parses a rate limit header and returns the number + of milliseconds until the rate limit resets. + + Args: + headers: A dictionary of HTTP headers. + + Returns: + The number of milliseconds until the rate limit resets. + Defaults to 1000ms if a rate limit header is not found or cannot be parsed. + """ + val = next((headers.get(h) for h in RATE_HEADERS if headers.get(h)), None) + # No rate limit header found + if val is None: + return 1_000 + # Rate limit header is a number of seconds + if val.isdigit(): + key = next((h for h in RATE_HEADERS if headers.get(h) == val), "") + if key.endswith("ms"): + return int(val) + return int(val) * 1_000 + # Rate limit header is an absolute date + try: + dt = datetime.strptime(val, "%a, %d %b %Y %H:%M:%S %Z") + return int((dt - datetime.now(timezone.utc)).total_seconds() * 1_000) + except Exception: + logger.warning(f"Failed to parse rate limit header: {val}. Defaulting to 1000ms.") + return 1_000 + + def _sanitize_uri(self, uri: str) -> str: + """Strip query params and fragments from URI for privacy.""" + + parsed = urlparse(uri) + return f"{parsed.scheme}://{parsed.netloc.strip('/')}/{parsed.path.strip('/')}" + + def _build_extra_metadata( + self, request_url: str | None = None, request_method: str | None = None + ) -> dict[str, str]: + """Build extra metadata for error reporting.""" + extra = { + "service": HTTPErrorAdapter.slug, + } + + if request_url: + extra["endpoint"] = self._sanitize_uri(request_url) + + if request_method: + extra["http_method"] = request_method.upper() + + return extra + + def _map_status_to_error( + self, + status: int, + headers: dict[str, str], + msg: str, + request_url: str | None = None, + request_method: str | None = None, + ) -> UpstreamError: + """Map HTTP status code to appropriate Arcade error.""" + extra = self._build_extra_metadata(request_url, request_method) + + # Special case for rate limiting + if status == 429: + return UpstreamRateLimitError( + retry_after_ms=self._parse_retry_ms(headers), + message=msg, + extra=extra, + ) + + return UpstreamError(message=msg, status_code=status, extra=extra) + + +class _HTTPXExceptionHandler: + """Handler for httpx-specific exceptions.""" + + def handle_exception(self, exc: Any, mapper: BaseHTTPErrorMapper) -> UpstreamError | None: + """Handle httpx HTTPStatusError exceptions. + + Args: + exc: An httpx.HTTPStatusError exception + mapper: The BaseHTTPErrorMapper instance to use for mapping + + Returns: + An Arcade error instance or None if not an httpx exception + """ + # Lazy import httpx types locally to avoid import errors for toolkits that don't use httpx + try: + import httpx + except ImportError: + return None + + if not isinstance(exc, httpx.HTTPStatusError): + return None + + response = exc.response + request_url = None + request_method = None + if hasattr(exc, "request") and exc.request: + request_url = str(exc.request.url) + request_method = exc.request.method + + return mapper._map_status_to_error( + response.status_code, + dict(response.headers), + str(exc), + request_url=request_url, + request_method=request_method, + ) + + +class _RequestsExceptionHandler: + """Handler for requests-specific exceptions.""" + + def handle_exception(self, exc: Any, mapper: BaseHTTPErrorMapper) -> UpstreamError | None: + """Handle requests library exceptions. + + Args: + exc: A requests.exceptions.HTTPError exception + mapper: The BaseHTTPErrorMapper instance to use for mapping + + Returns: + An Arcade error instance or None if not a requests exception + """ + # Lazy import requests types locally to avoid import errors for toolkits that don't use requests + try: + from requests.exceptions import HTTPError # type: ignore[import-untyped] + except ImportError: + return None + + if not isinstance(exc, HTTPError): + return None + + response = getattr(exc, "response", None) + if response is None: + return None + + # Extract request information + request_url = None + request_method = None + if hasattr(response, "request") and response.request: + request_url = response.request.url + request_method = response.request.method + elif hasattr(response, "url"): + request_url = response.url + + return mapper._map_status_to_error( + response.status_code, + dict(response.headers), + str(exc), + request_url=request_url, + request_method=request_method, + ) + + +class HTTPErrorAdapter(BaseHTTPErrorMapper): + """Main HTTP error adapter that supports multiple HTTP libraries.""" + + slug = "_http" + + def __init__(self) -> None: + self._httpx_handler = _HTTPXExceptionHandler() + self._requests_handler = _RequestsExceptionHandler() + + def from_exception(self, exc: Exception) -> UpstreamError | None: + """Convert HTTP library exceptions into Arcade errors.""" + + httpx_result = self._httpx_handler.handle_exception(exc, self) + if httpx_result is not None: + return httpx_result + + requests_result = self._requests_handler.handle_exception(exc, self) + if requests_result is not None: + return requests_result + + logger.info( + f"Exception type '{type(exc).__name__}' was not handled by the '{self.slug}' adapter. " + f"Either the exception is not from a supported HTTP library (httpx, requests) or " + f"the required library is not installed in the toolkit's environment." + ) + return None diff --git a/libs/arcade-tdk/arcade_tdk/tool.py b/libs/arcade-tdk/arcade_tdk/tool.py index 5a42abcf..e6fd057d 100644 --- a/libs/arcade-tdk/arcade_tdk/tool.py +++ b/libs/arcade-tdk/arcade_tdk/tool.py @@ -1,21 +1,104 @@ import functools import inspect -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, TypeVar from arcade_tdk.auth import ToolAuthorization -from arcade_tdk.errors import ToolExecutionError +from arcade_tdk.error_adapters import ErrorAdapter +from arcade_tdk.error_adapters.utils import get_adapter_for_auth_provider +from arcade_tdk.errors import ( + FatalToolError, + ToolRuntimeError, +) +from arcade_tdk.providers.http import HTTPErrorAdapter from arcade_tdk.utils import snake_to_pascal_case T = TypeVar("T") +def _build_adapter_chain( + adapters: list[ErrorAdapter] | None, auth_provider: ToolAuthorization | None +) -> list[ErrorAdapter]: + """ + Build the adapter chain for error handling. + + Args: + adapters: User-provided list of error adapters + auth_provider: The auth provider for the tool + + Returns: + A deduplicated list of error adapters with the HTTP adapter as fallback + + Raises: + ValueError: If any adapter doesn't follow the ErrorAdapter protocol + """ + adapter_chain = adapters or [] + + # Validate that all adapters follow the ErrorAdapter protocol + if not all(isinstance(adapter, ErrorAdapter) for adapter in adapter_chain): + invalid_adapters = [ + type(adapter).__name__ + for adapter in adapter_chain + if not isinstance(adapter, ErrorAdapter) + ] + raise ValueError( + f"All adapters must follow the ErrorAdapter protocol. " + f"Invalid adapters: {', '.join(invalid_adapters)}" + ) + + # Add the adapter that is mapped to the tool's auth provider if it exists + if auth_adapter := get_adapter_for_auth_provider(auth_provider): + adapter_chain.append(auth_adapter) + + # Always add HTTP adapter as the final adapter fallback + adapter_chain.append(HTTPErrorAdapter()) + + # Remove duplicates from the adapter chain, preserving order + seen_types = set() + deduplicated_chain = [] + for adapter in adapter_chain: + adapter_type = type(adapter) + if adapter_type not in seen_types: + seen_types.add(adapter_type) + deduplicated_chain.append(adapter) + + return deduplicated_chain + + +def _raise_as_arcade_error( + exception: Exception, adapter_chain: list[ErrorAdapter], tool_name: str, func_name: str +) -> None: + """ + Try to translate an exception using the adapter chain, then raise the translated error. + If no adapter can translate the exception, a FatalToolError is raised. + + Args: + exception: The exception to translate to an Arcade Error + adapter_chain: List of error adapters to try + tool_name: The tool's display name for error messages + func_name: The function name for developer messages + + Raises: + ToolRuntimeError or some subclass thereof + """ + for adapter in adapter_chain: + mapped = adapter.from_exception(exception) + if isinstance(mapped, ToolRuntimeError): + raise mapped from exception + + raise FatalToolError( + message=f"{exception!s}", + developer_message=f"{exception!s}", + ) from exception + + def tool( func: Callable | None = None, desc: str | None = None, name: str | None = None, - requires_auth: Union[ToolAuthorization, None] = None, - requires_secrets: Union[list[str], None] = None, - requires_metadata: Union[list[str], None] = None, + requires_auth: ToolAuthorization | None = None, + requires_secrets: list[str] | None = None, + requires_metadata: list[str] | None = None, + adapters: list[ErrorAdapter] | None = None, ) -> Callable: def decorator(func: Callable) -> Callable: func_name = str(getattr(func, "__name__", None)) @@ -27,22 +110,19 @@ def tool( func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined] func.__tool_requires_metadata__ = requires_metadata # type: ignore[attr-defined] + adapter_chain = _build_adapter_chain(adapters, requires_auth) + if inspect.iscoroutinefunction(func): @functools.wraps(func) async def func_with_error_handling(*args: Any, **kwargs: Any) -> Any: try: return await func(*args, **kwargs) - - # make sure developer raised ToolExecutionError is not - # reraised incorrectly. - except ToolExecutionError: + except ToolRuntimeError: + # re-raise as-is if it is already an Arcade Error raise except Exception as e: - raise ToolExecutionError( - message=f"Error in execution of {tool_name}", - developer_message=f"Error in {func_name}: {e!s}", - ) from e + _raise_as_arcade_error(e, adapter_chain, tool_name, func_name) else: @@ -50,13 +130,11 @@ def tool( def func_with_error_handling(*args: Any, **kwargs: Any) -> Any: try: return func(*args, **kwargs) - except ToolExecutionError: + except ToolRuntimeError: + # re-raise as-is if it is already an Arcade Error raise except Exception as e: - raise ToolExecutionError( - message=f"Error in execution of {tool_name}", - developer_message=f"Error in {func_name}: {e!s}", - ) from e + _raise_as_arcade_error(e, adapter_chain, tool_name, func_name) return func_with_error_handling diff --git a/libs/arcade-tdk/pyproject.toml b/libs/arcade-tdk/pyproject.toml index c1b3b405..d7e2cd95 100644 --- a/libs/arcade-tdk/pyproject.toml +++ b/libs/arcade-tdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-tdk" -version = "2.2.0" +version = "2.3.0" description = "Arcade TDK - Toolkit Development Kit for building Arcade tools" readme = "README.md" license = {text = "MIT"} @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "arcade-core>=2.3.0,<3.0.0", + "arcade-core>=2.4.0,<3.0.0", "pydantic>=2.7.0", ] diff --git a/libs/tests/core/test_catalog.py b/libs/tests/core/test_catalog.py index 945bff00..15f16128 100644 --- a/libs/tests/core/test_catalog.py +++ b/libs/tests/core/test_catalog.py @@ -1,11 +1,18 @@ +from typing import Annotated, Union from unittest.mock import MagicMock, patch import pytest from arcade_core.catalog import ToolCatalog -from arcade_core.errors import ToolDefinitionError -from arcade_core.schema import FullyQualifiedName +from arcade_core.errors import ( + ToolDefinitionError, + ToolInputSchemaError, + ToolkitLoadError, + ToolOutputSchemaError, +) +from arcade_core.schema import FullyQualifiedName, ToolContext from arcade_core.toolkit import Toolkit from arcade_tdk import tool +from pydantic import Field @tool @@ -16,6 +23,141 @@ def sample_tool() -> str: return "Hello, world!" +@tool +def valid_tool(input_text: Annotated[str, "The text to process"]) -> str: + """ + A test tool that processes input text. + + Args: + input_text: The text to process + + Returns: + The processed text + """ + return f"Processed: {input_text}" + + +@tool +def tool_with_missing_input_parameter_annotation(input_text: str) -> str: + """ + A test tool that processes input text. + + Args: + input_text: The text to process + + Returns: + The processed text + """ + return f"Processed: {input_text}" + + +# Invalid tool examples for testing error cases + + +# ToolDefinitionError cases +def tool_missing_description(input_text: Annotated[str, "The text to process"]) -> str: + return f"Processed: {input_text}" + + +@tool(requires_secrets=[123]) # type: ignore[misc] +def tool_with_invalid_secret_type(input_text: Annotated[str, "The text"]) -> str: + """A tool with invalid secret type.""" + return f"Processed: {input_text}" + + +@tool(requires_secrets=[""]) +def tool_with_empty_secret(input_text: Annotated[str, "The text"]) -> str: + """A tool with empty secret.""" + return f"Processed: {input_text}" + + +@tool(requires_metadata=[123]) # type: ignore[misc] +def tool_with_invalid_metadata_type(input_text: Annotated[str, "The text"]) -> str: + """A tool with invalid metadata type.""" + return f"Processed: {input_text}" + + +@tool(requires_metadata=["client_id"]) # Requires auth but no auth provided +def tool_with_metadata_requiring_auth_without_auth(input_text: Annotated[str, "The text"]) -> str: + """A tool with metadata requiring auth but no auth provided.""" + return f"Processed: {input_text}" + + +@tool(requires_metadata=[""]) +def tool_with_empty_metadata(input_text: Annotated[str, "The text"]) -> str: + """A tool with empty metadata.""" + return f"Processed: {input_text}" + + +class MyFancyTestClass: + pass + + +@tool +def tool_with_unsupported_param_type( + param: Annotated[MyFancyTestClass, "A class that is a parameter"], +) -> str: + """A tool with unsupported parameter type.""" + return "result" + + +# ToolInputSchemaError cases +@tool +def tool_with_no_type_annotation(param) -> str: # type: ignore[no-untyped-def] + """A tool with untyped parameter.""" + return f"Result: {param}" + + +@tool +def tool_with_invalid_param_name(param: Annotated[str, "123invalid", "Description"]) -> str: + """A tool with invalid parameter name.""" + return f"Result: {param}" + + +@tool +def tool_with_too_many_annotations(param: Annotated[str, "name", "desc", "extra"]) -> str: + """A tool with an input parameter that has too many annotations.""" + return f"Result: {param}" + + +@tool +def tool_with_required_union_param(param: Annotated[Union[str, int], "Union parameter"]) -> str: + """A tool with an input parameter that is a non-optional union type.""" + return f"Result: {param}" + + +def non_callable_factory(): + raise RuntimeError("This should not be called") + + +@tool +def tool_with_non_callable_default_factory( + param: Annotated[str, "Parameter"] = Field(default_factory="not_callable"), # type: ignore[arg-type] +) -> str: + """A tool with an input parameter that has a non-callable default factory.""" + return f"Result: {param}" + + +@tool +def tool_with_multiple_tool_contexts(ctx1: ToolContext, ctx2: ToolContext) -> str: + """A tool with multiple input parameters that are ToolContext.""" + return "result" + + +@tool +def tool_missing_return_type_hint(input_text: Annotated[str, "The text to process"]): + """A tool without return type hint.""" + return f"Processed: {input_text}" + + +@tool +def tool_with_unsupported_output_type( + input_text: Annotated[str, "The text to process"], +) -> Annotated[MyFancyTestClass, "THe output type"]: + """A tool with an output parameter type that is not supported.""" + return MyFancyTestClass() + + def test_add_tool_with_empty_toolkit_name_raises(): catalog = ToolCatalog() with pytest.raises(ValueError): @@ -98,6 +240,35 @@ def test_add_toolkit_type_error(): catalog.add_toolkit(mock_toolkit) +def test_add_toolkit_import_module_error(): + catalog = ToolCatalog() + + # Create a mock toolkit with an invalid tool + + mock_toolkit = Toolkit( + name="mock_toolkit", + description="A mock toolkit", + version="0.0.1", + package_name="mock_toolkit", + ) + mock_toolkit.tools = {"mock_module": ["sample_tool"]} + + # Mock the import_module and getattr functions + with ( + patch("arcade_core.catalog.import_module") as mock_import, + ): + mock_import.side_effect = ImportError("Mock import error") + + # Assert that ToolkitLoadError is raised + with pytest.raises(ToolkitLoadError) as exc_info: + catalog.add_toolkit(mock_toolkit) + + # Check that the error message contains the expected substring + assert "Could not import module mock_module. Reason: Mock import error" in str( + exc_info.value + ) + + def test_get_tool_by_name(): catalog = ToolCatalog() catalog.add_tool(sample_tool, "sample_toolkit") @@ -190,3 +361,157 @@ def test_add_tool_with_disabled_toolkit(monkeypatch): ) ) assert len(catalog._tools) == 0 + + +@pytest.mark.parametrize( + "tool_name, expected_error_type, expected_error_substring", + [ + # ToolDefinitionError cases + ( + "tool_missing_description", + ToolDefinitionError, + "Tool 'tool_missing_description' is missing a description", + ), + ( + "tool_with_invalid_secret_type", + ToolDefinitionError, + "Secret keys must be strings (error in tool ToolWithInvalidSecretType)", + ), + ( + "tool_with_empty_secret", + ToolDefinitionError, + "Secrets must have a non-empty key (error in tool ToolWithEmptySecret)", + ), + ( + "tool_with_invalid_metadata_type", + ToolDefinitionError, + "Metadata must be strings (error in tool ToolWithInvalidMetadataType)", + ), + ( + "tool_with_metadata_requiring_auth_without_auth", + ToolDefinitionError, + "Tool ToolWithMetadataRequiringAuthWithoutAuth declares metadata key 'client_id'", + ), + ( + "tool_with_empty_metadata", + ToolDefinitionError, + "Metadata must have a non-empty key (error in tool ToolWithEmptyMetadata)", + ), + ( + "tool_with_unsupported_param_type", + ToolDefinitionError, + "Unsupported parameter type: ", + ), + # ToolInputSchemaError cases + ( + "tool_with_missing_input_parameter_annotation", + ToolInputSchemaError, + "Parameter 'input_text' is missing a description", + ), + ( + "tool_with_no_type_annotation", + ToolInputSchemaError, + "Parameter param has no type annotation", + ), + ( + "tool_with_invalid_param_name", + ToolInputSchemaError, + "Invalid parameter name: '123invalid' is not a valid identifier", + ), + ( + "tool_with_too_many_annotations", + ToolInputSchemaError, + "Parameter param: Annotated[str, 'name', 'desc', 'extra'] has too many string annotations. Expected 0, 1, or 2, got 3", + ), + ( + "tool_with_required_union_param", + ToolInputSchemaError, + "Parameter param is a union type. Only optional types are supported", + ), + ( + "tool_with_non_callable_default_factory", + ToolInputSchemaError, + "Default factory for parameter param: Annotated[str, 'Parameter'] = FieldInfo(annotation=NoneType, required=False, default_factory=str) is not callable.", + ), + ( + "tool_with_multiple_tool_contexts", + ToolInputSchemaError, + "Only one ToolContext parameter is supported, but tool tool_with_multiple_tool_contexts has multiple", + ), + ( + "tool_missing_return_type_hint", + ToolOutputSchemaError, + "Tool 'ToolMissingReturnTypeHint' must have a return type", + ), + ( + "tool_with_unsupported_output_type", + ToolOutputSchemaError, + "Unsupported output type ''", + ), + ], +) +def test_add_toolkit_with_invalid_tools( + tool_name: str, expected_error_type: type, expected_error_substring: str +): + """Test that add_toolkit raises the correct error for various invalid tool definitions.""" + catalog = ToolCatalog() + + # Create a toolkit that references our test tool + test_toolkit = Toolkit( + name="test_toolkit", + description="A test toolkit", + version="1.0.0", + package_name="test_toolkit", + ) + test_toolkit.tools = {"tests.core.test_catalog": [tool_name]} + + # Mock the import_module to return the current module + import sys + + current_module = sys.modules[__name__] + + with patch("arcade_core.catalog.import_module") as mock_import: + mock_import.return_value = current_module + + # Add the toolkit and expect the specified error + with pytest.raises(expected_error_type) as exc_info: + catalog.add_toolkit(test_toolkit) + + # Check that the error message contains the expected substring + actual_error_message = str(exc_info.value) + # Adjust for Python 3.11 and below where Annotated is returned as "typing.Annotated" + if "typing.Annotated" in actual_error_message: + expected_error_substring = expected_error_substring.replace( + "Annotated", "typing.Annotated" + ) + assert expected_error_substring in actual_error_message + + +def test_add_toolkit_with_duplicate_tool(): + """Test that add_toolkit raises ToolkitLoadError when a tool already exists in the catalog.""" + catalog = ToolCatalog() + + test_toolkit = Toolkit( + name="test_toolkit", + description="A test toolkit", + version="1.0.0", + package_name="test_toolkit", + ) + test_toolkit.tools = {"tests.core.test_catalog": ["valid_tool", "valid_tool"]} + + # Mock the import_module to return the current module + import sys + + current_module = sys.modules[__name__] + + with patch("arcade_core.catalog.import_module") as mock_import: + mock_import.return_value = current_module + + # Adding the toolkit should raise ToolkitLoadError for duplicate tool + with pytest.raises(ToolkitLoadError) as exc_info: + catalog.add_toolkit(test_toolkit) + + # Check that the error message contains the expected substring + assert "Tool 'ValidTool' in toolkit 'test_toolkit' already exists in the catalog." in str( + exc_info.value + ) diff --git a/libs/tests/core/test_executor.py b/libs/tests/core/test_executor.py index 330d98c6..fa896349 100644 --- a/libs/tests/core/test_executor.py +++ b/libs/tests/core/test_executor.py @@ -2,10 +2,20 @@ from typing import Annotated import pytest from arcade_core.catalog import ToolCatalog +from arcade_core.errors import ( + ContextRequiredToolError, + ErrorKind, + ToolRuntimeError, + UpstreamError, + UpstreamRateLimitError, +) from arcade_core.executor import ToolExecutor from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput, ToolContext from arcade_tdk import tool -from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from arcade_tdk.errors import ( + RetryableToolError, + ToolExecutionError, +) from typing_extensions import TypedDict @@ -25,13 +35,13 @@ def simple_deprecated_tool(inp: Annotated[str, "input"]) -> Annotated[str, "outp @tool def retryable_error_tool() -> Annotated[str, "output"]: """Tool that raises a retryable error""" - raise RetryableToolError("test", "test", "test", 1000) + raise RetryableToolError("test", "test developer message", "additional prompt content", 1000) @tool -def exec_error_tool() -> Annotated[str, "output"]: +def tool_execution_error_tool() -> Annotated[str, "output"]: """Tool that raises an error""" - raise ToolExecutionError("test", "test") + raise ToolExecutionError("test", "test developer message") @tool @@ -40,6 +50,34 @@ def unexpected_error_tool() -> Annotated[str, "output"]: raise RuntimeError("test") +@tool +def context_required_error_tool() -> Annotated[str, "output"]: + """Tool that raises a context required error""" + raise ContextRequiredToolError( + "test", additional_prompt_content="need the user to clarify something" + ) + + +@tool +def upstream_error_tool() -> Annotated[str, "output"]: + """Tool that raises an upstream error""" + # TODO: or test raising a httpx error? Do these types of tests belong in adapter tests? + raise UpstreamError("test", status_code=400) + + +@tool +def upstream_ratelimit_error_tool() -> Annotated[str, "output"]: + """Tool that raises an upstream error""" + # TODO: or test raising a httpx error? Do these types of tests belong in adapter tests? + raise UpstreamRateLimitError("test", 1000) + + +@tool +def tool_runtime_error_tool() -> Annotated[str, "output"]: + """Tool that raises a tool runtime error""" + raise ToolRuntimeError("test", "test developer message") + + @tool def bad_output_error_tool() -> Annotated[str, "output"]: """tool that returns a bad output type""" @@ -77,17 +115,24 @@ def dict_output_tool() -> Annotated[dict, "Returns a plain dict"]: # ---- Test Driver ---- - +tools = [ + simple_tool, + simple_deprecated_tool, + retryable_error_tool, + tool_execution_error_tool, + unexpected_error_tool, + context_required_error_tool, + upstream_error_tool, + upstream_ratelimit_error_tool, + tool_runtime_error_tool, + bad_output_error_tool, + typeddict_output_tool, + list_typeddict_output_tool, + dict_output_tool, +] catalog = ToolCatalog() -catalog.add_tool(simple_tool, "simple_toolkit") -catalog.add_tool(simple_deprecated_tool, "simple_toolkit") -catalog.add_tool(retryable_error_tool, "simple_toolkit") -catalog.add_tool(exec_error_tool, "simple_toolkit") -catalog.add_tool(unexpected_error_tool, "simple_toolkit") -catalog.add_tool(bad_output_error_tool, "simple_toolkit") -catalog.add_tool(typeddict_output_tool, "simple_toolkit") -catalog.add_tool(list_typeddict_output_tool, "simple_toolkit") -catalog.add_tool(dict_output_tool, "simple_toolkit") +for tool_func in tools: + catalog.add_tool(tool_func, "simple_toolkit") @pytest.mark.asyncio @@ -114,21 +159,24 @@ catalog.add_tool(dict_output_tool, "simple_toolkit") {}, ToolCallOutput( error=ToolCallError( - message="test", - developer_message="test", - additional_prompt_content="test", + message="[TOOL_RUNTIME_RETRY] RetryableToolError during execution of tool 'retryable_error_tool': test", + kind=ErrorKind.TOOL_RUNTIME_RETRY, + developer_message="[TOOL_RUNTIME_RETRY] RetryableToolError during execution of tool 'retryable_error_tool': test developer message", + additional_prompt_content="additional prompt content", retry_after_ms=1000, can_retry=True, ) ), ), ( - exec_error_tool, + tool_execution_error_tool, {}, ToolCallOutput( error=ToolCallError( - message="test", - developer_message="test", + message="[TOOL_RUNTIME_FATAL] ToolExecutionError during execution of tool 'tool_execution_error_tool': test", + kind=ErrorKind.TOOL_RUNTIME_FATAL, + developer_message="[TOOL_RUNTIME_FATAL] ToolExecutionError during execution of tool 'tool_execution_error_tool': test developer message", + can_retry=False, ) ), ), @@ -137,8 +185,11 @@ catalog.add_tool(dict_output_tool, "simple_toolkit") {}, ToolCallOutput( error=ToolCallError( - message="Error in execution of UnexpectedErrorTool", - developer_message="Error in unexpected_error_tool: test", + message="[TOOL_RUNTIME_FATAL] FatalToolError during execution of tool 'unexpected_error_tool': test", + kind=ErrorKind.TOOL_RUNTIME_FATAL, + developer_message="[TOOL_RUNTIME_FATAL] FatalToolError during execution of tool 'unexpected_error_tool': test", + can_retry=False, + status_code=500, ) ), ), @@ -147,17 +198,71 @@ catalog.add_tool(dict_output_tool, "simple_toolkit") {"inp": {"test": "test"}}, # takes in a string not a dict ToolCallOutput( error=ToolCallError( - message="Error in tool input deserialization", + message="[TOOL_RUNTIME_BAD_INPUT_VALUE] ToolInputError during execution of tool 'simple_tool': Error in tool input deserialization", + kind=ErrorKind.TOOL_RUNTIME_BAD_INPUT_VALUE, + status_code=400, developer_message=None, # can't gaurantee this will be the same ) ), ), + ( + context_required_error_tool, + {}, + ToolCallOutput( + error=ToolCallError( + message="[TOOL_RUNTIME_CONTEXT_REQUIRED] ContextRequiredToolError during execution of tool 'context_required_error_tool': test", + kind=ErrorKind.TOOL_RUNTIME_CONTEXT_REQUIRED, + developer_message=None, + additional_prompt_content="need the user to clarify something", + ) + ), + ), + ( + upstream_error_tool, + {}, + ToolCallOutput( + error=ToolCallError( + message="[UPSTREAM_RUNTIME_BAD_REQUEST] UpstreamError during execution of tool 'upstream_error_tool': test", + kind=ErrorKind.UPSTREAM_RUNTIME_BAD_REQUEST, + status_code=400, + developer_message=None, + ) + ), + ), + ( + upstream_ratelimit_error_tool, + {}, + ToolCallOutput( + error=ToolCallError( + message="[UPSTREAM_RUNTIME_RATE_LIMIT] UpstreamRateLimitError during execution of tool 'upstream_ratelimit_error_tool': test", + kind=ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT, + status_code=429, + developer_message=None, + retry_after_ms=1000, + can_retry=True, + ) + ), + ), + ( + tool_runtime_error_tool, + {}, + ToolCallOutput( + error=ToolCallError( + message="[TOOL_RUNTIME_FATAL] ToolRuntimeError during execution of tool 'tool_runtime_error_tool': test", + kind=ErrorKind.TOOL_RUNTIME_FATAL, + developer_message="[TOOL_RUNTIME_FATAL] ToolRuntimeError during execution of tool 'tool_runtime_error_tool': test developer message", + can_retry=False, + ) + ), + ), ( bad_output_error_tool, {}, ToolCallOutput( error=ToolCallError( - message="Failed to serialize tool output", + message="[TOOL_RUNTIME_BAD_OUTPUT_VALUE] ToolOutputError during execution of tool 'bad_output_error_tool': Failed to serialize tool output", + kind=ErrorKind.TOOL_RUNTIME_BAD_OUTPUT_VALUE, + status_code=500, developer_message=None, # can't gaurantee this will be the same ) ), @@ -190,6 +295,10 @@ catalog.add_tool(dict_output_tool, "simple_toolkit") "exec_error_tool", "unexpected_error_tool", "invalid_input_type", + "context_required_error_tool", + "upstream_error_tool", + "upstream_ratelimit_error_tool", + "tool_runtime_error_tool", "bad_output_type", "typeddict_output", "list_typeddict_output", @@ -213,20 +322,28 @@ async def test_tool_executor(tool_func, inputs, expected_output): check_output(output, expected_output) -def check_output(output: ToolCallOutput, expected_output: ToolCallOutput): - # execution error in tool - if output.error: - assert output.error.message == expected_output.error.message - if expected_output.error.developer_message: - assert output.error.developer_message == expected_output.error.developer_message - if expected_output.error.traceback_info: - assert output.error.traceback_info == expected_output.error.traceback_info - assert output.error.can_retry == expected_output.error.can_retry +def check_output_error(output_error: ToolCallError, expected_error: ToolCallError): + assert output_error.message == expected_error.message, "message mismatch" + assert output_error.kind == expected_error.kind, "kind mismatch" + if expected_error.developer_message: assert ( - output.error.additional_prompt_content - == expected_output.error.additional_prompt_content - ) - assert output.error.retry_after_ms == expected_output.error.retry_after_ms + output_error.developer_message == expected_error.developer_message + ), "developer message mismatch" + assert output_error.can_retry == expected_error.can_retry, "can retry mismatch" + assert ( + output_error.additional_prompt_content == expected_error.additional_prompt_content + ), "additional prompt content mismatch" + assert output_error.retry_after_ms == expected_error.retry_after_ms, "retry after ms mismatch" + if expected_error.stacktrace: + assert output_error.stacktrace == expected_error.stacktrace, "stacktrace mismatch" + assert output_error.status_code == expected_error.status_code, "status code mismatch" + assert output_error.extra == expected_error.extra, "extra mismatch" + + +def check_output(output: ToolCallOutput, expected_output: ToolCallOutput): + # error in ToolCallOutput + if output.error: + check_output_error(output.error, expected_output.error) # normal tool execution else: diff --git a/libs/tests/core/test_schema_validation.py b/libs/tests/core/test_schema_validation.py index d0070419..6458d5bb 100644 --- a/libs/tests/core/test_schema_validation.py +++ b/libs/tests/core/test_schema_validation.py @@ -3,6 +3,7 @@ Tests for ToolCallOutput schema validation with complex types. """ import pytest +from arcade_core.errors import ErrorKind from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput from pydantic import ValidationError @@ -139,6 +140,7 @@ class TestToolCallOutputValidation: message="Partial failure", developer_message="Some items failed to process", can_retry=True, + kind=ErrorKind.TOOL_RUNTIME_RETRY, ) ) assert output.error.message == "Partial failure" diff --git a/libs/tests/sdk/test_google_adapter.py b/libs/tests/sdk/test_google_adapter.py new file mode 100644 index 00000000..33beed89 --- /dev/null +++ b/libs/tests/sdk/test_google_adapter.py @@ -0,0 +1,510 @@ +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +from arcade_core.errors import UpstreamError, UpstreamRateLimitError +from arcade_tdk.providers.google.error_adapter import GoogleErrorAdapter + + +class TestGoogleErrorAdapter: + """Test the Google error adapter functionality.""" + + def setup_method(self): + self.adapter = GoogleErrorAdapter() + + def _create_mock_errors_module(self): + """Create a mock errors module with all necessary error classes.""" + + class MockHttpError(Exception): + pass + + class MockBatchError(Exception): + pass + + class MockInvalidJsonError(Exception): + pass + + class MockUnknownApiNameOrVersion(Exception): + pass + + class MockUnacceptableMimeTypeError(Exception): + pass + + class MockMediaUploadSizeError(Exception): + pass + + class MockInvalidChunkSizeError(Exception): + pass + + class MockInvalidNotificationError(Exception): + pass + + mock_errors = Mock() + mock_errors.HttpError = MockHttpError + mock_errors.BatchError = MockBatchError + mock_errors.InvalidJsonError = MockInvalidJsonError + mock_errors.UnknownApiNameOrVersion = MockUnknownApiNameOrVersion + mock_errors.UnacceptableMimeTypeError = MockUnacceptableMimeTypeError + mock_errors.MediaUploadSizeError = MockMediaUploadSizeError + mock_errors.InvalidChunkSizeError = MockInvalidChunkSizeError + mock_errors.InvalidNotificationError = MockInvalidNotificationError + + return mock_errors + + def test_adapter_slug(self): + """Test that the adapter has the correct slug.""" + assert GoogleErrorAdapter.slug == "_google_api_client" + + def test_sanitize_uri_removes_query_params(self): + """Test URI sanitization removes query parameters.""" + uri = "https://www.googleapis.com/drive/v3/files/123?key=secret&fields=id,name" + result = self.adapter._sanitize_uri(uri) + assert result == "https://www.googleapis.com/drive/v3/files/123" + + def test_sanitize_uri_removes_fragments(self): + """Test URI sanitization removes fragments.""" + uri = "https://www.googleapis.com/gmail/v1/users/me/messages#inbox" + result = self.adapter._sanitize_uri(uri) + assert result == "https://www.googleapis.com/gmail/v1/users/me/messages" + + def test_sanitize_uri_handles_trailing_slashes(self): + """Test URI sanitization handles trailing slashes.""" + uri = "https://www.googleapis.com///sheets/v4/spreadsheets///" + result = self.adapter._sanitize_uri(uri) + assert result == "https://www.googleapis.com/sheets/v4/spreadsheets" + + def test_parse_retry_after_with_seconds(self): + """Test parsing retry-after header with seconds value.""" + mock_error = Mock() + mock_error.resp = Mock() + mock_error.resp.headers = {"Retry-After": "120"} + + result = self.adapter._parse_retry_after(mock_error) + assert result == 120_000 + + def test_parse_retry_after_with_lowercase_header(self): + """Test parsing retry-after header with lowercase key.""" + mock_error = Mock() + mock_error.resp = Mock() + mock_error.resp.headers = {"retry-after": "60"} + + result = self.adapter._parse_retry_after(mock_error) + assert result == 60_000 + + def test_parse_retry_after_with_date_format(self): + """Test parsing retry-after header with absolute date format.""" + future_date = "Mon, 01 Jan 2029 12:00:00 GMT" + mock_error = Mock() + mock_error.resp = Mock() + mock_error.resp.headers = {"Retry-After": future_date} + + with patch("arcade_tdk.providers.google.error_adapter.datetime") as mock_datetime: + parsed_date = datetime(2029, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + mock_datetime.strptime.return_value = parsed_date + + # Mock datetime.now() to return a time before the parsed date + current_time = datetime(2029, 1, 1, 11, 58, 0, tzinfo=timezone.utc) + mock_datetime.now.return_value = current_time + mock_datetime.timezone = timezone + + result = self.adapter._parse_retry_after(mock_error) + assert result == 120_000 # 2 minute diff + + def test_parse_retry_after_no_headers(self): + """Test parsing retry-after when no headers are present.""" + mock_error = Mock() + mock_error.resp = Mock() + mock_error.resp.headers = {} + + result = self.adapter._parse_retry_after(mock_error) + assert result == 1_000 + + def test_parse_retry_after_no_resp_attribute(self): + """Test parsing retry-after when error has no resp attribute.""" + mock_error = Mock() + del mock_error.resp + + result = self.adapter._parse_retry_after(mock_error) + assert result == 1_000 # defaults to 1 second + + def test_parse_retry_after_invalid_date(self): + """Test parsing retry-after with invalid date format falls back to default.""" + mock_error = Mock() + mock_error.resp = Mock() + mock_error.resp.headers = {"Retry-After": "invalid-date"} + + result = self.adapter._parse_retry_after(mock_error) + assert result == 1_000 + + def test_map_http_error_basic(self): + """Test mapping basic HTTP error.""" + mock_error = Mock() + mock_error.status_code = 404 + mock_error.reason = "Not Found" + mock_error.error_details = None + mock_error.uri = "https://www.googleapis.com/drive/v3/files/missing" + mock_error.method_ = "get" + + result = self.adapter._map_http_error(mock_error) + + assert isinstance(result, UpstreamError) + assert not isinstance(result, UpstreamRateLimitError) + assert result.status_code == 404 + assert result.message == "Upstream Google API error: Not Found" + assert result.extra["service"] == "_google_api_client" + assert result.extra["endpoint"] == "https://www.googleapis.com/drive/v3/files/missing" + assert result.extra["http_method"] == "GET" + + def test_map_http_error_with_string_details(self): + """Test mapping HTTP error with string error details.""" + mock_error = Mock() + mock_error.status_code = 400 + mock_error.reason = "Bad Request" + mock_error.error_details = "Invalid field value" + mock_error.uri = "https://www.googleapis.com/sheets/v4/spreadsheets" + mock_error.method_ = "post" + + result = self.adapter._map_http_error(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert "Invalid field value" in result.message + assert result.extra["service"] == "_google_api_client" + assert result.extra["http_method"] == "POST" + + def test_map_http_error_with_structured_details(self): + """Test mapping HTTP error with structured error details.""" + mock_error = Mock() + mock_error.status_code = 403 + mock_error.reason = "Forbidden" + mock_error.error_details = {"error": {"code": 403, "message": "Insufficient permissions"}} + mock_error.uri = "https://www.googleapis.com/drive/v3/files" + mock_error.method_ = "delete" + + result = self.adapter._map_http_error(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 403 + assert result.message == "Upstream Google API error: Forbidden" + assert "Upstream Google API error details" in result.developer_message + assert result.extra["http_method"] == "DELETE" + + def test_map_http_error_rate_limit(self): + """Test mapping 429 rate limit error.""" + mock_error = Mock() + mock_error.status_code = 429 + mock_error.reason = "Too Many Requests" + mock_error.error_details = None + mock_error.uri = "https://www.googleapis.com/gmail/v1/users/me/messages" + mock_error.method_ = "get" + mock_error.resp = Mock() + mock_error.resp.headers = {"Retry-After": "30"} + + result = self.adapter._map_http_error(mock_error) + + assert isinstance(result, UpstreamRateLimitError) + assert result.retry_after_ms == 30_000 + assert result.message == "Upstream Google API error: Too Many Requests" + assert result.extra["service"] == "_google_api_client" + assert result.extra["endpoint"] == "https://www.googleapis.com/gmail/v1/users/me/messages" + assert result.extra["http_method"] == "GET" + + def test_map_http_error_no_reason(self): + """Test mapping HTTP error with no reason.""" + mock_error = Mock() + mock_error.status_code = 500 + mock_error.reason = None + mock_error.error_details = None + mock_error.uri = "https://www.googleapis.com/calendar/v3/calendars" + mock_error.method_ = "post" + + result = self.adapter._map_http_error(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 500 + assert result.message == "Upstream Google API error: HTTP 500 error" + + def test_map_http_error_missing_attributes(self): + """Test mapping HTTP error without uri and method attributes.""" + mock_error = Mock() + mock_error.status_code = 503 + mock_error.reason = "Service Unavailable" + mock_error.error_details = None + + # Remove uri and method_ attributes + if hasattr(mock_error, "uri"): + del mock_error.uri + if hasattr(mock_error, "method_"): + del mock_error.method_ + + result = self.adapter._map_http_error(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 503 + assert result.extra["service"] == "_google_api_client" + assert "endpoint" not in result.extra + assert "http_method" not in result.extra + + def test_handle_http_errors_with_http_error(self): + """Test handling HttpError exceptions.""" + mock_errors = self._create_mock_errors_module() + + # Create mock error instance + mock_error = mock_errors.HttpError() + mock_error.status_code = 401 + mock_error.reason = "Unauthorized" + mock_error.error_details = None + mock_error.uri = "https://www.googleapis.com/drive/v3/files" + mock_error.method_ = "get" + + result = self.adapter._handle_http_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 401 + assert result.message == "Upstream Google API error: Unauthorized" + + def test_handle_http_errors_with_batch_error_with_status(self): + """Test handling BatchError with response status.""" + mock_errors = self._create_mock_errors_module() + + # Create mock error instance + mock_error = mock_errors.BatchError() + mock_error.reason = "Batch operation failed" + mock_error.error_details = None + mock_error.resp = Mock() + mock_error.resp.status = 400 + + result = self.adapter._handle_http_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert result.message == "Upstream Google API error: Batch operation failed" + + def test_handle_http_errors_with_batch_error_no_status(self): + """Test handling BatchError without response status.""" + mock_errors = self._create_mock_errors_module() + + # Create mock error instance + mock_error = mock_errors.BatchError() + mock_error.reason = "Batch operation failed" + + result = self.adapter._handle_http_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 500 + assert ( + result.message == "Upstream Google API batch operation failed: Batch operation failed" + ) + assert result.extra["service"] == "google_api" + assert result.extra["error_type"] == "BatchError" + + def test_handle_http_errors_unhandled_exception(self): + """Test handling non-HTTP exceptions returns None.""" + mock_errors = self._create_mock_errors_module() + + # Create a non-HTTP exception + mock_error = ValueError("Not an HTTP error") + + result = self.adapter._handle_http_errors(mock_error, mock_errors) + assert result is None + + def test_handle_other_errors_invalid_json_error(self): + """Test handling InvalidJsonError.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.InvalidJsonError("Invalid JSON response") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 502 + assert result.message == "Upstream Google API returned invalid JSON response" + assert result.developer_message == "Invalid JSON response" + assert result.extra["service"] == "_google_api_client" + assert result.extra["error_type"] == "InvalidJsonError" + + def test_handle_other_errors_unknown_api_name_or_version(self): + """Test handling UnknownApiNameOrVersion.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.UnknownApiNameOrVersion("Unknown API: nonexistent/v1") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 404 + assert result.message == "Upstream Google API error: Unknown API name or version" + assert result.developer_message == "Unknown API: nonexistent/v1" + assert result.extra["error_type"] == "UnknownApiNameOrVersion" + + def test_handle_other_errors_unacceptable_mime_type_error(self): + """Test handling UnacceptableMimeTypeError.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.UnacceptableMimeTypeError("MIME type not supported") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert ( + result.message == "Upstream Google API error: Unacceptable MIME type for this operation" + ) + assert result.developer_message == "MIME type not supported" + assert result.extra["error_type"] == "UnacceptableMimeTypeError" + + def test_handle_other_errors_media_upload_size_error(self): + """Test handling MediaUploadSizeError.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.MediaUploadSizeError("File size exceeds 5GB limit") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert result.message == "Upstream Google API error: Media file size exceeds allowed limit" + assert result.developer_message == "File size exceeds 5GB limit" + assert result.extra["error_type"] == "MediaUploadSizeError" + + def test_handle_other_errors_invalid_chunk_size_error(self): + """Test handling InvalidChunkSizeError.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.InvalidChunkSizeError("Chunk size must be multiple of 256KB") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert result.message == "Upstream Google API error: Invalid chunk size specified" + assert result.developer_message == "Chunk size must be multiple of 256KB" + assert result.extra["error_type"] == "InvalidChunkSizeError" + + def test_handle_other_errors_invalid_notification_error(self): + """Test handling InvalidNotificationError.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.InvalidNotificationError("Invalid webhook URL") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert result.message == "Upstream Google API error: Invalid notification configuration" + assert result.developer_message == "Invalid webhook URL" + assert result.extra["error_type"] == "InvalidNotificationError" + + def test_handle_other_errors_unhandled_exception(self): + """Test handling non-Google API exceptions returns None.""" + mock_errors = self._create_mock_errors_module() + + # Create a non-Google API exception + mock_error = ValueError("Not a Google API error") + + result = self.adapter._handle_other_errors(mock_error, mock_errors) + assert result is None + + def test_from_exception_googleapiclient_not_installed(self, caplog): + """Test handling when googleapiclient is not installed.""" + with ( + patch("arcade_tdk.providers.google.error_adapter.logger") as mock_logger, + patch.dict("sys.modules", {"googleapiclient": None}), + patch( + "builtins.__import__", + side_effect=ImportError("No module named 'googleapiclient'"), + ), + ): + mock_exc = Exception("test exception") + result = self.adapter.from_exception(mock_exc) + + assert result is None + mock_logger.info.assert_called_once() + warning_message = mock_logger.info.call_args[0][0] + assert "'googleapiclient' is not installed" in warning_message + assert "_google_api_client" in warning_message + + def test_from_exception_http_error_handling(self): + """Test full from_exception flow with HTTP error.""" + mock_errors = self._create_mock_errors_module() + + # Create mock error instance + mock_error = mock_errors.HttpError() + mock_error.status_code = 403 + mock_error.reason = "Forbidden" + mock_error.error_details = None + mock_error.uri = "https://www.googleapis.com/drive/v3/files" + mock_error.method_ = "get" + + # Create mock googleapiclient module + mock_googleapiclient = Mock() + mock_googleapiclient.errors = mock_errors + + with patch.dict( + "sys.modules", + {"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors}, + ): + result = self.adapter.from_exception(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 403 + assert result.message == "Upstream Google API error: Forbidden" + + def test_from_exception_other_error_handling(self): + """Test full from_exception flow with other error types.""" + mock_errors = self._create_mock_errors_module() + mock_error = mock_errors.InvalidJsonError("Invalid JSON") + + # Create mock googleapiclient module + mock_googleapiclient = Mock() + mock_googleapiclient.errors = mock_errors + + with patch.dict( + "sys.modules", + {"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors}, + ): + result = self.adapter.from_exception(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 502 + assert result.message == "Upstream Google API returned invalid JSON response" + + def test_from_exception_fallback_for_unhandled_google_error(self): + """Test fallback handling for unhandled Google API errors.""" + mock_errors = self._create_mock_errors_module() + + # Create an unhandled Google API error + class MockUnhandledError(Exception): + pass + + mock_error = MockUnhandledError("Some unhandled Google error") + mock_error.__module__ = "googleapiclient.errors" + + # Create mock googleapiclient module + mock_googleapiclient = Mock() + mock_googleapiclient.errors = mock_errors + + with patch.dict( + "sys.modules", + {"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors}, + ): + result = self.adapter.from_exception(mock_error) + + assert isinstance(result, UpstreamError) + assert result.status_code == 500 + assert result.message == "Upstream Google API error: Some unhandled Google error" + assert result.extra["service"] == "_google_api_client" + assert result.extra["error_type"] == "MockUnhandledError" + + def test_from_exception_non_google_error(self): + """Test handling non-Google API errors returns None.""" + mock_errors = self._create_mock_errors_module() + + # Create a non-Google API error + mock_error = ValueError("Not a Google error") + mock_error.__module__ = "builtins" + + # Create mock googleapiclient module + mock_googleapiclient = Mock() + mock_googleapiclient.errors = mock_errors + + with patch.dict( + "sys.modules", + {"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors}, + ): + result = self.adapter.from_exception(mock_error) + + assert result is None diff --git a/libs/tests/sdk/test_httpx_adapter.py b/libs/tests/sdk/test_httpx_adapter.py new file mode 100644 index 00000000..940ae5bc --- /dev/null +++ b/libs/tests/sdk/test_httpx_adapter.py @@ -0,0 +1,333 @@ +import logging +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +from arcade_core.errors import UpstreamError, UpstreamRateLimitError +from arcade_tdk.providers.http.error_adapter import BaseHTTPErrorMapper, HTTPErrorAdapter + + +class TestBaseHTTPErrorMapper: + """Test the base HTTP error mapper functionality.""" + + def setup_method(self): + self.mapper = BaseHTTPErrorMapper() + + def test_parse_retry_ms_with_retry_after_seconds(self): + """Test parsing retry-after header with seconds value.""" + headers = {"retry-after": "60"} + result = self.mapper._parse_retry_ms(headers) + assert result == 60_000 + + def test_parse_retry_ms_with_x_ratelimit_reset(self): + """Test parsing x-ratelimit-reset header with seconds value.""" + headers = {"x-ratelimit-reset": "120"} + result = self.mapper._parse_retry_ms(headers) + assert result == 120_000 + + def test_parse_retry_ms_with_x_ratelimit_reset_ms(self): + """Test parsing x-ratelimit-reset-ms header with milliseconds value.""" + headers = {"x-ratelimit-reset-ms": "5000"} + result = self.mapper._parse_retry_ms(headers) + assert result == 5_000 + + def test_parse_retry_ms_with_date_format(self): + """Test parsing retry header with absolute date format.""" + future_date = "Mon, 01 Jan 2029 12:00:00 GMT" + headers = {"retry-after": future_date} + + with patch("arcade_tdk.providers.http.error_adapter.datetime") as mock_datetime: + parsed_date = datetime(2029, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + mock_datetime.strptime.return_value = parsed_date + + # Mock datetime.now() to return a time before the parsed date + current_time = datetime(2029, 1, 1, 11, 59, 0, tzinfo=timezone.utc) + mock_datetime.now.return_value = current_time + + mock_datetime.timezone = timezone + + result = self.mapper._parse_retry_ms(headers) + assert result == 60_000 # 1 minute diff + + def test_parse_retry_ms_no_headers(self): + """Test parsing retry when no rate limit headers are present.""" + headers = {"content-type": "application/json"} + result = self.mapper._parse_retry_ms(headers) + assert result == 1_000 + + def test_parse_retry_ms_invalid_date(self): + """Test parsing retry with invalid date format falls back to default.""" + headers = {"retry-after": "invalid-date"} + result = self.mapper._parse_retry_ms(headers) + assert result == 1_000 + + def test_sanitize_uri_removes_query_params(self): + """Test URI sanitization removes query parameters.""" + uri = "https://api.example.com/users/123?token=secret&filter=active" + result = self.mapper._sanitize_uri(uri) + assert result == "https://api.example.com/users/123" + + def test_sanitize_uri_removes_fragments(self): + """Test URI sanitization removes fragments.""" + uri = "https://api.example.com/users#section" + result = self.mapper._sanitize_uri(uri) + assert result == "https://api.example.com/users" + + def test_sanitize_uri_handles_trailing_slashes(self): + """Test URI sanitization handles trailing slashes.""" + uri = "https://api.example.com///path///" + result = self.mapper._sanitize_uri(uri) + assert result == "https://api.example.com/path" + + def test_build_extra_metadata_basic(self): + """Test building extra metadata without request info.""" + result = self.mapper._build_extra_metadata() + assert result == {"service": "_http"} + + def test_build_extra_metadata_with_url_and_method(self): + """Test building extra metadata with URL and method.""" + result = self.mapper._build_extra_metadata( + request_url="https://api.example.com/test?secret=123", request_method="post" + ) + expected = { + "service": "_http", + "endpoint": "https://api.example.com/test", + "http_method": "POST", + } + assert result == expected + + def test_map_status_to_error_rate_limit(self): + """Test mapping 429 status to rate limit error.""" + headers = {"retry-after": "30"} + result = self.mapper._map_status_to_error( + status=429, + headers=headers, + msg="Rate limit exceeded", + request_url="https://api.example.com/test", + request_method="GET", + ) + + assert isinstance(result, UpstreamRateLimitError) + assert result.retry_after_ms == 30_000 + assert result.message == "Rate limit exceeded" + assert result.extra["service"] == "_http" + assert result.extra["endpoint"] == "https://api.example.com/test" + assert result.extra["http_method"] == "GET" + + def test_map_status_to_error_generic(self): + """Test mapping generic HTTP status to upstream error.""" + headers = {} + result = self.mapper._map_status_to_error( + status=500, + headers=headers, + msg="Internal server error", + request_url="https://api.example.com/test", + request_method="POST", + ) + + assert isinstance(result, UpstreamError) + assert not isinstance(result, UpstreamRateLimitError) + assert result.status_code == 500 + assert result.message == "Internal server error" + assert result.extra["service"] == "_http" + assert result.extra["endpoint"] == "https://api.example.com/test" + assert result.extra["http_method"] == "POST" + + +class TestHTTPErrorAdapter: + """Test the main HTTP error adapter.""" + + def setup_method(self): + self.adapter = HTTPErrorAdapter() + + def test_httpx_not_installed(self): + """Test handling when httpx is not installed.""" + with patch.object(self.adapter._httpx_handler, "handle_exception") as mock_handle: + # Simulate what happens when httpx is not installed (returns None) + mock_handle.return_value = None + + mock_exc = Exception("test exception") + + result = self.adapter.from_exception(mock_exc) + assert result is None + + def test_requests_not_installed(self): + """Test handling when requests is not installed.""" + with patch.object(self.adapter._requests_handler, "handle_exception") as mock_handle: + # Simulate what happens when requests is not installed (returns None) + mock_handle.return_value = None + + mock_exc = Exception("test exception") + + result = self.adapter.from_exception(mock_exc) + assert result is None + + def test_httpx_http_status_error_handling(self): + """Test handling httpx HTTPStatusError.""" + + # Create a mock HTTPStatusError class and make our exception inherit from it + class MockHTTPStatusError(Exception): + pass + + # Create the exception instance that inherits from our mock class + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {"content-type": "application/json"} + + mock_request = Mock() + mock_request.url = "https://api.example.com/users/123" + mock_request.method = "GET" + + mock_exc = MockHTTPStatusError("404 Client Error: Not Found") + mock_exc.response = mock_response + mock_exc.request = mock_request + + with patch("httpx.HTTPStatusError", MockHTTPStatusError): + result = self.adapter.from_exception(mock_exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == 404 + assert result.message == "404 Client Error: Not Found" + assert result.extra["service"] == "_http" + assert result.extra["endpoint"] == "https://api.example.com/users/123" + assert result.extra["http_method"] == "GET" + + def test_httpx_rate_limit_handling(self): + """Test handling httpx 429 rate limit.""" + + # Create a mock HTTPStatusError class + class MockHTTPStatusError(Exception): + pass + + mock_response = Mock() + mock_response.status_code = 429 + mock_response.headers = {"retry-after": "60", "content-type": "application/json"} + + mock_request = Mock() + mock_request.url = "https://api.example.com/upload" + mock_request.method = "POST" + + mock_exc = MockHTTPStatusError("429 Too Many Requests") + mock_exc.response = mock_response + mock_exc.request = mock_request + + with patch("httpx.HTTPStatusError", MockHTTPStatusError): + result = self.adapter.from_exception(mock_exc) + + assert isinstance(result, UpstreamRateLimitError) + assert result.retry_after_ms == 60_000 + assert result.message == "429 Too Many Requests" + assert result.extra["service"] == "_http" + assert result.extra["endpoint"] == "https://api.example.com/upload" + assert result.extra["http_method"] == "POST" + + def test_requests_http_error_handling(self): + """Test handling requests HTTPError.""" + + # Create a mock HTTPError class + class MockHTTPError(Exception): + pass + + mock_response = Mock() + mock_response.status_code = 403 + mock_response.headers = {"www-authenticate": "Bearer"} + + mock_request = Mock() + mock_request.url = "https://api.example.com/protected" + mock_request.method = "GET" + + mock_response.request = mock_request + + mock_exc = MockHTTPError("403 Forbidden") + mock_exc.response = mock_response + + with patch("requests.exceptions.HTTPError", MockHTTPError): + result = self.adapter.from_exception(mock_exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == 403 + assert result.message == "403 Forbidden" + assert result.extra["service"] == "_http" + assert result.extra["endpoint"] == "https://api.example.com/protected" + assert result.extra["http_method"] == "GET" + + def test_requests_http_error_with_url_fallback(self): + """Test handling requests HTTPError when request is not available but response.url is.""" + + # Create a mock HTTPError class + class MockHTTPError(Exception): + pass + + mock_response = Mock() + mock_response.status_code = 500 + mock_response.headers = {} + mock_response.request = None # No request object + mock_response.url = "https://api.example.com/server-error" + + mock_exc = MockHTTPError("500 Internal Server Error") + mock_exc.response = mock_response + + with patch("requests.exceptions.HTTPError", MockHTTPError): + result = self.adapter.from_exception(mock_exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == 500 + assert result.message == "500 Internal Server Error" + assert result.extra["service"] == "_http" + assert result.extra["endpoint"] == "https://api.example.com/server-error" + assert "http_method" not in result.extra # No method available + + def test_requests_http_error_no_response(self): + """Test handling requests HTTPError with no response.""" + + # Create a mock HTTPError class + class MockHTTPError(Exception): + pass + + mock_exc = MockHTTPError("No response") + mock_exc.response = None + + with patch("requests.exceptions.HTTPError", MockHTTPError): + result = self.adapter.from_exception(mock_exc) + + assert result is None + + def test_unhandled_exception_logs_warning(self, caplog): + """Test that unhandled exceptions log a warning.""" + with caplog.at_level(logging.INFO): + unknown_exc = ValueError("Some unrelated error") + result = self.adapter.from_exception(unknown_exc) + + assert result is None + assert len(caplog.records) == 1 + assert "ValueError" in caplog.records[0].message + assert "_http" in caplog.records[0].message + assert "not handled" in caplog.records[0].message + + def test_httpx_without_request_info(self): + """Test handling httpx exception without request information.""" + + # Create a mock HTTPStatusError class + class MockHTTPStatusError(Exception): + pass + + mock_response = Mock() + mock_response.status_code = 400 + mock_response.headers = {} + + mock_exc = MockHTTPStatusError("400 Bad Request") + mock_exc.response = mock_response + mock_exc.request = None + + with patch("httpx.HTTPStatusError", MockHTTPStatusError): + result = self.adapter.from_exception(mock_exc) + + assert isinstance(result, UpstreamError) + assert result.status_code == 400 + assert result.message == "400 Bad Request" + assert result.extra["service"] == "_http" + assert "endpoint" not in result.extra + assert "http_method" not in result.extra + + def test_adapter_slug(self): + """Test that the adapter has the correct slug.""" + assert HTTPErrorAdapter.slug == "_http" diff --git a/libs/tests/tool/test_create_tool_definition_errors.py b/libs/tests/tool/test_create_tool_definition_errors.py index da90337d..e84b6557 100644 --- a/libs/tests/tool/test_create_tool_definition_errors.py +++ b/libs/tests/tool/test_create_tool_definition_errors.py @@ -2,7 +2,7 @@ from typing import Annotated import pytest from arcade_core.catalog import ToolCatalog -from arcade_core.errors import ToolDefinitionError +from arcade_core.errors import ToolDefinitionError, ToolInputSchemaError from arcade_core.schema import ToolContext, ToolMetadataKey from arcade_tdk import tool @@ -149,7 +149,7 @@ def func_with_metadata_and_auth_dependency(): ), pytest.param( func_with_invalid_renamed_param, - ToolDefinitionError, + ToolInputSchemaError, id=func_with_invalid_renamed_param.__name__, ), pytest.param( diff --git a/pyproject.toml b/pyproject.toml index 7f15cfaf..45b30363 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-ai" -version = "2.1.4" +version = "2.2.0" description = "Arcade.dev - Tool Calling platform for Agents" readme = "README.md" license = {file = "LICENSE"} @@ -21,7 +21,7 @@ requires-python = ">=3.10" dependencies = [ # CLI dependencies - "arcade-core>=2.2.2,<3.0.0", + "arcade-core>=2.4.0,<3.0.0", "typer==0.10.0", "rich==13.9.4", "Jinja2==3.1.6", @@ -40,9 +40,9 @@ all = [ "pytz>=2024.1", "python-dateutil>=2.8.2", # serve - "arcade-serve>=2.0.0,<3.0.0", + "arcade-serve>=2.1.0,<3.0.0", # tdk - "arcade-tdk>=2.0.0,<3.0.0", + "arcade-tdk>=2.3.0,<3.0.0", ] # Evals also depends on arcade-core and openai, but they are already required deps evals = [