Improve Pydantic and Typedict support and add a bunch of tets.
1. Fixed the test failure where TypeDict was being serialized as a list
of tuples instead of a dict by:
- Adding proper handling for BaseModel instances in the output.py file
- Converting BaseModel results (from TypeDict conversion) to dicts using
model_dump()
- Handling lists containing BaseModel objects
2. Fixed None handling to ensure None results are converted to empty
strings as expected
3. Updated the schema.py to allow dict and list types in
ToolCallOutput.value
4. new tests
- TypeDict output execution tests
- Output factory tests
- Executor tests with TypeDict support
- Schema validation tests
The key changes were:
- In ``arcade_core/output.py``: Added BaseModel conversion logic in the
success method
- In ``arcade_core/schema.py``: Changed ToolCallOutput.value type from
list[str] to list to support complex types
TODO
- [ ] Confirm engine compatibility without changes made to engine
---------
Co-authored-by: Eric Gustin <eric@arcade.dev>
132 lines
4.1 KiB
Python
132 lines
4.1 KiB
Python
from typing import Any
|
|
|
|
import pytest
|
|
from arcade_core.output import ToolOutputFactory
|
|
from pydantic import BaseModel
|
|
|
|
|
|
@pytest.fixture
|
|
def output_factory():
|
|
return ToolOutputFactory()
|
|
|
|
|
|
class SampleOutputModel(BaseModel):
|
|
result: Any
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"data, expected_value",
|
|
[
|
|
(None, ""),
|
|
("success", "success"),
|
|
("", ""),
|
|
(None, ""),
|
|
(123, 123),
|
|
(0, 0),
|
|
(123.45, 123.45),
|
|
(True, True),
|
|
(False, False),
|
|
],
|
|
)
|
|
def test_success(output_factory, data, expected_value):
|
|
data_obj = SampleOutputModel(result=data) if data is not None else None
|
|
output = output_factory.success(data=data_obj)
|
|
assert output.value == expected_value
|
|
assert output.error is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"data, expected_value",
|
|
[
|
|
# Dict types (simulating TypedDict at runtime)
|
|
({"name": "test", "value": 123}, {"name": "test", "value": 123}),
|
|
({}, {}),
|
|
({"nested": {"key": "value"}}, {"nested": {"key": "value"}}),
|
|
# List types
|
|
(["a", "b", "c"], ["a", "b", "c"]),
|
|
([1, 2, 3], [1, 2, 3]),
|
|
([], []),
|
|
# List of dicts (simulating list[TypedDict])
|
|
(
|
|
[{"id": 1, "name": "a"}, {"id": 2, "name": "b"}],
|
|
[{"id": 1, "name": "a"}, {"id": 2, "name": "b"}],
|
|
),
|
|
([{}], [{}]),
|
|
# Mixed lists
|
|
([1, "two", 3.0, True], [1, "two", 3.0, True]),
|
|
],
|
|
)
|
|
def test_success_complex_types(output_factory, data, expected_value):
|
|
"""Test that dict and list types are properly handled by ToolOutputFactory."""
|
|
data_obj = SampleOutputModel(result=data)
|
|
output = output_factory.success(data=data_obj)
|
|
assert output.value == expected_value
|
|
assert output.error is None
|
|
|
|
|
|
def test_success_with_basemodel_direct(output_factory):
|
|
"""Test that BaseModel instances are converted to dict via model_dump()."""
|
|
|
|
class TestModel(BaseModel):
|
|
name: str
|
|
value: int
|
|
|
|
model = TestModel(name="test", value=42)
|
|
output = output_factory.success(data=model)
|
|
assert output.value == {"name": "test", "value": 42}
|
|
assert output.error is None
|
|
|
|
|
|
def test_success_raw_dict(output_factory):
|
|
"""Test that raw dict values (not wrapped in model) are handled correctly."""
|
|
raw_dict = {"key": "value", "number": 123}
|
|
output = output_factory.success(data=raw_dict)
|
|
assert output.value == raw_dict
|
|
assert output.error is None
|
|
|
|
|
|
def test_success_raw_list(output_factory):
|
|
"""Test that raw list values (not wrapped in model) are handled correctly."""
|
|
raw_list = [{"id": 1}, {"id": 2}, {"id": 3}]
|
|
output = output_factory.success(data=raw_list)
|
|
assert output.value == raw_list
|
|
assert output.error is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"message, developer_message",
|
|
[
|
|
("Error occurred", None),
|
|
("Error occurred", "Detailed error message"),
|
|
],
|
|
)
|
|
def test_fail(output_factory, message, developer_message):
|
|
output = output_factory.fail(message=message, developer_message=developer_message)
|
|
assert output.error is not None
|
|
assert output.error.message == message
|
|
assert output.error.developer_message == developer_message
|
|
assert output.error.can_retry is False
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"message, developer_message, additional_prompt_content, retry_after_ms",
|
|
[
|
|
("Retry error", None, None, None),
|
|
("Retry error", "Retrying", "Please try again with this additional data: foobar", 1000),
|
|
],
|
|
)
|
|
def test_fail_retry(
|
|
output_factory, message, developer_message, additional_prompt_content, retry_after_ms
|
|
):
|
|
output = output_factory.fail_retry(
|
|
message=message,
|
|
developer_message=developer_message,
|
|
additional_prompt_content=additional_prompt_content,
|
|
retry_after_ms=retry_after_ms,
|
|
)
|
|
assert output.error is not None
|
|
assert output.error.message == message
|
|
assert output.error.developer_message == developer_message
|
|
assert output.error.can_retry is True
|
|
assert output.error.additional_prompt_content == additional_prompt_content
|
|
assert output.error.retry_after_ms == retry_after_ms
|