Improve Pydantic and Typedict support and add a bunch of tets.
1. Fixed the test failure where TypeDict was being serialized as a list
of tuples instead of a dict by:
- Adding proper handling for BaseModel instances in the output.py file
- Converting BaseModel results (from TypeDict conversion) to dicts using
model_dump()
- Handling lists containing BaseModel objects
2. Fixed None handling to ensure None results are converted to empty
strings as expected
3. Updated the schema.py to allow dict and list types in
ToolCallOutput.value
4. new tests
- TypeDict output execution tests
- Output factory tests
- Executor tests with TypeDict support
- Schema validation tests
The key changes were:
- In ``arcade_core/output.py``: Added BaseModel conversion logic in the
success method
- In ``arcade_core/schema.py``: Changed ToolCallOutput.value type from
list[str] to list to support complex types
TODO
- [ ] Confirm engine compatibility without changes made to engine
---------
Co-authored-by: Eric Gustin <eric@arcade.dev>
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
from typing import TypeVar
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from arcade_core.schema import ToolCallError, ToolCallLog, ToolCallOutput
|
|
from arcade_core.utils import coerce_empty_list_to_none
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class ToolOutputFactory:
|
|
"""
|
|
Singleton pattern for unified return method from tools.
|
|
"""
|
|
|
|
def success(
|
|
self,
|
|
*,
|
|
data: T | None = None,
|
|
logs: list[ToolCallLog] | None = None,
|
|
) -> ToolCallOutput:
|
|
# Extract the result value
|
|
"""
|
|
Extracts the result value for the tool output.
|
|
|
|
The executor guarantees that `data` is either a string, a dict, or None.
|
|
"""
|
|
value: str | int | float | bool | dict | list | None
|
|
if data is None:
|
|
value = ""
|
|
elif hasattr(data, "result"):
|
|
result = getattr(data, "result", "")
|
|
# Handle None result the same way as None data
|
|
if result is None:
|
|
value = ""
|
|
# If the result is a BaseModel (e.g., from TypedDict conversion), convert to dict
|
|
elif isinstance(result, BaseModel):
|
|
value = result.model_dump()
|
|
# If the result is a list, check if it contains BaseModel objects
|
|
elif isinstance(result, list):
|
|
value = [
|
|
item.model_dump() if isinstance(item, BaseModel) else item for item in result
|
|
]
|
|
else:
|
|
value = result
|
|
elif isinstance(data, BaseModel):
|
|
value = data.model_dump()
|
|
elif isinstance(data, (str, int, float, bool, list, dict)):
|
|
value = data
|
|
else:
|
|
raise ValueError(f"Unsupported data output type: {type(data)}")
|
|
|
|
logs = coerce_empty_list_to_none(logs)
|
|
return ToolCallOutput(
|
|
value=value,
|
|
logs=logs,
|
|
)
|
|
|
|
def fail(
|
|
self,
|
|
*,
|
|
message: str,
|
|
developer_message: str | None = None,
|
|
traceback_info: str | None = None,
|
|
logs: list[ToolCallLog] | None = None,
|
|
) -> ToolCallOutput:
|
|
return ToolCallOutput(
|
|
error=ToolCallError(
|
|
message=message,
|
|
developer_message=developer_message,
|
|
can_retry=False,
|
|
traceback_info=traceback_info,
|
|
),
|
|
logs=coerce_empty_list_to_none(logs),
|
|
)
|
|
|
|
def fail_retry(
|
|
self,
|
|
*,
|
|
message: str,
|
|
developer_message: str | None = None,
|
|
additional_prompt_content: str | None = None,
|
|
retry_after_ms: int | None = None,
|
|
traceback_info: str | None = None,
|
|
logs: list[ToolCallLog] | None = None,
|
|
) -> ToolCallOutput:
|
|
return ToolCallOutput(
|
|
error=ToolCallError(
|
|
message=message,
|
|
developer_message=developer_message,
|
|
can_retry=True,
|
|
additional_prompt_content=additional_prompt_content,
|
|
retry_after_ms=retry_after_ms,
|
|
traceback_info=traceback_info,
|
|
),
|
|
logs=coerce_empty_list_to_none(logs),
|
|
)
|
|
|
|
|
|
output_factory = ToolOutputFactory()
|