Added RunErrorDetails object for MaxTurnsExceeded exception (#743)

### Summary

Introduced the `RunErrorDetails` object to get partial results from a
run interrupted by `MaxTurnsExceeded` exception. In this proposal the
`RunErrorDetails` object contains all the fields from `RunResult` with
`final_output` set to `None` and `output_guardrail_results` set to an
empty list. We can decide to return less information.

@rm-openai At the moment the exception doesn't return the
`RunErrorDetails` object for the streaming mode. Do you have any
suggestions on how to deal with it? In the `_check_errors` function of
`agents/result.py` file.

### Test plan

I have not implemented any tests currently, but if needed I can
implement a basic test to retrieve partial data.

### Issue number

This PR is an attempt to solve issue #719 

### Checks

- [ ] I've added new tests (if relevant)
- [ ] I've added/updated the relevant documentation
- [ ] I've run `make lint` and `make format`
- [ ] I've made sure tests pass
This commit is contained in:
Daniele Morotti 2025-05-29 22:11:33 +02:00 committed by GitHub
parent d46e2ec35b
commit 71968625cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 167 additions and 23 deletions

View file

@ -14,6 +14,7 @@ from .exceptions import (
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
RunErrorDetails,
UserError,
)
from .guardrail import (
@ -204,6 +205,7 @@ __all__ = [
"AgentHooks",
"RunContextWrapper",
"TContext",
"RunErrorDetails",
"RunResult",
"RunResultStreaming",
"RunConfig",

View file

@ -1,11 +1,39 @@
from typing import TYPE_CHECKING
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .agent import Agent
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ModelResponse, RunItem, TResponseInputItem
from .run_context import RunContextWrapper
from .util._pretty_print import pretty_print_run_error_details
@dataclass
class RunErrorDetails:
"""Data collected from an agent run when an exception occurs."""
input: str | list[TResponseInputItem]
new_items: list[RunItem]
raw_responses: list[ModelResponse]
last_agent: Agent[Any]
context_wrapper: RunContextWrapper[Any]
input_guardrail_results: list[InputGuardrailResult]
output_guardrail_results: list[OutputGuardrailResult]
def __str__(self) -> str:
return pretty_print_run_error_details(self)
class AgentsException(Exception):
"""Base class for all exceptions in the Agents SDK."""
run_data: RunErrorDetails | None
def __init__(self, *args: object) -> None:
super().__init__(*args)
self.run_data = None
class MaxTurnsExceeded(AgentsException):
@ -15,6 +43,7 @@ class MaxTurnsExceeded(AgentsException):
def __init__(self, message: str):
self.message = message
super().__init__(message)
class ModelBehaviorError(AgentsException):
@ -26,6 +55,7 @@ class ModelBehaviorError(AgentsException):
def __init__(self, message: str):
self.message = message
super().__init__(message)
class UserError(AgentsException):
@ -35,15 +65,16 @@ class UserError(AgentsException):
def __init__(self, message: str):
self.message = message
super().__init__(message)
class InputGuardrailTripwireTriggered(AgentsException):
"""Exception raised when a guardrail tripwire is triggered."""
guardrail_result: "InputGuardrailResult"
guardrail_result: InputGuardrailResult
"""The result data of the guardrail that was triggered."""
def __init__(self, guardrail_result: "InputGuardrailResult"):
def __init__(self, guardrail_result: InputGuardrailResult):
self.guardrail_result = guardrail_result
super().__init__(
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
@ -53,10 +84,10 @@ class InputGuardrailTripwireTriggered(AgentsException):
class OutputGuardrailTripwireTriggered(AgentsException):
"""Exception raised when a guardrail tripwire is triggered."""
guardrail_result: "OutputGuardrailResult"
guardrail_result: OutputGuardrailResult
"""The result data of the guardrail that was triggered."""
def __init__(self, guardrail_result: "OutputGuardrailResult"):
def __init__(self, guardrail_result: OutputGuardrailResult):
self.guardrail_result = guardrail_result
super().__init__(
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"

View file

@ -11,14 +11,22 @@ from typing_extensions import TypeVar
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .exceptions import (
AgentsException,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
RunErrorDetails,
)
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .run_context import RunContextWrapper
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
from .util._pretty_print import (
pretty_print_result,
pretty_print_run_result_streaming,
)
if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
@ -206,31 +214,53 @@ class RunResultStreaming(RunResultBase):
if self._stored_exception:
raise self._stored_exception
def _create_error_details(self) -> RunErrorDetails:
"""Return a `RunErrorDetails` object considering the current attributes of the class."""
return RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
max_turns_exc.run_data = self._create_error_details()
self._stored_exception = max_turns_exc
# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result)
tripwire_exc.run_data = self._create_error_details()
self._stored_exception = tripwire_exc
# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
exc = self._run_impl_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
run_impl_exc = self._run_impl_task.exception()
if run_impl_exc and isinstance(run_impl_exc, Exception):
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
run_impl_exc.run_data = self._create_error_details()
self._stored_exception = run_impl_exc
if self._input_guardrails_task and self._input_guardrails_task.done():
exc = self._input_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
in_guard_exc = self._input_guardrails_task.exception()
if in_guard_exc and isinstance(in_guard_exc, Exception):
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
in_guard_exc.run_data = self._create_error_details()
self._stored_exception = in_guard_exc
if self._output_guardrails_task and self._output_guardrails_task.done():
exc = self._output_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
out_guard_exc = self._output_guardrails_task.exception()
if out_guard_exc and isinstance(out_guard_exc, Exception):
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
out_guard_exc.run_data = self._create_error_details()
self._stored_exception = out_guard_exc
def _cleanup_tasks(self):
if self._run_impl_task and not self._run_impl_task.done():
@ -244,3 +274,4 @@ class RunResultStreaming(RunResultBase):
def __str__(self) -> str:
return pretty_print_run_result_streaming(self)

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import asyncio
@ -26,6 +27,7 @@ from .exceptions import (
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
RunErrorDetails,
)
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
from .handoffs import Handoff, HandoffInputFilter, handoff
@ -208,7 +210,9 @@ class Runner:
data={"max_turns": max_turns},
),
)
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
raise MaxTurnsExceeded(
f"Max turns ({max_turns}) exceeded"
)
logger.debug(
f"Running agent {current_agent.name} (turn {current_turn})",
@ -283,6 +287,17 @@ class Runner:
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
)
except AgentsException as exc:
exc.run_data = RunErrorDetails(
input=original_input,
new_items=generated_items,
raw_responses=model_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=[]
)
raise
finally:
if current_span:
current_span.finish(reset_current=True)
@ -609,6 +624,19 @@ class Runner:
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
elif isinstance(turn_result.next_step, NextStepRunAgain):
pass
except AgentsException as exc:
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
exc.run_data = RunErrorDetails(
input=streamed_result.input,
new_items=streamed_result.new_items,
raw_responses=streamed_result.raw_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=streamed_result.input_guardrail_results,
output_guardrail_results=streamed_result.output_guardrail_results,
)
raise
except Exception as e:
if current_span:
_error_tracing.attach_error_to_span(

View file

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from ..exceptions import RunErrorDetails
from ..result import RunResult, RunResultBase, RunResultStreaming
@ -38,6 +39,17 @@ def pretty_print_result(result: "RunResult") -> str:
return output
def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
output = "RunErrorDetails:"
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
output += f"\n- {len(result.new_items)} new item(s)"
output += f"\n- {len(result.raw_responses)} raw response(s)"
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
output += "\n(See `RunErrorDetails` for more details)"
return output
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
output = "RunResultStreaming:"
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'

View file

@ -0,0 +1,44 @@
import json
import pytest
from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner
from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
@pytest.mark.asyncio
async def test_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
with pytest.raises(MaxTurnsExceeded) as exc:
await Runner.run(agent, input="hello", max_turns=1)
data = exc.value.run_data
assert isinstance(data, RunErrorDetails)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0
@pytest.mark.asyncio
async def test_streamed_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
result = Runner.run_streamed(agent, input="hello", max_turns=1)
with pytest.raises(MaxTurnsExceeded) as exc:
async for _ in result.stream_events():
pass
data = exc.value.run_data
assert isinstance(data, RunErrorDetails)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0

View file

@ -168,10 +168,6 @@ async def test_tool_call_error():
"children": [
{
"type": "agent",
"error": {
"message": "Error in agent run",
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
},
"data": {
"name": "test_agent",
"handoffs": [],