Fix CI: config errors, Python 3.11 union type errors (#45)

Fixes 2 issues that were causing CI to fail:
- Loading `config` in `eval.py` breaks because no API key can be found
in CI
- Python 3.11+ changed `Union` to `UnionType`
This commit is contained in:
Nate Barbettini 2024-09-19 12:07:28 -07:00 committed by GitHub
parent 43198a3a9b
commit 739cc957f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 58 additions and 4 deletions

View file

@ -40,6 +40,7 @@ from arcade.core.utils import (
does_function_return_value,
first_or_none,
is_string_literal,
is_union,
snake_to_pascal_case,
)
from arcade.sdk.annotations import Inferrable
@ -447,12 +448,24 @@ def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
original_type = annotation.__args__[0] if get_origin(annotation) is Annotated else annotation
field_type = original_type
# Unwrap Optional types
# Handle optional types
# Both Optional[T] and T | None are supported
is_optional = False
if get_origin(field_type) is Union and type(None) in get_args(field_type):
if (
is_union(field_type)
and len(get_args(field_type)) == 2
and type(None) in get_args(field_type)
):
field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
is_optional = True
# Union types are not currently supported
# (other than optional, which is handled above)
if is_union(field_type):
raise ToolDefinitionError(
f"Parameter {param.name} is a union type. Only optional types are supported."
)
return ParamInfo(
name=param.name,
default=param.default if param.default is not inspect.Parameter.empty else None,

View file

@ -2,7 +2,8 @@ import ast
import inspect
import re
from collections.abc import Iterable
from typing import Any, Callable, Literal, Optional, TypeVar, get_args, get_origin
from types import UnionType
from typing import Any, Callable, Literal, Optional, TypeVar, Union, get_args, get_origin
T = TypeVar("T")
@ -44,6 +45,13 @@ def is_string_literal(_type: type) -> bool:
return get_origin(_type) is Literal and all(isinstance(arg, str) for arg in get_args(_type))
def is_union(_type: type) -> bool:
"""
Returns True if the given type is a union, i.e. a Union[T1, T2, ...] or T1 | T2 | ... etc.
"""
return get_origin(_type) in {Union, UnionType}
def does_function_return_value(func: Callable) -> bool:
"""
Returns True if the given function returns a value, i.e. if it has a return statement with a value.

View file

@ -13,7 +13,6 @@ except ImportError:
)
from arcade.client.client import Arcade, AsyncArcade
from arcade.core.config import config
from arcade.sdk.error import WeightError
if TYPE_CHECKING:
@ -423,6 +422,8 @@ class EvalSuite:
"""
Initialize the client instance for the EvalSuite.
"""
from arcade.core.config import config
if self.max_concurrent > 1:
self._client = AsyncArcade(
api_key=config.api.key,

View file

@ -0,0 +1,22 @@
from typing import Optional, Union
import pytest
from arcade.core.utils import is_union
@pytest.mark.parametrize(
"type_input, expected",
[
(Union[int, str], True),
(Optional[int], True), # Optional[int] is equivalent to Union[int, None]
(int | str, True),
(int | None, True), # int | None is equivalent to Optional[int]
(int, False),
(str, False),
(list, False),
(dict, False),
],
)
def test_is_union(type_input, expected):
assert is_union(type_input) == expected

View file

@ -31,6 +31,11 @@ def func_with_unsupported_param(param1: complex):
pass
@tool(desc="A function with a union parameter (illegal)")
def func_with_union_param(param1: str | int):
pass
@tool(desc="A function with multiple context parameters (illegal)")
def func_with_multiple_context_params(context: ToolContext, context2: ToolContext):
pass
@ -64,6 +69,11 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex
ToolDefinitionError,
id=func_with_unsupported_param.__name__,
),
pytest.param(
func_with_union_param,
ToolDefinitionError,
id=func_with_union_param.__name__,
),
pytest.param(
func_with_multiple_context_params,
ToolDefinitionError,