59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
from typing import Any
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from agents import Agent, RunContextWrapper, RunResult
|
|
|
|
|
|
def create_run_result(final_output: Any) -> RunResult:
|
|
return RunResult(
|
|
input="test",
|
|
new_items=[],
|
|
raw_responses=[],
|
|
final_output=final_output,
|
|
input_guardrail_results=[],
|
|
output_guardrail_results=[],
|
|
_last_agent=Agent(name="test"),
|
|
context_wrapper=RunContextWrapper(context=None),
|
|
)
|
|
|
|
|
|
class Foo(BaseModel):
|
|
bar: int
|
|
|
|
|
|
def test_result_cast_typechecks():
|
|
"""Correct casts should work fine."""
|
|
result = create_run_result(1)
|
|
assert result.final_output_as(int) == 1
|
|
|
|
result = create_run_result("test")
|
|
assert result.final_output_as(str) == "test"
|
|
|
|
result = create_run_result(Foo(bar=1))
|
|
assert result.final_output_as(Foo) == Foo(bar=1)
|
|
|
|
|
|
def test_bad_cast_doesnt_raise():
|
|
"""Bad casts shouldn't error unless we ask for it."""
|
|
result = create_run_result(1)
|
|
result.final_output_as(str)
|
|
|
|
result = create_run_result("test")
|
|
result.final_output_as(Foo)
|
|
|
|
|
|
def test_bad_cast_with_param_raises():
|
|
"""Bad casts should raise a TypeError when we ask for it."""
|
|
result = create_run_result(1)
|
|
with pytest.raises(TypeError):
|
|
result.final_output_as(str, raise_if_incorrect_type=True)
|
|
|
|
result = create_run_result("test")
|
|
with pytest.raises(TypeError):
|
|
result.final_output_as(Foo, raise_if_incorrect_type=True)
|
|
|
|
result = create_run_result(Foo(bar=1))
|
|
with pytest.raises(TypeError):
|
|
result.final_output_as(int, raise_if_incorrect_type=True)
|