SDK support for lists (arrays) in tool inputs & outputs (#36)
Working: - Declare tool functions that have `list[str]` (etc) as input parameter or output values - Engine can call these functions! <img width="1195" alt="image" src="https://github.com/user-attachments/assets/2aeb3c98-950a-4e2f-a8c7-39102e3fb7f0">
This commit is contained in:
parent
345b685b08
commit
5726778f11
8 changed files with 328 additions and 149 deletions
|
|
@ -46,7 +46,19 @@ from arcade.core.utils import (
|
|||
from arcade.sdk.annotations import Inferrable
|
||||
from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization
|
||||
|
||||
WireType = Literal["string", "integer", "float", "boolean", "json"]
|
||||
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
|
||||
WireType = Union[InnerWireType, Literal["array"]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WireTypeInfo:
|
||||
"""
|
||||
Represents the wire type information for a value, including its inner type if it's a list.
|
||||
"""
|
||||
|
||||
wire_type: WireType
|
||||
inner_wire_type: InnerWireType | None = None
|
||||
enum_values: list[str] | None = None
|
||||
|
||||
|
||||
class ToolMeta(BaseModel):
|
||||
|
|
@ -233,17 +245,6 @@ def create_input_definition(func: Callable) -> ToolInputs:
|
|||
|
||||
tool_field_info = extract_field_info(param)
|
||||
|
||||
is_enum = False
|
||||
enum_values: list[str] = []
|
||||
|
||||
# Special case: Literal["string1", "string2"] can be enumerated on the wire
|
||||
if is_string_literal(tool_field_info.field_type):
|
||||
is_enum = True
|
||||
enum_values = [str(e) for e in get_args(tool_field_info.field_type)]
|
||||
elif issubclass(tool_field_info.field_type, Enum):
|
||||
is_enum = True
|
||||
enum_values = [e.value for e in tool_field_info.field_type]
|
||||
|
||||
# If the field has a default value, it is not required
|
||||
# If the field is optional, it is not required
|
||||
has_default_value = tool_field_info.default is not None
|
||||
|
|
@ -256,8 +257,9 @@ def create_input_definition(func: Callable) -> ToolInputs:
|
|||
required=is_required,
|
||||
inferrable=tool_field_info.is_inferrable,
|
||||
value_schema=ValueSchema(
|
||||
val_type=tool_field_info.wire_type,
|
||||
enum=enum_values if is_enum else None,
|
||||
val_type=tool_field_info.wire_type_info.wire_type,
|
||||
inner_val_type=tool_field_info.wire_type_info.inner_wire_type,
|
||||
enum=tool_field_info.wire_type_info.enum_values,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -291,7 +293,7 @@ def create_output_definition(func: Callable) -> ToolOutput:
|
|||
return_type = next(arg for arg in get_args(return_type) if arg is not type(None))
|
||||
is_optional = True
|
||||
|
||||
wire_type = get_wire_type(return_type)
|
||||
wire_type_info = get_wire_type_info(return_type)
|
||||
|
||||
available_modes = ["value", "error"]
|
||||
|
||||
|
|
@ -301,7 +303,11 @@ def create_output_definition(func: Callable) -> ToolOutput:
|
|||
return ToolOutput(
|
||||
description=description,
|
||||
available_modes=available_modes,
|
||||
value_schema=ValueSchema(val_type=wire_type),
|
||||
value_schema=ValueSchema(
|
||||
val_type=wire_type_info.wire_type,
|
||||
inner_val_type=wire_type_info.inner_wire_type,
|
||||
enum=wire_type_info.enum_values,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -329,14 +335,17 @@ class ToolParamInfo:
|
|||
default: Any
|
||||
original_type: type
|
||||
field_type: type
|
||||
wire_type: WireType
|
||||
wire_type_info: WireTypeInfo
|
||||
description: str | None = None
|
||||
is_optional: bool = True
|
||||
is_inferrable: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_param_info(
|
||||
cls, param_info: ParamInfo, wire_type: WireType, is_inferrable: bool = True
|
||||
cls,
|
||||
param_info: ParamInfo,
|
||||
wire_type_info: WireTypeInfo,
|
||||
is_inferrable: bool = True,
|
||||
) -> "ToolParamInfo":
|
||||
return cls(
|
||||
name=param_info.name,
|
||||
|
|
@ -345,7 +354,7 @@ class ToolParamInfo:
|
|||
field_type=param_info.field_type,
|
||||
description=param_info.description,
|
||||
is_optional=param_info.is_optional,
|
||||
wire_type=wire_type,
|
||||
wire_type_info=wire_type_info,
|
||||
is_inferrable=is_inferrable,
|
||||
)
|
||||
|
||||
|
|
@ -362,7 +371,7 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|||
if isinstance(param.default, FieldInfo):
|
||||
param_info = extract_pydantic_param_info(param)
|
||||
else:
|
||||
param_info = extract_regular_param_info(param)
|
||||
param_info = extract_python_param_info(param)
|
||||
|
||||
metadata = getattr(annotation, "__metadata__", [])
|
||||
str_annotations = [m for m in metadata if isinstance(m, str)]
|
||||
|
|
@ -386,24 +395,59 @@ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|||
# Params are inferrable by default
|
||||
is_inferrable = inferrable_annotation.value if inferrable_annotation else True
|
||||
|
||||
# Get the wire type
|
||||
wire_type = (
|
||||
get_wire_type(str)
|
||||
if is_string_literal(param_info.field_type)
|
||||
else get_wire_type(param_info.field_type)
|
||||
)
|
||||
# Get the wire (serialization) type information for the type
|
||||
wire_type_info = get_wire_type_info(param_info.field_type)
|
||||
|
||||
# Final reality check
|
||||
if param_info.description is None:
|
||||
raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description")
|
||||
|
||||
if wire_type is None:
|
||||
if wire_type_info.wire_type is None:
|
||||
raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}")
|
||||
|
||||
return ToolParamInfo.from_param_info(param_info, wire_type, is_inferrable)
|
||||
return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable)
|
||||
|
||||
|
||||
def extract_regular_param_info(param: inspect.Parameter) -> ParamInfo:
|
||||
def get_wire_type_info(_type: type) -> WireTypeInfo:
|
||||
"""
|
||||
Get the wire type information for a given type.
|
||||
"""
|
||||
|
||||
# Is this a list type?
|
||||
# If so, get the inner (enclosed) type
|
||||
is_list = get_origin(_type) is list
|
||||
if is_list:
|
||||
inner_type = get_args(_type)[0]
|
||||
inner_wire_type = cast(
|
||||
InnerWireType,
|
||||
get_wire_type(str) if is_string_literal(inner_type) else get_wire_type(inner_type),
|
||||
)
|
||||
else:
|
||||
inner_wire_type = None
|
||||
|
||||
# Get the outer wire type
|
||||
wire_type = get_wire_type(str) if is_string_literal(_type) else get_wire_type(_type)
|
||||
|
||||
# Handle enums (known/fixed lists of values)
|
||||
is_enum = False
|
||||
enum_values: list[str] = []
|
||||
|
||||
type_to_check = inner_type if is_list else _type
|
||||
|
||||
# Special case: Literal["string1", "string2"] can be enumerated on the wire
|
||||
if is_string_literal(type_to_check):
|
||||
is_enum = True
|
||||
enum_values = [str(e) for e in get_args(type_to_check)]
|
||||
|
||||
# Special case: Enum can be enumerated on the wire
|
||||
elif issubclass(type_to_check, Enum):
|
||||
is_enum = True
|
||||
enum_values = [e.value for e in type_to_check]
|
||||
|
||||
return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None)
|
||||
|
||||
|
||||
def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
|
||||
# If the param is Annotated[], unwrap the annotation to get the "real" type
|
||||
# Otherwise, use the literal type
|
||||
annotation = param.annotation
|
||||
|
|
@ -465,28 +509,34 @@ def get_wire_type(
|
|||
"""
|
||||
Mapping between Python types and HTTP/JSON types
|
||||
"""
|
||||
type_mapping = {
|
||||
type_mapping: dict[type, WireType] = {
|
||||
str: "string",
|
||||
bool: "boolean",
|
||||
int: "integer",
|
||||
float: "float",
|
||||
float: "number",
|
||||
dict: "json",
|
||||
}
|
||||
|
||||
outer_type_mapping: dict[type, WireType] = {
|
||||
list: "array",
|
||||
dict: "json",
|
||||
list: "json",
|
||||
BaseModel: "json",
|
||||
}
|
||||
|
||||
wire_type = type_mapping.get(_type)
|
||||
if wire_type:
|
||||
return cast(Literal["string", "integer", "float", "boolean", "json"], wire_type)
|
||||
elif hasattr(_type, "__origin__"):
|
||||
# account for "list[str]" and "dict[str, int]" and "Optional[str]" and other typing types
|
||||
origin = _type.__origin__
|
||||
if origin in [list, dict]:
|
||||
return "json"
|
||||
elif issubclass(_type, Enum):
|
||||
return wire_type
|
||||
|
||||
if hasattr(_type, "__origin__"):
|
||||
wire_type = outer_type_mapping.get(cast(type, get_origin(_type)))
|
||||
if wire_type:
|
||||
return wire_type
|
||||
|
||||
if issubclass(_type, Enum):
|
||||
return "string"
|
||||
elif issubclass(_type, BaseModel):
|
||||
|
||||
if issubclass(_type, BaseModel):
|
||||
return "json"
|
||||
|
||||
raise ToolDefinitionError(f"Unsupported parameter type: {_type}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,12 @@ from pydantic import AnyUrl, BaseModel, Field
|
|||
class ValueSchema(BaseModel):
|
||||
"""Value schema for input parameters and outputs."""
|
||||
|
||||
val_type: Literal["string", "integer", "float", "boolean", "json"]
|
||||
val_type: Literal["string", "integer", "number", "boolean", "json", "array"]
|
||||
"""The type of the value."""
|
||||
|
||||
inner_val_type: Optional[Literal["string", "integer", "number", "boolean", "json"]] = None
|
||||
"""The type of the inner value, if the value is a list."""
|
||||
|
||||
enum: Optional[list[str]] = None
|
||||
"""The list of possible values for the value, if it is a closed list."""
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,15 @@ def does_function_return_value(func: Callable) -> bool:
|
|||
"""
|
||||
Returns True if the given function returns a value, i.e. if it has a return statement with a value.
|
||||
"""
|
||||
source = inspect.getsource(func)
|
||||
try:
|
||||
source: Optional[str] = inspect.getsource(func)
|
||||
except OSError:
|
||||
# Workaround for parameterized unit tests that use a dynamically-generated function
|
||||
source = getattr(func, "__source__", None)
|
||||
|
||||
if source is None:
|
||||
raise ValueError("Source code not found")
|
||||
|
||||
tree = ast.parse(source)
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
|
|
|
|||
|
|
@ -63,11 +63,6 @@ def func_with_google_auth_requirement():
|
|||
|
||||
|
||||
### Tests on input params
|
||||
@tool(desc="A function with an input parameter")
|
||||
def func_with_param(context: Annotated[str, "First param"]):
|
||||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with a non-inferrable input parameter")
|
||||
def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]):
|
||||
pass
|
||||
|
|
@ -79,23 +74,13 @@ def func_with_renamed_param(param1: Annotated[str, "ParamOne", "First param"]):
|
|||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with every possible input parameter")
|
||||
def func_with_every_param(
|
||||
param1: Annotated[str, "a string"],
|
||||
param2: Annotated[int, "an integer"],
|
||||
param3: Annotated[float, "a float"],
|
||||
param4: Annotated[bool, "a boolean"],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TestEnum(Enum):
|
||||
class MyEnum(Enum):
|
||||
FOO_BAR = "foo bar"
|
||||
BAZ = "baz"
|
||||
|
||||
|
||||
@tool(desc="A function that takes an enum")
|
||||
def func_with_enum_param(param1: Annotated[TestEnum, "an enum"]):
|
||||
def func_with_enum_param(param1: Annotated[MyEnum, "an enum"]):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -142,8 +127,25 @@ def func_with_mixed_params(
|
|||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with a list[str] parameter")
|
||||
def func_with_list_param(param1: Annotated[list[str], "A list of strings"]):
|
||||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with a list[float] parameter")
|
||||
def func_with_list_float_param(param1: Annotated[list[float], "A list of floats"]):
|
||||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with a list of enums parameter")
|
||||
def func_with_list_of_enums_param(param1: Annotated[list[MyEnum], "A list of enums"]):
|
||||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with a complex parameter type")
|
||||
def func_with_complex_param(param1: Annotated[list[str], "A list of strings"]):
|
||||
def func_with_complex_param(
|
||||
param1: Annotated[dict[str, list[int]], "A dictionary with lists of integers"],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -153,14 +155,19 @@ def func_with_context(my_context: ToolContext):
|
|||
|
||||
|
||||
### Tests on output/return values
|
||||
@tool(desc="A function that performs an action without returning anything")
|
||||
def func_with_no_return():
|
||||
pass
|
||||
@tool(desc="A function that returns a list of strings")
|
||||
def func_with_list_return() -> list[str]:
|
||||
return ["output1", "output2"]
|
||||
|
||||
|
||||
@tool(desc="A function that returns a value")
|
||||
def func_with_value_return() -> str:
|
||||
return "output"
|
||||
@tool(desc="A function that returns a known list of string literals")
|
||||
def func_with_known_list_return() -> Literal["value1", "value2"]:
|
||||
return "value1"
|
||||
|
||||
|
||||
@tool(desc="A function that returns an enum")
|
||||
def func_with_enum_return() -> MyEnum:
|
||||
return MyEnum.FOO_BAR
|
||||
|
||||
|
||||
@tool(desc="A function with an annotated return type")
|
||||
|
|
@ -174,7 +181,7 @@ def func_with_optional_return() -> Optional[str]:
|
|||
|
||||
|
||||
@tool(desc="A function with a complex return type")
|
||||
def func_with_complex_return() -> list[dict[str, str]]:
|
||||
def func_with_complex_return() -> dict[str, str]:
|
||||
return [{"key": "value"}]
|
||||
|
||||
|
||||
|
|
@ -244,30 +251,6 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
id="func_with_google_auth_requirement",
|
||||
),
|
||||
# Tests on input params
|
||||
pytest.param(
|
||||
func_with_value_return,
|
||||
{
|
||||
"inputs": ToolInputs(parameters=[]),
|
||||
},
|
||||
id="func_with_no_params",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_param,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="context", # Nothing special about this name, parameters can be named anything
|
||||
description="First param",
|
||||
inferrable=True, # Defaults to true
|
||||
required=True,
|
||||
value_schema=ValueSchema(val_type="string", enum=None),
|
||||
)
|
||||
]
|
||||
),
|
||||
},
|
||||
id="func_with_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_non_inferrable_param,
|
||||
{
|
||||
|
|
@ -302,44 +285,6 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
},
|
||||
id="func_with_renamed_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_every_param,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="param1",
|
||||
description="a string",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(val_type="string", enum=None),
|
||||
),
|
||||
InputParameter(
|
||||
name="param2",
|
||||
description="an integer",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(val_type="integer", enum=None),
|
||||
),
|
||||
InputParameter(
|
||||
name="param3",
|
||||
description="a float",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(val_type="float", enum=None),
|
||||
),
|
||||
InputParameter(
|
||||
name="param4",
|
||||
description="a boolean",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(val_type="boolean", enum=None),
|
||||
),
|
||||
]
|
||||
),
|
||||
},
|
||||
id="func_with_every_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_enum_param,
|
||||
{
|
||||
|
|
@ -497,7 +442,7 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
id="func_with_mixed_params",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_complex_param,
|
||||
func_with_list_param,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
|
|
@ -506,6 +451,63 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
description="A list of strings",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(
|
||||
val_type="array", inner_val_type="string", enum=None
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
},
|
||||
id="func_with_list_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_list_float_param,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="param1",
|
||||
description="A list of floats",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(
|
||||
val_type="array", inner_val_type="number", enum=None
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
},
|
||||
id="func_with_list_float_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_list_of_enums_param,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="param1",
|
||||
description="A list of enums",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(
|
||||
val_type="array", inner_val_type="string", enum=["foo bar", "baz"]
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
},
|
||||
id="func_with_list_of_enums_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_complex_param,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="param1",
|
||||
description="A dictionary with lists of integers",
|
||||
inferrable=True,
|
||||
required=True,
|
||||
value_schema=ValueSchema(val_type="json", enum=None),
|
||||
)
|
||||
]
|
||||
|
|
@ -524,25 +526,40 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
),
|
||||
# Tests on output values
|
||||
pytest.param(
|
||||
func_with_no_return,
|
||||
{
|
||||
"output": ToolOutput(
|
||||
available_modes=["null"], description="No description provided."
|
||||
),
|
||||
},
|
||||
id="func_with_no_return",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_value_return,
|
||||
func_with_list_return,
|
||||
{
|
||||
"inputs": ToolInputs(parameters=[]),
|
||||
"output": ToolOutput(
|
||||
value_schema=ValueSchema(val_type="string", enum=None),
|
||||
value_schema=ValueSchema(val_type="array", inner_val_type="string", enum=None),
|
||||
available_modes=["value", "error"],
|
||||
description="No description provided.",
|
||||
),
|
||||
},
|
||||
id="func_with_value_return",
|
||||
id="func_with_list_return",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_known_list_return,
|
||||
{
|
||||
"inputs": ToolInputs(parameters=[]),
|
||||
"output": ToolOutput(
|
||||
value_schema=ValueSchema(val_type="string", enum=["value1", "value2"]),
|
||||
available_modes=["value", "error"],
|
||||
description="No description provided.",
|
||||
),
|
||||
},
|
||||
id="func_with_known_list_return",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_enum_return,
|
||||
{
|
||||
"inputs": ToolInputs(parameters=[]),
|
||||
"output": ToolOutput(
|
||||
value_schema=ValueSchema(val_type="string", enum=["foo bar", "baz"]),
|
||||
available_modes=["value", "error"],
|
||||
description="No description provided.",
|
||||
),
|
||||
},
|
||||
id="func_with_enum_return",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_annotated_return,
|
||||
|
|
@ -592,5 +609,5 @@ def test_create_tool_def(func_under_test, expected_tool_def_fields):
|
|||
|
||||
|
||||
def tool_version_is_set_correctly():
|
||||
tool_def = ToolCatalog.create_tool_definition(func_with_no_return, "abcd1236")
|
||||
tool_def = ToolCatalog.create_tool_definition(func_with_description, "abcd1236")
|
||||
assert tool_def.version == "abcd1236"
|
||||
|
|
|
|||
79
arcade/tests/tool/test_create_tool_definition_new.py
Normal file
79
arcade/tests/tool/test_create_tool_definition_new.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
|
||||
from arcade.core.catalog import ToolCatalog, get_wire_type
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
class Case:
|
||||
def __init__(self, input_type: type, output_type: type | None):
|
||||
self.input_type = input_type
|
||||
self.output_type = output_type
|
||||
|
||||
def __str__(self):
|
||||
return f"Case(input_type={self.input_type}, output_type={self.output_type})"
|
||||
|
||||
|
||||
primitives = [bool, float, int, str]
|
||||
|
||||
test_cases = [
|
||||
Case(input_type=input_type, output_type=output_type)
|
||||
for input_type in [*primitives, []]
|
||||
for output_type in [*primitives, None]
|
||||
] + [
|
||||
Case(input_type=[primitives[i] for i in range(n)], output_type=output_type)
|
||||
for n in range(2, len(primitives) + 1)
|
||||
for output_type in [*primitives, None]
|
||||
]
|
||||
|
||||
|
||||
# Generate tool functions dynamically
|
||||
def generate_tool_function(input_types: list[type], output_type: type | None):
|
||||
input_annotation = ", ".join(
|
||||
[
|
||||
f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']"
|
||||
for i, input_type in enumerate(input_types)
|
||||
]
|
||||
)
|
||||
output_annotation = f" -> {output_type.__name__}" if output_type else ""
|
||||
|
||||
func_code = f"""
|
||||
@tool(desc="Generated function with input and output types")
|
||||
def generated_func({input_annotation}){output_annotation}:
|
||||
pass
|
||||
"""
|
||||
local_vars = {}
|
||||
exec(func_code, {"tool": tool, "Annotated": Annotated}, local_vars) # noqa: S102
|
||||
generated_func = local_vars.get("generated_func")
|
||||
generated_func.__source__ = func_code # Attach the source code to the function
|
||||
return generated_func
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case", test_cases, ids=[str(tc) for tc in test_cases])
|
||||
def test_create_tool_def2(test_case):
|
||||
input_types = (
|
||||
test_case.input_type if isinstance(test_case.input_type, list) else [test_case.input_type]
|
||||
)
|
||||
output_type = test_case.output_type
|
||||
|
||||
# Generate the function dynamically
|
||||
generated_func = generate_tool_function(input_types, output_type)
|
||||
|
||||
assert generated_func is not None, "generated_func was not created"
|
||||
|
||||
# Create tool definition using the generated function
|
||||
tool_def = ToolCatalog.create_tool_definition(generated_func, "1.0")
|
||||
|
||||
for i, input_type in enumerate(input_types):
|
||||
param = tool_def.inputs.parameters[i]
|
||||
assert (
|
||||
param.value_schema.val_type == get_wire_type(input_type)
|
||||
), f"Parameter {param.name} has value type {param.value_schema.val_type} but {input_type} was expected at index {i}"
|
||||
|
||||
if output_type:
|
||||
assert tool_def.output.value_schema.val_type == get_wire_type(
|
||||
output_type
|
||||
), f"Output has value type {tool_def.output.val_type} but {output_type} was expected"
|
||||
else:
|
||||
assert tool_def.output.value_schema is None, "Output is not None"
|
||||
|
|
@ -258,12 +258,14 @@ def read_products(
|
|||
name="cols",
|
||||
description="The columns to return",
|
||||
required=False,
|
||||
value_schema=ValueSchema(val_type="json", enum=None),
|
||||
value_schema=ValueSchema(
|
||||
val_type="array", inner_val_type="string", enum=None
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
"output": ToolOutput(
|
||||
value_schema=ValueSchema(val_type="json", enum=None),
|
||||
value_schema=ValueSchema(val_type="array", inner_val_type="json", enum=None),
|
||||
available_modes=["value", "error"],
|
||||
description="Data with the selected columns",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,14 +4,18 @@
|
|||
"primitives": {
|
||||
// All supported primitive data types
|
||||
"type": "string",
|
||||
"enum": ["string", "integer", "float", "boolean", "json"]
|
||||
"enum": ["string", "integer", "number", "boolean", "json"]
|
||||
},
|
||||
"value_schema": {
|
||||
// Represents the schema of a value (e.g. function input parameter value)
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"val_type": {
|
||||
"$ref": "#/$defs/primitives"
|
||||
"oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "string", "enum": ["array"] }]
|
||||
},
|
||||
"inner_val_type": {
|
||||
"$ref": "#/$defs/primitives",
|
||||
"description": "If the value type is a list, the type of the list values."
|
||||
},
|
||||
"enum": {
|
||||
"oneOf": [
|
||||
|
|
@ -26,7 +30,13 @@
|
|||
}
|
||||
},
|
||||
"required": ["val_type"],
|
||||
"additionalProperties": false
|
||||
"additionalProperties": false,
|
||||
"if": {
|
||||
"properties": { "val_type": { "const": "array" } }
|
||||
},
|
||||
"then": {
|
||||
"required": ["inner_val_type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
|
|
|||
|
|
@ -52,3 +52,13 @@ def sqrt(
|
|||
Get the square root of a number
|
||||
"""
|
||||
return math.sqrt(a)
|
||||
|
||||
|
||||
@tool
|
||||
def sum_list(
|
||||
numbers: Annotated[list[float], "The list of numbers"],
|
||||
) -> Annotated[float, "The sum of the numbers in the list"]:
|
||||
"""
|
||||
Sum all numbers in a list
|
||||
"""
|
||||
return sum(numbers)
|
||||
|
|
|
|||
Loading…
Reference in a new issue