Pretty print result classes
This commit is contained in:
parent
09d70c074d
commit
64e263b614
7 changed files with 338 additions and 1 deletions
8
Makefile
8
Makefile
|
|
@ -18,6 +18,14 @@ mypy:
|
|||
tests:
|
||||
uv run pytest
|
||||
|
||||
.PHONY: snapshots-fix
|
||||
snapshots-fix:
|
||||
uv run pytest --inline-snapshot=fix
|
||||
|
||||
.PHONY: snapshots-create
|
||||
snapshots-create:
|
||||
uv run pytest --inline-snapshot=create
|
||||
|
||||
.PHONY: old_version_tests
|
||||
old_version_tests:
|
||||
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ dev = [
|
|||
"mkdocstrings[python]>=0.28.0",
|
||||
"coverage>=7.6.12",
|
||||
"playwright==1.50.0",
|
||||
"inline-snapshot>=0.20.7",
|
||||
]
|
||||
[tool.uv.workspace]
|
||||
members = ["agents"]
|
||||
|
|
@ -116,4 +117,7 @@ filterwarnings = [
|
|||
]
|
||||
markers = [
|
||||
"allow_call_model_methods: mark test as allowing calls to real model implementations",
|
||||
]
|
||||
]
|
||||
|
||||
[tool.inline-snapshot]
|
||||
format-command="ruff format --stdin-filename {filename}"
|
||||
|
|
@ -17,6 +17,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
|||
from .logger import logger
|
||||
from .stream_events import StreamEvent
|
||||
from .tracing import Trace
|
||||
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._run_impl import QueueCompleteSentinel
|
||||
|
|
@ -89,6 +90,9 @@ class RunResult(RunResultBase):
|
|||
"""The last agent that was run."""
|
||||
return self._last_agent
|
||||
|
||||
def __str__(self) -> str:
|
||||
return pretty_print_result(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunResultStreaming(RunResultBase):
|
||||
|
|
@ -216,3 +220,6 @@ class RunResultStreaming(RunResultBase):
|
|||
|
||||
if self._output_guardrails_task and not self._output_guardrails_task.done():
|
||||
self._output_guardrails_task.cancel()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return pretty_print_run_result_streaming(self)
|
||||
|
|
|
|||
56
src/agents/util/_pretty_print.py
Normal file
56
src/agents/util/_pretty_print.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..result import RunResult, RunResultBase, RunResultStreaming
|
||||
|
||||
|
||||
def _indent(text: str, indent_level: int) -> str:
|
||||
indent_string = " " * indent_level
|
||||
return "\n".join(f"{indent_string}{line}" for line in text.splitlines())
|
||||
|
||||
|
||||
def _final_output_str(result: "RunResultBase") -> str:
|
||||
if result.final_output is None:
|
||||
return "None"
|
||||
elif isinstance(result.final_output, str):
|
||||
return result.final_output
|
||||
elif isinstance(result.final_output, BaseModel):
|
||||
return result.final_output.model_dump_json(indent=2)
|
||||
else:
|
||||
return str(result.final_output)
|
||||
|
||||
|
||||
def pretty_print_result(result: "RunResult") -> str:
|
||||
output = "RunResult:"
|
||||
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
|
||||
output += (
|
||||
f"\n- Final output ({type(result.final_output).__name__}):\n"
|
||||
f"{_indent(_final_output_str(result), 2)}"
|
||||
)
|
||||
output += f"\n- {len(result.new_items)} new item(s)"
|
||||
output += f"\n- {len(result.raw_responses)} raw response(s)"
|
||||
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
|
||||
output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)"
|
||||
output += "\n(See `RunResult` for more details)"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
|
||||
output = "RunResultStreaming:"
|
||||
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'
|
||||
output += f"\n- Current turn: {result.current_turn}"
|
||||
output += f"\n- Max turns: {result.max_turns}"
|
||||
output += f"\n- Is complete: {result.is_complete}"
|
||||
output += (
|
||||
f"\n- Final output ({type(result.final_output).__name__}):\n"
|
||||
f"{_indent(_final_output_str(result), 2)}"
|
||||
)
|
||||
output += f"\n- {len(result.new_items)} new item(s)"
|
||||
output += f"\n- {len(result.raw_responses)} raw response(s)"
|
||||
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
|
||||
output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)"
|
||||
output += "\n(See `RunResultStreaming` for more details)"
|
||||
return output
|
||||
25
tests/README.md
Normal file
25
tests/README.md
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Tests
|
||||
|
||||
Before running any tests, make sure you have `uv` installed (and ideally run `make sync` after).
|
||||
|
||||
## Running tests
|
||||
|
||||
```
|
||||
make tests
|
||||
```
|
||||
|
||||
## Snapshots
|
||||
|
||||
We use [inline-snapshots](https://15r10nk.github.io/inline-snapshot/latest/) for some tests. If your code adds new snapshot tests or breaks existing ones, you can fix/create them. After fixing/creating snapshots, run `make tests` again to verify the tests pass.
|
||||
|
||||
### Fixing snapshots
|
||||
|
||||
```
|
||||
make snapshots-fix
|
||||
```
|
||||
|
||||
### Creating snapshots
|
||||
|
||||
```
|
||||
make snapshots-update
|
||||
```
|
||||
201
tests/test_pretty_print.py
Normal file
201
tests/test_pretty_print.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent, Runner
|
||||
from agents.agent_output import _WRAPPER_DICT_KEY
|
||||
from agents.util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
|
||||
from tests.fake_model import FakeModel
|
||||
|
||||
from .test_responses import get_final_output_message, get_text_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pretty_result():
|
||||
model = FakeModel()
|
||||
model.set_next_output([get_text_message("Hi there")])
|
||||
|
||||
agent = Agent(name="test_agent", model=model)
|
||||
result = await Runner.run(agent, input="Hello")
|
||||
|
||||
assert pretty_print_result(result) == snapshot("""\
|
||||
RunResult:
|
||||
- Last agent: Agent(name="test_agent", ...)
|
||||
- Final output (str):
|
||||
Hi there
|
||||
- 1 new item(s)
|
||||
- 1 raw response(s)
|
||||
- 0 input guardrail result(s)
|
||||
- 0 output guardrail result(s)
|
||||
(See `RunResult` for more details)\
|
||||
""")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pretty_run_result_streaming():
|
||||
model = FakeModel()
|
||||
model.set_next_output([get_text_message("Hi there")])
|
||||
|
||||
agent = Agent(name="test_agent", model=model)
|
||||
result = Runner.run_streamed(agent, input="Hello")
|
||||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
assert pretty_print_run_result_streaming(result) == snapshot("""\
|
||||
RunResultStreaming:
|
||||
- Current agent: Agent(name="test_agent", ...)
|
||||
- Current turn: 1
|
||||
- Max turns: 10
|
||||
- Is complete: True
|
||||
- Final output (str):
|
||||
Hi there
|
||||
- 1 new item(s)
|
||||
- 1 raw response(s)
|
||||
- 0 input guardrail result(s)
|
||||
- 0 output guardrail result(s)
|
||||
(See `RunResultStreaming` for more details)\
|
||||
""")
|
||||
|
||||
|
||||
class Foo(BaseModel):
|
||||
bar: str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pretty_run_result_structured_output():
|
||||
model = FakeModel()
|
||||
model.set_next_output(
|
||||
[
|
||||
get_text_message("Test"),
|
||||
get_final_output_message(Foo(bar="Hi there").model_dump_json()),
|
||||
]
|
||||
)
|
||||
|
||||
agent = Agent(name="test_agent", model=model, output_type=Foo)
|
||||
result = await Runner.run(agent, input="Hello")
|
||||
|
||||
assert pretty_print_result(result) == snapshot("""\
|
||||
RunResult:
|
||||
- Last agent: Agent(name="test_agent", ...)
|
||||
- Final output (Foo):
|
||||
{
|
||||
"bar": "Hi there"
|
||||
}
|
||||
- 2 new item(s)
|
||||
- 1 raw response(s)
|
||||
- 0 input guardrail result(s)
|
||||
- 0 output guardrail result(s)
|
||||
(See `RunResult` for more details)\
|
||||
""")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pretty_run_result_streaming_structured_output():
|
||||
model = FakeModel()
|
||||
model.set_next_output(
|
||||
[
|
||||
get_text_message("Test"),
|
||||
get_final_output_message(Foo(bar="Hi there").model_dump_json()),
|
||||
]
|
||||
)
|
||||
|
||||
agent = Agent(name="test_agent", model=model, output_type=Foo)
|
||||
result = Runner.run_streamed(agent, input="Hello")
|
||||
|
||||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
assert pretty_print_run_result_streaming(result) == snapshot("""\
|
||||
RunResultStreaming:
|
||||
- Current agent: Agent(name="test_agent", ...)
|
||||
- Current turn: 1
|
||||
- Max turns: 10
|
||||
- Is complete: True
|
||||
- Final output (Foo):
|
||||
{
|
||||
"bar": "Hi there"
|
||||
}
|
||||
- 2 new item(s)
|
||||
- 1 raw response(s)
|
||||
- 0 input guardrail result(s)
|
||||
- 0 output guardrail result(s)
|
||||
(See `RunResultStreaming` for more details)\
|
||||
""")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pretty_run_result_list_structured_output():
|
||||
model = FakeModel()
|
||||
model.set_next_output(
|
||||
[
|
||||
get_text_message("Test"),
|
||||
get_final_output_message(
|
||||
json.dumps(
|
||||
{
|
||||
_WRAPPER_DICT_KEY: [
|
||||
Foo(bar="Hi there").model_dump(),
|
||||
Foo(bar="Hi there 2").model_dump(),
|
||||
]
|
||||
}
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
agent = Agent(name="test_agent", model=model, output_type=list[Foo])
|
||||
result = await Runner.run(agent, input="Hello")
|
||||
|
||||
assert pretty_print_result(result) == snapshot("""\
|
||||
RunResult:
|
||||
- Last agent: Agent(name="test_agent", ...)
|
||||
- Final output (list):
|
||||
[Foo(bar='Hi there'), Foo(bar='Hi there 2')]
|
||||
- 2 new item(s)
|
||||
- 1 raw response(s)
|
||||
- 0 input guardrail result(s)
|
||||
- 0 output guardrail result(s)
|
||||
(See `RunResult` for more details)\
|
||||
""")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pretty_run_result_streaming_list_structured_output():
|
||||
model = FakeModel()
|
||||
model.set_next_output(
|
||||
[
|
||||
get_text_message("Test"),
|
||||
get_final_output_message(
|
||||
json.dumps(
|
||||
{
|
||||
_WRAPPER_DICT_KEY: [
|
||||
Foo(bar="Test").model_dump(),
|
||||
Foo(bar="Test 2").model_dump(),
|
||||
]
|
||||
}
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
agent = Agent(name="test_agent", model=model, output_type=list[Foo])
|
||||
result = Runner.run_streamed(agent, input="Hello")
|
||||
|
||||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
assert pretty_print_run_result_streaming(result) == snapshot("""\
|
||||
RunResultStreaming:
|
||||
- Current agent: Agent(name="test_agent", ...)
|
||||
- Current turn: 1
|
||||
- Max turns: 10
|
||||
- Is complete: True
|
||||
- Final output (list):
|
||||
[Foo(bar='Test'), Foo(bar='Test 2')]
|
||||
- 2 new item(s)
|
||||
- 1 raw response(s)
|
||||
- 0 input guardrail result(s)
|
||||
- 0 output guardrail result(s)
|
||||
(See `RunResultStreaming` for more details)\
|
||||
""")
|
||||
36
uv.lock
36
uv.lock
|
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.9"
|
||||
|
||||
[[package]]
|
||||
|
|
@ -25,6 +26,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asttokens"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.17.0"
|
||||
|
|
@ -239,6 +249,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "executing"
|
||||
version = "2.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ghp-import"
|
||||
version = "2.1.0"
|
||||
|
|
@ -391,6 +410,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inline-snapshot"
|
||||
version = "0.20.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "asttokens" },
|
||||
{ name = "executing" },
|
||||
{ name = "rich" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b0/41/9bd2ecd10ef789e8aff6fb68dcc7677dc31b33b2d27c306c0d40fc982fbc/inline_snapshot-0.20.7.tar.gz", hash = "sha256:d55bbb6254d0727dc304729ca7998cde1c1e984c4bf50281514aa9d727a56cf2", size = 92643 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/01/8f/1bf23da63ad1a0b14ca2d9114700123ef76732e375548f4f9ca94052817e/inline_snapshot-0.20.7-py3-none-any.whl", hash = "sha256:2df6dd8710d1f0def2c1f9d6c25fd03d7beba01f3addf52fc370343d9ee9959f", size = 48108 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.6"
|
||||
|
|
@ -796,6 +830,7 @@ dependencies = [
|
|||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "coverage" },
|
||||
{ name = "inline-snapshot" },
|
||||
{ name = "mkdocs" },
|
||||
{ name = "mkdocs-material" },
|
||||
{ name = "mkdocstrings", extra = ["python"] },
|
||||
|
|
@ -821,6 +856,7 @@ requires-dist = [
|
|||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "coverage", specifier = ">=7.6.12" },
|
||||
{ name = "inline-snapshot", specifier = ">=0.20.7" },
|
||||
{ name = "mkdocs", specifier = ">=1.6.0" },
|
||||
{ name = "mkdocs-material", specifier = ">=9.6.0" },
|
||||
{ name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" },
|
||||
|
|
|
|||
Loading…
Reference in a new issue