diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index d47cc4d5..fd8dc21f 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -40,6 +40,7 @@ from arcade.core.utils import ( does_function_return_value, first_or_none, is_string_literal, + is_union, snake_to_pascal_case, ) from arcade.sdk.annotations import Inferrable @@ -447,12 +448,24 @@ def extract_python_param_info(param: inspect.Parameter) -> ParamInfo: original_type = annotation.__args__[0] if get_origin(annotation) is Annotated else annotation field_type = original_type - # Unwrap Optional types + # Handle optional types + # Both Optional[T] and T | None are supported is_optional = False - if get_origin(field_type) is Union and type(None) in get_args(field_type): + if ( + is_union(field_type) + and len(get_args(field_type)) == 2 + and type(None) in get_args(field_type) + ): field_type = next(arg for arg in get_args(field_type) if arg is not type(None)) is_optional = True + # Union types are not currently supported + # (other than optional, which is handled above) + if is_union(field_type): + raise ToolDefinitionError( + f"Parameter {param.name} is a union type. Only optional types are supported." + ) + return ParamInfo( name=param.name, default=param.default if param.default is not inspect.Parameter.empty else None, diff --git a/arcade/arcade/core/utils.py b/arcade/arcade/core/utils.py index 5fdb5bc6..61fd1868 100644 --- a/arcade/arcade/core/utils.py +++ b/arcade/arcade/core/utils.py @@ -2,7 +2,8 @@ import ast import inspect import re from collections.abc import Iterable -from typing import Any, Callable, Literal, Optional, TypeVar, get_args, get_origin +from types import UnionType +from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin T = TypeVar("T") @@ -44,6 +45,13 @@ def is_string_literal(_type: type) -> bool: return get_origin(_type) is Literal and all(isinstance(arg, str) for arg in get_args(_type)) +def is_union(_type: type) -> bool: + """ + Returns True if the given type is a union, i.e. a Union[T1, T2, ...] or T1 | T2 | ... etc. + """ + return get_origin(_type) in {Union, UnionType} + + def does_function_return_value(func: Callable) -> bool: """ Returns True if the given function returns a value, i.e. if it has a return statement with a value. diff --git a/arcade/arcade/sdk/eval/eval.py b/arcade/arcade/sdk/eval/eval.py index 54fe9eee..5a3c3a2b 100644 --- a/arcade/arcade/sdk/eval/eval.py +++ b/arcade/arcade/sdk/eval/eval.py @@ -13,7 +13,6 @@ except ImportError: ) from arcade.client.client import Arcade, AsyncArcade -from arcade.core.config import config from arcade.sdk.error import WeightError if TYPE_CHECKING: @@ -423,6 +422,8 @@ class EvalSuite: """ Initialize the client instance for the EvalSuite. """ + from arcade.core.config import config + if self.max_concurrent > 1: self._client = AsyncArcade( api_key=config.api.key, diff --git a/arcade/tests/utils/test_utils_casing.py b/arcade/tests/core/utils/test_casing.py similarity index 100% rename from arcade/tests/utils/test_utils_casing.py rename to arcade/tests/core/utils/test_casing.py diff --git a/arcade/tests/core/utils/test_is_union.py b/arcade/tests/core/utils/test_is_union.py new file mode 100644 index 00000000..68ab581b --- /dev/null +++ b/arcade/tests/core/utils/test_is_union.py @@ -0,0 +1,22 @@ +from typing import Optional, Union + +import pytest + +from arcade.core.utils import is_union + + +@pytest.mark.parametrize( + "type_input, expected", + [ + (Union[int, str], True), + (Optional[int], True), # Optional[int] is equivalent to Union[int, None] + (int | str, True), + (int | None, True), # int | None is equivalent to Optional[int] + (int, False), + (str, False), + (list, False), + (dict, False), + ], +) +def test_is_union(type_input, expected): + assert is_union(type_input) == expected diff --git a/arcade/tests/tool/test_create_tool_definition_errors.py b/arcade/tests/tool/test_create_tool_definition_errors.py index fe1b8ef2..959ea184 100644 --- a/arcade/tests/tool/test_create_tool_definition_errors.py +++ b/arcade/tests/tool/test_create_tool_definition_errors.py @@ -31,6 +31,11 @@ def func_with_unsupported_param(param1: complex): pass +@tool(desc="A function with a union parameter (illegal)") +def func_with_union_param(param1: str | int): + pass + + @tool(desc="A function with multiple context parameters (illegal)") def func_with_multiple_context_params(context: ToolContext, context2: ToolContext): pass @@ -64,6 +69,11 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex ToolDefinitionError, id=func_with_unsupported_param.__name__, ), + pytest.param( + func_with_union_param, + ToolDefinitionError, + id=func_with_union_param.__name__, + ), pytest.param( func_with_multiple_context_params, ToolDefinitionError,