SDK support for lists (arrays) in tool inputs & outputs (#36)

Working:
- Declare tool functions that have `list[str]` (etc) as input parameter
or output values
- Engine can call these functions!

<img width="1195" alt="image"
src="https://github.com/user-attachments/assets/2aeb3c98-950a-4e2f-a8c7-39102e3fb7f0">
This commit is contained in:
Nate Barbettini 2024-09-12 16:31:12 -07:00 committed by GitHub
parent 345b685b08
commit 5726778f11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 328 additions and 149 deletions

View file

@ -46,7 +46,19 @@ from arcade.core.utils import (
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization
WireType = Literal["string", "integer", "float", "boolean", "json"]
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
WireType = Union[InnerWireType, Literal["array"]]
@dataclass
class WireTypeInfo:
"""
Represents the wire type information for a value, including its inner type if it's a list.
"""
wire_type: WireType
inner_wire_type: InnerWireType | None = None
enum_values: list[str] | None = None
class ToolMeta(BaseModel):
@ -233,17 +245,6 @@ def create_input_definition(func: Callable) -> ToolInputs:
tool_field_info = extract_field_info(param)
is_enum = False
enum_values: list[str] = []
# Special case: Literal["string1", "string2"] can be enumerated on the wire
if is_string_literal(tool_field_info.field_type):
is_enum = True
enum_values = [str(e) for e in get_args(tool_field_info.field_type)]
elif issubclass(tool_field_info.field_type, Enum):
is_enum = True
enum_values = [e.value for e in tool_field_info.field_type]
# If the field has a default value, it is not required
# If the field is optional, it is not required
has_default_value = tool_field_info.default is not None
@ -256,8 +257,9 @@ def create_input_definition(func: Callable) -> ToolInputs:
required=is_required,
inferrable=tool_field_info.is_inferrable,
value_schema=ValueSchema(
val_type=tool_field_info.wire_type,
enum=enum_values if is_enum else None,
val_type=tool_field_info.wire_type_info.wire_type,
inner_val_type=tool_field_info.wire_type_info.inner_wire_type,
enum=tool_field_info.wire_type_info.enum_values,
),
)
)
@ -291,7 +293,7 @@ def create_output_definition(func: Callable) -> ToolOutput:
return_type = next(arg for arg in get_args(return_type) if arg is not type(None))
is_optional = True
wire_type = get_wire_type(return_type)
wire_type_info = get_wire_type_info(return_type)
available_modes = ["value", "error"]
@ -301,7 +303,11 @@ def create_output_definition(func: Callable) -> ToolOutput:
return ToolOutput(
description=description,
available_modes=available_modes,
value_schema=ValueSchema(val_type=wire_type),
value_schema=ValueSchema(
val_type=wire_type_info.wire_type,
inner_val_type=wire_type_info.inner_wire_type,
enum=wire_type_info.enum_values,
),
)
@ -329,14 +335,17 @@ class ToolParamInfo:
default: Any
original_type: type
field_type: type
wire_type: WireType
wire_type_info: WireTypeInfo
description: str | None = None
is_optional: bool = True
is_inferrable: bool = True
@classmethod
def from_param_info(
cls, param_info: ParamInfo, wire_type: WireType, is_inferrable: bool = True
cls,
param_info: ParamInfo,
wire_type_info: WireTypeInfo,
is_inferrable: bool = True,
) -> "ToolParamInfo":
return cls(
name=param_info.name,
@ -345,7 +354,7 @@ class ToolParamInfo:
field_type=param_info.field_type,
description=param_info.description,
is_optional=param_info.is_optional,
wire_type=wire_type,
wire_type_info=wire_type_info,
is_inferrable=is_inferrable,
)
@ -362,7 +371,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
if isinstance(param.default, FieldInfo):
param_info = extract_pydantic_param_info(param)
else:
param_info = extract_regular_param_info(param)
param_info = extract_python_param_info(param)
metadata = getattr(annotation, "__metadata__", [])
str_annotations = [m for m in metadata if isinstance(m, str)]
@ -386,24 +395,59 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
# Params are inferrable by default
is_inferrable = inferrable_annotation.value if inferrable_annotation else True
# Get the wire type
wire_type = (
get_wire_type(str)
if is_string_literal(param_info.field_type)
else get_wire_type(param_info.field_type)
)
# Get the wire (serialization) type information for the type
wire_type_info = get_wire_type_info(param_info.field_type)
# Final reality check
if param_info.description is None:
raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description")
if wire_type is None:
if wire_type_info.wire_type is None:
raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}")
return ToolParamInfo.from_param_info(param_info, wire_type, is_inferrable)
return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable)
def extract_regular_param_info(param: inspect.Parameter) -> ParamInfo:
def get_wire_type_info(_type: type) -> WireTypeInfo:
"""
Get the wire type information for a given type.
"""
# Is this a list type?
# If so, get the inner (enclosed) type
is_list = get_origin(_type) is list
if is_list:
inner_type = get_args(_type)[0]
inner_wire_type = cast(
InnerWireType,
get_wire_type(str) if is_string_literal(inner_type) else get_wire_type(inner_type),
)
else:
inner_wire_type = None
# Get the outer wire type
wire_type = get_wire_type(str) if is_string_literal(_type) else get_wire_type(_type)
# Handle enums (known/fixed lists of values)
is_enum = False
enum_values: list[str] = []
type_to_check = inner_type if is_list else _type
# Special case: Literal["string1", "string2"] can be enumerated on the wire
if is_string_literal(type_to_check):
is_enum = True
enum_values = [str(e) for e in get_args(type_to_check)]
# Special case: Enum can be enumerated on the wire
elif issubclass(type_to_check, Enum):
is_enum = True
enum_values = [e.value for e in type_to_check]
return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None)
def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
# If the param is Annotated[], unwrap the annotation to get the "real" type
# Otherwise, use the literal type
annotation = param.annotation
@ -465,28 +509,34 @@ def get_wire_type(
"""
Mapping between Python types and HTTP/JSON types
"""
type_mapping = {
type_mapping: dict[type, WireType] = {
str: "string",
bool: "boolean",
int: "integer",
float: "float",
float: "number",
dict: "json",
}
outer_type_mapping: dict[type, WireType] = {
list: "array",
dict: "json",
list: "json",
BaseModel: "json",
}
wire_type = type_mapping.get(_type)
if wire_type:
return cast(Literal["string", "integer", "float", "boolean", "json"], wire_type)
elif hasattr(_type, "__origin__"):
# account for "list[str]" and "dict[str, int]" and "Optional[str]" and other typing types
origin = _type.__origin__
if origin in [list, dict]:
return "json"
elif issubclass(_type, Enum):
return wire_type
if hasattr(_type, "__origin__"):
wire_type = outer_type_mapping.get(cast(type, get_origin(_type)))
if wire_type:
return wire_type
if issubclass(_type, Enum):
return "string"
elif issubclass(_type, BaseModel):
if issubclass(_type, BaseModel):
return "json"
raise ToolDefinitionError(f"Unsupported parameter type: {_type}")

View file

@ -6,9 +6,12 @@ from pydantic import AnyUrl, BaseModel, Field
class ValueSchema(BaseModel):
"""Value schema for input parameters and outputs."""
val_type: Literal["string", "integer", "float", "boolean", "json"]
val_type: Literal["string", "integer", "number", "boolean", "json", "array"]
"""The type of the value."""
inner_val_type: Optional[Literal["string", "integer", "number", "boolean", "json"]] = None
"""The type of the inner value, if the value is a list."""
enum: Optional[list[str]] = None
"""The list of possible values for the value, if it is a closed list."""

View file

@ -48,7 +48,15 @@ def does_function_return_value(func: Callable) -> bool:
"""
Returns True if the given function returns a value, i.e. if it has a return statement with a value.
"""
source = inspect.getsource(func)
try:
source: Optional[str] = inspect.getsource(func)
except OSError:
# Workaround for parameterized unit tests that use a dynamically-generated function
source = getattr(func, "__source__", None)
if source is None:
raise ValueError("Source code not found")
tree = ast.parse(source)
class ReturnVisitor(ast.NodeVisitor):

View file

@ -63,11 +63,6 @@ def func_with_google_auth_requirement():
### Tests on input params
@tool(desc="A function with an input parameter")
def func_with_param(context: Annotated[str, "First param"]):
pass
@tool(desc="A function with a non-inferrable input parameter")
def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]):
pass
@ -79,23 +74,13 @@ def func_with_renamed_param(param1: Annotated[str, "ParamOne", "First param"]):
pass
@tool(desc="A function with every possible input parameter")
def func_with_every_param(
param1: Annotated[str, "a string"],
param2: Annotated[int, "an integer"],
param3: Annotated[float, "a float"],
param4: Annotated[bool, "a boolean"],
):
pass
class TestEnum(Enum):
class MyEnum(Enum):
FOO_BAR = "foo bar"
BAZ = "baz"
@tool(desc="A function that takes an enum")
def func_with_enum_param(param1: Annotated[TestEnum, "an enum"]):
def func_with_enum_param(param1: Annotated[MyEnum, "an enum"]):
pass
@ -142,8 +127,25 @@ def func_with_mixed_params(
pass
@tool(desc="A function with a list[str] parameter")
def func_with_list_param(param1: Annotated[list[str], "A list of strings"]):
pass
@tool(desc="A function with a list[float] parameter")
def func_with_list_float_param(param1: Annotated[list[float], "A list of floats"]):
pass
@tool(desc="A function with a list of enums parameter")
def func_with_list_of_enums_param(param1: Annotated[list[MyEnum], "A list of enums"]):
pass
@tool(desc="A function with a complex parameter type")
def func_with_complex_param(param1: Annotated[list[str], "A list of strings"]):
def func_with_complex_param(
param1: Annotated[dict[str, list[int]], "A dictionary with lists of integers"],
):
pass
@ -153,14 +155,19 @@ def func_with_context(my_context: ToolContext):
### Tests on output/return values
@tool(desc="A function that performs an action without returning anything")
def func_with_no_return():
pass
@tool(desc="A function that returns a list of strings")
def func_with_list_return() -> list[str]:
return ["output1", "output2"]
@tool(desc="A function that returns a value")
def func_with_value_return() -> str:
return "output"
@tool(desc="A function that returns a known list of string literals")
def func_with_known_list_return() -> Literal["value1", "value2"]:
return "value1"
@tool(desc="A function that returns an enum")
def func_with_enum_return() -> MyEnum:
return MyEnum.FOO_BAR
@tool(desc="A function with an annotated return type")
@ -174,7 +181,7 @@ def func_with_optional_return() -> Optional[str]:
@tool(desc="A function with a complex return type")
def func_with_complex_return() -> list[dict[str, str]]:
def func_with_complex_return() -> dict[str, str]:
return [{"key": "value"}]
@ -244,30 +251,6 @@ def func_with_complex_return() -> list[dict[str, str]]:
id="func_with_google_auth_requirement",
),
# Tests on input params
pytest.param(
func_with_value_return,
{
"inputs": ToolInputs(parameters=[]),
},
id="func_with_no_params",
),
pytest.param(
func_with_param,
{
"inputs": ToolInputs(
parameters=[
InputParameter(
name="context", # Nothing special about this name, parameters can be named anything
description="First param",
inferrable=True, # Defaults to true
required=True,
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
},
id="func_with_param",
),
pytest.param(
func_with_non_inferrable_param,
{
@ -302,44 +285,6 @@ def func_with_complex_return() -> list[dict[str, str]]:
},
id="func_with_renamed_param",
),
pytest.param(
func_with_every_param,
{
"inputs": ToolInputs(
parameters=[
InputParameter(
name="param1",
description="a string",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="string", enum=None),
),
InputParameter(
name="param2",
description="an integer",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="integer", enum=None),
),
InputParameter(
name="param3",
description="a float",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="float", enum=None),
),
InputParameter(
name="param4",
description="a boolean",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="boolean", enum=None),
),
]
),
},
id="func_with_every_param",
),
pytest.param(
func_with_enum_param,
{
@ -497,7 +442,7 @@ def func_with_complex_return() -> list[dict[str, str]]:
id="func_with_mixed_params",
),
pytest.param(
func_with_complex_param,
func_with_list_param,
{
"inputs": ToolInputs(
parameters=[
@ -506,6 +451,63 @@ def func_with_complex_return() -> list[dict[str, str]]:
description="A list of strings",
inferrable=True,
required=True,
value_schema=ValueSchema(
val_type="array", inner_val_type="string", enum=None
),
)
]
),
},
id="func_with_list_param",
),
pytest.param(
func_with_list_float_param,
{
"inputs": ToolInputs(
parameters=[
InputParameter(
name="param1",
description="A list of floats",
inferrable=True,
required=True,
value_schema=ValueSchema(
val_type="array", inner_val_type="number", enum=None
),
)
]
),
},
id="func_with_list_float_param",
),
pytest.param(
func_with_list_of_enums_param,
{
"inputs": ToolInputs(
parameters=[
InputParameter(
name="param1",
description="A list of enums",
inferrable=True,
required=True,
value_schema=ValueSchema(
val_type="array", inner_val_type="string", enum=["foo bar", "baz"]
),
)
]
),
},
id="func_with_list_of_enums_param",
),
pytest.param(
func_with_complex_param,
{
"inputs": ToolInputs(
parameters=[
InputParameter(
name="param1",
description="A dictionary with lists of integers",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="json", enum=None),
)
]
@ -524,25 +526,40 @@ def func_with_complex_return() -> list[dict[str, str]]:
),
# Tests on output values
pytest.param(
func_with_no_return,
{
"output": ToolOutput(
available_modes=["null"], description="No description provided."
),
},
id="func_with_no_return",
),
pytest.param(
func_with_value_return,
func_with_list_return,
{
"inputs": ToolInputs(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
value_schema=ValueSchema(val_type="array", inner_val_type="string", enum=None),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_value_return",
id="func_with_list_return",
),
pytest.param(
func_with_known_list_return,
{
"inputs": ToolInputs(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=["value1", "value2"]),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_known_list_return",
),
pytest.param(
func_with_enum_return,
{
"inputs": ToolInputs(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=["foo bar", "baz"]),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_enum_return",
),
pytest.param(
func_with_annotated_return,
@ -592,5 +609,5 @@ def test_create_tool_def(func_under_test, expected_tool_def_fields):
def tool_version_is_set_correctly():
tool_def = ToolCatalog.create_tool_definition(func_with_no_return, "abcd1236")
tool_def = ToolCatalog.create_tool_definition(func_with_description, "abcd1236")
assert tool_def.version == "abcd1236"

View file

@ -0,0 +1,79 @@
from typing import Annotated
import pytest
from arcade.core.catalog import ToolCatalog, get_wire_type
from arcade.sdk import tool
class Case:
def __init__(self, input_type: type, output_type: type | None):
self.input_type = input_type
self.output_type = output_type
def __str__(self):
return f"Case(input_type={self.input_type}, output_type={self.output_type})"
primitives = [bool, float, int, str]
test_cases = [
Case(input_type=input_type, output_type=output_type)
for input_type in [*primitives, []]
for output_type in [*primitives, None]
] + [
Case(input_type=[primitives[i] for i in range(n)], output_type=output_type)
for n in range(2, len(primitives) + 1)
for output_type in [*primitives, None]
]
# Generate tool functions dynamically
def generate_tool_function(input_types: list[type], output_type: type | None):
input_annotation = ", ".join(
[
f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']"
for i, input_type in enumerate(input_types)
]
)
output_annotation = f" -> {output_type.__name__}" if output_type else ""
func_code = f"""
@tool(desc="Generated function with input and output types")
def generated_func({input_annotation}){output_annotation}:
pass
"""
local_vars = {}
exec(func_code, {"tool": tool, "Annotated": Annotated}, local_vars) # noqa: S102
generated_func = local_vars.get("generated_func")
generated_func.__source__ = func_code # Attach the source code to the function
return generated_func
@pytest.mark.parametrize("test_case", test_cases, ids=[str(tc) for tc in test_cases])
def test_create_tool_def2(test_case):
input_types = (
test_case.input_type if isinstance(test_case.input_type, list) else [test_case.input_type]
)
output_type = test_case.output_type
# Generate the function dynamically
generated_func = generate_tool_function(input_types, output_type)
assert generated_func is not None, "generated_func was not created"
# Create tool definition using the generated function
tool_def = ToolCatalog.create_tool_definition(generated_func, "1.0")
for i, input_type in enumerate(input_types):
param = tool_def.inputs.parameters[i]
assert (
param.value_schema.val_type == get_wire_type(input_type)
), f"Parameter {param.name} has value type {param.value_schema.val_type} but {input_type} was expected at index {i}"
if output_type:
assert tool_def.output.value_schema.val_type == get_wire_type(
output_type
), f"Output has value type {tool_def.output.val_type} but {output_type} was expected"
else:
assert tool_def.output.value_schema is None, "Output is not None"

View file

@ -258,12 +258,14 @@ def read_products(
name="cols",
description="The columns to return",
required=False,
value_schema=ValueSchema(val_type="json", enum=None),
value_schema=ValueSchema(
val_type="array", inner_val_type="string", enum=None
),
),
]
),
"output": ToolOutput(
value_schema=ValueSchema(val_type="json", enum=None),
value_schema=ValueSchema(val_type="array", inner_val_type="json", enum=None),
available_modes=["value", "error"],
description="Data with the selected columns",
),

View file

@ -4,14 +4,18 @@
"primitives": {
// All supported primitive data types
"type": "string",
"enum": ["string", "integer", "float", "boolean", "json"]
"enum": ["string", "integer", "number", "boolean", "json"]
},
"value_schema": {
// Represents the schema of a value (e.g. function input parameter value)
"type": "object",
"properties": {
"val_type": {
"$ref": "#/$defs/primitives"
"oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "string", "enum": ["array"] }]
},
"inner_val_type": {
"$ref": "#/$defs/primitives",
"description": "If the value type is a list, the type of the list values."
},
"enum": {
"oneOf": [
@ -26,7 +30,13 @@
}
},
"required": ["val_type"],
"additionalProperties": false
"additionalProperties": false,
"if": {
"properties": { "val_type": { "const": "array" } }
},
"then": {
"required": ["inner_val_type"]
}
}
},
"type": "object",

View file

@ -52,3 +52,13 @@ def sqrt(
Get the square root of a number
"""
return math.sqrt(a)
@tool
def sum_list(
numbers: Annotated[list[float], "The list of numbers"],
) -> Annotated[float, "The sum of the numbers in the list"]:
"""
Sum all numbers in a list
"""
return sum(numbers)