Add usage to context in streaming (#595)

This commit is contained in:
Rohan Mehta 2025-04-24 18:20:35 -04:00 committed by GitHub
parent 3bbc7c48cb
commit 8fd7773a5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 28 additions and 13 deletions

View file

@ -15,6 +15,7 @@ from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .guardrail import InputGuardrailResult, OutputGuardrailResult from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger from .logger import logger
from .run_context import RunContextWrapper
from .stream_events import StreamEvent from .stream_events import StreamEvent
from .tracing import Trace from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
@ -50,6 +51,9 @@ class RunResultBase(abc.ABC):
output_guardrail_results: list[OutputGuardrailResult] output_guardrail_results: list[OutputGuardrailResult]
"""Guardrail results for the final output of the agent.""" """Guardrail results for the final output of the agent."""
context_wrapper: RunContextWrapper[Any]
"""The context wrapper for the agent run."""
@property @property
@abc.abstractmethod @abc.abstractmethod
def last_agent(self) -> Agent[Any]: def last_agent(self) -> Agent[Any]:
@ -75,9 +79,7 @@ class RunResultBase(abc.ABC):
def to_input_list(self) -> list[TResponseInputItem]: def to_input_list(self) -> list[TResponseInputItem]:
"""Creates a new input list, merging the original input with all the new items generated.""" """Creates a new input list, merging the original input with all the new items generated."""
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
self.input
)
new_items = [item.to_input_item() for item in self.new_items] new_items = [item.to_input_item() for item in self.new_items]
return original_items + new_items return original_items + new_items
@ -206,17 +208,13 @@ class RunResultStreaming(RunResultBase):
def _check_errors(self): def _check_errors(self):
if self.current_turn > self.max_turns: if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded( self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
f"Max turns ({self.max_turns}) exceeded"
)
# Fetch all the completed guardrail results from the queue and raise if needed # Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty(): while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait() guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered: if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered( self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
guardrail_result
)
# Check the tasks for any exceptions # Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done(): if self._run_impl_task and self._run_impl_task.done():

View file

@ -270,6 +270,7 @@ class Runner:
_last_agent=current_agent, _last_agent=current_agent,
input_guardrail_results=input_guardrail_results, input_guardrail_results=input_guardrail_results,
output_guardrail_results=output_guardrail_results, output_guardrail_results=output_guardrail_results,
context_wrapper=context_wrapper,
) )
elif isinstance(turn_result.next_step, NextStepHandoff): elif isinstance(turn_result.next_step, NextStepHandoff):
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
@ -423,6 +424,7 @@ class Runner:
output_guardrail_results=[], output_guardrail_results=[],
_current_agent_output_schema=output_schema, _current_agent_output_schema=output_schema,
trace=new_trace, trace=new_trace,
context_wrapper=context_wrapper,
) )
# Kick off the actual agent loop in the background and return the streamed result object. # Kick off the actual agent loop in the background and return the streamed result object.
@ -696,6 +698,7 @@ class Runner:
usage=usage, usage=usage,
response_id=event.response.id, response_id=event.response.id,
) )
context_wrapper.usage.add(usage)
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

View file

@ -3,7 +3,8 @@ from __future__ import annotations
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
from openai.types.responses import Response, ResponseCompletedEvent from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from agents.agent_output import AgentOutputSchemaBase from agents.agent_output import AgentOutputSchemaBase
from agents.handoffs import Handoff from agents.handoffs import Handoff
@ -33,6 +34,10 @@ class FakeModel(Model):
) )
self.tracing_enabled = tracing_enabled self.tracing_enabled = tracing_enabled
self.last_turn_args: dict[str, Any] = {} self.last_turn_args: dict[str, Any] = {}
self.hardcoded_usage: Usage | None = None
def set_hardcoded_usage(self, usage: Usage):
self.hardcoded_usage = usage
def set_next_output(self, output: list[TResponseOutputItem] | Exception): def set_next_output(self, output: list[TResponseOutputItem] | Exception):
self.turn_outputs.append(output) self.turn_outputs.append(output)
@ -83,7 +88,7 @@ class FakeModel(Model):
return ModelResponse( return ModelResponse(
output=output, output=output,
usage=Usage(), usage=self.hardcoded_usage or Usage(),
response_id=None, response_id=None,
) )
@ -123,13 +128,14 @@ class FakeModel(Model):
yield ResponseCompletedEvent( yield ResponseCompletedEvent(
type="response.completed", type="response.completed",
response=get_response_obj(output), response=get_response_obj(output, usage=self.hardcoded_usage),
) )
def get_response_obj( def get_response_obj(
output: list[TResponseOutputItem], output: list[TResponseOutputItem],
response_id: str | None = None, response_id: str | None = None,
usage: Usage | None = None,
) -> Response: ) -> Response:
return Response( return Response(
id=response_id or "123", id=response_id or "123",
@ -141,4 +147,11 @@ def get_response_obj(
tools=[], tools=[],
top_p=None, top_p=None,
parallel_tool_calls=False, parallel_tool_calls=False,
usage=ResponseUsage(
input_tokens=usage.input_tokens if usage else 0,
output_tokens=usage.output_tokens if usage else 0,
total_tokens=usage.total_tokens if usage else 0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
),
) )

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from agents import Agent, RunResult from agents import Agent, RunContextWrapper, RunResult
def create_run_result(final_output: Any) -> RunResult: def create_run_result(final_output: Any) -> RunResult:
@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult:
input_guardrail_results=[], input_guardrail_results=[],
output_guardrail_results=[], output_guardrail_results=[],
_last_agent=Agent(name="test"), _last_agent=Agent(name="test"),
context_wrapper=RunContextWrapper(context=None),
) )