Tool Error Handling (#539)
# Improvements to Arcade TDK Error Handling
I tried my very best to not make any breaking changes in this PR. So,
you will notice various "Deprecation" notices throughout.
### Instructions for PR reviewers
1. Pull down this PR's branch
2. Pull down the Engine's tool error handling PR's branch
3. Update your installed arcadepy to have the following:
- In `arcadepy/resources/tools/tools.py`, if you want to test out
including stacktraces, then you need to update `ToolsResource.execute`
to accept a `include_error_stacktrace` argument and also include the
"include_error_stacktrace" argument to the POST to the Engine inside of
the function's execute method's body.
- In `arcadepy/types/execute_tool_response.py` add the following enum
```py
class ErrorKind(str, Enum):
"""Error kind that is comprised of
- the who (toolkit, tool, upstream)
- the when (load time, definition parsing time, runtime)
- the what (bad_definition, bad_input, bad_output, retry,
context_required, fatal, etc.)"""
TOOLKIT_LOAD_FAILED = "TOOLKIT_LOAD_FAILED"
TOOL_DEFINITION_BAD_DEFINITION = "TOOL_DEFINITION_BAD_DEFINITION"
TOOL_DEFINITION_BAD_INPUT_SCHEMA = "TOOL_DEFINITION_BAD_INPUT_SCHEMA"
TOOL_DEFINITION_BAD_OUTPUT_SCHEMA = "TOOL_DEFINITION_BAD_OUTPUT_SCHEMA"
TOOL_RUNTIME_BAD_INPUT_VALUE = "TOOL_RUNTIME_BAD_INPUT_VALUE"
TOOL_RUNTIME_BAD_OUTPUT_VALUE = "TOOL_RUNTIME_BAD_OUTPUT_VALUE"
TOOL_RUNTIME_RETRY = "TOOL_RUNTIME_RETRY"
TOOL_RUNTIME_CONTEXT_REQUIRED = "TOOL_RUNTIME_CONTEXT_REQUIRED"
TOOL_RUNTIME_FATAL = "TOOL_RUNTIME_FATAL"
UPSTREAM_RUNTIME_BAD_REQUEST = "UPSTREAM_RUNTIME_BAD_REQUEST"
UPSTREAM_RUNTIME_AUTH_ERROR = "UPSTREAM_RUNTIME_AUTH_ERROR"
UPSTREAM_RUNTIME_NOT_FOUND = "UPSTREAM_RUNTIME_NOT_FOUND"
UPSTREAM_RUNTIME_VALIDATION_ERROR = "UPSTREAM_RUNTIME_VALIDATION_ERROR"
UPSTREAM_RUNTIME_RATE_LIMIT = "UPSTREAM_RUNTIME_RATE_LIMIT"
UPSTREAM_RUNTIME_SERVER_ERROR = "UPSTREAM_RUNTIME_SERVER_ERROR"
UPSTREAM_RUNTIME_UNMAPPED = "UPSTREAM_RUNTIME_UNMAPPED"
UNKNOWN = "UNKNOWN"
```
- In `arcadepy/types/execute_tool_response.py` add the following fields
to OutputError:
```py
kind: ErrorKind
status_code: Optional[int] = None
stacktrace: Optional[str] = None
extra: Optional[dict[str, Any]] = None
```
### Example Client Usage
```py
# Example of handling an upstream rate limit
error = response.output.error
if error and error.kind == ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT:
sleep_time = error.retry_after_ms / 1000
time.sleep(sleep_time)
# and then execute again
```
```py
# Examples of determining what type of runtime error it is
error = response.output.error
if error:
is_retryable_error = error.kind == ErrorKind.TOOL_RUNTIME_RETRY
is_a_bug_in_the_tool = error.kind == ErrorKind.TOOL_RUNTIME_FATAL
is_additional_context_required = error.kind == ErrorKind.TOOL_RUNTIME_CONTEXT_REQUIRED
```
### Example Tool Usage
```py
# EXAMPLE 1 letting Arcade handle upstream error handling for you
reddit_client.post(params) # Arcade's httpx adapter will handle error handling for you!
# ------------------------------------
# EXAMPLE 2 handling upstream bad request yourself, but letting Arcade handle the rest
try:
reddit_client.post(params)
except httpx.HTTPStatusError as e:
if e.status_code == 400:
raise UpstreamError("My extra custom message) from e
raise
```
```py
# EXAMPLE 1 letting Arcade handle it for you
risky_element = my_risky_list[42] # Arcade will raise a FatalToolError for you
# ------------------------------------
# EXAMPLE 2 handling it yourself for extra flexibility
try:
risky_element = my_risky_list[42]
except IndexError as e:
raise FatalToolError("My extra custom message") from e
```
### Non-runtime Error Message Examples
Example ToolkitLoadError Messages:
```
- [TOOLKIT_LOAD_FAILED] ToolkitLoadError when loading toolkit 'sample_tool': Could not import module mock_module. Reason: Mock import error
- [TOOLKIT_LOAD_FAILED] ToolkitLoadError when loading toolkit 'test_toolkit': Tool 'ValidTool' in toolkit 'test_toolkit' already exists in the catalog.
```
Example ToolDefinitionError Messages
```
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_missing_description': Tool 'tool_missing_description' is missing a description
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_invalid_secret_type': Secret keys must be strings (error in tool ToolWithInvalidSecretType).
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_empty_secret': Secrets must have a non-empty key (error in tool ToolWithEmptySecret).
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_invalid_metadata_type': Metadata must be strings (error in tool ToolWithInvalidMetadataType).
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_metadata_requiring_auth_without_auth': Tool ToolWithMetadataRequiringAuthWithoutAuth declares metadata key 'client_id', which requires that the tool has an auth requirement, but no auth requirement was provided. Please specify an auth requirement.
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_empty_metadata': Metadata must have a non-empty key (error in tool ToolWithEmptyMetadata).
- [TOOL_DEFINITION_BAD_DEFINITION] ToolDefinitionError in definition of tool 'tool_with_unsupported_param_type': Unsupported parameter type: <class 'test_catalog.MyFancyTestClass'>
```
Example ToolInputSchemaError Messages
```
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_missing_input_parameter_annotation': Parameter 'input_text' is missing a description
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_no_type_annotation': Parameter param has no type annotation.
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_invalid_param_name': Invalid parameter name: '123invalid' is not a valid identifier. Identifiers must start with a letter or underscore, and can only contain letters, digits, or underscores.
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_too_many_annotations': Parameter param: Annotated[str, 'name', 'desc', 'extra'] has too many string annotations. Expected 0, 1, or 2, got 3.
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_required_union_param': Parameter param is a union type. Only optional types are supported.
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_non_callable_default_factory': Default factory for parameter param: Annotated[str, 'Parameter'] = FieldInfo(annotation=NoneType, required=False, default_factory=str) is not callable.
- [TOOL_DEFINITION_BAD_INPUT_SCHEMA] ToolInputSchemaError in definition of tool 'tool_with_multiple_tool_contexts': Only one ToolContext parameter is supported, but tool tool_with_multiple_tool_contexts has multiple.
```
Example ToolOutputSchemaError Messages
```
- [TOOL_DEFINITION_BAD_OUTPUT_SCHEMA] ToolOutputSchemaError in definition of tool 'tool_missing_return_type_hint': Tool 'ToolMissingReturnTypeHint' must have a return type
- [TOOL_DEFINITION_BAD_OUTPUT_SCHEMA] ToolOutputSchemaError in definition of tool 'tool_with_unsupported_output_type': Unsupported output type '<class 'test_catalog.MyFancyTestClass'>'. Only built-in Python types, TypedDicts, Pydantic models, and standard collections are supported as tool output types.
```
### Runtime Error Message Examples
Example Tool Runtime Error Messages
```
- [TOOL_RUNTIME_FATAL] FatalToolError during execution of tool 'get_posts_in_subreddit': list index out of range
- [TOOL_RUNTIME_CONTEXT_REQUIRED] ContextRequiredToolError during execution of tool 'get_posts_in_subreddit': Ambiguous username. Please provide a more specific username
- [TOOL_RUNTIME_RETRY] RetryableToolError during execution of tool 'get_posts_in_subreddit': Retry with subreddit=learnpython or subreddit=learnprogramming
```
Example Upstream Runtime Error Messages
```
- [UPSTREAM_RUNTIME_RATE_LIMIT] UpstreamRateLimitError during execution of tool 'get_posts_in_subreddit': 429 Client Error: Too Many Requests
- [UPSTREAM_RUNTIME_BAD_REQUEST] UpstreamError during execution of tool 'get_posts_in_subreddit': 400 Client Error: Bad request. Missing 'id' parameter.
- [UPSTREAM_RUNTIME_BAD_REQUEST] UpstreamError during execution of tool 'search_files': Upstream Google API error: Invalid value '-23'. Values must be within the range: [value: 1\n, value: 1000\n]
```
This commit is contained in:
parent
a97450b3af
commit
f4558ef3a8
26 changed files with 2348 additions and 157 deletions
|
|
@ -27,7 +27,12 @@ from pydantic_core import PydanticUndefined
|
|||
|
||||
from arcade_core.annotations import Inferrable
|
||||
from arcade_core.auth import OAuth2, ToolAuthorization
|
||||
from arcade_core.errors import ToolDefinitionError
|
||||
from arcade_core.errors import (
|
||||
ToolDefinitionError,
|
||||
ToolInputSchemaError,
|
||||
ToolkitLoadError,
|
||||
ToolOutputSchemaError,
|
||||
)
|
||||
from arcade_core.schema import (
|
||||
TOOL_NAME_SEPARATOR,
|
||||
FullyQualifiedName,
|
||||
|
|
@ -224,7 +229,9 @@ class ToolCatalog(BaseModel):
|
|||
fully_qualified_name = definition.get_fully_qualified_name()
|
||||
|
||||
if fully_qualified_name in self._tools:
|
||||
raise KeyError(f"Tool '{definition.name}' already exists in the catalog.")
|
||||
raise ToolkitLoadError(
|
||||
f"Tool '{definition.name}' in toolkit '{toolkit_name}' already exists in the catalog."
|
||||
)
|
||||
|
||||
if str(fully_qualified_name).lower() in self._disabled_tools:
|
||||
logger.info(f"Tool '{fully_qualified_name!s}' is disabled and will not be cataloged.")
|
||||
|
|
@ -270,20 +277,26 @@ class ToolCatalog(BaseModel):
|
|||
tool_func = getattr(module, tool_name)
|
||||
self.add_tool(tool_func, toolkit, module)
|
||||
|
||||
except ToolDefinitionError as e:
|
||||
raise e.with_context(tool_name) from e
|
||||
except ToolkitLoadError as e:
|
||||
raise e.with_context(toolkit.name) from e
|
||||
except ImportError as e:
|
||||
raise ToolkitLoadError(
|
||||
f"Could not import module {module_name}. Reason: {e}"
|
||||
).with_context(tool_name)
|
||||
except AttributeError as e:
|
||||
raise ToolDefinitionError(
|
||||
f"Could not import tool {tool_name} in module {module_name}. Reason: {e}"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ToolDefinitionError(f"Could not import module {module_name}. Reason: {e}")
|
||||
).with_context(tool_name)
|
||||
except TypeError as e:
|
||||
raise ToolDefinitionError(
|
||||
f"Type error encountered while adding tool {tool_name} from {module_name}. Reason: {e}"
|
||||
)
|
||||
).with_context(tool_name)
|
||||
except Exception as e:
|
||||
raise ToolDefinitionError(
|
||||
f"Error encountered while adding tool {tool_name} from {module_name}. Reason: {e}"
|
||||
)
|
||||
).with_context(tool_name)
|
||||
|
||||
def __getitem__(self, name: FullyQualifiedName) -> MaterializedTool:
|
||||
return self.get_tool(name)
|
||||
|
|
@ -392,11 +405,11 @@ class ToolCatalog(BaseModel):
|
|||
# Hard requirement: tools must have descriptions
|
||||
tool_description = getattr(tool, "__tool_description__", None)
|
||||
if not tool_description:
|
||||
raise ToolDefinitionError(f"Tool {raw_tool_name} is missing a description")
|
||||
raise ToolDefinitionError(f"Tool '{raw_tool_name}' is missing a description")
|
||||
|
||||
# If the function returns a value, it must have a type annotation
|
||||
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
|
||||
raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
|
||||
raise ToolOutputSchemaError(f"Tool '{raw_tool_name}' must have a return type")
|
||||
|
||||
auth_requirement = create_auth_requirement(tool)
|
||||
secrets_requirement = create_secrets_requirement(tool)
|
||||
|
|
@ -438,7 +451,7 @@ def create_input_definition(func: Callable) -> ToolInput:
|
|||
for _, param in inspect.signature(func, follow_wrapped=True).parameters.items():
|
||||
if param.annotation is ToolContext:
|
||||
if tool_context_param_name is not None:
|
||||
raise ToolDefinitionError(
|
||||
raise ToolInputSchemaError(
|
||||
f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple."
|
||||
)
|
||||
|
||||
|
|
@ -635,7 +648,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|||
"""
|
||||
annotation = param.annotation
|
||||
if annotation == inspect.Parameter.empty:
|
||||
raise ToolDefinitionError(f"Parameter {param} has no type annotation.")
|
||||
raise ToolInputSchemaError(f"Parameter {param} has no type annotation.")
|
||||
|
||||
# Get the majority of the param info from either the Pydantic Field() or regular inspection
|
||||
if isinstance(param.default, FieldInfo):
|
||||
|
|
@ -654,7 +667,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|||
elif len(str_annotations) == 2:
|
||||
new_name = str_annotations[0]
|
||||
if not new_name.isidentifier():
|
||||
raise ToolDefinitionError(
|
||||
raise ToolInputSchemaError(
|
||||
f"Invalid parameter name: '{new_name}' is not a valid identifier. "
|
||||
"Identifiers must start with a letter or underscore, "
|
||||
"and can only contain letters, digits, or underscores."
|
||||
|
|
@ -662,7 +675,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|||
param_info.name = new_name
|
||||
param_info.description = str_annotations[1]
|
||||
else:
|
||||
raise ToolDefinitionError(
|
||||
raise ToolInputSchemaError(
|
||||
f"Parameter {param} has too many string annotations. Expected 0, 1, or 2, got {len(str_annotations)}."
|
||||
)
|
||||
|
||||
|
|
@ -677,10 +690,10 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|||
|
||||
# Final reality check
|
||||
if param_info.description is None:
|
||||
raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description")
|
||||
raise ToolInputSchemaError(f"Parameter '{param_info.name}' is missing a description")
|
||||
|
||||
if wire_type_info.wire_type is None:
|
||||
raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}")
|
||||
raise ToolInputSchemaError(f"Unknown parameter type: {param_info.field_type}")
|
||||
|
||||
return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable)
|
||||
|
||||
|
|
@ -875,7 +888,7 @@ def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
|
|||
# Union types are not currently supported
|
||||
# (other than optional, which is handled above)
|
||||
if is_union(field_type):
|
||||
raise ToolDefinitionError(
|
||||
raise ToolInputSchemaError(
|
||||
f"Parameter {param.name} is a union type. Only optional types are supported."
|
||||
)
|
||||
|
||||
|
|
@ -895,7 +908,7 @@ def extract_pydantic_param_info(param: inspect.Parameter) -> ParamInfo:
|
|||
if callable(param.default.default_factory):
|
||||
default_value = param.default.default_factory()
|
||||
else:
|
||||
raise ToolDefinitionError(f"Default factory for parameter {param} is not callable.")
|
||||
raise ToolInputSchemaError(f"Default factory for parameter {param} is not callable.")
|
||||
|
||||
# If the param is Annotated[], unwrap the annotation to get the "real" type
|
||||
# Otherwise, use the literal type
|
||||
|
|
@ -988,7 +1001,6 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel]
|
|||
input_model = create_model(f"{snake_to_pascal_case(func.__name__)}Input", **input_fields) # type: ignore[call-overload]
|
||||
|
||||
output_model = determine_output_model(func)
|
||||
|
||||
return input_model, output_model
|
||||
|
||||
|
||||
|
|
@ -1033,10 +1045,16 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901
|
|||
|
||||
# If the return annotation has a description, use it
|
||||
if description:
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(field_type, Field(description=str(description))),
|
||||
)
|
||||
try:
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(field_type, Field(description=str(description))),
|
||||
)
|
||||
except Exception:
|
||||
raise ToolOutputSchemaError(
|
||||
f"Unsupported output type '{field_type}'. Only built-in Python types, TypedDicts, "
|
||||
"Pydantic models, and standard collections are supported as tool output types."
|
||||
)
|
||||
|
||||
# If the return annotation is a Union type
|
||||
origin = return_annotation.__origin__
|
||||
|
|
|
|||
|
|
@ -1,103 +1,378 @@
|
|||
import traceback
|
||||
from typing import Optional
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolkitError(Exception):
|
||||
class ErrorKind(str, Enum):
|
||||
"""Error kind that is comprised of
|
||||
- the who (toolkit, tool, upstream)
|
||||
- the when (load time, definition parsing time, runtime)
|
||||
- the what (bad_definition, bad_input, bad_output, retry, context_required, fatal, etc.)"""
|
||||
|
||||
TOOLKIT_LOAD_FAILED = "TOOLKIT_LOAD_FAILED"
|
||||
TOOL_DEFINITION_BAD_DEFINITION = "TOOL_DEFINITION_BAD_DEFINITION"
|
||||
TOOL_DEFINITION_BAD_INPUT_SCHEMA = "TOOL_DEFINITION_BAD_INPUT_SCHEMA"
|
||||
TOOL_DEFINITION_BAD_OUTPUT_SCHEMA = "TOOL_DEFINITION_BAD_OUTPUT_SCHEMA"
|
||||
TOOL_RUNTIME_BAD_INPUT_VALUE = "TOOL_RUNTIME_BAD_INPUT_VALUE"
|
||||
TOOL_RUNTIME_BAD_OUTPUT_VALUE = "TOOL_RUNTIME_BAD_OUTPUT_VALUE"
|
||||
TOOL_RUNTIME_RETRY = "TOOL_RUNTIME_RETRY"
|
||||
TOOL_RUNTIME_CONTEXT_REQUIRED = "TOOL_RUNTIME_CONTEXT_REQUIRED"
|
||||
TOOL_RUNTIME_FATAL = "TOOL_RUNTIME_FATAL"
|
||||
UPSTREAM_RUNTIME_BAD_REQUEST = "UPSTREAM_RUNTIME_BAD_REQUEST"
|
||||
UPSTREAM_RUNTIME_AUTH_ERROR = "UPSTREAM_RUNTIME_AUTH_ERROR"
|
||||
UPSTREAM_RUNTIME_NOT_FOUND = "UPSTREAM_RUNTIME_NOT_FOUND"
|
||||
UPSTREAM_RUNTIME_VALIDATION_ERROR = "UPSTREAM_RUNTIME_VALIDATION_ERROR"
|
||||
UPSTREAM_RUNTIME_RATE_LIMIT = "UPSTREAM_RUNTIME_RATE_LIMIT"
|
||||
UPSTREAM_RUNTIME_SERVER_ERROR = "UPSTREAM_RUNTIME_SERVER_ERROR"
|
||||
UPSTREAM_RUNTIME_UNMAPPED = "UPSTREAM_RUNTIME_UNMAPPED"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
|
||||
|
||||
class ToolkitError(Exception, ABC):
|
||||
"""
|
||||
Base class for all errors related to toolkits.
|
||||
Base class for all Arcade errors.
|
||||
|
||||
Note: This class is an abstract class and cannot be instantiated directly.
|
||||
|
||||
These errors are ultimately converted to the ToolCallError schema.
|
||||
Attributes expected from subclasses:
|
||||
message : str # user-facing error message
|
||||
kind : ErrorKind # the error kind
|
||||
can_retry : bool # whether the operation can be retried
|
||||
developer_message : str | None # developer-facing error details
|
||||
status_code : int | None # HTTP status code when relevant
|
||||
additional_prompt_content : str | None # content for retry prompts
|
||||
retry_after_ms : int | None # milliseconds to wait before retry
|
||||
stacktrace : str | None # stacktrace information
|
||||
extra : dict[str, Any] | None # arbitrary structured metadata
|
||||
"""
|
||||
|
||||
pass
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "ToolkitError":
|
||||
abs_methods = getattr(cls, "__abstractmethods__", None)
|
||||
if abs_methods:
|
||||
raise TypeError(f"Can't instantiate abstract class {cls.__name__}")
|
||||
return super().__new__(cls)
|
||||
|
||||
@abstractmethod
|
||||
def create_message_prefix(self, name: str) -> str:
|
||||
pass
|
||||
|
||||
def with_context(self, name: str) -> "ToolkitError":
|
||||
"""
|
||||
Add context to the error message.
|
||||
|
||||
Args:
|
||||
name: The name of the tool or toolkit that caused the error.
|
||||
|
||||
Returns:
|
||||
The error with the context added to the message.
|
||||
"""
|
||||
prefix = self.create_message_prefix(name)
|
||||
self.message = f"{prefix}{self.message}" # type: ignore[has-type]
|
||||
if hasattr(self, "developer_message") and self.developer_message: # type: ignore[has-type]
|
||||
self.developer_message = f"{prefix}{self.developer_message}" # type: ignore[has-type]
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_toolkit_error(self) -> bool:
|
||||
"""Check if this error originated from loading a toolkit."""
|
||||
return hasattr(self, "kind") and self.kind.name.startswith("TOOLKIT_")
|
||||
|
||||
@property
|
||||
def is_tool_error(self) -> bool:
|
||||
"""Check if this error originated from a tool."""
|
||||
return hasattr(self, "kind") and self.kind.name.startswith("TOOL_")
|
||||
|
||||
@property
|
||||
def is_upstream_error(self) -> bool:
|
||||
"""Check if this error originated from an upstream service."""
|
||||
return hasattr(self, "kind") and self.kind.name.startswith("UPSTREAM_")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.message
|
||||
|
||||
|
||||
class ToolkitLoadError(ToolkitError):
|
||||
"""
|
||||
Raised when there is an error loading a toolkit.
|
||||
Raised while importing / loading a toolkit package
|
||||
(e.g. missing dependency, SyntaxError in module top-level code).
|
||||
"""
|
||||
|
||||
pass
|
||||
kind: ErrorKind = ErrorKind.TOOLKIT_LOAD_FAILED
|
||||
can_retry: bool = False
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def create_message_prefix(self, toolkit_name: str) -> str:
|
||||
return f"[{self.kind.value}] {type(self).__name__} when loading toolkit '{toolkit_name}': "
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
class ToolError(ToolkitError):
|
||||
"""
|
||||
Base class for all errors related to tools.
|
||||
Any error related to an Arcade tool.
|
||||
|
||||
Note: This class is an abstract class and cannot be instantiated directly.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# ------ definition-time errors (tool developer's responsibility) ------
|
||||
class ToolDefinitionError(ToolError):
|
||||
"""
|
||||
Raised when there is an error in the definition of a tool.
|
||||
Raised when there is an error in the definition/signature of a tool.
|
||||
"""
|
||||
|
||||
pass
|
||||
kind: ErrorKind = ErrorKind.TOOL_DEFINITION_BAD_DEFINITION
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def create_message_prefix(self, tool_name: str) -> str:
|
||||
return f"[{self.kind.value}] {type(self).__name__} in definition of tool '{tool_name}': "
|
||||
|
||||
|
||||
class ToolInputSchemaError(ToolDefinitionError):
|
||||
"""Raised when there is an error in the schema of a tool's input parameter."""
|
||||
|
||||
kind: ErrorKind = ErrorKind.TOOL_DEFINITION_BAD_INPUT_SCHEMA
|
||||
|
||||
|
||||
class ToolOutputSchemaError(ToolDefinitionError):
|
||||
"""Raised when there is an error in the schema of a tool's output parameter."""
|
||||
|
||||
kind: ErrorKind = ErrorKind.TOOL_DEFINITION_BAD_OUTPUT_SCHEMA
|
||||
|
||||
|
||||
# ------ runtime errors ------
|
||||
class ToolRuntimeError(ToolError, RuntimeError):
|
||||
"""
|
||||
Any failure starting from when the tool call begins until the tool call returns.
|
||||
|
||||
Note: This class should typically not be instantiated directly, but rather subclassed.
|
||||
"""
|
||||
|
||||
kind: ErrorKind = ErrorKind.TOOL_RUNTIME_FATAL
|
||||
can_retry: bool = False
|
||||
status_code: int | None = None
|
||||
extra: dict[str, Any] | None = None
|
||||
|
||||
class ToolRuntimeError(RuntimeError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
developer_message: Optional[str] = None,
|
||||
developer_message: str | None = None,
|
||||
*,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.developer_message = developer_message
|
||||
self.developer_message = developer_message # type: ignore[assignment]
|
||||
self.extra = extra
|
||||
|
||||
def traceback_info(self) -> str | None:
|
||||
# return the traceback information of the parent exception
|
||||
def create_message_prefix(self, tool_name: str) -> str:
|
||||
return f"[{self.kind.value}] {type(self).__name__} during execution of tool '{tool_name}': "
|
||||
|
||||
def stacktrace(self) -> str | None:
|
||||
if self.__cause__:
|
||||
return "\n".join(traceback.format_exception(self.__cause__))
|
||||
return None
|
||||
|
||||
def traceback_info(self) -> str | None:
|
||||
"""DEPRECATED: Use stacktrace() instead.
|
||||
|
||||
This method is deprecated and will be removed in a future major version.
|
||||
"""
|
||||
return self.stacktrace()
|
||||
|
||||
# wire-format helper
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"message": self.message,
|
||||
"developer_message": self.developer_message,
|
||||
"kind": self.kind,
|
||||
"can_retry": self.can_retry,
|
||||
"status_code": self.status_code,
|
||||
**(self.extra or {}),
|
||||
}
|
||||
|
||||
|
||||
# 1. ------ serialization errors ------
|
||||
class ToolSerializationError(ToolRuntimeError):
|
||||
"""
|
||||
Raised when there is an error serializing/marshalling the tool call arguments or return value.
|
||||
|
||||
Note: This class is not intended to be instantiated directly, but rather subclassed.
|
||||
"""
|
||||
|
||||
|
||||
class ToolInputError(ToolSerializationError):
|
||||
"""
|
||||
Raised when there is an error parsing a tool call argument.
|
||||
"""
|
||||
|
||||
kind: ErrorKind = ErrorKind.TOOL_RUNTIME_BAD_INPUT_VALUE
|
||||
status_code: int = 400
|
||||
|
||||
|
||||
class ToolOutputError(ToolSerializationError):
|
||||
"""
|
||||
Raised when there is an error serializing a tool call return value.
|
||||
"""
|
||||
|
||||
kind: ErrorKind = ErrorKind.TOOL_RUNTIME_BAD_OUTPUT_VALUE
|
||||
status_code: int = 500
|
||||
|
||||
|
||||
# 2. ------ tool-body errors ------
|
||||
class ToolExecutionError(ToolRuntimeError):
|
||||
"""
|
||||
Raised when there is an error executing a tool.
|
||||
"""
|
||||
DEPRECATED: Raised when there is an error executing a tool.
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RetryableToolError(ToolExecutionError):
|
||||
"""
|
||||
Raised when a tool error is retryable.
|
||||
ToolExecutionError is deprecated and will be removed in a future major version.
|
||||
Use more specific error types instead:
|
||||
- RetryableToolError for retryable errors
|
||||
- ContextRequiredToolError for errors requiring user context
|
||||
- FatalToolError for fatal/unexpected errors
|
||||
- UpstreamError for upstream service errors
|
||||
- UpstreamRateLimitError for upstream rate limiting errors
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
developer_message: Optional[str] = None,
|
||||
additional_prompt_content: Optional[str] = None,
|
||||
retry_after_ms: Optional[int] = None,
|
||||
developer_message: str | None = None,
|
||||
*,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, developer_message)
|
||||
if type(self) is ToolExecutionError:
|
||||
warnings.warn(
|
||||
"ToolExecutionError is deprecated and will be removed in a future major version. "
|
||||
"Use more specific error types instead: RetryableToolError, ContextRequiredToolError, "
|
||||
"FatalToolError, UpstreamError, or UpstreamRateLimitError.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(message, developer_message=developer_message, extra=extra)
|
||||
|
||||
|
||||
class RetryableToolError(ToolExecutionError):
|
||||
"""
|
||||
Raised when a tool execution error is retryable.
|
||||
"""
|
||||
|
||||
kind: ErrorKind = ErrorKind.TOOL_RUNTIME_RETRY
|
||||
can_retry: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
developer_message: str | None = None,
|
||||
additional_prompt_content: str | None = None, # TODO: Make required in next major version
|
||||
retry_after_ms: int | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, developer_message=developer_message, extra=extra)
|
||||
self.additional_prompt_content = additional_prompt_content
|
||||
self.retry_after_ms = retry_after_ms
|
||||
|
||||
|
||||
class ToolSerializationError(ToolRuntimeError):
|
||||
class ContextRequiredToolError(ToolExecutionError):
|
||||
"""
|
||||
Raised when there is an error executing a tool.
|
||||
Raised when the combination of additional content from the tool AND
|
||||
additional context from the end-user/orchestrator is required before retrying the tool.
|
||||
|
||||
This is typically used when an argument provided to the tool is invalid in some way,
|
||||
and immediately prompting an LLM to retry the tool call is not desired.
|
||||
"""
|
||||
|
||||
pass
|
||||
kind: ErrorKind = ErrorKind.TOOL_RUNTIME_CONTEXT_REQUIRED
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
additional_prompt_content: str,
|
||||
developer_message: str | None = None,
|
||||
*,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, developer_message=developer_message, extra=extra)
|
||||
self.additional_prompt_content = additional_prompt_content
|
||||
|
||||
|
||||
class ToolInputError(ToolSerializationError):
|
||||
class FatalToolError(ToolExecutionError):
|
||||
"""
|
||||
Raised when there is an error in the input to a tool.
|
||||
Raised when there is an unexpected or unknown error executing a tool.
|
||||
"""
|
||||
|
||||
pass
|
||||
status_code: int = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
developer_message: str | None = None,
|
||||
*,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, developer_message=developer_message, extra=extra)
|
||||
|
||||
|
||||
class ToolOutputError(ToolSerializationError):
|
||||
# 3. ------ upstream errors in tool body------
|
||||
class UpstreamError(ToolExecutionError):
|
||||
"""
|
||||
Raised when there is an error in the output of a tool.
|
||||
Error from an upstream service/API during tool execution.
|
||||
|
||||
This class handles all upstream failures except rate limiting.
|
||||
The status_code and extra dict provide details about the specific error type.
|
||||
"""
|
||||
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
developer_message: str | None = None,
|
||||
*,
|
||||
status_code: int,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, developer_message=developer_message, extra=extra)
|
||||
self.status_code = status_code
|
||||
# Determine retryability based on status code
|
||||
self.can_retry = status_code >= 500 or status_code == 429
|
||||
# Set appropriate error kind based on status
|
||||
if status_code in (401, 403):
|
||||
self.kind = ErrorKind.UPSTREAM_RUNTIME_AUTH_ERROR
|
||||
elif status_code == 404:
|
||||
self.kind = ErrorKind.UPSTREAM_RUNTIME_NOT_FOUND
|
||||
elif status_code == 429:
|
||||
self.kind = ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT
|
||||
elif status_code >= 500:
|
||||
self.kind = ErrorKind.UPSTREAM_RUNTIME_SERVER_ERROR
|
||||
elif 400 <= status_code < 500:
|
||||
self.kind = ErrorKind.UPSTREAM_RUNTIME_BAD_REQUEST
|
||||
else:
|
||||
self.kind = ErrorKind.UPSTREAM_RUNTIME_UNMAPPED
|
||||
|
||||
|
||||
class UpstreamRateLimitError(UpstreamError):
|
||||
"""
|
||||
Rate limit error from an upstream service.
|
||||
|
||||
Special case of UpstreamError that includes retry_after_ms information.
|
||||
"""
|
||||
|
||||
kind: ErrorKind = ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT
|
||||
can_retry: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
retry_after_ms: int,
|
||||
developer_message: str | None = None,
|
||||
*,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, status_code=429, developer_message=developer_message, extra=extra)
|
||||
self.retry_after_ms = retry_after_ms
|
||||
|
|
|
|||
|
|
@ -6,11 +6,9 @@ from typing import Any
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from arcade_core.errors import (
|
||||
RetryableToolError,
|
||||
ToolInputError,
|
||||
ToolOutputError,
|
||||
ToolRuntimeError,
|
||||
ToolSerializationError,
|
||||
)
|
||||
from arcade_core.output import output_factory
|
||||
from arcade_core.schema import (
|
||||
|
|
@ -69,31 +67,26 @@ class ToolExecutor:
|
|||
# return the output
|
||||
return output_factory.success(data=output, logs=tool_call_logs)
|
||||
|
||||
except RetryableToolError as e:
|
||||
return output_factory.fail_retry(
|
||||
message=e.message,
|
||||
developer_message=e.developer_message,
|
||||
additional_prompt_content=e.additional_prompt_content,
|
||||
retry_after_ms=e.retry_after_ms,
|
||||
)
|
||||
|
||||
except ToolSerializationError as e:
|
||||
return output_factory.fail(message=e.message, developer_message=e.developer_message)
|
||||
|
||||
# should catch all tool exceptions due to the try/except in the tool decorator
|
||||
except ToolRuntimeError as e:
|
||||
e.with_context(func.__name__)
|
||||
return output_factory.fail(
|
||||
message=e.message,
|
||||
developer_message=e.developer_message,
|
||||
traceback_info=e.traceback_info(),
|
||||
stacktrace=e.stacktrace(),
|
||||
additional_prompt_content=getattr(e, "additional_prompt_content", None),
|
||||
retry_after_ms=getattr(e, "retry_after_ms", None),
|
||||
kind=e.kind,
|
||||
can_retry=e.can_retry,
|
||||
status_code=e.status_code,
|
||||
extra=e.extra,
|
||||
)
|
||||
|
||||
# if we get here we're in trouble
|
||||
except Exception as e:
|
||||
return output_factory.fail(
|
||||
message="Error in execution",
|
||||
message=f"Error in execution of '{func.__name__}'",
|
||||
developer_message=str(e),
|
||||
traceback_info=traceback.format_exc(),
|
||||
stacktrace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from arcade_core.errors import ErrorKind
|
||||
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput
|
||||
from arcade_core.utils import coerce_empty_list_to_none
|
||||
|
||||
|
|
@ -61,15 +62,26 @@ class ToolOutputFactory:
|
|||
*,
|
||||
message: str,
|
||||
developer_message: str | None = None,
|
||||
traceback_info: str | None = None,
|
||||
stacktrace: str | None = None,
|
||||
logs: list[ToolCallLog] | None = None,
|
||||
additional_prompt_content: str | None = None,
|
||||
retry_after_ms: int | None = None,
|
||||
kind: ErrorKind = ErrorKind.UNKNOWN,
|
||||
can_retry: bool = False,
|
||||
status_code: int | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> ToolCallOutput:
|
||||
return ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message=message,
|
||||
developer_message=developer_message,
|
||||
can_retry=False,
|
||||
traceback_info=traceback_info,
|
||||
can_retry=can_retry,
|
||||
additional_prompt_content=additional_prompt_content,
|
||||
retry_after_ms=retry_after_ms,
|
||||
stacktrace=stacktrace,
|
||||
kind=kind,
|
||||
status_code=status_code,
|
||||
extra=extra,
|
||||
),
|
||||
logs=coerce_empty_list_to_none(logs),
|
||||
)
|
||||
|
|
@ -81,9 +93,17 @@ class ToolOutputFactory:
|
|||
developer_message: str | None = None,
|
||||
additional_prompt_content: str | None = None,
|
||||
retry_after_ms: int | None = None,
|
||||
traceback_info: str | None = None,
|
||||
stacktrace: str | None = None,
|
||||
logs: list[ToolCallLog] | None = None,
|
||||
kind: ErrorKind = ErrorKind.TOOL_RUNTIME_RETRY,
|
||||
status_code: int = 500,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> ToolCallOutput:
|
||||
"""
|
||||
DEPRECATED: Use ToolOutputFactory.fail instead.
|
||||
This method will be removed in version 3.0.0
|
||||
"""
|
||||
|
||||
return ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message=message,
|
||||
|
|
@ -91,7 +111,10 @@ class ToolOutputFactory:
|
|||
can_retry=True,
|
||||
additional_prompt_content=additional_prompt_content,
|
||||
retry_after_ms=retry_after_ms,
|
||||
traceback_info=traceback_info,
|
||||
stacktrace=stacktrace,
|
||||
kind=kind,
|
||||
status_code=status_code,
|
||||
extra=extra,
|
||||
),
|
||||
logs=coerce_empty_list_to_none(logs),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from typing import Any, Literal
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arcade_core.errors import ErrorKind
|
||||
|
||||
# allow for custom tool name separator
|
||||
TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".")
|
||||
|
||||
|
|
@ -390,6 +392,8 @@ class ToolCallError(BaseModel):
|
|||
|
||||
message: str
|
||||
"""The user-facing error message."""
|
||||
kind: ErrorKind
|
||||
"""The error kind that uniquely identifies the kind of error."""
|
||||
developer_message: str | None = None
|
||||
"""The developer-facing error details."""
|
||||
can_retry: bool = False
|
||||
|
|
@ -398,8 +402,27 @@ class ToolCallError(BaseModel):
|
|||
"""Additional content to be included in the retry prompt."""
|
||||
retry_after_ms: int | None = None
|
||||
"""The number of milliseconds (if any) to wait before retrying the tool call."""
|
||||
traceback_info: str | None = None
|
||||
"""The traceback information for the tool call."""
|
||||
stacktrace: str | None = None
|
||||
"""The stacktrace information for the tool call."""
|
||||
status_code: int | None = None
|
||||
"""The HTTP status code of the error."""
|
||||
extra: dict[str, Any] | None = None
|
||||
"""Additional information about the error."""
|
||||
|
||||
@property
|
||||
def is_toolkit_error(self) -> bool:
|
||||
"""Check if this error originated from loading a toolkit."""
|
||||
return self.kind.name.startswith("TOOLKIT_")
|
||||
|
||||
@property
|
||||
def is_tool_error(self) -> bool:
|
||||
"""Check if this error originated from a tool."""
|
||||
return self.kind.name.startswith("TOOL_")
|
||||
|
||||
@property
|
||||
def is_upstream_error(self) -> bool:
|
||||
"""Check if this error originated from an upstream service."""
|
||||
return self.kind.name.startswith("UPSTREAM_")
|
||||
|
||||
|
||||
class ToolCallRequiresAuthorization(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-core"
|
||||
version = "2.3.0"
|
||||
version = "2.4.0"
|
||||
description = "Arcade Core - Core library for Arcade platform"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
|
|||
|
|
@ -153,8 +153,8 @@ class BaseWorker(Worker):
|
|||
logger.debug(
|
||||
f"{execution_id} | duration: {duration_ms}ms | Tool output: {output.value}"
|
||||
)
|
||||
if output.error.traceback_info:
|
||||
logger.debug(f"{execution_id} | Tool traceback: {output.error.traceback_info}")
|
||||
if output.error.stacktrace:
|
||||
logger.debug(f"{execution_id} | Tool traceback: {output.error.stacktrace}")
|
||||
else:
|
||||
logger.info(
|
||||
f"{execution_id} | Tool {tool_fqname} version {tool_request.tool.version} success"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-serve"
|
||||
version = "2.0.0"
|
||||
version = "2.1.0"
|
||||
description = "Arcade Serve - Serving infrastructure for Arcade tools and workers"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
@ -19,7 +19,7 @@ classifiers = [
|
|||
]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-core>=2.0.0,<3.0.0",
|
||||
"arcade-core>=2.4.0,<3.0.0",
|
||||
"fastapi>=0.115.3",
|
||||
"uvicorn>=0.30.0",
|
||||
"watchfiles>=1.0.5",
|
||||
|
|
|
|||
5
libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py
Normal file
5
libs/arcade-tdk/arcade_tdk/error_adapters/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from arcade_tdk.error_adapters.base import ErrorAdapter
|
||||
from arcade_tdk.providers.google import GoogleErrorAdapter
|
||||
from arcade_tdk.providers.http import HTTPErrorAdapter
|
||||
|
||||
__all__ = ["ErrorAdapter", "HTTPErrorAdapter", "GoogleErrorAdapter"]
|
||||
22
libs/arcade-tdk/arcade_tdk/error_adapters/base.py
Normal file
22
libs/arcade-tdk/arcade_tdk/error_adapters/base.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from arcade_tdk.errors import ToolRuntimeError
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ErrorAdapter(Protocol):
|
||||
"""
|
||||
Plugin that translates vendor-specific exceptions / responses into
|
||||
the appropriate Arcade Errors.
|
||||
"""
|
||||
|
||||
slug: str # for logging & metrics
|
||||
|
||||
def from_exception(self, exc: Exception) -> ToolRuntimeError | None:
|
||||
"""
|
||||
Translate an exception raised by an SDK, HTTP client, etc.
|
||||
into a `ToolRuntimeError` subclass.
|
||||
"""
|
||||
...
|
||||
15
libs/arcade-tdk/arcade_tdk/error_adapters/utils.py
Normal file
15
libs/arcade-tdk/arcade_tdk/error_adapters/utils.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from arcade_tdk.auth import Google, ToolAuthorization
|
||||
from arcade_tdk.error_adapters import ErrorAdapter, GoogleErrorAdapter
|
||||
|
||||
|
||||
def get_adapter_for_auth_provider(auth_provider: ToolAuthorization | None) -> ErrorAdapter | None:
|
||||
"""
|
||||
Get an error adapter from an auth provider.
|
||||
"""
|
||||
if not auth_provider:
|
||||
return None
|
||||
|
||||
if isinstance(auth_provider, Google):
|
||||
return GoogleErrorAdapter()
|
||||
|
||||
return None
|
||||
|
|
@ -1,16 +1,34 @@
|
|||
from arcade_core.errors import RetryableToolError, ToolExecutionError, ToolRuntimeError
|
||||
from arcade_core.errors import (
|
||||
ContextRequiredToolError,
|
||||
ErrorKind,
|
||||
FatalToolError,
|
||||
RetryableToolError,
|
||||
ToolExecutionError,
|
||||
ToolRuntimeError,
|
||||
UpstreamError,
|
||||
UpstreamRateLimitError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ErrorKind",
|
||||
"FatalToolError",
|
||||
"RetryableToolError",
|
||||
"SDKError",
|
||||
"ToolExecutionError",
|
||||
"ToolRuntimeError",
|
||||
"UpstreamError",
|
||||
"UpstreamRateLimitError",
|
||||
"ContextRequiredToolError",
|
||||
"WeightError",
|
||||
]
|
||||
|
||||
|
||||
class SDKError(Exception):
|
||||
"""Base class for all SDK errors."""
|
||||
"""
|
||||
DEPRECATED: Base class for all SDK errors.
|
||||
|
||||
SDKError is deprecated and will be removed in a future major version.
|
||||
"""
|
||||
|
||||
|
||||
class WeightError(SDKError):
|
||||
|
|
|
|||
0
libs/arcade-tdk/arcade_tdk/providers/__init__.py
Normal file
0
libs/arcade-tdk/arcade_tdk/providers/__init__.py
Normal file
3
libs/arcade-tdk/arcade_tdk/providers/google/__init__.py
Normal file
3
libs/arcade-tdk/arcade_tdk/providers/google/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from arcade_tdk.providers.google.error_adapter import GoogleErrorAdapter
|
||||
|
||||
__all__ = ["GoogleErrorAdapter"]
|
||||
228
libs/arcade-tdk/arcade_tdk/providers/google/error_adapter.py
Normal file
228
libs/arcade-tdk/arcade_tdk/providers/google/error_adapter.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from arcade_core.errors import (
|
||||
ToolRuntimeError,
|
||||
UpstreamError,
|
||||
UpstreamRateLimitError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleErrorAdapter:
|
||||
"""Error adapter for Google's API Python Client library."""
|
||||
|
||||
slug = "_google_api_client"
|
||||
|
||||
def _sanitize_uri(self, uri: str) -> str:
|
||||
"""Strip query params and fragments from URI for privacy."""
|
||||
|
||||
parsed = urlparse(uri)
|
||||
return f"{parsed.scheme}://{parsed.netloc.strip('/')}/{parsed.path.strip('/')}"
|
||||
|
||||
def _parse_retry_after(self, error: Any) -> int:
|
||||
"""
|
||||
Extract retry-after from Google API errors.
|
||||
Returns milliseconds to wait before retry.
|
||||
Defaults to 1000ms if not found.
|
||||
|
||||
Args:
|
||||
error: The Google client error to parse
|
||||
|
||||
Returns:
|
||||
The number of milliseconds to wait before retry
|
||||
"""
|
||||
if hasattr(error, "resp") and hasattr(error.resp, "headers"):
|
||||
headers = error.resp.headers
|
||||
|
||||
retry_after = headers.get("Retry-After", headers.get("retry-after"))
|
||||
if retry_after:
|
||||
try:
|
||||
# If it's a number, it's seconds
|
||||
if retry_after.isdigit():
|
||||
return int(retry_after) * 1000
|
||||
# Otherwise try to parse as date
|
||||
dt = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z")
|
||||
return int((dt - datetime.now(timezone.utc)).total_seconds() * 1000)
|
||||
except Exception:
|
||||
# TODO: Log?
|
||||
return 1000
|
||||
|
||||
return 1000
|
||||
|
||||
def _map_http_error(self, error: Any) -> ToolRuntimeError | None:
|
||||
"""Map Google HttpError to appropriate ToolRuntimeError."""
|
||||
status_code = error.status_code
|
||||
reason = str(error.reason) if error.reason else f"HTTP {status_code} error"
|
||||
|
||||
message = f"Upstream Google API error: {reason}"
|
||||
|
||||
developer_message = None
|
||||
if error.error_details:
|
||||
# str error details are added to the message
|
||||
if isinstance(error.error_details, str):
|
||||
message = f"{message} - Details: {error.error_details}"
|
||||
else:
|
||||
# structured error details are added to the developer message
|
||||
developer_message = f"Upstream Google API error details: {error.error_details}"
|
||||
|
||||
# Build extra metadata
|
||||
extra = {
|
||||
"service": self.slug,
|
||||
}
|
||||
|
||||
# Try to extract request details if available
|
||||
if hasattr(error, "uri"):
|
||||
extra["endpoint"] = self._sanitize_uri(error.uri)
|
||||
if hasattr(error, "method_"):
|
||||
extra["http_method"] = error.method_.upper()
|
||||
|
||||
# Special case for rate limiting
|
||||
if status_code == 429:
|
||||
return UpstreamRateLimitError(
|
||||
retry_after_ms=self._parse_retry_after(error),
|
||||
message=message,
|
||||
developer_message=developer_message,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
return UpstreamError(
|
||||
message=message,
|
||||
status_code=status_code,
|
||||
developer_message=developer_message,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
def _handle_http_errors(self, exc: Exception, errors_module: Any) -> ToolRuntimeError | None:
|
||||
"""Handle HttpError and its subclasses."""
|
||||
if isinstance(exc, errors_module.HttpError):
|
||||
return self._map_http_error(exc)
|
||||
|
||||
if isinstance(exc, errors_module.BatchError):
|
||||
# BatchError might not have status_code, so handle carefully
|
||||
if hasattr(exc, "resp") and hasattr(exc.resp, "status"):
|
||||
exc.status_code = exc.resp.status
|
||||
return self._map_http_error(exc)
|
||||
else:
|
||||
# No status code available, treat as server error
|
||||
extra = {
|
||||
"service": "google_api",
|
||||
"error_type": "BatchError",
|
||||
}
|
||||
return UpstreamError(
|
||||
message=f"Upstream Google API batch operation failed: {exc.reason}",
|
||||
status_code=500,
|
||||
extra=extra,
|
||||
)
|
||||
return None
|
||||
|
||||
def _handle_other_errors(self, exc: Exception, errors_module: Any) -> ToolRuntimeError | None:
|
||||
"""Handle non-HTTP Google API errors."""
|
||||
if isinstance(exc, errors_module.InvalidJsonError):
|
||||
return UpstreamError(
|
||||
message="Upstream Google API returned invalid JSON response",
|
||||
status_code=502,
|
||||
developer_message=str(exc),
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "InvalidJsonError",
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(exc, errors_module.UnknownApiNameOrVersion):
|
||||
return UpstreamError(
|
||||
message="Upstream Google API error: Unknown API name or version",
|
||||
status_code=404,
|
||||
developer_message=str(exc),
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "UnknownApiNameOrVersion",
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(exc, errors_module.UnacceptableMimeTypeError):
|
||||
return UpstreamError(
|
||||
message="Upstream Google API error: Unacceptable MIME type for this operation",
|
||||
status_code=400,
|
||||
developer_message=str(exc),
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "UnacceptableMimeTypeError",
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(exc, errors_module.MediaUploadSizeError):
|
||||
return UpstreamError(
|
||||
message="Upstream Google API error: Media file size exceeds allowed limit",
|
||||
status_code=400,
|
||||
developer_message=str(exc),
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "MediaUploadSizeError",
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(exc, errors_module.InvalidChunkSizeError):
|
||||
return UpstreamError(
|
||||
message="Upstream Google API error: Invalid chunk size specified",
|
||||
developer_message=str(exc),
|
||||
status_code=400,
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "InvalidChunkSizeError",
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(exc, errors_module.InvalidNotificationError):
|
||||
return UpstreamError(
|
||||
message="Upstream Google API error: Invalid notification configuration",
|
||||
developer_message=str(exc),
|
||||
status_code=400,
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": "InvalidNotificationError",
|
||||
},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def from_exception(self, exc: Exception) -> ToolRuntimeError | None:
|
||||
"""
|
||||
Translate a Google API client exception into a ToolRuntimeError.
|
||||
"""
|
||||
# Lazy import the Google API client errors module to avoid import errors for toolkits that don't use googleapiclient
|
||||
try:
|
||||
from googleapiclient import errors
|
||||
except ImportError:
|
||||
logger.info(
|
||||
f"'googleapiclient' is not installed in the toolkit's environment, "
|
||||
f"so the '{self.slug}' adapter was not used to handle the upstream error"
|
||||
)
|
||||
return None
|
||||
|
||||
# Try HTTP errors first
|
||||
result = self._handle_http_errors(exc, errors)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Then try other error types
|
||||
result = self._handle_other_errors(exc, errors)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Failsafe for any unhandled Google API client errors that are not mapped above
|
||||
if hasattr(exc, "__module__") and exc.__module__ == "googleapiclient.errors":
|
||||
return UpstreamError(
|
||||
message=f"Upstream Google API error: {exc}",
|
||||
status_code=500,
|
||||
extra={
|
||||
"service": self.slug,
|
||||
"error_type": exc.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
||||
# Not a Google API client error
|
||||
return None
|
||||
3
libs/arcade-tdk/arcade_tdk/providers/http/__init__.py
Normal file
3
libs/arcade-tdk/arcade_tdk/providers/http/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from arcade_tdk.providers.http.error_adapter import HTTPErrorAdapter
|
||||
|
||||
__all__ = ["HTTPErrorAdapter"]
|
||||
200
libs/arcade-tdk/arcade_tdk/providers/http/error_adapter.py
Normal file
200
libs/arcade-tdk/arcade_tdk/providers/http/error_adapter.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from arcade_core.errors import (
|
||||
UpstreamError,
|
||||
UpstreamRateLimitError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RATE_HEADERS = ("retry-after", "x-ratelimit-reset", "x-ratelimit-reset-ms")
|
||||
|
||||
|
||||
class BaseHTTPErrorMapper:
|
||||
"""Base class for HTTP error mapping functionality."""
|
||||
|
||||
def _parse_retry_ms(self, headers: dict[str, str]) -> int:
|
||||
"""
|
||||
Parses a rate limit header and returns the number
|
||||
of milliseconds until the rate limit resets.
|
||||
|
||||
Args:
|
||||
headers: A dictionary of HTTP headers.
|
||||
|
||||
Returns:
|
||||
The number of milliseconds until the rate limit resets.
|
||||
Defaults to 1000ms if a rate limit header is not found or cannot be parsed.
|
||||
"""
|
||||
val = next((headers.get(h) for h in RATE_HEADERS if headers.get(h)), None)
|
||||
# No rate limit header found
|
||||
if val is None:
|
||||
return 1_000
|
||||
# Rate limit header is a number of seconds
|
||||
if val.isdigit():
|
||||
key = next((h for h in RATE_HEADERS if headers.get(h) == val), "")
|
||||
if key.endswith("ms"):
|
||||
return int(val)
|
||||
return int(val) * 1_000
|
||||
# Rate limit header is an absolute date
|
||||
try:
|
||||
dt = datetime.strptime(val, "%a, %d %b %Y %H:%M:%S %Z")
|
||||
return int((dt - datetime.now(timezone.utc)).total_seconds() * 1_000)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to parse rate limit header: {val}. Defaulting to 1000ms.")
|
||||
return 1_000
|
||||
|
||||
def _sanitize_uri(self, uri: str) -> str:
|
||||
"""Strip query params and fragments from URI for privacy."""
|
||||
|
||||
parsed = urlparse(uri)
|
||||
return f"{parsed.scheme}://{parsed.netloc.strip('/')}/{parsed.path.strip('/')}"
|
||||
|
||||
def _build_extra_metadata(
|
||||
self, request_url: str | None = None, request_method: str | None = None
|
||||
) -> dict[str, str]:
|
||||
"""Build extra metadata for error reporting."""
|
||||
extra = {
|
||||
"service": HTTPErrorAdapter.slug,
|
||||
}
|
||||
|
||||
if request_url:
|
||||
extra["endpoint"] = self._sanitize_uri(request_url)
|
||||
|
||||
if request_method:
|
||||
extra["http_method"] = request_method.upper()
|
||||
|
||||
return extra
|
||||
|
||||
def _map_status_to_error(
|
||||
self,
|
||||
status: int,
|
||||
headers: dict[str, str],
|
||||
msg: str,
|
||||
request_url: str | None = None,
|
||||
request_method: str | None = None,
|
||||
) -> UpstreamError:
|
||||
"""Map HTTP status code to appropriate Arcade error."""
|
||||
extra = self._build_extra_metadata(request_url, request_method)
|
||||
|
||||
# Special case for rate limiting
|
||||
if status == 429:
|
||||
return UpstreamRateLimitError(
|
||||
retry_after_ms=self._parse_retry_ms(headers),
|
||||
message=msg,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
return UpstreamError(message=msg, status_code=status, extra=extra)
|
||||
|
||||
|
||||
class _HTTPXExceptionHandler:
|
||||
"""Handler for httpx-specific exceptions."""
|
||||
|
||||
def handle_exception(self, exc: Any, mapper: BaseHTTPErrorMapper) -> UpstreamError | None:
|
||||
"""Handle httpx HTTPStatusError exceptions.
|
||||
|
||||
Args:
|
||||
exc: An httpx.HTTPStatusError exception
|
||||
mapper: The BaseHTTPErrorMapper instance to use for mapping
|
||||
|
||||
Returns:
|
||||
An Arcade error instance or None if not an httpx exception
|
||||
"""
|
||||
# Lazy import httpx types locally to avoid import errors for toolkits that don't use httpx
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not isinstance(exc, httpx.HTTPStatusError):
|
||||
return None
|
||||
|
||||
response = exc.response
|
||||
request_url = None
|
||||
request_method = None
|
||||
if hasattr(exc, "request") and exc.request:
|
||||
request_url = str(exc.request.url)
|
||||
request_method = exc.request.method
|
||||
|
||||
return mapper._map_status_to_error(
|
||||
response.status_code,
|
||||
dict(response.headers),
|
||||
str(exc),
|
||||
request_url=request_url,
|
||||
request_method=request_method,
|
||||
)
|
||||
|
||||
|
||||
class _RequestsExceptionHandler:
|
||||
"""Handler for requests-specific exceptions."""
|
||||
|
||||
def handle_exception(self, exc: Any, mapper: BaseHTTPErrorMapper) -> UpstreamError | None:
|
||||
"""Handle requests library exceptions.
|
||||
|
||||
Args:
|
||||
exc: A requests.exceptions.HTTPError exception
|
||||
mapper: The BaseHTTPErrorMapper instance to use for mapping
|
||||
|
||||
Returns:
|
||||
An Arcade error instance or None if not a requests exception
|
||||
"""
|
||||
# Lazy import requests types locally to avoid import errors for toolkits that don't use requests
|
||||
try:
|
||||
from requests.exceptions import HTTPError # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not isinstance(exc, HTTPError):
|
||||
return None
|
||||
|
||||
response = getattr(exc, "response", None)
|
||||
if response is None:
|
||||
return None
|
||||
|
||||
# Extract request information
|
||||
request_url = None
|
||||
request_method = None
|
||||
if hasattr(response, "request") and response.request:
|
||||
request_url = response.request.url
|
||||
request_method = response.request.method
|
||||
elif hasattr(response, "url"):
|
||||
request_url = response.url
|
||||
|
||||
return mapper._map_status_to_error(
|
||||
response.status_code,
|
||||
dict(response.headers),
|
||||
str(exc),
|
||||
request_url=request_url,
|
||||
request_method=request_method,
|
||||
)
|
||||
|
||||
|
||||
class HTTPErrorAdapter(BaseHTTPErrorMapper):
|
||||
"""Main HTTP error adapter that supports multiple HTTP libraries."""
|
||||
|
||||
slug = "_http"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._httpx_handler = _HTTPXExceptionHandler()
|
||||
self._requests_handler = _RequestsExceptionHandler()
|
||||
|
||||
def from_exception(self, exc: Exception) -> UpstreamError | None:
|
||||
"""Convert HTTP library exceptions into Arcade errors."""
|
||||
|
||||
httpx_result = self._httpx_handler.handle_exception(exc, self)
|
||||
if httpx_result is not None:
|
||||
return httpx_result
|
||||
|
||||
requests_result = self._requests_handler.handle_exception(exc, self)
|
||||
if requests_result is not None:
|
||||
return requests_result
|
||||
|
||||
logger.info(
|
||||
f"Exception type '{type(exc).__name__}' was not handled by the '{self.slug}' adapter. "
|
||||
f"Either the exception is not from a supported HTTP library (httpx, requests) or "
|
||||
f"the required library is not installed in the toolkit's environment."
|
||||
)
|
||||
return None
|
||||
|
|
@ -1,21 +1,104 @@
|
|||
import functools
|
||||
import inspect
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from arcade_tdk.auth import ToolAuthorization
|
||||
from arcade_tdk.errors import ToolExecutionError
|
||||
from arcade_tdk.error_adapters import ErrorAdapter
|
||||
from arcade_tdk.error_adapters.utils import get_adapter_for_auth_provider
|
||||
from arcade_tdk.errors import (
|
||||
FatalToolError,
|
||||
ToolRuntimeError,
|
||||
)
|
||||
from arcade_tdk.providers.http import HTTPErrorAdapter
|
||||
from arcade_tdk.utils import snake_to_pascal_case
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _build_adapter_chain(
|
||||
adapters: list[ErrorAdapter] | None, auth_provider: ToolAuthorization | None
|
||||
) -> list[ErrorAdapter]:
|
||||
"""
|
||||
Build the adapter chain for error handling.
|
||||
|
||||
Args:
|
||||
adapters: User-provided list of error adapters
|
||||
auth_provider: The auth provider for the tool
|
||||
|
||||
Returns:
|
||||
A deduplicated list of error adapters with the HTTP adapter as fallback
|
||||
|
||||
Raises:
|
||||
ValueError: If any adapter doesn't follow the ErrorAdapter protocol
|
||||
"""
|
||||
adapter_chain = adapters or []
|
||||
|
||||
# Validate that all adapters follow the ErrorAdapter protocol
|
||||
if not all(isinstance(adapter, ErrorAdapter) for adapter in adapter_chain):
|
||||
invalid_adapters = [
|
||||
type(adapter).__name__
|
||||
for adapter in adapter_chain
|
||||
if not isinstance(adapter, ErrorAdapter)
|
||||
]
|
||||
raise ValueError(
|
||||
f"All adapters must follow the ErrorAdapter protocol. "
|
||||
f"Invalid adapters: {', '.join(invalid_adapters)}"
|
||||
)
|
||||
|
||||
# Add the adapter that is mapped to the tool's auth provider if it exists
|
||||
if auth_adapter := get_adapter_for_auth_provider(auth_provider):
|
||||
adapter_chain.append(auth_adapter)
|
||||
|
||||
# Always add HTTP adapter as the final adapter fallback
|
||||
adapter_chain.append(HTTPErrorAdapter())
|
||||
|
||||
# Remove duplicates from the adapter chain, preserving order
|
||||
seen_types = set()
|
||||
deduplicated_chain = []
|
||||
for adapter in adapter_chain:
|
||||
adapter_type = type(adapter)
|
||||
if adapter_type not in seen_types:
|
||||
seen_types.add(adapter_type)
|
||||
deduplicated_chain.append(adapter)
|
||||
|
||||
return deduplicated_chain
|
||||
|
||||
|
||||
def _raise_as_arcade_error(
|
||||
exception: Exception, adapter_chain: list[ErrorAdapter], tool_name: str, func_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Try to translate an exception using the adapter chain, then raise the translated error.
|
||||
If no adapter can translate the exception, a FatalToolError is raised.
|
||||
|
||||
Args:
|
||||
exception: The exception to translate to an Arcade Error
|
||||
adapter_chain: List of error adapters to try
|
||||
tool_name: The tool's display name for error messages
|
||||
func_name: The function name for developer messages
|
||||
|
||||
Raises:
|
||||
ToolRuntimeError or some subclass thereof
|
||||
"""
|
||||
for adapter in adapter_chain:
|
||||
mapped = adapter.from_exception(exception)
|
||||
if isinstance(mapped, ToolRuntimeError):
|
||||
raise mapped from exception
|
||||
|
||||
raise FatalToolError(
|
||||
message=f"{exception!s}",
|
||||
developer_message=f"{exception!s}",
|
||||
) from exception
|
||||
|
||||
|
||||
def tool(
|
||||
func: Callable | None = None,
|
||||
desc: str | None = None,
|
||||
name: str | None = None,
|
||||
requires_auth: Union[ToolAuthorization, None] = None,
|
||||
requires_secrets: Union[list[str], None] = None,
|
||||
requires_metadata: Union[list[str], None] = None,
|
||||
requires_auth: ToolAuthorization | None = None,
|
||||
requires_secrets: list[str] | None = None,
|
||||
requires_metadata: list[str] | None = None,
|
||||
adapters: list[ErrorAdapter] | None = None,
|
||||
) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = str(getattr(func, "__name__", None))
|
||||
|
|
@ -27,22 +110,19 @@ def tool(
|
|||
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
|
||||
func.__tool_requires_metadata__ = requires_metadata # type: ignore[attr-defined]
|
||||
|
||||
adapter_chain = _build_adapter_chain(adapters, requires_auth)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def func_with_error_handling(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# make sure developer raised ToolExecutionError is not
|
||||
# reraised incorrectly.
|
||||
except ToolExecutionError:
|
||||
except ToolRuntimeError:
|
||||
# re-raise as-is if it is already an Arcade Error
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ToolExecutionError(
|
||||
message=f"Error in execution of {tool_name}",
|
||||
developer_message=f"Error in {func_name}: {e!s}",
|
||||
) from e
|
||||
_raise_as_arcade_error(e, adapter_chain, tool_name, func_name)
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -50,13 +130,11 @@ def tool(
|
|||
def func_with_error_handling(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ToolExecutionError:
|
||||
except ToolRuntimeError:
|
||||
# re-raise as-is if it is already an Arcade Error
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ToolExecutionError(
|
||||
message=f"Error in execution of {tool_name}",
|
||||
developer_message=f"Error in {func_name}: {e!s}",
|
||||
) from e
|
||||
_raise_as_arcade_error(e, adapter_chain, tool_name, func_name)
|
||||
|
||||
return func_with_error_handling
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-tdk"
|
||||
version = "2.2.0"
|
||||
version = "2.3.0"
|
||||
description = "Arcade TDK - Toolkit Development Kit for building Arcade tools"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
@ -19,7 +19,7 @@ classifiers = [
|
|||
]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-core>=2.3.0,<3.0.0",
|
||||
"arcade-core>=2.4.0,<3.0.0",
|
||||
"pydantic>=2.7.0",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
from typing import Annotated, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from arcade_core.catalog import ToolCatalog
|
||||
from arcade_core.errors import ToolDefinitionError
|
||||
from arcade_core.schema import FullyQualifiedName
|
||||
from arcade_core.errors import (
|
||||
ToolDefinitionError,
|
||||
ToolInputSchemaError,
|
||||
ToolkitLoadError,
|
||||
ToolOutputSchemaError,
|
||||
)
|
||||
from arcade_core.schema import FullyQualifiedName, ToolContext
|
||||
from arcade_core.toolkit import Toolkit
|
||||
from arcade_tdk import tool
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -16,6 +23,141 @@ def sample_tool() -> str:
|
|||
return "Hello, world!"
|
||||
|
||||
|
||||
@tool
|
||||
def valid_tool(input_text: Annotated[str, "The text to process"]) -> str:
|
||||
"""
|
||||
A test tool that processes input text.
|
||||
|
||||
Args:
|
||||
input_text: The text to process
|
||||
|
||||
Returns:
|
||||
The processed text
|
||||
"""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_missing_input_parameter_annotation(input_text: str) -> str:
|
||||
"""
|
||||
A test tool that processes input text.
|
||||
|
||||
Args:
|
||||
input_text: The text to process
|
||||
|
||||
Returns:
|
||||
The processed text
|
||||
"""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
# Invalid tool examples for testing error cases
|
||||
|
||||
|
||||
# ToolDefinitionError cases
|
||||
def tool_missing_description(input_text: Annotated[str, "The text to process"]) -> str:
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool(requires_secrets=[123]) # type: ignore[misc]
|
||||
def tool_with_invalid_secret_type(input_text: Annotated[str, "The text"]) -> str:
|
||||
"""A tool with invalid secret type."""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool(requires_secrets=[""])
|
||||
def tool_with_empty_secret(input_text: Annotated[str, "The text"]) -> str:
|
||||
"""A tool with empty secret."""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool(requires_metadata=[123]) # type: ignore[misc]
|
||||
def tool_with_invalid_metadata_type(input_text: Annotated[str, "The text"]) -> str:
|
||||
"""A tool with invalid metadata type."""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool(requires_metadata=["client_id"]) # Requires auth but no auth provided
|
||||
def tool_with_metadata_requiring_auth_without_auth(input_text: Annotated[str, "The text"]) -> str:
|
||||
"""A tool with metadata requiring auth but no auth provided."""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool(requires_metadata=[""])
|
||||
def tool_with_empty_metadata(input_text: Annotated[str, "The text"]) -> str:
|
||||
"""A tool with empty metadata."""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
class MyFancyTestClass:
|
||||
pass
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_unsupported_param_type(
|
||||
param: Annotated[MyFancyTestClass, "A class that is a parameter"],
|
||||
) -> str:
|
||||
"""A tool with unsupported parameter type."""
|
||||
return "result"
|
||||
|
||||
|
||||
# ToolInputSchemaError cases
|
||||
@tool
|
||||
def tool_with_no_type_annotation(param) -> str: # type: ignore[no-untyped-def]
|
||||
"""A tool with untyped parameter."""
|
||||
return f"Result: {param}"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_invalid_param_name(param: Annotated[str, "123invalid", "Description"]) -> str:
|
||||
"""A tool with invalid parameter name."""
|
||||
return f"Result: {param}"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_too_many_annotations(param: Annotated[str, "name", "desc", "extra"]) -> str:
|
||||
"""A tool with an input parameter that has too many annotations."""
|
||||
return f"Result: {param}"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_required_union_param(param: Annotated[Union[str, int], "Union parameter"]) -> str:
|
||||
"""A tool with an input parameter that is a non-optional union type."""
|
||||
return f"Result: {param}"
|
||||
|
||||
|
||||
def non_callable_factory():
|
||||
raise RuntimeError("This should not be called")
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_non_callable_default_factory(
|
||||
param: Annotated[str, "Parameter"] = Field(default_factory="not_callable"), # type: ignore[arg-type]
|
||||
) -> str:
|
||||
"""A tool with an input parameter that has a non-callable default factory."""
|
||||
return f"Result: {param}"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_multiple_tool_contexts(ctx1: ToolContext, ctx2: ToolContext) -> str:
|
||||
"""A tool with multiple input parameters that are ToolContext."""
|
||||
return "result"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_missing_return_type_hint(input_text: Annotated[str, "The text to process"]):
|
||||
"""A tool without return type hint."""
|
||||
return f"Processed: {input_text}"
|
||||
|
||||
|
||||
@tool
|
||||
def tool_with_unsupported_output_type(
|
||||
input_text: Annotated[str, "The text to process"],
|
||||
) -> Annotated[MyFancyTestClass, "THe output type"]:
|
||||
"""A tool with an output parameter type that is not supported."""
|
||||
return MyFancyTestClass()
|
||||
|
||||
|
||||
def test_add_tool_with_empty_toolkit_name_raises():
|
||||
catalog = ToolCatalog()
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -98,6 +240,35 @@ def test_add_toolkit_type_error():
|
|||
catalog.add_toolkit(mock_toolkit)
|
||||
|
||||
|
||||
def test_add_toolkit_import_module_error():
|
||||
catalog = ToolCatalog()
|
||||
|
||||
# Create a mock toolkit with an invalid tool
|
||||
|
||||
mock_toolkit = Toolkit(
|
||||
name="mock_toolkit",
|
||||
description="A mock toolkit",
|
||||
version="0.0.1",
|
||||
package_name="mock_toolkit",
|
||||
)
|
||||
mock_toolkit.tools = {"mock_module": ["sample_tool"]}
|
||||
|
||||
# Mock the import_module and getattr functions
|
||||
with (
|
||||
patch("arcade_core.catalog.import_module") as mock_import,
|
||||
):
|
||||
mock_import.side_effect = ImportError("Mock import error")
|
||||
|
||||
# Assert that ToolkitLoadError is raised
|
||||
with pytest.raises(ToolkitLoadError) as exc_info:
|
||||
catalog.add_toolkit(mock_toolkit)
|
||||
|
||||
# Check that the error message contains the expected substring
|
||||
assert "Could not import module mock_module. Reason: Mock import error" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_by_name():
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(sample_tool, "sample_toolkit")
|
||||
|
|
@ -190,3 +361,157 @@ def test_add_tool_with_disabled_toolkit(monkeypatch):
|
|||
)
|
||||
)
|
||||
assert len(catalog._tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name, expected_error_type, expected_error_substring",
|
||||
[
|
||||
# ToolDefinitionError cases
|
||||
(
|
||||
"tool_missing_description",
|
||||
ToolDefinitionError,
|
||||
"Tool 'tool_missing_description' is missing a description",
|
||||
),
|
||||
(
|
||||
"tool_with_invalid_secret_type",
|
||||
ToolDefinitionError,
|
||||
"Secret keys must be strings (error in tool ToolWithInvalidSecretType)",
|
||||
),
|
||||
(
|
||||
"tool_with_empty_secret",
|
||||
ToolDefinitionError,
|
||||
"Secrets must have a non-empty key (error in tool ToolWithEmptySecret)",
|
||||
),
|
||||
(
|
||||
"tool_with_invalid_metadata_type",
|
||||
ToolDefinitionError,
|
||||
"Metadata must be strings (error in tool ToolWithInvalidMetadataType)",
|
||||
),
|
||||
(
|
||||
"tool_with_metadata_requiring_auth_without_auth",
|
||||
ToolDefinitionError,
|
||||
"Tool ToolWithMetadataRequiringAuthWithoutAuth declares metadata key 'client_id'",
|
||||
),
|
||||
(
|
||||
"tool_with_empty_metadata",
|
||||
ToolDefinitionError,
|
||||
"Metadata must have a non-empty key (error in tool ToolWithEmptyMetadata)",
|
||||
),
|
||||
(
|
||||
"tool_with_unsupported_param_type",
|
||||
ToolDefinitionError,
|
||||
"Unsupported parameter type: <class 'test_catalog.MyFancyTestClass'>",
|
||||
),
|
||||
# ToolInputSchemaError cases
|
||||
(
|
||||
"tool_with_missing_input_parameter_annotation",
|
||||
ToolInputSchemaError,
|
||||
"Parameter 'input_text' is missing a description",
|
||||
),
|
||||
(
|
||||
"tool_with_no_type_annotation",
|
||||
ToolInputSchemaError,
|
||||
"Parameter param has no type annotation",
|
||||
),
|
||||
(
|
||||
"tool_with_invalid_param_name",
|
||||
ToolInputSchemaError,
|
||||
"Invalid parameter name: '123invalid' is not a valid identifier",
|
||||
),
|
||||
(
|
||||
"tool_with_too_many_annotations",
|
||||
ToolInputSchemaError,
|
||||
"Parameter param: Annotated[str, 'name', 'desc', 'extra'] has too many string annotations. Expected 0, 1, or 2, got 3",
|
||||
),
|
||||
(
|
||||
"tool_with_required_union_param",
|
||||
ToolInputSchemaError,
|
||||
"Parameter param is a union type. Only optional types are supported",
|
||||
),
|
||||
(
|
||||
"tool_with_non_callable_default_factory",
|
||||
ToolInputSchemaError,
|
||||
"Default factory for parameter param: Annotated[str, 'Parameter'] = FieldInfo(annotation=NoneType, required=False, default_factory=str) is not callable.",
|
||||
),
|
||||
(
|
||||
"tool_with_multiple_tool_contexts",
|
||||
ToolInputSchemaError,
|
||||
"Only one ToolContext parameter is supported, but tool tool_with_multiple_tool_contexts has multiple",
|
||||
),
|
||||
(
|
||||
"tool_missing_return_type_hint",
|
||||
ToolOutputSchemaError,
|
||||
"Tool 'ToolMissingReturnTypeHint' must have a return type",
|
||||
),
|
||||
(
|
||||
"tool_with_unsupported_output_type",
|
||||
ToolOutputSchemaError,
|
||||
"Unsupported output type '<class 'test_catalog.MyFancyTestClass'>'",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_add_toolkit_with_invalid_tools(
|
||||
tool_name: str, expected_error_type: type, expected_error_substring: str
|
||||
):
|
||||
"""Test that add_toolkit raises the correct error for various invalid tool definitions."""
|
||||
catalog = ToolCatalog()
|
||||
|
||||
# Create a toolkit that references our test tool
|
||||
test_toolkit = Toolkit(
|
||||
name="test_toolkit",
|
||||
description="A test toolkit",
|
||||
version="1.0.0",
|
||||
package_name="test_toolkit",
|
||||
)
|
||||
test_toolkit.tools = {"tests.core.test_catalog": [tool_name]}
|
||||
|
||||
# Mock the import_module to return the current module
|
||||
import sys
|
||||
|
||||
current_module = sys.modules[__name__]
|
||||
|
||||
with patch("arcade_core.catalog.import_module") as mock_import:
|
||||
mock_import.return_value = current_module
|
||||
|
||||
# Add the toolkit and expect the specified error
|
||||
with pytest.raises(expected_error_type) as exc_info:
|
||||
catalog.add_toolkit(test_toolkit)
|
||||
|
||||
# Check that the error message contains the expected substring
|
||||
actual_error_message = str(exc_info.value)
|
||||
# Adjust for Python 3.11 and below where Annotated is returned as "typing.Annotated"
|
||||
if "typing.Annotated" in actual_error_message:
|
||||
expected_error_substring = expected_error_substring.replace(
|
||||
"Annotated", "typing.Annotated"
|
||||
)
|
||||
assert expected_error_substring in actual_error_message
|
||||
|
||||
|
||||
def test_add_toolkit_with_duplicate_tool():
|
||||
"""Test that add_toolkit raises ToolkitLoadError when a tool already exists in the catalog."""
|
||||
catalog = ToolCatalog()
|
||||
|
||||
test_toolkit = Toolkit(
|
||||
name="test_toolkit",
|
||||
description="A test toolkit",
|
||||
version="1.0.0",
|
||||
package_name="test_toolkit",
|
||||
)
|
||||
test_toolkit.tools = {"tests.core.test_catalog": ["valid_tool", "valid_tool"]}
|
||||
|
||||
# Mock the import_module to return the current module
|
||||
import sys
|
||||
|
||||
current_module = sys.modules[__name__]
|
||||
|
||||
with patch("arcade_core.catalog.import_module") as mock_import:
|
||||
mock_import.return_value = current_module
|
||||
|
||||
# Adding the toolkit should raise ToolkitLoadError for duplicate tool
|
||||
with pytest.raises(ToolkitLoadError) as exc_info:
|
||||
catalog.add_toolkit(test_toolkit)
|
||||
|
||||
# Check that the error message contains the expected substring
|
||||
assert "Tool 'ValidTool' in toolkit 'test_toolkit' already exists in the catalog." in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,20 @@ from typing import Annotated
|
|||
|
||||
import pytest
|
||||
from arcade_core.catalog import ToolCatalog
|
||||
from arcade_core.errors import (
|
||||
ContextRequiredToolError,
|
||||
ErrorKind,
|
||||
ToolRuntimeError,
|
||||
UpstreamError,
|
||||
UpstreamRateLimitError,
|
||||
)
|
||||
from arcade_core.executor import ToolExecutor
|
||||
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput, ToolContext
|
||||
from arcade_tdk import tool
|
||||
from arcade_tdk.errors import RetryableToolError, ToolExecutionError
|
||||
from arcade_tdk.errors import (
|
||||
RetryableToolError,
|
||||
ToolExecutionError,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
|
|
@ -25,13 +35,13 @@ def simple_deprecated_tool(inp: Annotated[str, "input"]) -> Annotated[str, "outp
|
|||
@tool
|
||||
def retryable_error_tool() -> Annotated[str, "output"]:
|
||||
"""Tool that raises a retryable error"""
|
||||
raise RetryableToolError("test", "test", "test", 1000)
|
||||
raise RetryableToolError("test", "test developer message", "additional prompt content", 1000)
|
||||
|
||||
|
||||
@tool
|
||||
def exec_error_tool() -> Annotated[str, "output"]:
|
||||
def tool_execution_error_tool() -> Annotated[str, "output"]:
|
||||
"""Tool that raises an error"""
|
||||
raise ToolExecutionError("test", "test")
|
||||
raise ToolExecutionError("test", "test developer message")
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -40,6 +50,34 @@ def unexpected_error_tool() -> Annotated[str, "output"]:
|
|||
raise RuntimeError("test")
|
||||
|
||||
|
||||
@tool
|
||||
def context_required_error_tool() -> Annotated[str, "output"]:
|
||||
"""Tool that raises a context required error"""
|
||||
raise ContextRequiredToolError(
|
||||
"test", additional_prompt_content="need the user to clarify something"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def upstream_error_tool() -> Annotated[str, "output"]:
|
||||
"""Tool that raises an upstream error"""
|
||||
# TODO: or test raising a httpx error? Do these types of tests belong in adapter tests?
|
||||
raise UpstreamError("test", status_code=400)
|
||||
|
||||
|
||||
@tool
|
||||
def upstream_ratelimit_error_tool() -> Annotated[str, "output"]:
|
||||
"""Tool that raises an upstream error"""
|
||||
# TODO: or test raising a httpx error? Do these types of tests belong in adapter tests?
|
||||
raise UpstreamRateLimitError("test", 1000)
|
||||
|
||||
|
||||
@tool
|
||||
def tool_runtime_error_tool() -> Annotated[str, "output"]:
|
||||
"""Tool that raises a tool runtime error"""
|
||||
raise ToolRuntimeError("test", "test developer message")
|
||||
|
||||
|
||||
@tool
|
||||
def bad_output_error_tool() -> Annotated[str, "output"]:
|
||||
"""tool that returns a bad output type"""
|
||||
|
|
@ -77,17 +115,24 @@ def dict_output_tool() -> Annotated[dict, "Returns a plain dict"]:
|
|||
|
||||
|
||||
# ---- Test Driver ----
|
||||
|
||||
tools = [
|
||||
simple_tool,
|
||||
simple_deprecated_tool,
|
||||
retryable_error_tool,
|
||||
tool_execution_error_tool,
|
||||
unexpected_error_tool,
|
||||
context_required_error_tool,
|
||||
upstream_error_tool,
|
||||
upstream_ratelimit_error_tool,
|
||||
tool_runtime_error_tool,
|
||||
bad_output_error_tool,
|
||||
typeddict_output_tool,
|
||||
list_typeddict_output_tool,
|
||||
dict_output_tool,
|
||||
]
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(simple_tool, "simple_toolkit")
|
||||
catalog.add_tool(simple_deprecated_tool, "simple_toolkit")
|
||||
catalog.add_tool(retryable_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(exec_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(unexpected_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(bad_output_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(typeddict_output_tool, "simple_toolkit")
|
||||
catalog.add_tool(list_typeddict_output_tool, "simple_toolkit")
|
||||
catalog.add_tool(dict_output_tool, "simple_toolkit")
|
||||
for tool_func in tools:
|
||||
catalog.add_tool(tool_func, "simple_toolkit")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -114,21 +159,24 @@ catalog.add_tool(dict_output_tool, "simple_toolkit")
|
|||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="test",
|
||||
developer_message="test",
|
||||
additional_prompt_content="test",
|
||||
message="[TOOL_RUNTIME_RETRY] RetryableToolError during execution of tool 'retryable_error_tool': test",
|
||||
kind=ErrorKind.TOOL_RUNTIME_RETRY,
|
||||
developer_message="[TOOL_RUNTIME_RETRY] RetryableToolError during execution of tool 'retryable_error_tool': test developer message",
|
||||
additional_prompt_content="additional prompt content",
|
||||
retry_after_ms=1000,
|
||||
can_retry=True,
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
exec_error_tool,
|
||||
tool_execution_error_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="test",
|
||||
developer_message="test",
|
||||
message="[TOOL_RUNTIME_FATAL] ToolExecutionError during execution of tool 'tool_execution_error_tool': test",
|
||||
kind=ErrorKind.TOOL_RUNTIME_FATAL,
|
||||
developer_message="[TOOL_RUNTIME_FATAL] ToolExecutionError during execution of tool 'tool_execution_error_tool': test developer message",
|
||||
can_retry=False,
|
||||
)
|
||||
),
|
||||
),
|
||||
|
|
@ -137,8 +185,11 @@ catalog.add_tool(dict_output_tool, "simple_toolkit")
|
|||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="Error in execution of UnexpectedErrorTool",
|
||||
developer_message="Error in unexpected_error_tool: test",
|
||||
message="[TOOL_RUNTIME_FATAL] FatalToolError during execution of tool 'unexpected_error_tool': test",
|
||||
kind=ErrorKind.TOOL_RUNTIME_FATAL,
|
||||
developer_message="[TOOL_RUNTIME_FATAL] FatalToolError during execution of tool 'unexpected_error_tool': test",
|
||||
can_retry=False,
|
||||
status_code=500,
|
||||
)
|
||||
),
|
||||
),
|
||||
|
|
@ -147,17 +198,71 @@ catalog.add_tool(dict_output_tool, "simple_toolkit")
|
|||
{"inp": {"test": "test"}}, # takes in a string not a dict
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="Error in tool input deserialization",
|
||||
message="[TOOL_RUNTIME_BAD_INPUT_VALUE] ToolInputError during execution of tool 'simple_tool': Error in tool input deserialization",
|
||||
kind=ErrorKind.TOOL_RUNTIME_BAD_INPUT_VALUE,
|
||||
status_code=400,
|
||||
developer_message=None, # can't gaurantee this will be the same
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
context_required_error_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="[TOOL_RUNTIME_CONTEXT_REQUIRED] ContextRequiredToolError during execution of tool 'context_required_error_tool': test",
|
||||
kind=ErrorKind.TOOL_RUNTIME_CONTEXT_REQUIRED,
|
||||
developer_message=None,
|
||||
additional_prompt_content="need the user to clarify something",
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
upstream_error_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="[UPSTREAM_RUNTIME_BAD_REQUEST] UpstreamError during execution of tool 'upstream_error_tool': test",
|
||||
kind=ErrorKind.UPSTREAM_RUNTIME_BAD_REQUEST,
|
||||
status_code=400,
|
||||
developer_message=None,
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
upstream_ratelimit_error_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="[UPSTREAM_RUNTIME_RATE_LIMIT] UpstreamRateLimitError during execution of tool 'upstream_ratelimit_error_tool': test",
|
||||
kind=ErrorKind.UPSTREAM_RUNTIME_RATE_LIMIT,
|
||||
status_code=429,
|
||||
developer_message=None,
|
||||
retry_after_ms=1000,
|
||||
can_retry=True,
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
tool_runtime_error_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="[TOOL_RUNTIME_FATAL] ToolRuntimeError during execution of tool 'tool_runtime_error_tool': test",
|
||||
kind=ErrorKind.TOOL_RUNTIME_FATAL,
|
||||
developer_message="[TOOL_RUNTIME_FATAL] ToolRuntimeError during execution of tool 'tool_runtime_error_tool': test developer message",
|
||||
can_retry=False,
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
bad_output_error_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="Failed to serialize tool output",
|
||||
message="[TOOL_RUNTIME_BAD_OUTPUT_VALUE] ToolOutputError during execution of tool 'bad_output_error_tool': Failed to serialize tool output",
|
||||
kind=ErrorKind.TOOL_RUNTIME_BAD_OUTPUT_VALUE,
|
||||
status_code=500,
|
||||
developer_message=None, # can't gaurantee this will be the same
|
||||
)
|
||||
),
|
||||
|
|
@ -190,6 +295,10 @@ catalog.add_tool(dict_output_tool, "simple_toolkit")
|
|||
"exec_error_tool",
|
||||
"unexpected_error_tool",
|
||||
"invalid_input_type",
|
||||
"context_required_error_tool",
|
||||
"upstream_error_tool",
|
||||
"upstream_ratelimit_error_tool",
|
||||
"tool_runtime_error_tool",
|
||||
"bad_output_type",
|
||||
"typeddict_output",
|
||||
"list_typeddict_output",
|
||||
|
|
@ -213,20 +322,28 @@ async def test_tool_executor(tool_func, inputs, expected_output):
|
|||
check_output(output, expected_output)
|
||||
|
||||
|
||||
def check_output(output: ToolCallOutput, expected_output: ToolCallOutput):
|
||||
# execution error in tool
|
||||
if output.error:
|
||||
assert output.error.message == expected_output.error.message
|
||||
if expected_output.error.developer_message:
|
||||
assert output.error.developer_message == expected_output.error.developer_message
|
||||
if expected_output.error.traceback_info:
|
||||
assert output.error.traceback_info == expected_output.error.traceback_info
|
||||
assert output.error.can_retry == expected_output.error.can_retry
|
||||
def check_output_error(output_error: ToolCallError, expected_error: ToolCallError):
|
||||
assert output_error.message == expected_error.message, "message mismatch"
|
||||
assert output_error.kind == expected_error.kind, "kind mismatch"
|
||||
if expected_error.developer_message:
|
||||
assert (
|
||||
output.error.additional_prompt_content
|
||||
== expected_output.error.additional_prompt_content
|
||||
)
|
||||
assert output.error.retry_after_ms == expected_output.error.retry_after_ms
|
||||
output_error.developer_message == expected_error.developer_message
|
||||
), "developer message mismatch"
|
||||
assert output_error.can_retry == expected_error.can_retry, "can retry mismatch"
|
||||
assert (
|
||||
output_error.additional_prompt_content == expected_error.additional_prompt_content
|
||||
), "additional prompt content mismatch"
|
||||
assert output_error.retry_after_ms == expected_error.retry_after_ms, "retry after ms mismatch"
|
||||
if expected_error.stacktrace:
|
||||
assert output_error.stacktrace == expected_error.stacktrace, "stacktrace mismatch"
|
||||
assert output_error.status_code == expected_error.status_code, "status code mismatch"
|
||||
assert output_error.extra == expected_error.extra, "extra mismatch"
|
||||
|
||||
|
||||
def check_output(output: ToolCallOutput, expected_output: ToolCallOutput):
|
||||
# error in ToolCallOutput
|
||||
if output.error:
|
||||
check_output_error(output.error, expected_output.error)
|
||||
|
||||
# normal tool execution
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Tests for ToolCallOutput schema validation with complex types.
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from arcade_core.errors import ErrorKind
|
||||
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -139,6 +140,7 @@ class TestToolCallOutputValidation:
|
|||
message="Partial failure",
|
||||
developer_message="Some items failed to process",
|
||||
can_retry=True,
|
||||
kind=ErrorKind.TOOL_RUNTIME_RETRY,
|
||||
)
|
||||
)
|
||||
assert output.error.message == "Partial failure"
|
||||
|
|
|
|||
510
libs/tests/sdk/test_google_adapter.py
Normal file
510
libs/tests/sdk/test_google_adapter.py
Normal file
|
|
@ -0,0 +1,510 @@
|
|||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from arcade_core.errors import UpstreamError, UpstreamRateLimitError
|
||||
from arcade_tdk.providers.google.error_adapter import GoogleErrorAdapter
|
||||
|
||||
|
||||
class TestGoogleErrorAdapter:
|
||||
"""Test the Google error adapter functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
self.adapter = GoogleErrorAdapter()
|
||||
|
||||
def _create_mock_errors_module(self):
|
||||
"""Create a mock errors module with all necessary error classes."""
|
||||
|
||||
class MockHttpError(Exception):
|
||||
pass
|
||||
|
||||
class MockBatchError(Exception):
|
||||
pass
|
||||
|
||||
class MockInvalidJsonError(Exception):
|
||||
pass
|
||||
|
||||
class MockUnknownApiNameOrVersion(Exception):
|
||||
pass
|
||||
|
||||
class MockUnacceptableMimeTypeError(Exception):
|
||||
pass
|
||||
|
||||
class MockMediaUploadSizeError(Exception):
|
||||
pass
|
||||
|
||||
class MockInvalidChunkSizeError(Exception):
|
||||
pass
|
||||
|
||||
class MockInvalidNotificationError(Exception):
|
||||
pass
|
||||
|
||||
mock_errors = Mock()
|
||||
mock_errors.HttpError = MockHttpError
|
||||
mock_errors.BatchError = MockBatchError
|
||||
mock_errors.InvalidJsonError = MockInvalidJsonError
|
||||
mock_errors.UnknownApiNameOrVersion = MockUnknownApiNameOrVersion
|
||||
mock_errors.UnacceptableMimeTypeError = MockUnacceptableMimeTypeError
|
||||
mock_errors.MediaUploadSizeError = MockMediaUploadSizeError
|
||||
mock_errors.InvalidChunkSizeError = MockInvalidChunkSizeError
|
||||
mock_errors.InvalidNotificationError = MockInvalidNotificationError
|
||||
|
||||
return mock_errors
|
||||
|
||||
def test_adapter_slug(self):
|
||||
"""Test that the adapter has the correct slug."""
|
||||
assert GoogleErrorAdapter.slug == "_google_api_client"
|
||||
|
||||
def test_sanitize_uri_removes_query_params(self):
|
||||
"""Test URI sanitization removes query parameters."""
|
||||
uri = "https://www.googleapis.com/drive/v3/files/123?key=secret&fields=id,name"
|
||||
result = self.adapter._sanitize_uri(uri)
|
||||
assert result == "https://www.googleapis.com/drive/v3/files/123"
|
||||
|
||||
def test_sanitize_uri_removes_fragments(self):
|
||||
"""Test URI sanitization removes fragments."""
|
||||
uri = "https://www.googleapis.com/gmail/v1/users/me/messages#inbox"
|
||||
result = self.adapter._sanitize_uri(uri)
|
||||
assert result == "https://www.googleapis.com/gmail/v1/users/me/messages"
|
||||
|
||||
def test_sanitize_uri_handles_trailing_slashes(self):
|
||||
"""Test URI sanitization handles trailing slashes."""
|
||||
uri = "https://www.googleapis.com///sheets/v4/spreadsheets///"
|
||||
result = self.adapter._sanitize_uri(uri)
|
||||
assert result == "https://www.googleapis.com/sheets/v4/spreadsheets"
|
||||
|
||||
def test_parse_retry_after_with_seconds(self):
|
||||
"""Test parsing retry-after header with seconds value."""
|
||||
mock_error = Mock()
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.headers = {"Retry-After": "120"}
|
||||
|
||||
result = self.adapter._parse_retry_after(mock_error)
|
||||
assert result == 120_000
|
||||
|
||||
def test_parse_retry_after_with_lowercase_header(self):
|
||||
"""Test parsing retry-after header with lowercase key."""
|
||||
mock_error = Mock()
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.headers = {"retry-after": "60"}
|
||||
|
||||
result = self.adapter._parse_retry_after(mock_error)
|
||||
assert result == 60_000
|
||||
|
||||
def test_parse_retry_after_with_date_format(self):
|
||||
"""Test parsing retry-after header with absolute date format."""
|
||||
future_date = "Mon, 01 Jan 2029 12:00:00 GMT"
|
||||
mock_error = Mock()
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.headers = {"Retry-After": future_date}
|
||||
|
||||
with patch("arcade_tdk.providers.google.error_adapter.datetime") as mock_datetime:
|
||||
parsed_date = datetime(2029, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
mock_datetime.strptime.return_value = parsed_date
|
||||
|
||||
# Mock datetime.now() to return a time before the parsed date
|
||||
current_time = datetime(2029, 1, 1, 11, 58, 0, tzinfo=timezone.utc)
|
||||
mock_datetime.now.return_value = current_time
|
||||
mock_datetime.timezone = timezone
|
||||
|
||||
result = self.adapter._parse_retry_after(mock_error)
|
||||
assert result == 120_000 # 2 minute diff
|
||||
|
||||
def test_parse_retry_after_no_headers(self):
|
||||
"""Test parsing retry-after when no headers are present."""
|
||||
mock_error = Mock()
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.headers = {}
|
||||
|
||||
result = self.adapter._parse_retry_after(mock_error)
|
||||
assert result == 1_000
|
||||
|
||||
def test_parse_retry_after_no_resp_attribute(self):
|
||||
"""Test parsing retry-after when error has no resp attribute."""
|
||||
mock_error = Mock()
|
||||
del mock_error.resp
|
||||
|
||||
result = self.adapter._parse_retry_after(mock_error)
|
||||
assert result == 1_000 # defaults to 1 second
|
||||
|
||||
def test_parse_retry_after_invalid_date(self):
|
||||
"""Test parsing retry-after with invalid date format falls back to default."""
|
||||
mock_error = Mock()
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.headers = {"Retry-After": "invalid-date"}
|
||||
|
||||
result = self.adapter._parse_retry_after(mock_error)
|
||||
assert result == 1_000
|
||||
|
||||
def test_map_http_error_basic(self):
|
||||
"""Test mapping basic HTTP error."""
|
||||
mock_error = Mock()
|
||||
mock_error.status_code = 404
|
||||
mock_error.reason = "Not Found"
|
||||
mock_error.error_details = None
|
||||
mock_error.uri = "https://www.googleapis.com/drive/v3/files/missing"
|
||||
mock_error.method_ = "get"
|
||||
|
||||
result = self.adapter._map_http_error(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert not isinstance(result, UpstreamRateLimitError)
|
||||
assert result.status_code == 404
|
||||
assert result.message == "Upstream Google API error: Not Found"
|
||||
assert result.extra["service"] == "_google_api_client"
|
||||
assert result.extra["endpoint"] == "https://www.googleapis.com/drive/v3/files/missing"
|
||||
assert result.extra["http_method"] == "GET"
|
||||
|
||||
def test_map_http_error_with_string_details(self):
|
||||
"""Test mapping HTTP error with string error details."""
|
||||
mock_error = Mock()
|
||||
mock_error.status_code = 400
|
||||
mock_error.reason = "Bad Request"
|
||||
mock_error.error_details = "Invalid field value"
|
||||
mock_error.uri = "https://www.googleapis.com/sheets/v4/spreadsheets"
|
||||
mock_error.method_ = "post"
|
||||
|
||||
result = self.adapter._map_http_error(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert "Invalid field value" in result.message
|
||||
assert result.extra["service"] == "_google_api_client"
|
||||
assert result.extra["http_method"] == "POST"
|
||||
|
||||
def test_map_http_error_with_structured_details(self):
|
||||
"""Test mapping HTTP error with structured error details."""
|
||||
mock_error = Mock()
|
||||
mock_error.status_code = 403
|
||||
mock_error.reason = "Forbidden"
|
||||
mock_error.error_details = {"error": {"code": 403, "message": "Insufficient permissions"}}
|
||||
mock_error.uri = "https://www.googleapis.com/drive/v3/files"
|
||||
mock_error.method_ = "delete"
|
||||
|
||||
result = self.adapter._map_http_error(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 403
|
||||
assert result.message == "Upstream Google API error: Forbidden"
|
||||
assert "Upstream Google API error details" in result.developer_message
|
||||
assert result.extra["http_method"] == "DELETE"
|
||||
|
||||
def test_map_http_error_rate_limit(self):
|
||||
"""Test mapping 429 rate limit error."""
|
||||
mock_error = Mock()
|
||||
mock_error.status_code = 429
|
||||
mock_error.reason = "Too Many Requests"
|
||||
mock_error.error_details = None
|
||||
mock_error.uri = "https://www.googleapis.com/gmail/v1/users/me/messages"
|
||||
mock_error.method_ = "get"
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.headers = {"Retry-After": "30"}
|
||||
|
||||
result = self.adapter._map_http_error(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamRateLimitError)
|
||||
assert result.retry_after_ms == 30_000
|
||||
assert result.message == "Upstream Google API error: Too Many Requests"
|
||||
assert result.extra["service"] == "_google_api_client"
|
||||
assert result.extra["endpoint"] == "https://www.googleapis.com/gmail/v1/users/me/messages"
|
||||
assert result.extra["http_method"] == "GET"
|
||||
|
||||
def test_map_http_error_no_reason(self):
|
||||
"""Test mapping HTTP error with no reason."""
|
||||
mock_error = Mock()
|
||||
mock_error.status_code = 500
|
||||
mock_error.reason = None
|
||||
mock_error.error_details = None
|
||||
mock_error.uri = "https://www.googleapis.com/calendar/v3/calendars"
|
||||
mock_error.method_ = "post"
|
||||
|
||||
result = self.adapter._map_http_error(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 500
|
||||
assert result.message == "Upstream Google API error: HTTP 500 error"
|
||||
|
||||
def test_map_http_error_missing_attributes(self):
|
||||
"""Test mapping HTTP error without uri and method attributes."""
|
||||
mock_error = Mock()
|
||||
mock_error.status_code = 503
|
||||
mock_error.reason = "Service Unavailable"
|
||||
mock_error.error_details = None
|
||||
|
||||
# Remove uri and method_ attributes
|
||||
if hasattr(mock_error, "uri"):
|
||||
del mock_error.uri
|
||||
if hasattr(mock_error, "method_"):
|
||||
del mock_error.method_
|
||||
|
||||
result = self.adapter._map_http_error(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 503
|
||||
assert result.extra["service"] == "_google_api_client"
|
||||
assert "endpoint" not in result.extra
|
||||
assert "http_method" not in result.extra
|
||||
|
||||
def test_handle_http_errors_with_http_error(self):
|
||||
"""Test handling HttpError exceptions."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create mock error instance
|
||||
mock_error = mock_errors.HttpError()
|
||||
mock_error.status_code = 401
|
||||
mock_error.reason = "Unauthorized"
|
||||
mock_error.error_details = None
|
||||
mock_error.uri = "https://www.googleapis.com/drive/v3/files"
|
||||
mock_error.method_ = "get"
|
||||
|
||||
result = self.adapter._handle_http_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 401
|
||||
assert result.message == "Upstream Google API error: Unauthorized"
|
||||
|
||||
def test_handle_http_errors_with_batch_error_with_status(self):
|
||||
"""Test handling BatchError with response status."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create mock error instance
|
||||
mock_error = mock_errors.BatchError()
|
||||
mock_error.reason = "Batch operation failed"
|
||||
mock_error.error_details = None
|
||||
mock_error.resp = Mock()
|
||||
mock_error.resp.status = 400
|
||||
|
||||
result = self.adapter._handle_http_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert result.message == "Upstream Google API error: Batch operation failed"
|
||||
|
||||
def test_handle_http_errors_with_batch_error_no_status(self):
|
||||
"""Test handling BatchError without response status."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create mock error instance
|
||||
mock_error = mock_errors.BatchError()
|
||||
mock_error.reason = "Batch operation failed"
|
||||
|
||||
result = self.adapter._handle_http_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 500
|
||||
assert (
|
||||
result.message == "Upstream Google API batch operation failed: Batch operation failed"
|
||||
)
|
||||
assert result.extra["service"] == "google_api"
|
||||
assert result.extra["error_type"] == "BatchError"
|
||||
|
||||
def test_handle_http_errors_unhandled_exception(self):
|
||||
"""Test handling non-HTTP exceptions returns None."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create a non-HTTP exception
|
||||
mock_error = ValueError("Not an HTTP error")
|
||||
|
||||
result = self.adapter._handle_http_errors(mock_error, mock_errors)
|
||||
assert result is None
|
||||
|
||||
def test_handle_other_errors_invalid_json_error(self):
|
||||
"""Test handling InvalidJsonError."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.InvalidJsonError("Invalid JSON response")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 502
|
||||
assert result.message == "Upstream Google API returned invalid JSON response"
|
||||
assert result.developer_message == "Invalid JSON response"
|
||||
assert result.extra["service"] == "_google_api_client"
|
||||
assert result.extra["error_type"] == "InvalidJsonError"
|
||||
|
||||
def test_handle_other_errors_unknown_api_name_or_version(self):
|
||||
"""Test handling UnknownApiNameOrVersion."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.UnknownApiNameOrVersion("Unknown API: nonexistent/v1")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 404
|
||||
assert result.message == "Upstream Google API error: Unknown API name or version"
|
||||
assert result.developer_message == "Unknown API: nonexistent/v1"
|
||||
assert result.extra["error_type"] == "UnknownApiNameOrVersion"
|
||||
|
||||
def test_handle_other_errors_unacceptable_mime_type_error(self):
|
||||
"""Test handling UnacceptableMimeTypeError."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.UnacceptableMimeTypeError("MIME type not supported")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert (
|
||||
result.message == "Upstream Google API error: Unacceptable MIME type for this operation"
|
||||
)
|
||||
assert result.developer_message == "MIME type not supported"
|
||||
assert result.extra["error_type"] == "UnacceptableMimeTypeError"
|
||||
|
||||
def test_handle_other_errors_media_upload_size_error(self):
|
||||
"""Test handling MediaUploadSizeError."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.MediaUploadSizeError("File size exceeds 5GB limit")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert result.message == "Upstream Google API error: Media file size exceeds allowed limit"
|
||||
assert result.developer_message == "File size exceeds 5GB limit"
|
||||
assert result.extra["error_type"] == "MediaUploadSizeError"
|
||||
|
||||
def test_handle_other_errors_invalid_chunk_size_error(self):
|
||||
"""Test handling InvalidChunkSizeError."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.InvalidChunkSizeError("Chunk size must be multiple of 256KB")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert result.message == "Upstream Google API error: Invalid chunk size specified"
|
||||
assert result.developer_message == "Chunk size must be multiple of 256KB"
|
||||
assert result.extra["error_type"] == "InvalidChunkSizeError"
|
||||
|
||||
def test_handle_other_errors_invalid_notification_error(self):
|
||||
"""Test handling InvalidNotificationError."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.InvalidNotificationError("Invalid webhook URL")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert result.message == "Upstream Google API error: Invalid notification configuration"
|
||||
assert result.developer_message == "Invalid webhook URL"
|
||||
assert result.extra["error_type"] == "InvalidNotificationError"
|
||||
|
||||
def test_handle_other_errors_unhandled_exception(self):
|
||||
"""Test handling non-Google API exceptions returns None."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create a non-Google API exception
|
||||
mock_error = ValueError("Not a Google API error")
|
||||
|
||||
result = self.adapter._handle_other_errors(mock_error, mock_errors)
|
||||
assert result is None
|
||||
|
||||
def test_from_exception_googleapiclient_not_installed(self, caplog):
|
||||
"""Test handling when googleapiclient is not installed."""
|
||||
with (
|
||||
patch("arcade_tdk.providers.google.error_adapter.logger") as mock_logger,
|
||||
patch.dict("sys.modules", {"googleapiclient": None}),
|
||||
patch(
|
||||
"builtins.__import__",
|
||||
side_effect=ImportError("No module named 'googleapiclient'"),
|
||||
),
|
||||
):
|
||||
mock_exc = Exception("test exception")
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert result is None
|
||||
mock_logger.info.assert_called_once()
|
||||
warning_message = mock_logger.info.call_args[0][0]
|
||||
assert "'googleapiclient' is not installed" in warning_message
|
||||
assert "_google_api_client" in warning_message
|
||||
|
||||
def test_from_exception_http_error_handling(self):
|
||||
"""Test full from_exception flow with HTTP error."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create mock error instance
|
||||
mock_error = mock_errors.HttpError()
|
||||
mock_error.status_code = 403
|
||||
mock_error.reason = "Forbidden"
|
||||
mock_error.error_details = None
|
||||
mock_error.uri = "https://www.googleapis.com/drive/v3/files"
|
||||
mock_error.method_ = "get"
|
||||
|
||||
# Create mock googleapiclient module
|
||||
mock_googleapiclient = Mock()
|
||||
mock_googleapiclient.errors = mock_errors
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors},
|
||||
):
|
||||
result = self.adapter.from_exception(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 403
|
||||
assert result.message == "Upstream Google API error: Forbidden"
|
||||
|
||||
def test_from_exception_other_error_handling(self):
|
||||
"""Test full from_exception flow with other error types."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
mock_error = mock_errors.InvalidJsonError("Invalid JSON")
|
||||
|
||||
# Create mock googleapiclient module
|
||||
mock_googleapiclient = Mock()
|
||||
mock_googleapiclient.errors = mock_errors
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors},
|
||||
):
|
||||
result = self.adapter.from_exception(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 502
|
||||
assert result.message == "Upstream Google API returned invalid JSON response"
|
||||
|
||||
def test_from_exception_fallback_for_unhandled_google_error(self):
|
||||
"""Test fallback handling for unhandled Google API errors."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create an unhandled Google API error
|
||||
class MockUnhandledError(Exception):
|
||||
pass
|
||||
|
||||
mock_error = MockUnhandledError("Some unhandled Google error")
|
||||
mock_error.__module__ = "googleapiclient.errors"
|
||||
|
||||
# Create mock googleapiclient module
|
||||
mock_googleapiclient = Mock()
|
||||
mock_googleapiclient.errors = mock_errors
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors},
|
||||
):
|
||||
result = self.adapter.from_exception(mock_error)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 500
|
||||
assert result.message == "Upstream Google API error: Some unhandled Google error"
|
||||
assert result.extra["service"] == "_google_api_client"
|
||||
assert result.extra["error_type"] == "MockUnhandledError"
|
||||
|
||||
def test_from_exception_non_google_error(self):
|
||||
"""Test handling non-Google API errors returns None."""
|
||||
mock_errors = self._create_mock_errors_module()
|
||||
|
||||
# Create a non-Google API error
|
||||
mock_error = ValueError("Not a Google error")
|
||||
mock_error.__module__ = "builtins"
|
||||
|
||||
# Create mock googleapiclient module
|
||||
mock_googleapiclient = Mock()
|
||||
mock_googleapiclient.errors = mock_errors
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{"googleapiclient": mock_googleapiclient, "googleapiclient.errors": mock_errors},
|
||||
):
|
||||
result = self.adapter.from_exception(mock_error)
|
||||
|
||||
assert result is None
|
||||
333
libs/tests/sdk/test_httpx_adapter.py
Normal file
333
libs/tests/sdk/test_httpx_adapter.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from arcade_core.errors import UpstreamError, UpstreamRateLimitError
|
||||
from arcade_tdk.providers.http.error_adapter import BaseHTTPErrorMapper, HTTPErrorAdapter
|
||||
|
||||
|
||||
class TestBaseHTTPErrorMapper:
|
||||
"""Test the base HTTP error mapper functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
self.mapper = BaseHTTPErrorMapper()
|
||||
|
||||
def test_parse_retry_ms_with_retry_after_seconds(self):
|
||||
"""Test parsing retry-after header with seconds value."""
|
||||
headers = {"retry-after": "60"}
|
||||
result = self.mapper._parse_retry_ms(headers)
|
||||
assert result == 60_000
|
||||
|
||||
def test_parse_retry_ms_with_x_ratelimit_reset(self):
|
||||
"""Test parsing x-ratelimit-reset header with seconds value."""
|
||||
headers = {"x-ratelimit-reset": "120"}
|
||||
result = self.mapper._parse_retry_ms(headers)
|
||||
assert result == 120_000
|
||||
|
||||
def test_parse_retry_ms_with_x_ratelimit_reset_ms(self):
|
||||
"""Test parsing x-ratelimit-reset-ms header with milliseconds value."""
|
||||
headers = {"x-ratelimit-reset-ms": "5000"}
|
||||
result = self.mapper._parse_retry_ms(headers)
|
||||
assert result == 5_000
|
||||
|
||||
def test_parse_retry_ms_with_date_format(self):
|
||||
"""Test parsing retry header with absolute date format."""
|
||||
future_date = "Mon, 01 Jan 2029 12:00:00 GMT"
|
||||
headers = {"retry-after": future_date}
|
||||
|
||||
with patch("arcade_tdk.providers.http.error_adapter.datetime") as mock_datetime:
|
||||
parsed_date = datetime(2029, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
mock_datetime.strptime.return_value = parsed_date
|
||||
|
||||
# Mock datetime.now() to return a time before the parsed date
|
||||
current_time = datetime(2029, 1, 1, 11, 59, 0, tzinfo=timezone.utc)
|
||||
mock_datetime.now.return_value = current_time
|
||||
|
||||
mock_datetime.timezone = timezone
|
||||
|
||||
result = self.mapper._parse_retry_ms(headers)
|
||||
assert result == 60_000 # 1 minute diff
|
||||
|
||||
def test_parse_retry_ms_no_headers(self):
|
||||
"""Test parsing retry when no rate limit headers are present."""
|
||||
headers = {"content-type": "application/json"}
|
||||
result = self.mapper._parse_retry_ms(headers)
|
||||
assert result == 1_000
|
||||
|
||||
def test_parse_retry_ms_invalid_date(self):
|
||||
"""Test parsing retry with invalid date format falls back to default."""
|
||||
headers = {"retry-after": "invalid-date"}
|
||||
result = self.mapper._parse_retry_ms(headers)
|
||||
assert result == 1_000
|
||||
|
||||
def test_sanitize_uri_removes_query_params(self):
|
||||
"""Test URI sanitization removes query parameters."""
|
||||
uri = "https://api.example.com/users/123?token=secret&filter=active"
|
||||
result = self.mapper._sanitize_uri(uri)
|
||||
assert result == "https://api.example.com/users/123"
|
||||
|
||||
def test_sanitize_uri_removes_fragments(self):
|
||||
"""Test URI sanitization removes fragments."""
|
||||
uri = "https://api.example.com/users#section"
|
||||
result = self.mapper._sanitize_uri(uri)
|
||||
assert result == "https://api.example.com/users"
|
||||
|
||||
def test_sanitize_uri_handles_trailing_slashes(self):
|
||||
"""Test URI sanitization handles trailing slashes."""
|
||||
uri = "https://api.example.com///path///"
|
||||
result = self.mapper._sanitize_uri(uri)
|
||||
assert result == "https://api.example.com/path"
|
||||
|
||||
def test_build_extra_metadata_basic(self):
|
||||
"""Test building extra metadata without request info."""
|
||||
result = self.mapper._build_extra_metadata()
|
||||
assert result == {"service": "_http"}
|
||||
|
||||
def test_build_extra_metadata_with_url_and_method(self):
|
||||
"""Test building extra metadata with URL and method."""
|
||||
result = self.mapper._build_extra_metadata(
|
||||
request_url="https://api.example.com/test?secret=123", request_method="post"
|
||||
)
|
||||
expected = {
|
||||
"service": "_http",
|
||||
"endpoint": "https://api.example.com/test",
|
||||
"http_method": "POST",
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
def test_map_status_to_error_rate_limit(self):
|
||||
"""Test mapping 429 status to rate limit error."""
|
||||
headers = {"retry-after": "30"}
|
||||
result = self.mapper._map_status_to_error(
|
||||
status=429,
|
||||
headers=headers,
|
||||
msg="Rate limit exceeded",
|
||||
request_url="https://api.example.com/test",
|
||||
request_method="GET",
|
||||
)
|
||||
|
||||
assert isinstance(result, UpstreamRateLimitError)
|
||||
assert result.retry_after_ms == 30_000
|
||||
assert result.message == "Rate limit exceeded"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert result.extra["endpoint"] == "https://api.example.com/test"
|
||||
assert result.extra["http_method"] == "GET"
|
||||
|
||||
def test_map_status_to_error_generic(self):
|
||||
"""Test mapping generic HTTP status to upstream error."""
|
||||
headers = {}
|
||||
result = self.mapper._map_status_to_error(
|
||||
status=500,
|
||||
headers=headers,
|
||||
msg="Internal server error",
|
||||
request_url="https://api.example.com/test",
|
||||
request_method="POST",
|
||||
)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert not isinstance(result, UpstreamRateLimitError)
|
||||
assert result.status_code == 500
|
||||
assert result.message == "Internal server error"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert result.extra["endpoint"] == "https://api.example.com/test"
|
||||
assert result.extra["http_method"] == "POST"
|
||||
|
||||
|
||||
class TestHTTPErrorAdapter:
|
||||
"""Test the main HTTP error adapter."""
|
||||
|
||||
def setup_method(self):
|
||||
self.adapter = HTTPErrorAdapter()
|
||||
|
||||
def test_httpx_not_installed(self):
|
||||
"""Test handling when httpx is not installed."""
|
||||
with patch.object(self.adapter._httpx_handler, "handle_exception") as mock_handle:
|
||||
# Simulate what happens when httpx is not installed (returns None)
|
||||
mock_handle.return_value = None
|
||||
|
||||
mock_exc = Exception("test exception")
|
||||
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
assert result is None
|
||||
|
||||
def test_requests_not_installed(self):
|
||||
"""Test handling when requests is not installed."""
|
||||
with patch.object(self.adapter._requests_handler, "handle_exception") as mock_handle:
|
||||
# Simulate what happens when requests is not installed (returns None)
|
||||
mock_handle.return_value = None
|
||||
|
||||
mock_exc = Exception("test exception")
|
||||
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
assert result is None
|
||||
|
||||
def test_httpx_http_status_error_handling(self):
|
||||
"""Test handling httpx HTTPStatusError."""
|
||||
|
||||
# Create a mock HTTPStatusError class and make our exception inherit from it
|
||||
class MockHTTPStatusError(Exception):
|
||||
pass
|
||||
|
||||
# Create the exception instance that inherits from our mock class
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.url = "https://api.example.com/users/123"
|
||||
mock_request.method = "GET"
|
||||
|
||||
mock_exc = MockHTTPStatusError("404 Client Error: Not Found")
|
||||
mock_exc.response = mock_response
|
||||
mock_exc.request = mock_request
|
||||
|
||||
with patch("httpx.HTTPStatusError", MockHTTPStatusError):
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 404
|
||||
assert result.message == "404 Client Error: Not Found"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert result.extra["endpoint"] == "https://api.example.com/users/123"
|
||||
assert result.extra["http_method"] == "GET"
|
||||
|
||||
def test_httpx_rate_limit_handling(self):
|
||||
"""Test handling httpx 429 rate limit."""
|
||||
|
||||
# Create a mock HTTPStatusError class
|
||||
class MockHTTPStatusError(Exception):
|
||||
pass
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 429
|
||||
mock_response.headers = {"retry-after": "60", "content-type": "application/json"}
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.url = "https://api.example.com/upload"
|
||||
mock_request.method = "POST"
|
||||
|
||||
mock_exc = MockHTTPStatusError("429 Too Many Requests")
|
||||
mock_exc.response = mock_response
|
||||
mock_exc.request = mock_request
|
||||
|
||||
with patch("httpx.HTTPStatusError", MockHTTPStatusError):
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert isinstance(result, UpstreamRateLimitError)
|
||||
assert result.retry_after_ms == 60_000
|
||||
assert result.message == "429 Too Many Requests"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert result.extra["endpoint"] == "https://api.example.com/upload"
|
||||
assert result.extra["http_method"] == "POST"
|
||||
|
||||
def test_requests_http_error_handling(self):
|
||||
"""Test handling requests HTTPError."""
|
||||
|
||||
# Create a mock HTTPError class
|
||||
class MockHTTPError(Exception):
|
||||
pass
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.headers = {"www-authenticate": "Bearer"}
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.url = "https://api.example.com/protected"
|
||||
mock_request.method = "GET"
|
||||
|
||||
mock_response.request = mock_request
|
||||
|
||||
mock_exc = MockHTTPError("403 Forbidden")
|
||||
mock_exc.response = mock_response
|
||||
|
||||
with patch("requests.exceptions.HTTPError", MockHTTPError):
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 403
|
||||
assert result.message == "403 Forbidden"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert result.extra["endpoint"] == "https://api.example.com/protected"
|
||||
assert result.extra["http_method"] == "GET"
|
||||
|
||||
def test_requests_http_error_with_url_fallback(self):
|
||||
"""Test handling requests HTTPError when request is not available but response.url is."""
|
||||
|
||||
# Create a mock HTTPError class
|
||||
class MockHTTPError(Exception):
|
||||
pass
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.headers = {}
|
||||
mock_response.request = None # No request object
|
||||
mock_response.url = "https://api.example.com/server-error"
|
||||
|
||||
mock_exc = MockHTTPError("500 Internal Server Error")
|
||||
mock_exc.response = mock_response
|
||||
|
||||
with patch("requests.exceptions.HTTPError", MockHTTPError):
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 500
|
||||
assert result.message == "500 Internal Server Error"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert result.extra["endpoint"] == "https://api.example.com/server-error"
|
||||
assert "http_method" not in result.extra # No method available
|
||||
|
||||
def test_requests_http_error_no_response(self):
|
||||
"""Test handling requests HTTPError with no response."""
|
||||
|
||||
# Create a mock HTTPError class
|
||||
class MockHTTPError(Exception):
|
||||
pass
|
||||
|
||||
mock_exc = MockHTTPError("No response")
|
||||
mock_exc.response = None
|
||||
|
||||
with patch("requests.exceptions.HTTPError", MockHTTPError):
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_unhandled_exception_logs_warning(self, caplog):
|
||||
"""Test that unhandled exceptions log a warning."""
|
||||
with caplog.at_level(logging.INFO):
|
||||
unknown_exc = ValueError("Some unrelated error")
|
||||
result = self.adapter.from_exception(unknown_exc)
|
||||
|
||||
assert result is None
|
||||
assert len(caplog.records) == 1
|
||||
assert "ValueError" in caplog.records[0].message
|
||||
assert "_http" in caplog.records[0].message
|
||||
assert "not handled" in caplog.records[0].message
|
||||
|
||||
def test_httpx_without_request_info(self):
|
||||
"""Test handling httpx exception without request information."""
|
||||
|
||||
# Create a mock HTTPStatusError class
|
||||
class MockHTTPStatusError(Exception):
|
||||
pass
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.headers = {}
|
||||
|
||||
mock_exc = MockHTTPStatusError("400 Bad Request")
|
||||
mock_exc.response = mock_response
|
||||
mock_exc.request = None
|
||||
|
||||
with patch("httpx.HTTPStatusError", MockHTTPStatusError):
|
||||
result = self.adapter.from_exception(mock_exc)
|
||||
|
||||
assert isinstance(result, UpstreamError)
|
||||
assert result.status_code == 400
|
||||
assert result.message == "400 Bad Request"
|
||||
assert result.extra["service"] == "_http"
|
||||
assert "endpoint" not in result.extra
|
||||
assert "http_method" not in result.extra
|
||||
|
||||
def test_adapter_slug(self):
|
||||
"""Test that the adapter has the correct slug."""
|
||||
assert HTTPErrorAdapter.slug == "_http"
|
||||
|
|
@ -2,7 +2,7 @@ from typing import Annotated
|
|||
|
||||
import pytest
|
||||
from arcade_core.catalog import ToolCatalog
|
||||
from arcade_core.errors import ToolDefinitionError
|
||||
from arcade_core.errors import ToolDefinitionError, ToolInputSchemaError
|
||||
from arcade_core.schema import ToolContext, ToolMetadataKey
|
||||
from arcade_tdk import tool
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ def func_with_metadata_and_auth_dependency():
|
|||
),
|
||||
pytest.param(
|
||||
func_with_invalid_renamed_param,
|
||||
ToolDefinitionError,
|
||||
ToolInputSchemaError,
|
||||
id=func_with_invalid_renamed_param.__name__,
|
||||
),
|
||||
pytest.param(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "arcade-ai"
|
||||
version = "2.1.4"
|
||||
version = "2.2.0"
|
||||
description = "Arcade.dev - Tool Calling platform for Agents"
|
||||
readme = "README.md"
|
||||
license = {file = "LICENSE"}
|
||||
|
|
@ -21,7 +21,7 @@ requires-python = ">=3.10"
|
|||
|
||||
dependencies = [
|
||||
# CLI dependencies
|
||||
"arcade-core>=2.2.2,<3.0.0",
|
||||
"arcade-core>=2.4.0,<3.0.0",
|
||||
"typer==0.10.0",
|
||||
"rich==13.9.4",
|
||||
"Jinja2==3.1.6",
|
||||
|
|
@ -40,9 +40,9 @@ all = [
|
|||
"pytz>=2024.1",
|
||||
"python-dateutil>=2.8.2",
|
||||
# serve
|
||||
"arcade-serve>=2.0.0,<3.0.0",
|
||||
"arcade-serve>=2.1.0,<3.0.0",
|
||||
# tdk
|
||||
"arcade-tdk>=2.0.0,<3.0.0",
|
||||
"arcade-tdk>=2.3.0,<3.0.0",
|
||||
]
|
||||
# Evals also depends on arcade-core and openai, but they are already required deps
|
||||
evals = [
|
||||
|
|
|
|||
Loading…
Reference in a new issue