openai-agents-python/src/agents/result.py
2025-03-11 09:42:28 -07:00

220 lines
8.1 KiB
Python

from __future__ import annotations
import abc
import asyncio
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
from typing_extensions import TypeVar
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .agent_output import AgentOutputSchema
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .stream_events import StreamEvent
from .tracing import Trace
if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
T = TypeVar("T")
@dataclass
class RunResultBase(abc.ABC):
input: str | list[TResponseInputItem]
"""The original input items i.e. the items before run() was called. This may be a mutated
version of the input, if there are handoff input filters that mutate the input.
"""
new_items: list[RunItem]
"""The new items generated during the agent run. These include things like new messages, tool
calls and their outputs, etc.
"""
raw_responses: list[ModelResponse]
"""The raw LLM responses generated by the model during the agent run."""
final_output: Any
"""The output of the last agent."""
input_guardrail_results: list[InputGuardrailResult]
"""Guardrail results for the input messages."""
output_guardrail_results: list[OutputGuardrailResult]
"""Guardrail results for the final output of the agent."""
@property
@abc.abstractmethod
def last_agent(self) -> Agent[Any]:
"""The last agent that was run."""
def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T:
"""A convenience method to cast the final output to a specific type. By default, the cast
is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a
TypeError if the final output is not of the given type.
Args:
cls: The type to cast the final output to.
raise_if_incorrect_type: If True, we'll raise a TypeError if the final output is not of
the given type.
Returns:
The final output casted to the given type.
"""
if raise_if_incorrect_type and not isinstance(self.final_output, cls):
raise TypeError(f"Final output is not of type {cls.__name__}")
return cast(T, self.final_output)
def to_input_list(self) -> list[TResponseInputItem]:
"""Creates a new input list, merging the original input with all the new items generated."""
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
new_items = [item.to_input_item() for item in self.new_items]
return original_items + new_items
@dataclass
class RunResult(RunResultBase):
_last_agent: Agent[Any]
@property
def last_agent(self) -> Agent[Any]:
"""The last agent that was run."""
return self._last_agent
@dataclass
class RunResultStreaming(RunResultBase):
"""The result of an agent run in streaming mode. You can use the `stream_events` method to
receive semantic events as they are generated.
The streaming method will raise:
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit.
- A GuardrailTripwireTriggered exception if a guardrail is tripped.
"""
current_agent: Agent[Any]
"""The current agent that is running."""
current_turn: int
"""The current turn number."""
max_turns: int
"""The maximum number of turns the agent can run for."""
final_output: Any
"""The final output of the agent. This is None until the agent has finished running."""
_current_agent_output_schema: AgentOutputSchema | None = field(repr=False)
_trace: Trace | None = field(repr=False)
is_complete: bool = False
"""Whether the agent has finished running."""
# Queues that the background run_loop writes to
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
default_factory=asyncio.Queue, repr=False
)
_input_guardrail_queue: asyncio.Queue[InputGuardrailResult] = field(
default_factory=asyncio.Queue, repr=False
)
# Store the asyncio tasks that we're waiting on
_run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_stored_exception: Exception | None = field(default=None, repr=False)
@property
def last_agent(self) -> Agent[Any]:
"""The last agent that was run. Updates as the agent run progresses, so the true last agent
is only available after the agent run is complete.
"""
return self.current_agent
async def stream_events(self) -> AsyncIterator[StreamEvent]:
"""Stream deltas for new items as they are generated. We're using the types from the
OpenAI Responses API, so these are semantic events: each event has a `type` field that
describes the type of the event, along with the data for that event.
This will raise:
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit.
- A GuardrailTripwireTriggered exception if a guardrail is tripped.
"""
while True:
self._check_errors()
if self._stored_exception:
logger.debug("Breaking due to stored exception")
self.is_complete = True
break
if self.is_complete and self._event_queue.empty():
break
try:
item = await self._event_queue.get()
except asyncio.CancelledError:
break
if isinstance(item, QueueCompleteSentinel):
self._event_queue.task_done()
# Check for errors, in case the queue was completed due to an exception
self._check_errors()
break
yield item
self._event_queue.task_done()
if self._trace:
self._trace.finish(reset_current=True)
self._cleanup_tasks()
if self._stored_exception:
raise self._stored_exception
def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
exc = self._run_impl_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._input_guardrails_task and self._input_guardrails_task.done():
exc = self._input_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._output_guardrails_task and self._output_guardrails_task.done():
exc = self._output_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
def _cleanup_tasks(self):
if self._run_impl_task and not self._run_impl_task.done():
self._run_impl_task.cancel()
if self._input_guardrails_task and not self._input_guardrails_task.done():
self._input_guardrails_task.cancel()
if self._output_guardrails_task and not self._output_guardrails_task.done():
self._output_guardrails_task.cancel()
self._output_guardrails_task.cancel()
self._output_guardrails_task.cancel()