Add usage to context in streaming (#595)
This commit is contained in:
parent
3bbc7c48cb
commit
8fd7773a5e
4 changed files with 28 additions and 13 deletions
|
|
@ -15,6 +15,7 @@ from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
|
|||
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
||||
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
||||
from .logger import logger
|
||||
from .run_context import RunContextWrapper
|
||||
from .stream_events import StreamEvent
|
||||
from .tracing import Trace
|
||||
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
|
||||
|
|
@ -50,6 +51,9 @@ class RunResultBase(abc.ABC):
|
|||
output_guardrail_results: list[OutputGuardrailResult]
|
||||
"""Guardrail results for the final output of the agent."""
|
||||
|
||||
context_wrapper: RunContextWrapper[Any]
|
||||
"""The context wrapper for the agent run."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def last_agent(self) -> Agent[Any]:
|
||||
|
|
@ -75,9 +79,7 @@ class RunResultBase(abc.ABC):
|
|||
|
||||
def to_input_list(self) -> list[TResponseInputItem]:
|
||||
"""Creates a new input list, merging the original input with all the new items generated."""
|
||||
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
|
||||
self.input
|
||||
)
|
||||
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
|
||||
new_items = [item.to_input_item() for item in self.new_items]
|
||||
|
||||
return original_items + new_items
|
||||
|
|
@ -206,17 +208,13 @@ class RunResultStreaming(RunResultBase):
|
|||
|
||||
def _check_errors(self):
|
||||
if self.current_turn > self.max_turns:
|
||||
self._stored_exception = MaxTurnsExceeded(
|
||||
f"Max turns ({self.max_turns}) exceeded"
|
||||
)
|
||||
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
|
||||
|
||||
# Fetch all the completed guardrail results from the queue and raise if needed
|
||||
while not self._input_guardrail_queue.empty():
|
||||
guardrail_result = self._input_guardrail_queue.get_nowait()
|
||||
if guardrail_result.output.tripwire_triggered:
|
||||
self._stored_exception = InputGuardrailTripwireTriggered(
|
||||
guardrail_result
|
||||
)
|
||||
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
|
||||
|
||||
# Check the tasks for any exceptions
|
||||
if self._run_impl_task and self._run_impl_task.done():
|
||||
|
|
|
|||
|
|
@ -270,6 +270,7 @@ class Runner:
|
|||
_last_agent=current_agent,
|
||||
input_guardrail_results=input_guardrail_results,
|
||||
output_guardrail_results=output_guardrail_results,
|
||||
context_wrapper=context_wrapper,
|
||||
)
|
||||
elif isinstance(turn_result.next_step, NextStepHandoff):
|
||||
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
|
||||
|
|
@ -423,6 +424,7 @@ class Runner:
|
|||
output_guardrail_results=[],
|
||||
_current_agent_output_schema=output_schema,
|
||||
trace=new_trace,
|
||||
context_wrapper=context_wrapper,
|
||||
)
|
||||
|
||||
# Kick off the actual agent loop in the background and return the streamed result object.
|
||||
|
|
@ -696,6 +698,7 @@ class Runner:
|
|||
usage=usage,
|
||||
response_id=event.response.id,
|
||||
)
|
||||
context_wrapper.usage.add(usage)
|
||||
|
||||
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai.types.responses import Response, ResponseCompletedEvent
|
||||
from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from agents.agent_output import AgentOutputSchemaBase
|
||||
from agents.handoffs import Handoff
|
||||
|
|
@ -33,6 +34,10 @@ class FakeModel(Model):
|
|||
)
|
||||
self.tracing_enabled = tracing_enabled
|
||||
self.last_turn_args: dict[str, Any] = {}
|
||||
self.hardcoded_usage: Usage | None = None
|
||||
|
||||
def set_hardcoded_usage(self, usage: Usage):
|
||||
self.hardcoded_usage = usage
|
||||
|
||||
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
|
||||
self.turn_outputs.append(output)
|
||||
|
|
@ -83,7 +88,7 @@ class FakeModel(Model):
|
|||
|
||||
return ModelResponse(
|
||||
output=output,
|
||||
usage=Usage(),
|
||||
usage=self.hardcoded_usage or Usage(),
|
||||
response_id=None,
|
||||
)
|
||||
|
||||
|
|
@ -123,13 +128,14 @@ class FakeModel(Model):
|
|||
|
||||
yield ResponseCompletedEvent(
|
||||
type="response.completed",
|
||||
response=get_response_obj(output),
|
||||
response=get_response_obj(output, usage=self.hardcoded_usage),
|
||||
)
|
||||
|
||||
|
||||
def get_response_obj(
|
||||
output: list[TResponseOutputItem],
|
||||
response_id: str | None = None,
|
||||
usage: Usage | None = None,
|
||||
) -> Response:
|
||||
return Response(
|
||||
id=response_id or "123",
|
||||
|
|
@ -141,4 +147,11 @@ def get_response_obj(
|
|||
tools=[],
|
||||
top_p=None,
|
||||
parallel_tool_calls=False,
|
||||
usage=ResponseUsage(
|
||||
input_tokens=usage.input_tokens if usage else 0,
|
||||
output_tokens=usage.output_tokens if usage else 0,
|
||||
total_tokens=usage.total_tokens if usage else 0,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=0),
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent, RunResult
|
||||
from agents import Agent, RunContextWrapper, RunResult
|
||||
|
||||
|
||||
def create_run_result(final_output: Any) -> RunResult:
|
||||
|
|
@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult:
|
|||
input_guardrail_results=[],
|
||||
output_guardrail_results=[],
|
||||
_last_agent=Agent(name="test"),
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue