Added RunErrorDetails object for MaxTurnsExceeded exception (#743)
### Summary Introduced the `RunErrorDetails` object to get partial results from a run interrupted by `MaxTurnsExceeded` exception. In this proposal the `RunErrorDetails` object contains all the fields from `RunResult` with `final_output` set to `None` and `output_guardrail_results` set to an empty list. We can decide to return less information. @rm-openai At the moment the exception doesn't return the `RunErrorDetails` object for the streaming mode. Do you have any suggestions on how to deal with it? In the `_check_errors` function of `agents/result.py` file. ### Test plan I have not implemented any tests currently, but if needed I can implement a basic test to retrieve partial data. ### Issue number This PR is an attempt to solve issue #719 ### Checks - [✅ ] I've added new tests (if relevant) - [ ] I've added/updated the relevant documentation - [ ✅] I've run `make lint` and `make format` - [ ✅] I've made sure tests pass
This commit is contained in:
parent
d46e2ec35b
commit
71968625cc
7 changed files with 167 additions and 23 deletions
|
|
@ -14,6 +14,7 @@ from .exceptions import (
|
|||
MaxTurnsExceeded,
|
||||
ModelBehaviorError,
|
||||
OutputGuardrailTripwireTriggered,
|
||||
RunErrorDetails,
|
||||
UserError,
|
||||
)
|
||||
from .guardrail import (
|
||||
|
|
@ -204,6 +205,7 @@ __all__ = [
|
|||
"AgentHooks",
|
||||
"RunContextWrapper",
|
||||
"TContext",
|
||||
"RunErrorDetails",
|
||||
"RunResult",
|
||||
"RunResultStreaming",
|
||||
"RunConfig",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,39 @@
|
|||
from typing import TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import Agent
|
||||
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
||||
from .items import ModelResponse, RunItem, TResponseInputItem
|
||||
from .run_context import RunContextWrapper
|
||||
|
||||
from .util._pretty_print import pretty_print_run_error_details
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunErrorDetails:
|
||||
"""Data collected from an agent run when an exception occurs."""
|
||||
input: str | list[TResponseInputItem]
|
||||
new_items: list[RunItem]
|
||||
raw_responses: list[ModelResponse]
|
||||
last_agent: Agent[Any]
|
||||
context_wrapper: RunContextWrapper[Any]
|
||||
input_guardrail_results: list[InputGuardrailResult]
|
||||
output_guardrail_results: list[OutputGuardrailResult]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return pretty_print_run_error_details(self)
|
||||
|
||||
|
||||
class AgentsException(Exception):
|
||||
"""Base class for all exceptions in the Agents SDK."""
|
||||
run_data: RunErrorDetails | None
|
||||
|
||||
def __init__(self, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.run_data = None
|
||||
|
||||
|
||||
class MaxTurnsExceeded(AgentsException):
|
||||
|
|
@ -15,6 +43,7 @@ class MaxTurnsExceeded(AgentsException):
|
|||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelBehaviorError(AgentsException):
|
||||
|
|
@ -26,6 +55,7 @@ class ModelBehaviorError(AgentsException):
|
|||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserError(AgentsException):
|
||||
|
|
@ -35,15 +65,16 @@ class UserError(AgentsException):
|
|||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InputGuardrailTripwireTriggered(AgentsException):
|
||||
"""Exception raised when a guardrail tripwire is triggered."""
|
||||
|
||||
guardrail_result: "InputGuardrailResult"
|
||||
guardrail_result: InputGuardrailResult
|
||||
"""The result data of the guardrail that was triggered."""
|
||||
|
||||
def __init__(self, guardrail_result: "InputGuardrailResult"):
|
||||
def __init__(self, guardrail_result: InputGuardrailResult):
|
||||
self.guardrail_result = guardrail_result
|
||||
super().__init__(
|
||||
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
||||
|
|
@ -53,10 +84,10 @@ class InputGuardrailTripwireTriggered(AgentsException):
|
|||
class OutputGuardrailTripwireTriggered(AgentsException):
|
||||
"""Exception raised when a guardrail tripwire is triggered."""
|
||||
|
||||
guardrail_result: "OutputGuardrailResult"
|
||||
guardrail_result: OutputGuardrailResult
|
||||
"""The result data of the guardrail that was triggered."""
|
||||
|
||||
def __init__(self, guardrail_result: "OutputGuardrailResult"):
|
||||
def __init__(self, guardrail_result: OutputGuardrailResult):
|
||||
self.guardrail_result = guardrail_result
|
||||
super().__init__(
|
||||
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
||||
|
|
|
|||
|
|
@ -11,14 +11,22 @@ from typing_extensions import TypeVar
|
|||
from ._run_impl import QueueCompleteSentinel
|
||||
from .agent import Agent
|
||||
from .agent_output import AgentOutputSchemaBase
|
||||
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
|
||||
from .exceptions import (
|
||||
AgentsException,
|
||||
InputGuardrailTripwireTriggered,
|
||||
MaxTurnsExceeded,
|
||||
RunErrorDetails,
|
||||
)
|
||||
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
||||
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
||||
from .logger import logger
|
||||
from .run_context import RunContextWrapper
|
||||
from .stream_events import StreamEvent
|
||||
from .tracing import Trace
|
||||
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
|
||||
from .util._pretty_print import (
|
||||
pretty_print_result,
|
||||
pretty_print_run_result_streaming,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._run_impl import QueueCompleteSentinel
|
||||
|
|
@ -206,31 +214,53 @@ class RunResultStreaming(RunResultBase):
|
|||
if self._stored_exception:
|
||||
raise self._stored_exception
|
||||
|
||||
def _create_error_details(self) -> RunErrorDetails:
|
||||
"""Return a `RunErrorDetails` object considering the current attributes of the class."""
|
||||
return RunErrorDetails(
|
||||
input=self.input,
|
||||
new_items=self.new_items,
|
||||
raw_responses=self.raw_responses,
|
||||
last_agent=self.current_agent,
|
||||
context_wrapper=self.context_wrapper,
|
||||
input_guardrail_results=self.input_guardrail_results,
|
||||
output_guardrail_results=self.output_guardrail_results,
|
||||
)
|
||||
|
||||
def _check_errors(self):
|
||||
if self.current_turn > self.max_turns:
|
||||
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
|
||||
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
|
||||
max_turns_exc.run_data = self._create_error_details()
|
||||
self._stored_exception = max_turns_exc
|
||||
|
||||
# Fetch all the completed guardrail results from the queue and raise if needed
|
||||
while not self._input_guardrail_queue.empty():
|
||||
guardrail_result = self._input_guardrail_queue.get_nowait()
|
||||
if guardrail_result.output.tripwire_triggered:
|
||||
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
|
||||
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result)
|
||||
tripwire_exc.run_data = self._create_error_details()
|
||||
self._stored_exception = tripwire_exc
|
||||
|
||||
# Check the tasks for any exceptions
|
||||
if self._run_impl_task and self._run_impl_task.done():
|
||||
exc = self._run_impl_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
run_impl_exc = self._run_impl_task.exception()
|
||||
if run_impl_exc and isinstance(run_impl_exc, Exception):
|
||||
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
|
||||
run_impl_exc.run_data = self._create_error_details()
|
||||
self._stored_exception = run_impl_exc
|
||||
|
||||
if self._input_guardrails_task and self._input_guardrails_task.done():
|
||||
exc = self._input_guardrails_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
in_guard_exc = self._input_guardrails_task.exception()
|
||||
if in_guard_exc and isinstance(in_guard_exc, Exception):
|
||||
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
|
||||
in_guard_exc.run_data = self._create_error_details()
|
||||
self._stored_exception = in_guard_exc
|
||||
|
||||
if self._output_guardrails_task and self._output_guardrails_task.done():
|
||||
exc = self._output_guardrails_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
out_guard_exc = self._output_guardrails_task.exception()
|
||||
if out_guard_exc and isinstance(out_guard_exc, Exception):
|
||||
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
|
||||
out_guard_exc.run_data = self._create_error_details()
|
||||
self._stored_exception = out_guard_exc
|
||||
|
||||
def _cleanup_tasks(self):
|
||||
if self._run_impl_task and not self._run_impl_task.done():
|
||||
|
|
@ -244,3 +274,4 @@ class RunResultStreaming(RunResultBase):
|
|||
|
||||
def __str__(self) -> str:
|
||||
return pretty_print_run_result_streaming(self)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -26,6 +27,7 @@ from .exceptions import (
|
|||
MaxTurnsExceeded,
|
||||
ModelBehaviorError,
|
||||
OutputGuardrailTripwireTriggered,
|
||||
RunErrorDetails,
|
||||
)
|
||||
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
||||
from .handoffs import Handoff, HandoffInputFilter, handoff
|
||||
|
|
@ -208,7 +210,9 @@ class Runner:
|
|||
data={"max_turns": max_turns},
|
||||
),
|
||||
)
|
||||
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
|
||||
raise MaxTurnsExceeded(
|
||||
f"Max turns ({max_turns}) exceeded"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Running agent {current_agent.name} (turn {current_turn})",
|
||||
|
|
@ -283,6 +287,17 @@ class Runner:
|
|||
raise AgentsException(
|
||||
f"Unknown next step type: {type(turn_result.next_step)}"
|
||||
)
|
||||
except AgentsException as exc:
|
||||
exc.run_data = RunErrorDetails(
|
||||
input=original_input,
|
||||
new_items=generated_items,
|
||||
raw_responses=model_responses,
|
||||
last_agent=current_agent,
|
||||
context_wrapper=context_wrapper,
|
||||
input_guardrail_results=input_guardrail_results,
|
||||
output_guardrail_results=[]
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if current_span:
|
||||
current_span.finish(reset_current=True)
|
||||
|
|
@ -609,6 +624,19 @@ class Runner:
|
|||
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
|
||||
elif isinstance(turn_result.next_step, NextStepRunAgain):
|
||||
pass
|
||||
except AgentsException as exc:
|
||||
streamed_result.is_complete = True
|
||||
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
|
||||
exc.run_data = RunErrorDetails(
|
||||
input=streamed_result.input,
|
||||
new_items=streamed_result.new_items,
|
||||
raw_responses=streamed_result.raw_responses,
|
||||
last_agent=current_agent,
|
||||
context_wrapper=context_wrapper,
|
||||
input_guardrail_results=streamed_result.input_guardrail_results,
|
||||
output_guardrail_results=streamed_result.output_guardrail_results,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
if current_span:
|
||||
_error_tracing.attach_error_to_span(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
|
|||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..exceptions import RunErrorDetails
|
||||
from ..result import RunResult, RunResultBase, RunResultStreaming
|
||||
|
||||
|
||||
|
|
@ -38,6 +39,17 @@ def pretty_print_result(result: "RunResult") -> str:
|
|||
return output
|
||||
|
||||
|
||||
def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
|
||||
output = "RunErrorDetails:"
|
||||
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
|
||||
output += f"\n- {len(result.new_items)} new item(s)"
|
||||
output += f"\n- {len(result.raw_responses)} raw response(s)"
|
||||
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
|
||||
output += "\n(See `RunErrorDetails` for more details)"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
|
||||
output = "RunResultStreaming:"
|
||||
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'
|
||||
|
|
|
|||
44
tests/test_run_error_details.py
Normal file
44
tests/test_run_error_details.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_error_includes_data():
|
||||
model = FakeModel()
|
||||
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
|
||||
model.add_multiple_turn_outputs([
|
||||
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
|
||||
[get_text_message("done")],
|
||||
])
|
||||
with pytest.raises(MaxTurnsExceeded) as exc:
|
||||
await Runner.run(agent, input="hello", max_turns=1)
|
||||
data = exc.value.run_data
|
||||
assert isinstance(data, RunErrorDetails)
|
||||
assert data.last_agent == agent
|
||||
assert len(data.raw_responses) == 1
|
||||
assert len(data.new_items) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streamed_run_error_includes_data():
|
||||
model = FakeModel()
|
||||
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
|
||||
model.add_multiple_turn_outputs([
|
||||
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
|
||||
[get_text_message("done")],
|
||||
])
|
||||
result = Runner.run_streamed(agent, input="hello", max_turns=1)
|
||||
with pytest.raises(MaxTurnsExceeded) as exc:
|
||||
async for _ in result.stream_events():
|
||||
pass
|
||||
data = exc.value.run_data
|
||||
assert isinstance(data, RunErrorDetails)
|
||||
assert data.last_agent == agent
|
||||
assert len(data.raw_responses) == 1
|
||||
assert len(data.new_items) > 0
|
||||
|
|
@ -168,10 +168,6 @@ async def test_tool_call_error():
|
|||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"error": {
|
||||
"message": "Error in agent run",
|
||||
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
|
||||
},
|
||||
"data": {
|
||||
"name": "test_agent",
|
||||
"handoffs": [],
|
||||
|
|
|
|||
Loading…
Reference in a new issue