From 8fd7773a5ef5121d9349edbabebc9522a2f3c4f0 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 18:20:35 -0400 Subject: [PATCH] Add usage to context in streaming (#595) --- src/agents/result.py | 16 +++++++--------- src/agents/run.py | 3 +++ tests/fake_model.py | 19 ++++++++++++++++--- tests/test_result_cast.py | 3 ++- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/agents/result.py b/src/agents/result.py index 1f1c783..243db15 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -15,6 +15,7 @@ from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger +from .run_context import RunContextWrapper from .stream_events import StreamEvent from .tracing import Trace from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming @@ -50,6 +51,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + context_wrapper: RunContextWrapper[Any] + """The context wrapper for the agent run.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -75,9 +79,7 @@ class RunResultBase(abc.ABC): def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( - self.input - ) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -206,17 +208,13 @@ class RunResultStreaming(RunResultBase): def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded( - f"Max turns ({self.max_turns}) exceeded" - ) + self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered( - guardrail_result - ) + self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/src/agents/run.py b/src/agents/run.py index 2af558d..849da7b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -270,6 +270,7 @@ class Runner: _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + context_wrapper=context_wrapper, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -423,6 +424,7 @@ class Runner: output_guardrail_results=[], _current_agent_output_schema=output_schema, trace=new_trace, + context_wrapper=context_wrapper, ) # Kick off the actual agent loop in the background and return the streamed result object. @@ -696,6 +698,7 @@ class Runner: usage=usage, response_id=event.response.id, ) + context_wrapper.usage.add(usage) streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) diff --git a/tests/fake_model.py b/tests/fake_model.py index da3019a..32f919e 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,7 +3,8 @@ from __future__ import annotations from collections.abc import AsyncIterator from typing import Any -from openai.types.responses import Response, ResponseCompletedEvent +from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff @@ -33,6 +34,10 @@ class FakeModel(Model): ) self.tracing_enabled = tracing_enabled self.last_turn_args: dict[str, Any] = {} + self.hardcoded_usage: Usage | None = None + + def set_hardcoded_usage(self, usage: Usage): + self.hardcoded_usage = usage def set_next_output(self, output: list[TResponseOutputItem] | Exception): self.turn_outputs.append(output) @@ -83,7 +88,7 @@ class FakeModel(Model): return ModelResponse( output=output, - usage=Usage(), + usage=self.hardcoded_usage or Usage(), response_id=None, ) @@ -123,13 +128,14 @@ class FakeModel(Model): yield ResponseCompletedEvent( type="response.completed", - response=get_response_obj(output), + response=get_response_obj(output, usage=self.hardcoded_usage), ) def get_response_obj( output: list[TResponseOutputItem], response_id: str | None = None, + usage: Usage | None = None, ) -> Response: return Response( id=response_id or "123", @@ -141,4 +147,11 @@ def get_response_obj( tools=[], top_p=None, parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=usage.input_tokens if usage else 0, + output_tokens=usage.output_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), ) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index ec17e32..c621e73 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -3,7 +3,7 @@ from typing import Any import pytest from pydantic import BaseModel -from agents import Agent, RunResult +from agents import Agent, RunContextWrapper, RunResult def create_run_result(final_output: Any) -> RunResult: @@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult: input_guardrail_results=[], output_guardrail_results=[], _last_agent=Agent(name="test"), + context_wrapper=RunContextWrapper(context=None), )