From e188fc6ae9c1ccc4176b3fa6fc6e7dc650d4d836 Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Wed, 27 Aug 2025 16:50:09 -0700 Subject: [PATCH] Improve Typedict and Basemodel support (#523) Improve Pydantic and Typedict support and add a bunch of tets. 1. Fixed the test failure where TypeDict was being serialized as a list of tuples instead of a dict by: - Adding proper handling for BaseModel instances in the output.py file - Converting BaseModel results (from TypeDict conversion) to dicts using model_dump() - Handling lists containing BaseModel objects 2. Fixed None handling to ensure None results are converted to empty strings as expected 3. Updated the schema.py to allow dict and list types in ToolCallOutput.value 4. new tests - TypeDict output execution tests - Output factory tests - Executor tests with TypeDict support - Schema validation tests The key changes were: - In ``arcade_core/output.py``: Added BaseModel conversion logic in the success method - In ``arcade_core/schema.py``: Changed ToolCallOutput.value type from list[str] to list to support complex types TODO - [ ] Confirm engine compatibility without changes made to engine --------- Co-authored-by: Eric Gustin --- libs/arcade-cli/arcade_cli/display.py | 64 +++-- libs/arcade-core/arcade_core/catalog.py | 48 +++- libs/arcade-core/arcade_core/output.py | 19 +- libs/arcade-core/arcade_core/schema.py | 2 +- libs/tests/core/test_executor.py | 57 ++++ libs/tests/core/test_output.py | 58 ++++ libs/tests/core/test_schema_validation.py | 196 ++++++++++++++ .../core/test_typeddict_output_execution.py | 250 ++++++++++++++++++ 8 files changed, 663 insertions(+), 31 deletions(-) create mode 100644 libs/tests/core/test_schema_validation.py create mode 100644 libs/tests/core/test_typeddict_output_execution.py diff --git a/libs/arcade-cli/arcade_cli/display.py b/libs/arcade-cli/arcade_cli/display.py index 0d7e13bd..b661a57b 100644 --- a/libs/arcade-cli/arcade_cli/display.py +++ b/libs/arcade-cli/arcade_cli/display.py @@ -59,19 +59,36 @@ def display_tool_details(tool: ToolDefinition, worker: bool = False) -> None: # inputs_table.add_column("Type", style="magenta") inputs_table.add_column("Required", style="yellow") inputs_table.add_column("Description", style="white") - inputs_table.add_column("Default", style="blue") + for param in inputs: - # Since InputParameter does not have a default field, we use "N/A" - default_value = "N/A" - if param.value_schema.enum: - default_value = f"One of {param.value_schema.enum}" + # Format the type string properly + type_str = _format_type_string(param.value_schema) + + # Add the main parameter row inputs_table.add_row( param.name, - param.value_schema.val_type, + type_str, str(param.required), param.description or "", - default_value, ) + + # If this is a json type with properties, show them + if ( + param.value_schema.val_type == "json" + and hasattr(param.value_schema, "properties") + and param.value_schema.properties + ): + _add_nested_properties(inputs_table, param.value_schema.properties, indent=1) + # Handle arrays with inner properties + elif ( + param.value_schema.val_type == "array" + and hasattr(param.value_schema, "inner_properties") + and param.value_schema.inner_properties + ): + _add_nested_properties( + inputs_table, param.value_schema.inner_properties, indent=1, is_array_item=True + ) + inputs_panel = Panel( inputs_table, title="Input Parameters", @@ -241,7 +258,7 @@ def _add_nested_properties( is_array_item: bool = False, ) -> None: """ - Recursively add nested properties to the output table. + Recursively add nested properties to the table. Args: table: The Rich table to add rows to @@ -253,11 +270,14 @@ def _add_nested_properties( # Show array item indicator if needed if is_array_item and indent > 0: - table.add_row( - f"{indent_prefix[:-2]}[item]", - "", - "[dim]Each item in array:[/dim]", - ) + # Get column count from the table + num_columns = len(table.columns) + + # Create a row with the array indicator in the first column and empty strings for the rest + row_data = [f"{indent_prefix[:-2]}[item]"] + [""] * (num_columns - 1) + if num_columns >= 3: + row_data[2] = "[dim]Each item in array:[/dim]" + table.add_row(*row_data) for prop_name, prop_schema in properties.items(): # Format the type string @@ -269,11 +289,19 @@ def _add_nested_properties( if hasattr(prop_schema, "description") and prop_schema.description: description = prop_schema.description - table.add_row( - f"{indent_prefix}{prop_name}", - type_str, - f"[dim]{description}[/dim]" if description else "", - ) + # Create row data based on number of columns + num_columns = len(table.columns) + row_data = [f"{indent_prefix}{prop_name}", type_str] + + # For input parameter tables (4 columns), add empty required column + if num_columns == 4: + row_data.append("") # Empty "Required" column for nested properties + row_data.append(f"[dim]{description}[/dim]" if description else "") + # For output tables (3 columns), just add description + elif num_columns == 3: + row_data.append(f"[dim]{description}[/dim]" if description else "") + + table.add_row(*row_data) # Recursively add nested properties if this is a json type with properties if ( diff --git a/libs/arcade-core/arcade_core/catalog.py b/libs/arcade-core/arcade_core/catalog.py index 80c1e364..0893915e 100644 --- a/libs/arcade-core/arcade_core/catalog.py +++ b/libs/arcade-core/arcade_core/catalog.py @@ -483,7 +483,11 @@ def create_output_definition(func: Callable) -> ToolOutput: ) if hasattr(return_type, "__metadata__"): - description = return_type.__metadata__[0] if return_type.__metadata__ else None # type: ignore[assignment] + description = ( + return_type.__metadata__[0] + if return_type.__metadata__ + else "No description provided for return type." + ) return_type = return_type.__origin__ # Unwrap Optional types @@ -792,6 +796,7 @@ def extract_properties(type_to_check: type) -> dict[str, WireTypeInfo] | None: field_type = next(arg for arg in get_args(field_type) if arg is not type(None)) # Get wire type info recursively + # field_type cannot be None here due to the check above wire_info = get_wire_type_info(field_type) properties[field_name] = wire_info @@ -973,7 +978,9 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel] tool_field_info = extract_field_info(param) param_fields = { "default": tool_field_info.default, - "description": tool_field_info.description, + "description": tool_field_info.description + if tool_field_info.description + else "No description provided.", # TODO more here? } input_fields[name] = (tool_field_info.field_type, Field(**param_fields)) @@ -991,8 +998,14 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 """ return_annotation = inspect.signature(func).return_annotation output_model_name = f"{snake_to_pascal_case(func.__name__)}Output" + + # If the return annotation is empty, create a model with no fields if return_annotation is inspect.Signature.empty: return create_model(output_model_name) + + # If the return annotation has an __origin__ attribute + # and does not have a __metadata__ attribute. + # This is the case for TypedDicts. elif hasattr(return_annotation, "__origin__"): if hasattr(return_annotation, "__metadata__"): field_type = return_annotation.__args__[0] @@ -1008,15 +1021,24 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 ) return create_model( output_model_name, - result=(typeddict_model, Field(description=str(description))), + result=( + typeddict_model, + Field( + description=str(description) + if description + else "No description provided." + ), + ), ) + # If the return annotation has a description, use it if description: return create_model( output_model_name, result=(field_type, Field(description=str(description))), ) - # Handle Union types + + # If the return annotation is a Union type origin = return_annotation.__origin__ if origin is typing.Union: # For union types, create a model with the first non-None argument @@ -1037,10 +1059,15 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 ) return create_model( output_model_name, - result=(arg, Field(description="No description provided.")), + result=( + arg, + Field(description="No description provided."), + ), ) - # when the return_annotation has an __origin__ attribute + + # If the return annotation has an __origin__ attribute # and does not have a __metadata__ attribute. + # This is the case for TypedDicts. return create_model( output_model_name, result=( @@ -1049,7 +1076,7 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 ), ) else: - # Check if return type is TypedDict + # If the return annotation is a TypedDict if is_typeddict(return_annotation): typeddict_model = create_model_from_typeddict(return_annotation, output_model_name) return create_model( @@ -1060,10 +1087,13 @@ def determine_output_model(func: Callable) -> type[BaseModel]: # noqa: C901 ), ) - # Handle simple return types (like str) + # If the return annotation is a simple type (like str) return create_model( output_model_name, - result=(return_annotation, Field(description="No description provided.")), + result=( + return_annotation, + Field(description="No description provided."), + ), ) diff --git a/libs/arcade-core/arcade_core/output.py b/libs/arcade-core/arcade_core/output.py index 0b9b45b7..bb32e3b4 100644 --- a/libs/arcade-core/arcade_core/output.py +++ b/libs/arcade-core/arcade_core/output.py @@ -25,14 +25,27 @@ class ToolOutputFactory: The executor guarantees that `data` is either a string, a dict, or None. """ - value: str | int | float | bool | dict | list[str] | None + value: str | int | float | bool | dict | list | None if data is None: value = "" elif hasattr(data, "result"): - value = getattr(data, "result", "") + result = getattr(data, "result", "") + # Handle None result the same way as None data + if result is None: + value = "" + # If the result is a BaseModel (e.g., from TypedDict conversion), convert to dict + elif isinstance(result, BaseModel): + value = result.model_dump() + # If the result is a list, check if it contains BaseModel objects + elif isinstance(result, list): + value = [ + item.model_dump() if isinstance(item, BaseModel) else item for item in result + ] + else: + value = result elif isinstance(data, BaseModel): value = data.model_dump() - elif isinstance(data, (str, int, float, bool, list)): + elif isinstance(data, (str, int, float, bool, list, dict)): value = data else: raise ValueError(f"Unsupported data output type: {type(data)}") diff --git a/libs/arcade-core/arcade_core/schema.py b/libs/arcade-core/arcade_core/schema.py index 029dc198..28708e8c 100644 --- a/libs/arcade-core/arcade_core/schema.py +++ b/libs/arcade-core/arcade_core/schema.py @@ -418,7 +418,7 @@ class ToolCallRequiresAuthorization(BaseModel): class ToolCallOutput(BaseModel): """The output of a tool invocation.""" - value: str | int | float | bool | dict | list[str] | None = None + value: str | int | float | bool | dict | list | None = None """The value returned by the tool.""" logs: list[ToolCallLog] | None = None """The logs that occurred during the tool invocation.""" diff --git a/libs/tests/core/test_executor.py b/libs/tests/core/test_executor.py index c5136652..330d98c6 100644 --- a/libs/tests/core/test_executor.py +++ b/libs/tests/core/test_executor.py @@ -6,6 +6,7 @@ from arcade_core.executor import ToolExecutor from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput, ToolContext from arcade_tdk import tool from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from typing_extensions import TypedDict @tool @@ -45,6 +46,36 @@ def bad_output_error_tool() -> Annotated[str, "output"]: return {"output": "test"} +# TypedDict output tools +class ResultDict(TypedDict): + """Result dictionary.""" + + status: str + count: int + items: list[str] + + +@tool +def typeddict_output_tool() -> Annotated[ResultDict, "Returns a TypedDict"]: + """Tool that returns a TypedDict.""" + return ResultDict(status="success", count=3, items=["a", "b", "c"]) + + +@tool +def list_typeddict_output_tool() -> Annotated[list[ResultDict], "Returns list of TypedDict"]: + """Tool that returns a list of TypedDict.""" + return [ + ResultDict(status="first", count=1, items=["x"]), + ResultDict(status="second", count=2, items=["y", "z"]), + ] + + +@tool +def dict_output_tool() -> Annotated[dict, "Returns a plain dict"]: + """Tool that returns a plain dict.""" + return {"key": "value", "number": 42, "nested": {"inner": "data"}} + + # ---- Test Driver ---- catalog = ToolCatalog() @@ -54,6 +85,9 @@ catalog.add_tool(retryable_error_tool, "simple_toolkit") catalog.add_tool(exec_error_tool, "simple_toolkit") catalog.add_tool(unexpected_error_tool, "simple_toolkit") catalog.add_tool(bad_output_error_tool, "simple_toolkit") +catalog.add_tool(typeddict_output_tool, "simple_toolkit") +catalog.add_tool(list_typeddict_output_tool, "simple_toolkit") +catalog.add_tool(dict_output_tool, "simple_toolkit") @pytest.mark.asyncio @@ -128,6 +162,26 @@ catalog.add_tool(bad_output_error_tool, "simple_toolkit") ) ), ), + ( + typeddict_output_tool, + {}, + ToolCallOutput(value={"status": "success", "count": 3, "items": ["a", "b", "c"]}), + ), + ( + list_typeddict_output_tool, + {}, + ToolCallOutput( + value=[ + {"status": "first", "count": 1, "items": ["x"]}, + {"status": "second", "count": 2, "items": ["y", "z"]}, + ] + ), + ), + ( + dict_output_tool, + {}, + ToolCallOutput(value={"key": "value", "number": 42, "nested": {"inner": "data"}}), + ), ], ids=[ "simple_tool", @@ -137,6 +191,9 @@ catalog.add_tool(bad_output_error_tool, "simple_toolkit") "unexpected_error_tool", "invalid_input_type", "bad_output_type", + "typeddict_output", + "list_typeddict_output", + "dict_output", ], ) async def test_tool_executor(tool_func, inputs, expected_output): diff --git a/libs/tests/core/test_output.py b/libs/tests/core/test_output.py index bca1c18d..7b390144 100644 --- a/libs/tests/core/test_output.py +++ b/libs/tests/core/test_output.py @@ -35,6 +35,64 @@ def test_success(output_factory, data, expected_value): assert output.error is None +@pytest.mark.parametrize( + "data, expected_value", + [ + # Dict types (simulating TypedDict at runtime) + ({"name": "test", "value": 123}, {"name": "test", "value": 123}), + ({}, {}), + ({"nested": {"key": "value"}}, {"nested": {"key": "value"}}), + # List types + (["a", "b", "c"], ["a", "b", "c"]), + ([1, 2, 3], [1, 2, 3]), + ([], []), + # List of dicts (simulating list[TypedDict]) + ( + [{"id": 1, "name": "a"}, {"id": 2, "name": "b"}], + [{"id": 1, "name": "a"}, {"id": 2, "name": "b"}], + ), + ([{}], [{}]), + # Mixed lists + ([1, "two", 3.0, True], [1, "two", 3.0, True]), + ], +) +def test_success_complex_types(output_factory, data, expected_value): + """Test that dict and list types are properly handled by ToolOutputFactory.""" + data_obj = SampleOutputModel(result=data) + output = output_factory.success(data=data_obj) + assert output.value == expected_value + assert output.error is None + + +def test_success_with_basemodel_direct(output_factory): + """Test that BaseModel instances are converted to dict via model_dump().""" + + class TestModel(BaseModel): + name: str + value: int + + model = TestModel(name="test", value=42) + output = output_factory.success(data=model) + assert output.value == {"name": "test", "value": 42} + assert output.error is None + + +def test_success_raw_dict(output_factory): + """Test that raw dict values (not wrapped in model) are handled correctly.""" + raw_dict = {"key": "value", "number": 123} + output = output_factory.success(data=raw_dict) + assert output.value == raw_dict + assert output.error is None + + +def test_success_raw_list(output_factory): + """Test that raw list values (not wrapped in model) are handled correctly.""" + raw_list = [{"id": 1}, {"id": 2}, {"id": 3}] + output = output_factory.success(data=raw_list) + assert output.value == raw_list + assert output.error is None + + @pytest.mark.parametrize( "message, developer_message", [ diff --git a/libs/tests/core/test_schema_validation.py b/libs/tests/core/test_schema_validation.py new file mode 100644 index 00000000..d0070419 --- /dev/null +++ b/libs/tests/core/test_schema_validation.py @@ -0,0 +1,196 @@ +""" +Tests for ToolCallOutput schema validation with complex types. +""" + +import pytest +from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput +from pydantic import ValidationError + + +class TestToolCallOutputValidation: + """Test ToolCallOutput validation with various data types.""" + + def test_basic_types(self): + """Test that basic types are validated correctly.""" + # String + output = ToolCallOutput(value="test string") + assert output.value == "test string" + + # Integer + output = ToolCallOutput(value=42) + assert output.value == 42 + + # Float + output = ToolCallOutput(value=3.14) + assert output.value == 3.14 + + # Boolean + output = ToolCallOutput(value=True) + assert output.value is True + + # None + output = ToolCallOutput(value=None) + assert output.value is None + + def test_dict_types(self): + """Test that dict types are validated correctly.""" + # Simple dict + output = ToolCallOutput(value={"key": "value"}) + assert output.value == {"key": "value"} + + # Nested dict + output = ToolCallOutput(value={"outer": {"inner": "value"}}) + assert output.value == {"outer": {"inner": "value"}} + + # Empty dict + output = ToolCallOutput(value={}) + assert output.value == {} + + # Dict with mixed types + output = ToolCallOutput( + value={ + "string": "text", + "number": 123, + "float": 45.6, + "bool": True, + "null": None, + "list": [1, 2, 3], + "dict": {"nested": "value"}, + } + ) + assert output.value["string"] == "text" + assert output.value["number"] == 123 + assert output.value["list"] == [1, 2, 3] + + def test_list_types(self): + """Test that list types are validated correctly.""" + # List of strings (original type) + output = ToolCallOutput(value=["a", "b", "c"]) + assert output.value == ["a", "b", "c"] + + # List of integers + output = ToolCallOutput(value=[1, 2, 3]) + assert output.value == [1, 2, 3] + + # List of dicts (TypedDict at runtime) + output = ToolCallOutput(value=[{"id": 1, "name": "first"}, {"id": 2, "name": "second"}]) + assert output.value == [{"id": 1, "name": "first"}, {"id": 2, "name": "second"}] + + # Mixed type list + output = ToolCallOutput(value=[1, "two", 3.0, True, None, {"key": "value"}]) + assert len(output.value) == 6 + assert output.value[5] == {"key": "value"} + + # Empty list + output = ToolCallOutput(value=[]) + assert output.value == [] + + # Nested lists + output = ToolCallOutput(value=[[1, 2], [3, 4], [5, 6]]) + assert output.value == [[1, 2], [3, 4], [5, 6]] + + def test_complex_nested_structures(self): + """Test complex nested structures that might come from TypedDict.""" + # Simulate a complex API response structure + complex_data = { + "status": "success", + "data": { + "users": [ + { + "id": 1, + "name": "Alice", + "roles": ["admin", "user"], + "metadata": {"last_login": "2024-01-01", "active": True}, + }, + { + "id": 2, + "name": "Bob", + "roles": ["user"], + "metadata": {"last_login": "2024-01-02", "active": False}, + }, + ], + "total": 2, + "page_info": {"page": 1, "per_page": 10, "has_next": False}, + }, + "errors": [], + } + + output = ToolCallOutput(value=complex_data) + assert output.value == complex_data + assert output.value["data"]["users"][0]["name"] == "Alice" + assert output.value["data"]["page_info"]["has_next"] is False + + def test_error_and_logs_with_value(self): + """Test that error and logs can coexist with different value types.""" + # With dict value and logs + output = ToolCallOutput( + value={"result": "success"}, + logs=[ + ToolCallLog(message="Processing started", level="info"), + ToolCallLog(message="Deprecation warning", level="warning", subtype="deprecation"), + ], + ) + assert output.value == {"result": "success"} + assert len(output.logs) == 2 + + # With list value and error + output = ToolCallOutput( + error=ToolCallError( + message="Partial failure", + developer_message="Some items failed to process", + can_retry=True, + ) + ) + assert output.error.message == "Partial failure" + assert output.value is None + + def test_unsupported_types_still_fail(self): + """Test that truly unsupported types still fail validation.""" + + # Custom object (not dict, list, or basic type) + class CustomClass: + def __init__(self): + self.data = "test" + + # This should fail because CustomClass instance is not a supported type + # Note: This test is about Pydantic validation, not the output factory + # The output factory would catch this earlier + with pytest.raises(ValidationError): + # Directly creating with an unsupported type should fail + ToolCallOutput(value=CustomClass()) + + def test_very_large_structures(self): + """Test that large structures are handled properly.""" + # Large list of dicts + large_list = [{"id": i, "value": f"item_{i}"} for i in range(1000)] + output = ToolCallOutput(value=large_list) + assert len(output.value) == 1000 + assert output.value[500]["id"] == 500 + + # Deeply nested structure + deep_dict = {"level1": {"level2": {"level3": {"level4": {"level5": "deep_value"}}}}} + output = ToolCallOutput(value=deep_dict) + assert output.value["level1"]["level2"]["level3"]["level4"]["level5"] == "deep_value" + + def test_json_serializable(self): + """Test that all supported types are JSON serializable.""" + import json + + test_cases = [ + {"type": "string"}, + ["list", "of", "strings"], + [{"id": 1}, {"id": 2}], + {"nested": {"data": [1, 2, 3]}}, + 123, + 45.6, + True, + None, + ] + + for test_value in test_cases: + output = ToolCallOutput(value=test_value) + # This should not raise an exception + json_str = json.dumps(output.model_dump()) + # And we should be able to parse it back + parsed = json.loads(json_str) + assert parsed["value"] == test_value diff --git a/libs/tests/core/test_typeddict_output_execution.py b/libs/tests/core/test_typeddict_output_execution.py new file mode 100644 index 00000000..8f74c4ae --- /dev/null +++ b/libs/tests/core/test_typeddict_output_execution.py @@ -0,0 +1,250 @@ +""" +End-to-end tests for TypedDict output execution through the entire tool pipeline. +""" + +from typing import Annotated, Optional + +import pytest +from arcade_core.catalog import ToolCatalog, create_func_models +from arcade_core.executor import ToolExecutor +from arcade_core.schema import ToolContext +from arcade_tdk import tool +from typing_extensions import TypedDict + + +# Define various TypedDict structures for testing +class SimpleDict(TypedDict): + """A simple typed dictionary.""" + + name: str + value: int + + +class NestedDict(TypedDict): + """A nested typed dictionary.""" + + id: int + info: SimpleDict + tags: list[str] + + +class OptionalFieldsDict(TypedDict, total=False): + """TypedDict with optional fields.""" + + required_field: str + optional_field: int + + +# Define test tools +@tool +def returns_typeddict() -> Annotated[SimpleDict, "Returns a simple TypedDict"]: + """Tool that returns a TypedDict.""" + return SimpleDict(name="test", value=42) + + +@tool +def returns_list_of_typeddict() -> Annotated[list[SimpleDict], "Returns list of TypedDict"]: + """Tool that returns a list of TypedDict.""" + return [ + SimpleDict(name="item1", value=1), + SimpleDict(name="item2", value=2), + SimpleDict(name="item3", value=3), + ] + + +@tool +def returns_optional_typeddict( + return_none: Annotated[bool, "Whether to return None"] = False, +) -> Annotated[Optional[SimpleDict], "Returns optional TypedDict"]: + """Tool that returns an optional TypedDict.""" + if return_none: + return None + return SimpleDict(name="optional", value=100) + + +@tool +def returns_nested_typeddict() -> Annotated[NestedDict, "Returns nested TypedDict"]: + """Tool that returns a nested TypedDict.""" + return NestedDict(id=1, info=SimpleDict(name="nested", value=99), tags=["tag1", "tag2", "tag3"]) + + +@tool +def accepts_and_returns_typeddict( + data: Annotated[SimpleDict, "Input TypedDict"], +) -> Annotated[SimpleDict, "Modified TypedDict"]: + """Tool that accepts and returns a TypedDict.""" + return SimpleDict(name=f"Modified: {data['name']}", value=data["value"] * 2) + + +@tool +def returns_dict_list() -> Annotated[list[dict], "Returns list of dicts"]: + """Tool that returns a list of dictionaries including TypedDicts.""" + return [ + {"type": "plain", "value": 42}, + {"name": "string", "data": "test"}, + SimpleDict(name="typed", value=99), + ] + + +class TestTypeDictOutputExecution: + """Test TypedDict outputs through the full execution pipeline.""" + + @pytest.fixture + def catalog(self): + return ToolCatalog() + + @pytest.fixture + def context(self): + return ToolContext() + + @pytest.mark.asyncio + async def test_returns_typeddict(self, catalog, context): + """Test executing a tool that returns a TypedDict.""" + # Create tool definition + definition = catalog.create_tool_definition( + returns_typeddict, toolkit_name="test", toolkit_version="1.0.0" + ) + + # Create models + input_model, output_model = create_func_models(returns_typeddict) + + # Execute tool + result = await ToolExecutor.run( + func=returns_typeddict, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + ) + + # Verify result + assert result.error is None + assert result.value == {"name": "test", "value": 42} + + @pytest.mark.asyncio + async def test_returns_list_of_typeddict(self, catalog, context): + """Test executing a tool that returns a list of TypedDict.""" + definition = catalog.create_tool_definition( + returns_list_of_typeddict, toolkit_name="test", toolkit_version="1.0.0" + ) + + input_model, output_model = create_func_models(returns_list_of_typeddict) + + result = await ToolExecutor.run( + func=returns_list_of_typeddict, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + ) + + assert result.error is None + assert result.value == [ + {"name": "item1", "value": 1}, + {"name": "item2", "value": 2}, + {"name": "item3", "value": 3}, + ] + + @pytest.mark.asyncio + async def test_returns_optional_typeddict(self, catalog, context): + """Test executing a tool that returns an optional TypedDict.""" + definition = catalog.create_tool_definition( + returns_optional_typeddict, toolkit_name="test", toolkit_version="1.0.0" + ) + + input_model, output_model = create_func_models(returns_optional_typeddict) + + # Test returning a value + result = await ToolExecutor.run( + func=returns_optional_typeddict, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + return_none=False, + ) + + assert result.error is None + assert result.value == {"name": "optional", "value": 100} + + # Test returning None + result_none = await ToolExecutor.run( + func=returns_optional_typeddict, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + return_none=True, + ) + + assert result_none.error is None + assert result_none.value == "" # None is converted to empty string + + @pytest.mark.asyncio + async def test_returns_nested_typeddict(self, catalog, context): + """Test executing a tool that returns a nested TypedDict.""" + definition = catalog.create_tool_definition( + returns_nested_typeddict, toolkit_name="test", toolkit_version="1.0.0" + ) + + input_model, output_model = create_func_models(returns_nested_typeddict) + + result = await ToolExecutor.run( + func=returns_nested_typeddict, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + ) + + assert result.error is None + assert result.value == { + "id": 1, + "info": {"name": "nested", "value": 99}, + "tags": ["tag1", "tag2", "tag3"], + } + + @pytest.mark.asyncio + async def test_accepts_and_returns_typeddict(self, catalog, context): + """Test executing a tool that accepts and returns TypedDict.""" + definition = catalog.create_tool_definition( + accepts_and_returns_typeddict, toolkit_name="test", toolkit_version="1.0.0" + ) + + input_model, output_model = create_func_models(accepts_and_returns_typeddict) + + result = await ToolExecutor.run( + func=accepts_and_returns_typeddict, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + data={"name": "input", "value": 21}, + ) + + assert result.error is None + assert result.value == {"name": "Modified: input", "value": 42} + + @pytest.mark.asyncio + async def test_returns_dict_list(self, catalog, context): + """Test executing a tool that returns a list of dicts.""" + definition = catalog.create_tool_definition( + returns_dict_list, toolkit_name="test", toolkit_version="1.0.0" + ) + + input_model, output_model = create_func_models(returns_dict_list) + + result = await ToolExecutor.run( + func=returns_dict_list, + definition=definition, + input_model=input_model, + output_model=output_model, + context=context, + ) + + assert result.error is None + assert result.value == [ + {"type": "plain", "value": 42}, + {"name": "string", "data": "test"}, + {"name": "typed", "value": 99}, # TypedDict becomes regular dict at runtime + ]