diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index f664e4e7..02d39ca5 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -46,7 +46,19 @@ from arcade.core.utils import ( from arcade.sdk.annotations import Inferrable from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization -WireType = Literal["string", "integer", "float", "boolean", "json"] +InnerWireType = Literal["string", "integer", "number", "boolean", "json"] +WireType = Union[InnerWireType, Literal["array"]] + + +@dataclass +class WireTypeInfo: + """ + Represents the wire type information for a value, including its inner type if it's a list. + """ + + wire_type: WireType + inner_wire_type: InnerWireType | None = None + enum_values: list[str] | None = None class ToolMeta(BaseModel): @@ -233,17 +245,6 @@ def create_input_definition(func: Callable) -> ToolInputs: tool_field_info = extract_field_info(param) - is_enum = False - enum_values: list[str] = [] - - # Special case: Literal["string1", "string2"] can be enumerated on the wire - if is_string_literal(tool_field_info.field_type): - is_enum = True - enum_values = [str(e) for e in get_args(tool_field_info.field_type)] - elif issubclass(tool_field_info.field_type, Enum): - is_enum = True - enum_values = [e.value for e in tool_field_info.field_type] - # If the field has a default value, it is not required # If the field is optional, it is not required has_default_value = tool_field_info.default is not None @@ -256,8 +257,9 @@ def create_input_definition(func: Callable) -> ToolInputs: required=is_required, inferrable=tool_field_info.is_inferrable, value_schema=ValueSchema( - val_type=tool_field_info.wire_type, - enum=enum_values if is_enum else None, + val_type=tool_field_info.wire_type_info.wire_type, + inner_val_type=tool_field_info.wire_type_info.inner_wire_type, + enum=tool_field_info.wire_type_info.enum_values, ), ) ) @@ -291,7 +293,7 @@ def create_output_definition(func: Callable) -> ToolOutput: return_type = next(arg for arg in get_args(return_type) if arg is not type(None)) is_optional = True - wire_type = get_wire_type(return_type) + wire_type_info = get_wire_type_info(return_type) available_modes = ["value", "error"] @@ -301,7 +303,11 @@ def create_output_definition(func: Callable) -> ToolOutput: return ToolOutput( description=description, available_modes=available_modes, - value_schema=ValueSchema(val_type=wire_type), + value_schema=ValueSchema( + val_type=wire_type_info.wire_type, + inner_val_type=wire_type_info.inner_wire_type, + enum=wire_type_info.enum_values, + ), ) @@ -329,14 +335,17 @@ class ToolParamInfo: default: Any original_type: type field_type: type - wire_type: WireType + wire_type_info: WireTypeInfo description: str | None = None is_optional: bool = True is_inferrable: bool = True @classmethod def from_param_info( - cls, param_info: ParamInfo, wire_type: WireType, is_inferrable: bool = True + cls, + param_info: ParamInfo, + wire_type_info: WireTypeInfo, + is_inferrable: bool = True, ) -> "ToolParamInfo": return cls( name=param_info.name, @@ -345,7 +354,7 @@ class ToolParamInfo: field_type=param_info.field_type, description=param_info.description, is_optional=param_info.is_optional, - wire_type=wire_type, + wire_type_info=wire_type_info, is_inferrable=is_inferrable, ) @@ -362,7 +371,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: if isinstance(param.default, FieldInfo): param_info = extract_pydantic_param_info(param) else: - param_info = extract_regular_param_info(param) + param_info = extract_python_param_info(param) metadata = getattr(annotation, "__metadata__", []) str_annotations = [m for m in metadata if isinstance(m, str)] @@ -386,24 +395,59 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo: # Params are inferrable by default is_inferrable = inferrable_annotation.value if inferrable_annotation else True - # Get the wire type - wire_type = ( - get_wire_type(str) - if is_string_literal(param_info.field_type) - else get_wire_type(param_info.field_type) - ) + # Get the wire (serialization) type information for the type + wire_type_info = get_wire_type_info(param_info.field_type) # Final reality check if param_info.description is None: raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description") - if wire_type is None: + if wire_type_info.wire_type is None: raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}") - return ToolParamInfo.from_param_info(param_info, wire_type, is_inferrable) + return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable) -def extract_regular_param_info(param: inspect.Parameter) -> ParamInfo: +def get_wire_type_info(_type: type) -> WireTypeInfo: + """ + Get the wire type information for a given type. + """ + + # Is this a list type? + # If so, get the inner (enclosed) type + is_list = get_origin(_type) is list + if is_list: + inner_type = get_args(_type)[0] + inner_wire_type = cast( + InnerWireType, + get_wire_type(str) if is_string_literal(inner_type) else get_wire_type(inner_type), + ) + else: + inner_wire_type = None + + # Get the outer wire type + wire_type = get_wire_type(str) if is_string_literal(_type) else get_wire_type(_type) + + # Handle enums (known/fixed lists of values) + is_enum = False + enum_values: list[str] = [] + + type_to_check = inner_type if is_list else _type + + # Special case: Literal["string1", "string2"] can be enumerated on the wire + if is_string_literal(type_to_check): + is_enum = True + enum_values = [str(e) for e in get_args(type_to_check)] + + # Special case: Enum can be enumerated on the wire + elif issubclass(type_to_check, Enum): + is_enum = True + enum_values = [e.value for e in type_to_check] + + return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None) + + +def extract_python_param_info(param: inspect.Parameter) -> ParamInfo: # If the param is Annotated[], unwrap the annotation to get the "real" type # Otherwise, use the literal type annotation = param.annotation @@ -465,28 +509,34 @@ def get_wire_type( """ Mapping between Python types and HTTP/JSON types """ - type_mapping = { + type_mapping: dict[type, WireType] = { str: "string", bool: "boolean", int: "integer", - float: "float", + float: "number", + dict: "json", + } + + outer_type_mapping: dict[type, WireType] = { + list: "array", dict: "json", - list: "json", - BaseModel: "json", } wire_type = type_mapping.get(_type) if wire_type: - return cast(Literal["string", "integer", "float", "boolean", "json"], wire_type) - elif hasattr(_type, "__origin__"): - # account for "list[str]" and "dict[str, int]" and "Optional[str]" and other typing types - origin = _type.__origin__ - if origin in [list, dict]: - return "json" - elif issubclass(_type, Enum): + return wire_type + + if hasattr(_type, "__origin__"): + wire_type = outer_type_mapping.get(cast(type, get_origin(_type))) + if wire_type: + return wire_type + + if issubclass(_type, Enum): return "string" - elif issubclass(_type, BaseModel): + + if issubclass(_type, BaseModel): return "json" + raise ToolDefinitionError(f"Unsupported parameter type: {_type}") diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index e8d58ff3..12eda81b 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -6,9 +6,12 @@ from pydantic import AnyUrl, BaseModel, Field class ValueSchema(BaseModel): """Value schema for input parameters and outputs.""" - val_type: Literal["string", "integer", "float", "boolean", "json"] + val_type: Literal["string", "integer", "number", "boolean", "json", "array"] """The type of the value.""" + inner_val_type: Optional[Literal["string", "integer", "number", "boolean", "json"]] = None + """The type of the inner value, if the value is a list.""" + enum: Optional[list[str]] = None """The list of possible values for the value, if it is a closed list.""" diff --git a/arcade/arcade/core/utils.py b/arcade/arcade/core/utils.py index 42a84480..5fdb5bc6 100644 --- a/arcade/arcade/core/utils.py +++ b/arcade/arcade/core/utils.py @@ -48,7 +48,15 @@ def does_function_return_value(func: Callable) -> bool: """ Returns True if the given function returns a value, i.e. if it has a return statement with a value. """ - source = inspect.getsource(func) + try: + source: Optional[str] = inspect.getsource(func) + except OSError: + # Workaround for parameterized unit tests that use a dynamically-generated function + source = getattr(func, "__source__", None) + + if source is None: + raise ValueError("Source code not found") + tree = ast.parse(source) class ReturnVisitor(ast.NodeVisitor): diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index a6f4a648..424303fc 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -63,11 +63,6 @@ def func_with_google_auth_requirement(): ### Tests on input params -@tool(desc="A function with an input parameter") -def func_with_param(context: Annotated[str, "First param"]): - pass - - @tool(desc="A function with a non-inferrable input parameter") def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]): pass @@ -79,23 +74,13 @@ def func_with_renamed_param(param1: Annotated[str, "ParamOne", "First param"]): pass -@tool(desc="A function with every possible input parameter") -def func_with_every_param( - param1: Annotated[str, "a string"], - param2: Annotated[int, "an integer"], - param3: Annotated[float, "a float"], - param4: Annotated[bool, "a boolean"], -): - pass - - -class TestEnum(Enum): +class MyEnum(Enum): FOO_BAR = "foo bar" BAZ = "baz" @tool(desc="A function that takes an enum") -def func_with_enum_param(param1: Annotated[TestEnum, "an enum"]): +def func_with_enum_param(param1: Annotated[MyEnum, "an enum"]): pass @@ -142,8 +127,25 @@ def func_with_mixed_params( pass +@tool(desc="A function with a list[str] parameter") +def func_with_list_param(param1: Annotated[list[str], "A list of strings"]): + pass + + +@tool(desc="A function with a list[float] parameter") +def func_with_list_float_param(param1: Annotated[list[float], "A list of floats"]): + pass + + +@tool(desc="A function with a list of enums parameter") +def func_with_list_of_enums_param(param1: Annotated[list[MyEnum], "A list of enums"]): + pass + + @tool(desc="A function with a complex parameter type") -def func_with_complex_param(param1: Annotated[list[str], "A list of strings"]): +def func_with_complex_param( + param1: Annotated[dict[str, list[int]], "A dictionary with lists of integers"], +): pass @@ -153,14 +155,19 @@ def func_with_context(my_context: ToolContext): ### Tests on output/return values -@tool(desc="A function that performs an action without returning anything") -def func_with_no_return(): - pass +@tool(desc="A function that returns a list of strings") +def func_with_list_return() -> list[str]: + return ["output1", "output2"] -@tool(desc="A function that returns a value") -def func_with_value_return() -> str: - return "output" +@tool(desc="A function that returns a known list of string literals") +def func_with_known_list_return() -> Literal["value1", "value2"]: + return "value1" + + +@tool(desc="A function that returns an enum") +def func_with_enum_return() -> MyEnum: + return MyEnum.FOO_BAR @tool(desc="A function with an annotated return type") @@ -174,7 +181,7 @@ def func_with_optional_return() -> Optional[str]: @tool(desc="A function with a complex return type") -def func_with_complex_return() -> list[dict[str, str]]: +def func_with_complex_return() -> dict[str, str]: return [{"key": "value"}] @@ -244,30 +251,6 @@ def func_with_complex_return() -> list[dict[str, str]]: id="func_with_google_auth_requirement", ), # Tests on input params - pytest.param( - func_with_value_return, - { - "inputs": ToolInputs(parameters=[]), - }, - id="func_with_no_params", - ), - pytest.param( - func_with_param, - { - "inputs": ToolInputs( - parameters=[ - InputParameter( - name="context", # Nothing special about this name, parameters can be named anything - description="First param", - inferrable=True, # Defaults to true - required=True, - value_schema=ValueSchema(val_type="string", enum=None), - ) - ] - ), - }, - id="func_with_param", - ), pytest.param( func_with_non_inferrable_param, { @@ -302,44 +285,6 @@ def func_with_complex_return() -> list[dict[str, str]]: }, id="func_with_renamed_param", ), - pytest.param( - func_with_every_param, - { - "inputs": ToolInputs( - parameters=[ - InputParameter( - name="param1", - description="a string", - inferrable=True, - required=True, - value_schema=ValueSchema(val_type="string", enum=None), - ), - InputParameter( - name="param2", - description="an integer", - inferrable=True, - required=True, - value_schema=ValueSchema(val_type="integer", enum=None), - ), - InputParameter( - name="param3", - description="a float", - inferrable=True, - required=True, - value_schema=ValueSchema(val_type="float", enum=None), - ), - InputParameter( - name="param4", - description="a boolean", - inferrable=True, - required=True, - value_schema=ValueSchema(val_type="boolean", enum=None), - ), - ] - ), - }, - id="func_with_every_param", - ), pytest.param( func_with_enum_param, { @@ -497,7 +442,7 @@ def func_with_complex_return() -> list[dict[str, str]]: id="func_with_mixed_params", ), pytest.param( - func_with_complex_param, + func_with_list_param, { "inputs": ToolInputs( parameters=[ @@ -506,6 +451,63 @@ def func_with_complex_return() -> list[dict[str, str]]: description="A list of strings", inferrable=True, required=True, + value_schema=ValueSchema( + val_type="array", inner_val_type="string", enum=None + ), + ) + ] + ), + }, + id="func_with_list_param", + ), + pytest.param( + func_with_list_float_param, + { + "inputs": ToolInputs( + parameters=[ + InputParameter( + name="param1", + description="A list of floats", + inferrable=True, + required=True, + value_schema=ValueSchema( + val_type="array", inner_val_type="number", enum=None + ), + ) + ] + ), + }, + id="func_with_list_float_param", + ), + pytest.param( + func_with_list_of_enums_param, + { + "inputs": ToolInputs( + parameters=[ + InputParameter( + name="param1", + description="A list of enums", + inferrable=True, + required=True, + value_schema=ValueSchema( + val_type="array", inner_val_type="string", enum=["foo bar", "baz"] + ), + ) + ] + ), + }, + id="func_with_list_of_enums_param", + ), + pytest.param( + func_with_complex_param, + { + "inputs": ToolInputs( + parameters=[ + InputParameter( + name="param1", + description="A dictionary with lists of integers", + inferrable=True, + required=True, value_schema=ValueSchema(val_type="json", enum=None), ) ] @@ -524,25 +526,40 @@ def func_with_complex_return() -> list[dict[str, str]]: ), # Tests on output values pytest.param( - func_with_no_return, - { - "output": ToolOutput( - available_modes=["null"], description="No description provided." - ), - }, - id="func_with_no_return", - ), - pytest.param( - func_with_value_return, + func_with_list_return, { "inputs": ToolInputs(parameters=[]), "output": ToolOutput( - value_schema=ValueSchema(val_type="string", enum=None), + value_schema=ValueSchema(val_type="array", inner_val_type="string", enum=None), available_modes=["value", "error"], description="No description provided.", ), }, - id="func_with_value_return", + id="func_with_list_return", + ), + pytest.param( + func_with_known_list_return, + { + "inputs": ToolInputs(parameters=[]), + "output": ToolOutput( + value_schema=ValueSchema(val_type="string", enum=["value1", "value2"]), + available_modes=["value", "error"], + description="No description provided.", + ), + }, + id="func_with_known_list_return", + ), + pytest.param( + func_with_enum_return, + { + "inputs": ToolInputs(parameters=[]), + "output": ToolOutput( + value_schema=ValueSchema(val_type="string", enum=["foo bar", "baz"]), + available_modes=["value", "error"], + description="No description provided.", + ), + }, + id="func_with_enum_return", ), pytest.param( func_with_annotated_return, @@ -592,5 +609,5 @@ def test_create_tool_def(func_under_test, expected_tool_def_fields): def tool_version_is_set_correctly(): - tool_def = ToolCatalog.create_tool_definition(func_with_no_return, "abcd1236") + tool_def = ToolCatalog.create_tool_definition(func_with_description, "abcd1236") assert tool_def.version == "abcd1236" diff --git a/arcade/tests/tool/test_create_tool_definition_new.py b/arcade/tests/tool/test_create_tool_definition_new.py new file mode 100644 index 00000000..1f727269 --- /dev/null +++ b/arcade/tests/tool/test_create_tool_definition_new.py @@ -0,0 +1,79 @@ +from typing import Annotated + +import pytest + +from arcade.core.catalog import ToolCatalog, get_wire_type +from arcade.sdk import tool + + +class Case: + def __init__(self, input_type: type, output_type: type | None): + self.input_type = input_type + self.output_type = output_type + + def __str__(self): + return f"Case(input_type={self.input_type}, output_type={self.output_type})" + + +primitives = [bool, float, int, str] + +test_cases = [ + Case(input_type=input_type, output_type=output_type) + for input_type in [*primitives, []] + for output_type in [*primitives, None] +] + [ + Case(input_type=[primitives[i] for i in range(n)], output_type=output_type) + for n in range(2, len(primitives) + 1) + for output_type in [*primitives, None] +] + + +# Generate tool functions dynamically +def generate_tool_function(input_types: list[type], output_type: type | None): + input_annotation = ", ".join( + [ + f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']" + for i, input_type in enumerate(input_types) + ] + ) + output_annotation = f" -> {output_type.__name__}" if output_type else "" + + func_code = f""" +@tool(desc="Generated function with input and output types") +def generated_func({input_annotation}){output_annotation}: + pass +""" + local_vars = {} + exec(func_code, {"tool": tool, "Annotated": Annotated}, local_vars) # noqa: S102 + generated_func = local_vars.get("generated_func") + generated_func.__source__ = func_code # Attach the source code to the function + return generated_func + + +@pytest.mark.parametrize("test_case", test_cases, ids=[str(tc) for tc in test_cases]) +def test_create_tool_def2(test_case): + input_types = ( + test_case.input_type if isinstance(test_case.input_type, list) else [test_case.input_type] + ) + output_type = test_case.output_type + + # Generate the function dynamically + generated_func = generate_tool_function(input_types, output_type) + + assert generated_func is not None, "generated_func was not created" + + # Create tool definition using the generated function + tool_def = ToolCatalog.create_tool_definition(generated_func, "1.0") + + for i, input_type in enumerate(input_types): + param = tool_def.inputs.parameters[i] + assert ( + param.value_schema.val_type == get_wire_type(input_type) + ), f"Parameter {param.name} has value type {param.value_schema.val_type} but {input_type} was expected at index {i}" + + if output_type: + assert tool_def.output.value_schema.val_type == get_wire_type( + output_type + ), f"Output has value type {tool_def.output.val_type} but {output_type} was expected" + else: + assert tool_def.output.value_schema is None, "Output is not None" diff --git a/arcade/tests/tool/test_create_tool_definition_pydantic.py b/arcade/tests/tool/test_create_tool_definition_pydantic.py index d3a1272b..cddb00d1 100644 --- a/arcade/tests/tool/test_create_tool_definition_pydantic.py +++ b/arcade/tests/tool/test_create_tool_definition_pydantic.py @@ -258,12 +258,14 @@ def read_products( name="cols", description="The columns to return", required=False, - value_schema=ValueSchema(val_type="json", enum=None), + value_schema=ValueSchema( + val_type="array", inner_val_type="string", enum=None + ), ), ] ), "output": ToolOutput( - value_schema=ValueSchema(val_type="json", enum=None), + value_schema=ValueSchema(val_type="array", inner_val_type="json", enum=None), available_modes=["value", "error"], description="Data with the selected columns", ), diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index f34b52a6..757a04d0 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -4,14 +4,18 @@ "primitives": { // All supported primitive data types "type": "string", - "enum": ["string", "integer", "float", "boolean", "json"] + "enum": ["string", "integer", "number", "boolean", "json"] }, "value_schema": { // Represents the schema of a value (e.g. function input parameter value) "type": "object", "properties": { "val_type": { - "$ref": "#/$defs/primitives" + "oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "string", "enum": ["array"] }] + }, + "inner_val_type": { + "$ref": "#/$defs/primitives", + "description": "If the value type is a list, the type of the list values." }, "enum": { "oneOf": [ @@ -26,7 +30,13 @@ } }, "required": ["val_type"], - "additionalProperties": false + "additionalProperties": false, + "if": { + "properties": { "val_type": { "const": "array" } } + }, + "then": { + "required": ["inner_val_type"] + } } }, "type": "object", diff --git a/toolkits/math/arcade_arithmetic/tools/arithmetic.py b/toolkits/math/arcade_arithmetic/tools/arithmetic.py index dde95fad..47ccb01f 100644 --- a/toolkits/math/arcade_arithmetic/tools/arithmetic.py +++ b/toolkits/math/arcade_arithmetic/tools/arithmetic.py @@ -52,3 +52,13 @@ def sqrt( Get the square root of a number """ return math.sqrt(a) + + +@tool +def sum_list( + numbers: Annotated[list[float], "The list of numbers"], +) -> Annotated[float, "The sum of the numbers in the list"]: + """ + Sum all numbers in a list + """ + return sum(numbers)