Add usage to context in streaming (#595)

This commit is contained in:
Rohan Mehta 2025-04-24 18:20:35 -04:00 committed by GitHub
parent 3bbc7c48cb
commit 8fd7773a5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 28 additions and 13 deletions

View file

@ -15,6 +15,7 @@ from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .run_context import RunContextWrapper
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
@ -50,6 +51,9 @@ class RunResultBase(abc.ABC):
output_guardrail_results: list[OutputGuardrailResult]
"""Guardrail results for the final output of the agent."""
context_wrapper: RunContextWrapper[Any]
"""The context wrapper for the agent run."""
@property
@abc.abstractmethod
def last_agent(self) -> Agent[Any]:
@ -75,9 +79,7 @@ class RunResultBase(abc.ABC):
def to_input_list(self) -> list[TResponseInputItem]:
"""Creates a new input list, merging the original input with all the new items generated."""
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
self.input
)
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
new_items = [item.to_input_item() for item in self.new_items]
return original_items + new_items
@ -206,17 +208,13 @@ class RunResultStreaming(RunResultBase):
def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(
f"Max turns ({self.max_turns}) exceeded"
)
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(
guardrail_result
)
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():

View file

@ -270,6 +270,7 @@ class Runner:
_last_agent=current_agent,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=output_guardrail_results,
context_wrapper=context_wrapper,
)
elif isinstance(turn_result.next_step, NextStepHandoff):
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
@ -423,6 +424,7 @@ class Runner:
output_guardrail_results=[],
_current_agent_output_schema=output_schema,
trace=new_trace,
context_wrapper=context_wrapper,
)
# Kick off the actual agent loop in the background and return the streamed result object.
@ -696,6 +698,7 @@ class Runner:
usage=usage,
response_id=event.response.id,
)
context_wrapper.usage.add(usage)
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

View file

@ -3,7 +3,8 @@ from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
from openai.types.responses import Response, ResponseCompletedEvent
from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from agents.agent_output import AgentOutputSchemaBase
from agents.handoffs import Handoff
@ -33,6 +34,10 @@ class FakeModel(Model):
)
self.tracing_enabled = tracing_enabled
self.last_turn_args: dict[str, Any] = {}
self.hardcoded_usage: Usage | None = None
def set_hardcoded_usage(self, usage: Usage):
self.hardcoded_usage = usage
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
self.turn_outputs.append(output)
@ -83,7 +88,7 @@ class FakeModel(Model):
return ModelResponse(
output=output,
usage=Usage(),
usage=self.hardcoded_usage or Usage(),
response_id=None,
)
@ -123,13 +128,14 @@ class FakeModel(Model):
yield ResponseCompletedEvent(
type="response.completed",
response=get_response_obj(output),
response=get_response_obj(output, usage=self.hardcoded_usage),
)
def get_response_obj(
output: list[TResponseOutputItem],
response_id: str | None = None,
usage: Usage | None = None,
) -> Response:
return Response(
id=response_id or "123",
@ -141,4 +147,11 @@ def get_response_obj(
tools=[],
top_p=None,
parallel_tool_calls=False,
usage=ResponseUsage(
input_tokens=usage.input_tokens if usage else 0,
output_tokens=usage.output_tokens if usage else 0,
total_tokens=usage.total_tokens if usage else 0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
),
)

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest
from pydantic import BaseModel
from agents import Agent, RunResult
from agents import Agent, RunContextWrapper, RunResult
def create_run_result(final_output: Any) -> RunResult:
@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult:
input_guardrail_results=[],
output_guardrail_results=[],
_last_agent=Agent(name="test"),
context_wrapper=RunContextWrapper(context=None),
)