Fix CI: config errors, Python 3.11 union type errors (#45)
Fixes 2 issues that were causing CI to fail: - Loading `config` in `eval.py` breaks because no API key can be found in CI - Python 3.11+ changed `Union` to `UnionType`
This commit is contained in:
parent
43198a3a9b
commit
739cc957f1
6 changed files with 58 additions and 4 deletions
|
|
@ -40,6 +40,7 @@ from arcade.core.utils import (
|
|||
does_function_return_value,
|
||||
first_or_none,
|
||||
is_string_literal,
|
||||
is_union,
|
||||
snake_to_pascal_case,
|
||||
)
|
||||
from arcade.sdk.annotations import Inferrable
|
||||
|
|
@ -447,12 +448,24 @@ def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
|
|||
original_type = annotation.__args__[0] if get_origin(annotation) is Annotated else annotation
|
||||
field_type = original_type
|
||||
|
||||
# Unwrap Optional types
|
||||
# Handle optional types
|
||||
# Both Optional[T] and T | None are supported
|
||||
is_optional = False
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
if (
|
||||
is_union(field_type)
|
||||
and len(get_args(field_type)) == 2
|
||||
and type(None) in get_args(field_type)
|
||||
):
|
||||
field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
is_optional = True
|
||||
|
||||
# Union types are not currently supported
|
||||
# (other than optional, which is handled above)
|
||||
if is_union(field_type):
|
||||
raise ToolDefinitionError(
|
||||
f"Parameter {param.name} is a union type. Only optional types are supported."
|
||||
)
|
||||
|
||||
return ParamInfo(
|
||||
name=param.name,
|
||||
default=param.default if param.default is not inspect.Parameter.empty else None,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import ast
|
|||
import inspect
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, get_args, get_origin
|
||||
from types import UnionType
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
@ -44,6 +45,13 @@ def is_string_literal(_type: type) -> bool:
|
|||
return get_origin(_type) is Literal and all(isinstance(arg, str) for arg in get_args(_type))
|
||||
|
||||
|
||||
def is_union(_type: type) -> bool:
|
||||
"""
|
||||
Returns True if the given type is a union, i.e. a Union[T1, T2, ...] or T1 | T2 | ... etc.
|
||||
"""
|
||||
return get_origin(_type) in {Union, UnionType}
|
||||
|
||||
|
||||
def does_function_return_value(func: Callable) -> bool:
|
||||
"""
|
||||
Returns True if the given function returns a value, i.e. if it has a return statement with a value.
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ except ImportError:
|
|||
)
|
||||
|
||||
from arcade.client.client import Arcade, AsyncArcade
|
||||
from arcade.core.config import config
|
||||
from arcade.sdk.error import WeightError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -423,6 +422,8 @@ class EvalSuite:
|
|||
"""
|
||||
Initialize the client instance for the EvalSuite.
|
||||
"""
|
||||
from arcade.core.config import config
|
||||
|
||||
if self.max_concurrent > 1:
|
||||
self._client = AsyncArcade(
|
||||
api_key=config.api.key,
|
||||
|
|
|
|||
22
arcade/tests/core/utils/test_is_union.py
Normal file
22
arcade/tests/core/utils/test_is_union.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from arcade.core.utils import is_union
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_input, expected",
|
||||
[
|
||||
(Union[int, str], True),
|
||||
(Optional[int], True), # Optional[int] is equivalent to Union[int, None]
|
||||
(int | str, True),
|
||||
(int | None, True), # int | None is equivalent to Optional[int]
|
||||
(int, False),
|
||||
(str, False),
|
||||
(list, False),
|
||||
(dict, False),
|
||||
],
|
||||
)
|
||||
def test_is_union(type_input, expected):
|
||||
assert is_union(type_input) == expected
|
||||
|
|
@ -31,6 +31,11 @@ def func_with_unsupported_param(param1: complex):
|
|||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with a union parameter (illegal)")
|
||||
def func_with_union_param(param1: str | int):
|
||||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with multiple context parameters (illegal)")
|
||||
def func_with_multiple_context_params(context: ToolContext, context2: ToolContext):
|
||||
pass
|
||||
|
|
@ -64,6 +69,11 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex
|
|||
ToolDefinitionError,
|
||||
id=func_with_unsupported_param.__name__,
|
||||
),
|
||||
pytest.param(
|
||||
func_with_union_param,
|
||||
ToolDefinitionError,
|
||||
id=func_with_union_param.__name__,
|
||||
),
|
||||
pytest.param(
|
||||
func_with_multiple_context_params,
|
||||
ToolDefinitionError,
|
||||
|
|
|
|||
Loading…
Reference in a new issue