Improve Typedict and Basemodel support (#523)
Improve Pydantic and Typedict support and add a bunch of tets.
1. Fixed the test failure where TypeDict was being serialized as a list
of tuples instead of a dict by:
- Adding proper handling for BaseModel instances in the output.py file
- Converting BaseModel results (from TypeDict conversion) to dicts using
model_dump()
- Handling lists containing BaseModel objects
2. Fixed None handling to ensure None results are converted to empty
strings as expected
3. Updated the schema.py to allow dict and list types in
ToolCallOutput.value
4. new tests
- TypeDict output execution tests
- Output factory tests
- Executor tests with TypeDict support
- Schema validation tests
The key changes were:
- In ``arcade_core/output.py``: Added BaseModel conversion logic in the
success method
- In ``arcade_core/schema.py``: Changed ToolCallOutput.value type from
list[str] to list to support complex types
TODO
- [ ] Confirm engine compatibility without changes made to engine
---------
Co-authored-by: Eric Gustin <eric@arcade.dev>
This commit is contained in:
parent
e0ddc0ce90
commit
e188fc6ae9
8 changed files with 663 additions and 31 deletions
|
|
@ -59,19 +59,36 @@ def display_tool_details(tool: ToolDefinition, worker: bool = False) -> None: #
|
|||
inputs_table.add_column("Type", style="magenta")
|
||||
inputs_table.add_column("Required", style="yellow")
|
||||
inputs_table.add_column("Description", style="white")
|
||||
inputs_table.add_column("Default", style="blue")
|
||||
|
||||
for param in inputs:
|
||||
# Since InputParameter does not have a default field, we use "N/A"
|
||||
default_value = "N/A"
|
||||
if param.value_schema.enum:
|
||||
default_value = f"One of {param.value_schema.enum}"
|
||||
# Format the type string properly
|
||||
type_str = _format_type_string(param.value_schema)
|
||||
|
||||
# Add the main parameter row
|
||||
inputs_table.add_row(
|
||||
param.name,
|
||||
param.value_schema.val_type,
|
||||
type_str,
|
||||
str(param.required),
|
||||
param.description or "",
|
||||
default_value,
|
||||
)
|
||||
|
||||
# If this is a json type with properties, show them
|
||||
if (
|
||||
param.value_schema.val_type == "json"
|
||||
and hasattr(param.value_schema, "properties")
|
||||
and param.value_schema.properties
|
||||
):
|
||||
_add_nested_properties(inputs_table, param.value_schema.properties, indent=1)
|
||||
# Handle arrays with inner properties
|
||||
elif (
|
||||
param.value_schema.val_type == "array"
|
||||
and hasattr(param.value_schema, "inner_properties")
|
||||
and param.value_schema.inner_properties
|
||||
):
|
||||
_add_nested_properties(
|
||||
inputs_table, param.value_schema.inner_properties, indent=1, is_array_item=True
|
||||
)
|
||||
|
||||
inputs_panel = Panel(
|
||||
inputs_table,
|
||||
title="Input Parameters",
|
||||
|
|
@ -241,7 +258,7 @@ def _add_nested_properties(
|
|||
is_array_item: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add nested properties to the output table.
|
||||
Recursively add nested properties to the table.
|
||||
|
||||
Args:
|
||||
table: The Rich table to add rows to
|
||||
|
|
@ -253,11 +270,14 @@ def _add_nested_properties(
|
|||
|
||||
# Show array item indicator if needed
|
||||
if is_array_item and indent > 0:
|
||||
table.add_row(
|
||||
f"{indent_prefix[:-2]}[item]",
|
||||
"",
|
||||
"[dim]Each item in array:[/dim]",
|
||||
)
|
||||
# Get column count from the table
|
||||
num_columns = len(table.columns)
|
||||
|
||||
# Create a row with the array indicator in the first column and empty strings for the rest
|
||||
row_data = [f"{indent_prefix[:-2]}[item]"] + [""] * (num_columns - 1)
|
||||
if num_columns >= 3:
|
||||
row_data[2] = "[dim]Each item in array:[/dim]"
|
||||
table.add_row(*row_data)
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
# Format the type string
|
||||
|
|
@ -269,11 +289,19 @@ def _add_nested_properties(
|
|||
if hasattr(prop_schema, "description") and prop_schema.description:
|
||||
description = prop_schema.description
|
||||
|
||||
table.add_row(
|
||||
f"{indent_prefix}{prop_name}",
|
||||
type_str,
|
||||
f"[dim]{description}[/dim]" if description else "",
|
||||
)
|
||||
# Create row data based on number of columns
|
||||
num_columns = len(table.columns)
|
||||
row_data = [f"{indent_prefix}{prop_name}", type_str]
|
||||
|
||||
# For input parameter tables (4 columns), add empty required column
|
||||
if num_columns == 4:
|
||||
row_data.append("") # Empty "Required" column for nested properties
|
||||
row_data.append(f"[dim]{description}[/dim]" if description else "")
|
||||
# For output tables (3 columns), just add description
|
||||
elif num_columns == 3:
|
||||
row_data.append(f"[dim]{description}[/dim]" if description else "")
|
||||
|
||||
table.add_row(*row_data)
|
||||
|
||||
# Recursively add nested properties if this is a json type with properties
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -483,7 +483,11 @@ def create_output_definition(func: Callable) -> ToolOutput:
|
|||
)
|
||||
|
||||
if hasattr(return_type, "__metadata__"):
|
||||
description = return_type.__metadata__[0] if return_type.__metadata__ else None # type: ignore[assignment]
|
||||
description = (
|
||||
return_type.__metadata__[0]
|
||||
if return_type.__metadata__
|
||||
else "No description provided for return type."
|
||||
)
|
||||
return_type = return_type.__origin__
|
||||
|
||||
# Unwrap Optional types
|
||||
|
|
@ -792,6 +796,7 @@ def extract_properties(type_to_check: type) -> dict[str, WireTypeInfo] | None:
|
|||
field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
# Get wire type info recursively
|
||||
# field_type cannot be None here due to the check above
|
||||
wire_info = get_wire_type_info(field_type)
|
||||
properties[field_name] = wire_info
|
||||
|
||||
|
|
@ -973,7 +978,9 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel]
|
|||
tool_field_info = extract_field_info(param)
|
||||
param_fields = {
|
||||
"default": tool_field_info.default,
|
||||
"description": tool_field_info.description,
|
||||
"description": tool_field_info.description
|
||||
if tool_field_info.description
|
||||
else "No description provided.",
|
||||
# TODO more here?
|
||||
}
|
||||
input_fields[name] = (tool_field_info.field_type, Field(**param_fields))
|
||||
|
|
@ -991,8 +998,14 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901
|
|||
"""
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
output_model_name = f"{snake_to_pascal_case(func.__name__)}Output"
|
||||
|
||||
# If the return annotation is empty, create a model with no fields
|
||||
if return_annotation is inspect.Signature.empty:
|
||||
return create_model(output_model_name)
|
||||
|
||||
# If the return annotation has an __origin__ attribute
|
||||
# and does not have a __metadata__ attribute.
|
||||
# This is the case for TypedDicts.
|
||||
elif hasattr(return_annotation, "__origin__"):
|
||||
if hasattr(return_annotation, "__metadata__"):
|
||||
field_type = return_annotation.__args__[0]
|
||||
|
|
@ -1008,15 +1021,24 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901
|
|||
)
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(typeddict_model, Field(description=str(description))),
|
||||
result=(
|
||||
typeddict_model,
|
||||
Field(
|
||||
description=str(description)
|
||||
if description
|
||||
else "No description provided."
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# If the return annotation has a description, use it
|
||||
if description:
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(field_type, Field(description=str(description))),
|
||||
)
|
||||
# Handle Union types
|
||||
|
||||
# If the return annotation is a Union type
|
||||
origin = return_annotation.__origin__
|
||||
if origin is typing.Union:
|
||||
# For union types, create a model with the first non-None argument
|
||||
|
|
@ -1037,10 +1059,15 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901
|
|||
)
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(arg, Field(description="No description provided.")),
|
||||
result=(
|
||||
arg,
|
||||
Field(description="No description provided."),
|
||||
),
|
||||
)
|
||||
# when the return_annotation has an __origin__ attribute
|
||||
|
||||
# If the return annotation has an __origin__ attribute
|
||||
# and does not have a __metadata__ attribute.
|
||||
# This is the case for TypedDicts.
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(
|
||||
|
|
@ -1049,7 +1076,7 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901
|
|||
),
|
||||
)
|
||||
else:
|
||||
# Check if return type is TypedDict
|
||||
# If the return annotation is a TypedDict
|
||||
if is_typeddict(return_annotation):
|
||||
typeddict_model = create_model_from_typeddict(return_annotation, output_model_name)
|
||||
return create_model(
|
||||
|
|
@ -1060,10 +1087,13 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901
|
|||
),
|
||||
)
|
||||
|
||||
# Handle simple return types (like str)
|
||||
# If the return annotation is a simple type (like str)
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(return_annotation, Field(description="No description provided.")),
|
||||
result=(
|
||||
return_annotation,
|
||||
Field(description="No description provided."),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,14 +25,27 @@ class ToolOutputFactory:
|
|||
|
||||
The executor guarantees that `data` is either a string, a dict, or None.
|
||||
"""
|
||||
value: str | int | float | bool | dict | list[str] | None
|
||||
value: str | int | float | bool | dict | list | None
|
||||
if data is None:
|
||||
value = ""
|
||||
elif hasattr(data, "result"):
|
||||
value = getattr(data, "result", "")
|
||||
result = getattr(data, "result", "")
|
||||
# Handle None result the same way as None data
|
||||
if result is None:
|
||||
value = ""
|
||||
# If the result is a BaseModel (e.g., from TypedDict conversion), convert to dict
|
||||
elif isinstance(result, BaseModel):
|
||||
value = result.model_dump()
|
||||
# If the result is a list, check if it contains BaseModel objects
|
||||
elif isinstance(result, list):
|
||||
value = [
|
||||
item.model_dump() if isinstance(item, BaseModel) else item for item in result
|
||||
]
|
||||
else:
|
||||
value = result
|
||||
elif isinstance(data, BaseModel):
|
||||
value = data.model_dump()
|
||||
elif isinstance(data, (str, int, float, bool, list)):
|
||||
elif isinstance(data, (str, int, float, bool, list, dict)):
|
||||
value = data
|
||||
else:
|
||||
raise ValueError(f"Unsupported data output type: {type(data)}")
|
||||
|
|
|
|||
|
|
@ -418,7 +418,7 @@ class ToolCallRequiresAuthorization(BaseModel):
|
|||
class ToolCallOutput(BaseModel):
|
||||
"""The output of a tool invocation."""
|
||||
|
||||
value: str | int | float | bool | dict | list[str] | None = None
|
||||
value: str | int | float | bool | dict | list | None = None
|
||||
"""The value returned by the tool."""
|
||||
logs: list[ToolCallLog] | None = None
|
||||
"""The logs that occurred during the tool invocation."""
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from arcade_core.executor import ToolExecutor
|
|||
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput, ToolContext
|
||||
from arcade_tdk import tool
|
||||
from arcade_tdk.errors import RetryableToolError, ToolExecutionError
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -45,6 +46,36 @@ def bad_output_error_tool() -> Annotated[str, "output"]:
|
|||
return {"output": "test"}
|
||||
|
||||
|
||||
# TypedDict output tools
|
||||
class ResultDict(TypedDict):
|
||||
"""Result dictionary."""
|
||||
|
||||
status: str
|
||||
count: int
|
||||
items: list[str]
|
||||
|
||||
|
||||
@tool
|
||||
def typeddict_output_tool() -> Annotated[ResultDict, "Returns a TypedDict"]:
|
||||
"""Tool that returns a TypedDict."""
|
||||
return ResultDict(status="success", count=3, items=["a", "b", "c"])
|
||||
|
||||
|
||||
@tool
|
||||
def list_typeddict_output_tool() -> Annotated[list[ResultDict], "Returns list of TypedDict"]:
|
||||
"""Tool that returns a list of TypedDict."""
|
||||
return [
|
||||
ResultDict(status="first", count=1, items=["x"]),
|
||||
ResultDict(status="second", count=2, items=["y", "z"]),
|
||||
]
|
||||
|
||||
|
||||
@tool
|
||||
def dict_output_tool() -> Annotated[dict, "Returns a plain dict"]:
|
||||
"""Tool that returns a plain dict."""
|
||||
return {"key": "value", "number": 42, "nested": {"inner": "data"}}
|
||||
|
||||
|
||||
# ---- Test Driver ----
|
||||
|
||||
catalog = ToolCatalog()
|
||||
|
|
@ -54,6 +85,9 @@ catalog.add_tool(retryable_error_tool, "simple_toolkit")
|
|||
catalog.add_tool(exec_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(unexpected_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(bad_output_error_tool, "simple_toolkit")
|
||||
catalog.add_tool(typeddict_output_tool, "simple_toolkit")
|
||||
catalog.add_tool(list_typeddict_output_tool, "simple_toolkit")
|
||||
catalog.add_tool(dict_output_tool, "simple_toolkit")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -128,6 +162,26 @@ catalog.add_tool(bad_output_error_tool, "simple_toolkit")
|
|||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
typeddict_output_tool,
|
||||
{},
|
||||
ToolCallOutput(value={"status": "success", "count": 3, "items": ["a", "b", "c"]}),
|
||||
),
|
||||
(
|
||||
list_typeddict_output_tool,
|
||||
{},
|
||||
ToolCallOutput(
|
||||
value=[
|
||||
{"status": "first", "count": 1, "items": ["x"]},
|
||||
{"status": "second", "count": 2, "items": ["y", "z"]},
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
dict_output_tool,
|
||||
{},
|
||||
ToolCallOutput(value={"key": "value", "number": 42, "nested": {"inner": "data"}}),
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"simple_tool",
|
||||
|
|
@ -137,6 +191,9 @@ catalog.add_tool(bad_output_error_tool, "simple_toolkit")
|
|||
"unexpected_error_tool",
|
||||
"invalid_input_type",
|
||||
"bad_output_type",
|
||||
"typeddict_output",
|
||||
"list_typeddict_output",
|
||||
"dict_output",
|
||||
],
|
||||
)
|
||||
async def test_tool_executor(tool_func, inputs, expected_output):
|
||||
|
|
|
|||
|
|
@ -35,6 +35,64 @@ def test_success(output_factory, data, expected_value):
|
|||
assert output.error is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, expected_value",
|
||||
[
|
||||
# Dict types (simulating TypedDict at runtime)
|
||||
({"name": "test", "value": 123}, {"name": "test", "value": 123}),
|
||||
({}, {}),
|
||||
({"nested": {"key": "value"}}, {"nested": {"key": "value"}}),
|
||||
# List types
|
||||
(["a", "b", "c"], ["a", "b", "c"]),
|
||||
([1, 2, 3], [1, 2, 3]),
|
||||
([], []),
|
||||
# List of dicts (simulating list[TypedDict])
|
||||
(
|
||||
[{"id": 1, "name": "a"}, {"id": 2, "name": "b"}],
|
||||
[{"id": 1, "name": "a"}, {"id": 2, "name": "b"}],
|
||||
),
|
||||
([{}], [{}]),
|
||||
# Mixed lists
|
||||
([1, "two", 3.0, True], [1, "two", 3.0, True]),
|
||||
],
|
||||
)
|
||||
def test_success_complex_types(output_factory, data, expected_value):
|
||||
"""Test that dict and list types are properly handled by ToolOutputFactory."""
|
||||
data_obj = SampleOutputModel(result=data)
|
||||
output = output_factory.success(data=data_obj)
|
||||
assert output.value == expected_value
|
||||
assert output.error is None
|
||||
|
||||
|
||||
def test_success_with_basemodel_direct(output_factory):
|
||||
"""Test that BaseModel instances are converted to dict via model_dump()."""
|
||||
|
||||
class TestModel(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
model = TestModel(name="test", value=42)
|
||||
output = output_factory.success(data=model)
|
||||
assert output.value == {"name": "test", "value": 42}
|
||||
assert output.error is None
|
||||
|
||||
|
||||
def test_success_raw_dict(output_factory):
|
||||
"""Test that raw dict values (not wrapped in model) are handled correctly."""
|
||||
raw_dict = {"key": "value", "number": 123}
|
||||
output = output_factory.success(data=raw_dict)
|
||||
assert output.value == raw_dict
|
||||
assert output.error is None
|
||||
|
||||
|
||||
def test_success_raw_list(output_factory):
|
||||
"""Test that raw list values (not wrapped in model) are handled correctly."""
|
||||
raw_list = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
output = output_factory.success(data=raw_list)
|
||||
assert output.value == raw_list
|
||||
assert output.error is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message, developer_message",
|
||||
[
|
||||
|
|
|
|||
196
libs/tests/core/test_schema_validation.py
Normal file
196
libs/tests/core/test_schema_validation.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""
|
||||
Tests for ToolCallOutput schema validation with complex types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestToolCallOutputValidation:
|
||||
"""Test ToolCallOutput validation with various data types."""
|
||||
|
||||
def test_basic_types(self):
|
||||
"""Test that basic types are validated correctly."""
|
||||
# String
|
||||
output = ToolCallOutput(value="test string")
|
||||
assert output.value == "test string"
|
||||
|
||||
# Integer
|
||||
output = ToolCallOutput(value=42)
|
||||
assert output.value == 42
|
||||
|
||||
# Float
|
||||
output = ToolCallOutput(value=3.14)
|
||||
assert output.value == 3.14
|
||||
|
||||
# Boolean
|
||||
output = ToolCallOutput(value=True)
|
||||
assert output.value is True
|
||||
|
||||
# None
|
||||
output = ToolCallOutput(value=None)
|
||||
assert output.value is None
|
||||
|
||||
def test_dict_types(self):
|
||||
"""Test that dict types are validated correctly."""
|
||||
# Simple dict
|
||||
output = ToolCallOutput(value={"key": "value"})
|
||||
assert output.value == {"key": "value"}
|
||||
|
||||
# Nested dict
|
||||
output = ToolCallOutput(value={"outer": {"inner": "value"}})
|
||||
assert output.value == {"outer": {"inner": "value"}}
|
||||
|
||||
# Empty dict
|
||||
output = ToolCallOutput(value={})
|
||||
assert output.value == {}
|
||||
|
||||
# Dict with mixed types
|
||||
output = ToolCallOutput(
|
||||
value={
|
||||
"string": "text",
|
||||
"number": 123,
|
||||
"float": 45.6,
|
||||
"bool": True,
|
||||
"null": None,
|
||||
"list": [1, 2, 3],
|
||||
"dict": {"nested": "value"},
|
||||
}
|
||||
)
|
||||
assert output.value["string"] == "text"
|
||||
assert output.value["number"] == 123
|
||||
assert output.value["list"] == [1, 2, 3]
|
||||
|
||||
def test_list_types(self):
|
||||
"""Test that list types are validated correctly."""
|
||||
# List of strings (original type)
|
||||
output = ToolCallOutput(value=["a", "b", "c"])
|
||||
assert output.value == ["a", "b", "c"]
|
||||
|
||||
# List of integers
|
||||
output = ToolCallOutput(value=[1, 2, 3])
|
||||
assert output.value == [1, 2, 3]
|
||||
|
||||
# List of dicts (TypedDict at runtime)
|
||||
output = ToolCallOutput(value=[{"id": 1, "name": "first"}, {"id": 2, "name": "second"}])
|
||||
assert output.value == [{"id": 1, "name": "first"}, {"id": 2, "name": "second"}]
|
||||
|
||||
# Mixed type list
|
||||
output = ToolCallOutput(value=[1, "two", 3.0, True, None, {"key": "value"}])
|
||||
assert len(output.value) == 6
|
||||
assert output.value[5] == {"key": "value"}
|
||||
|
||||
# Empty list
|
||||
output = ToolCallOutput(value=[])
|
||||
assert output.value == []
|
||||
|
||||
# Nested lists
|
||||
output = ToolCallOutput(value=[[1, 2], [3, 4], [5, 6]])
|
||||
assert output.value == [[1, 2], [3, 4], [5, 6]]
|
||||
|
||||
def test_complex_nested_structures(self):
|
||||
"""Test complex nested structures that might come from TypedDict."""
|
||||
# Simulate a complex API response structure
|
||||
complex_data = {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"users": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Alice",
|
||||
"roles": ["admin", "user"],
|
||||
"metadata": {"last_login": "2024-01-01", "active": True},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Bob",
|
||||
"roles": ["user"],
|
||||
"metadata": {"last_login": "2024-01-02", "active": False},
|
||||
},
|
||||
],
|
||||
"total": 2,
|
||||
"page_info": {"page": 1, "per_page": 10, "has_next": False},
|
||||
},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
output = ToolCallOutput(value=complex_data)
|
||||
assert output.value == complex_data
|
||||
assert output.value["data"]["users"][0]["name"] == "Alice"
|
||||
assert output.value["data"]["page_info"]["has_next"] is False
|
||||
|
||||
def test_error_and_logs_with_value(self):
|
||||
"""Test that error and logs can coexist with different value types."""
|
||||
# With dict value and logs
|
||||
output = ToolCallOutput(
|
||||
value={"result": "success"},
|
||||
logs=[
|
||||
ToolCallLog(message="Processing started", level="info"),
|
||||
ToolCallLog(message="Deprecation warning", level="warning", subtype="deprecation"),
|
||||
],
|
||||
)
|
||||
assert output.value == {"result": "success"}
|
||||
assert len(output.logs) == 2
|
||||
|
||||
# With list value and error
|
||||
output = ToolCallOutput(
|
||||
error=ToolCallError(
|
||||
message="Partial failure",
|
||||
developer_message="Some items failed to process",
|
||||
can_retry=True,
|
||||
)
|
||||
)
|
||||
assert output.error.message == "Partial failure"
|
||||
assert output.value is None
|
||||
|
||||
def test_unsupported_types_still_fail(self):
|
||||
"""Test that truly unsupported types still fail validation."""
|
||||
|
||||
# Custom object (not dict, list, or basic type)
|
||||
class CustomClass:
|
||||
def __init__(self):
|
||||
self.data = "test"
|
||||
|
||||
# This should fail because CustomClass instance is not a supported type
|
||||
# Note: This test is about Pydantic validation, not the output factory
|
||||
# The output factory would catch this earlier
|
||||
with pytest.raises(ValidationError):
|
||||
# Directly creating with an unsupported type should fail
|
||||
ToolCallOutput(value=CustomClass())
|
||||
|
||||
def test_very_large_structures(self):
|
||||
"""Test that large structures are handled properly."""
|
||||
# Large list of dicts
|
||||
large_list = [{"id": i, "value": f"item_{i}"} for i in range(1000)]
|
||||
output = ToolCallOutput(value=large_list)
|
||||
assert len(output.value) == 1000
|
||||
assert output.value[500]["id"] == 500
|
||||
|
||||
# Deeply nested structure
|
||||
deep_dict = {"level1": {"level2": {"level3": {"level4": {"level5": "deep_value"}}}}}
|
||||
output = ToolCallOutput(value=deep_dict)
|
||||
assert output.value["level1"]["level2"]["level3"]["level4"]["level5"] == "deep_value"
|
||||
|
||||
def test_json_serializable(self):
|
||||
"""Test that all supported types are JSON serializable."""
|
||||
import json
|
||||
|
||||
test_cases = [
|
||||
{"type": "string"},
|
||||
["list", "of", "strings"],
|
||||
[{"id": 1}, {"id": 2}],
|
||||
{"nested": {"data": [1, 2, 3]}},
|
||||
123,
|
||||
45.6,
|
||||
True,
|
||||
None,
|
||||
]
|
||||
|
||||
for test_value in test_cases:
|
||||
output = ToolCallOutput(value=test_value)
|
||||
# This should not raise an exception
|
||||
json_str = json.dumps(output.model_dump())
|
||||
# And we should be able to parse it back
|
||||
parsed = json.loads(json_str)
|
||||
assert parsed["value"] == test_value
|
||||
250
libs/tests/core/test_typeddict_output_execution.py
Normal file
250
libs/tests/core/test_typeddict_output_execution.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""
|
||||
End-to-end tests for TypedDict output execution through the entire tool pipeline.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import pytest
|
||||
from arcade_core.catalog import ToolCatalog, create_func_models
|
||||
from arcade_core.executor import ToolExecutor
|
||||
from arcade_core.schema import ToolContext
|
||||
from arcade_tdk import tool
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# Define various TypedDict structures for testing
|
||||
class SimpleDict(TypedDict):
|
||||
"""A simple typed dictionary."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class NestedDict(TypedDict):
|
||||
"""A nested typed dictionary."""
|
||||
|
||||
id: int
|
||||
info: SimpleDict
|
||||
tags: list[str]
|
||||
|
||||
|
||||
class OptionalFieldsDict(TypedDict, total=False):
|
||||
"""TypedDict with optional fields."""
|
||||
|
||||
required_field: str
|
||||
optional_field: int
|
||||
|
||||
|
||||
# Define test tools
|
||||
@tool
|
||||
def returns_typeddict() -> Annotated[SimpleDict, "Returns a simple TypedDict"]:
|
||||
"""Tool that returns a TypedDict."""
|
||||
return SimpleDict(name="test", value=42)
|
||||
|
||||
|
||||
@tool
|
||||
def returns_list_of_typeddict() -> Annotated[list[SimpleDict], "Returns list of TypedDict"]:
|
||||
"""Tool that returns a list of TypedDict."""
|
||||
return [
|
||||
SimpleDict(name="item1", value=1),
|
||||
SimpleDict(name="item2", value=2),
|
||||
SimpleDict(name="item3", value=3),
|
||||
]
|
||||
|
||||
|
||||
@tool
|
||||
def returns_optional_typeddict(
|
||||
return_none: Annotated[bool, "Whether to return None"] = False,
|
||||
) -> Annotated[Optional[SimpleDict], "Returns optional TypedDict"]:
|
||||
"""Tool that returns an optional TypedDict."""
|
||||
if return_none:
|
||||
return None
|
||||
return SimpleDict(name="optional", value=100)
|
||||
|
||||
|
||||
@tool
|
||||
def returns_nested_typeddict() -> Annotated[NestedDict, "Returns nested TypedDict"]:
|
||||
"""Tool that returns a nested TypedDict."""
|
||||
return NestedDict(id=1, info=SimpleDict(name="nested", value=99), tags=["tag1", "tag2", "tag3"])
|
||||
|
||||
|
||||
@tool
|
||||
def accepts_and_returns_typeddict(
|
||||
data: Annotated[SimpleDict, "Input TypedDict"],
|
||||
) -> Annotated[SimpleDict, "Modified TypedDict"]:
|
||||
"""Tool that accepts and returns a TypedDict."""
|
||||
return SimpleDict(name=f"Modified: {data['name']}", value=data["value"] * 2)
|
||||
|
||||
|
||||
@tool
|
||||
def returns_dict_list() -> Annotated[list[dict], "Returns list of dicts"]:
|
||||
"""Tool that returns a list of dictionaries including TypedDicts."""
|
||||
return [
|
||||
{"type": "plain", "value": 42},
|
||||
{"name": "string", "data": "test"},
|
||||
SimpleDict(name="typed", value=99),
|
||||
]
|
||||
|
||||
|
||||
class TestTypeDictOutputExecution:
|
||||
"""Test TypedDict outputs through the full execution pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def catalog(self):
|
||||
return ToolCatalog()
|
||||
|
||||
@pytest.fixture
|
||||
def context(self):
|
||||
return ToolContext()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_typeddict(self, catalog, context):
|
||||
"""Test executing a tool that returns a TypedDict."""
|
||||
# Create tool definition
|
||||
definition = catalog.create_tool_definition(
|
||||
returns_typeddict, toolkit_name="test", toolkit_version="1.0.0"
|
||||
)
|
||||
|
||||
# Create models
|
||||
input_model, output_model = create_func_models(returns_typeddict)
|
||||
|
||||
# Execute tool
|
||||
result = await ToolExecutor.run(
|
||||
func=returns_typeddict,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result.error is None
|
||||
assert result.value == {"name": "test", "value": 42}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_list_of_typeddict(self, catalog, context):
|
||||
"""Test executing a tool that returns a list of TypedDict."""
|
||||
definition = catalog.create_tool_definition(
|
||||
returns_list_of_typeddict, toolkit_name="test", toolkit_version="1.0.0"
|
||||
)
|
||||
|
||||
input_model, output_model = create_func_models(returns_list_of_typeddict)
|
||||
|
||||
result = await ToolExecutor.run(
|
||||
func=returns_list_of_typeddict,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.value == [
|
||||
{"name": "item1", "value": 1},
|
||||
{"name": "item2", "value": 2},
|
||||
{"name": "item3", "value": 3},
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_optional_typeddict(self, catalog, context):
|
||||
"""Test executing a tool that returns an optional TypedDict."""
|
||||
definition = catalog.create_tool_definition(
|
||||
returns_optional_typeddict, toolkit_name="test", toolkit_version="1.0.0"
|
||||
)
|
||||
|
||||
input_model, output_model = create_func_models(returns_optional_typeddict)
|
||||
|
||||
# Test returning a value
|
||||
result = await ToolExecutor.run(
|
||||
func=returns_optional_typeddict,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
return_none=False,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.value == {"name": "optional", "value": 100}
|
||||
|
||||
# Test returning None
|
||||
result_none = await ToolExecutor.run(
|
||||
func=returns_optional_typeddict,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
return_none=True,
|
||||
)
|
||||
|
||||
assert result_none.error is None
|
||||
assert result_none.value == "" # None is converted to empty string
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_nested_typeddict(self, catalog, context):
|
||||
"""Test executing a tool that returns a nested TypedDict."""
|
||||
definition = catalog.create_tool_definition(
|
||||
returns_nested_typeddict, toolkit_name="test", toolkit_version="1.0.0"
|
||||
)
|
||||
|
||||
input_model, output_model = create_func_models(returns_nested_typeddict)
|
||||
|
||||
result = await ToolExecutor.run(
|
||||
func=returns_nested_typeddict,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.value == {
|
||||
"id": 1,
|
||||
"info": {"name": "nested", "value": 99},
|
||||
"tags": ["tag1", "tag2", "tag3"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_and_returns_typeddict(self, catalog, context):
|
||||
"""Test executing a tool that accepts and returns TypedDict."""
|
||||
definition = catalog.create_tool_definition(
|
||||
accepts_and_returns_typeddict, toolkit_name="test", toolkit_version="1.0.0"
|
||||
)
|
||||
|
||||
input_model, output_model = create_func_models(accepts_and_returns_typeddict)
|
||||
|
||||
result = await ToolExecutor.run(
|
||||
func=accepts_and_returns_typeddict,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
data={"name": "input", "value": 21},
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.value == {"name": "Modified: input", "value": 42}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_dict_list(self, catalog, context):
|
||||
"""Test executing a tool that returns a list of dicts."""
|
||||
definition = catalog.create_tool_definition(
|
||||
returns_dict_list, toolkit_name="test", toolkit_version="1.0.0"
|
||||
)
|
||||
|
||||
input_model, output_model = create_func_models(returns_dict_list)
|
||||
|
||||
result = await ToolExecutor.run(
|
||||
func=returns_dict_list,
|
||||
definition=definition,
|
||||
input_model=input_model,
|
||||
output_model=output_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.value == [
|
||||
{"type": "plain", "value": 42},
|
||||
{"name": "string", "data": "test"},
|
||||
{"name": "typed", "value": 99}, # TypedDict becomes regular dict at runtime
|
||||
]
|
||||
Loading…
Reference in a new issue