Add usage to context in streaming (#595)
This commit is contained in:
parent
3bbc7c48cb
commit
8fd7773a5e
4 changed files with 28 additions and 13 deletions
|
|
@ -15,6 +15,7 @@ from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
|
||||||
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
||||||
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
|
from .run_context import RunContextWrapper
|
||||||
from .stream_events import StreamEvent
|
from .stream_events import StreamEvent
|
||||||
from .tracing import Trace
|
from .tracing import Trace
|
||||||
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
|
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
|
||||||
|
|
@ -50,6 +51,9 @@ class RunResultBase(abc.ABC):
|
||||||
output_guardrail_results: list[OutputGuardrailResult]
|
output_guardrail_results: list[OutputGuardrailResult]
|
||||||
"""Guardrail results for the final output of the agent."""
|
"""Guardrail results for the final output of the agent."""
|
||||||
|
|
||||||
|
context_wrapper: RunContextWrapper[Any]
|
||||||
|
"""The context wrapper for the agent run."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def last_agent(self) -> Agent[Any]:
|
def last_agent(self) -> Agent[Any]:
|
||||||
|
|
@ -75,9 +79,7 @@ class RunResultBase(abc.ABC):
|
||||||
|
|
||||||
def to_input_list(self) -> list[TResponseInputItem]:
|
def to_input_list(self) -> list[TResponseInputItem]:
|
||||||
"""Creates a new input list, merging the original input with all the new items generated."""
|
"""Creates a new input list, merging the original input with all the new items generated."""
|
||||||
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
|
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
|
||||||
self.input
|
|
||||||
)
|
|
||||||
new_items = [item.to_input_item() for item in self.new_items]
|
new_items = [item.to_input_item() for item in self.new_items]
|
||||||
|
|
||||||
return original_items + new_items
|
return original_items + new_items
|
||||||
|
|
@ -206,17 +208,13 @@ class RunResultStreaming(RunResultBase):
|
||||||
|
|
||||||
def _check_errors(self):
|
def _check_errors(self):
|
||||||
if self.current_turn > self.max_turns:
|
if self.current_turn > self.max_turns:
|
||||||
self._stored_exception = MaxTurnsExceeded(
|
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
|
||||||
f"Max turns ({self.max_turns}) exceeded"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fetch all the completed guardrail results from the queue and raise if needed
|
# Fetch all the completed guardrail results from the queue and raise if needed
|
||||||
while not self._input_guardrail_queue.empty():
|
while not self._input_guardrail_queue.empty():
|
||||||
guardrail_result = self._input_guardrail_queue.get_nowait()
|
guardrail_result = self._input_guardrail_queue.get_nowait()
|
||||||
if guardrail_result.output.tripwire_triggered:
|
if guardrail_result.output.tripwire_triggered:
|
||||||
self._stored_exception = InputGuardrailTripwireTriggered(
|
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
|
||||||
guardrail_result
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check the tasks for any exceptions
|
# Check the tasks for any exceptions
|
||||||
if self._run_impl_task and self._run_impl_task.done():
|
if self._run_impl_task and self._run_impl_task.done():
|
||||||
|
|
|
||||||
|
|
@ -270,6 +270,7 @@ class Runner:
|
||||||
_last_agent=current_agent,
|
_last_agent=current_agent,
|
||||||
input_guardrail_results=input_guardrail_results,
|
input_guardrail_results=input_guardrail_results,
|
||||||
output_guardrail_results=output_guardrail_results,
|
output_guardrail_results=output_guardrail_results,
|
||||||
|
context_wrapper=context_wrapper,
|
||||||
)
|
)
|
||||||
elif isinstance(turn_result.next_step, NextStepHandoff):
|
elif isinstance(turn_result.next_step, NextStepHandoff):
|
||||||
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
|
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
|
||||||
|
|
@ -423,6 +424,7 @@ class Runner:
|
||||||
output_guardrail_results=[],
|
output_guardrail_results=[],
|
||||||
_current_agent_output_schema=output_schema,
|
_current_agent_output_schema=output_schema,
|
||||||
trace=new_trace,
|
trace=new_trace,
|
||||||
|
context_wrapper=context_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Kick off the actual agent loop in the background and return the streamed result object.
|
# Kick off the actual agent loop in the background and return the streamed result object.
|
||||||
|
|
@ -696,6 +698,7 @@ class Runner:
|
||||||
usage=usage,
|
usage=usage,
|
||||||
response_id=event.response.id,
|
response_id=event.response.id,
|
||||||
)
|
)
|
||||||
|
context_wrapper.usage.add(usage)
|
||||||
|
|
||||||
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
|
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai.types.responses import Response, ResponseCompletedEvent
|
from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
|
||||||
|
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||||
|
|
||||||
from agents.agent_output import AgentOutputSchemaBase
|
from agents.agent_output import AgentOutputSchemaBase
|
||||||
from agents.handoffs import Handoff
|
from agents.handoffs import Handoff
|
||||||
|
|
@ -33,6 +34,10 @@ class FakeModel(Model):
|
||||||
)
|
)
|
||||||
self.tracing_enabled = tracing_enabled
|
self.tracing_enabled = tracing_enabled
|
||||||
self.last_turn_args: dict[str, Any] = {}
|
self.last_turn_args: dict[str, Any] = {}
|
||||||
|
self.hardcoded_usage: Usage | None = None
|
||||||
|
|
||||||
|
def set_hardcoded_usage(self, usage: Usage):
|
||||||
|
self.hardcoded_usage = usage
|
||||||
|
|
||||||
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
|
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
|
||||||
self.turn_outputs.append(output)
|
self.turn_outputs.append(output)
|
||||||
|
|
@ -83,7 +88,7 @@ class FakeModel(Model):
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
output=output,
|
output=output,
|
||||||
usage=Usage(),
|
usage=self.hardcoded_usage or Usage(),
|
||||||
response_id=None,
|
response_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -123,13 +128,14 @@ class FakeModel(Model):
|
||||||
|
|
||||||
yield ResponseCompletedEvent(
|
yield ResponseCompletedEvent(
|
||||||
type="response.completed",
|
type="response.completed",
|
||||||
response=get_response_obj(output),
|
response=get_response_obj(output, usage=self.hardcoded_usage),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_response_obj(
|
def get_response_obj(
|
||||||
output: list[TResponseOutputItem],
|
output: list[TResponseOutputItem],
|
||||||
response_id: str | None = None,
|
response_id: str | None = None,
|
||||||
|
usage: Usage | None = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
return Response(
|
return Response(
|
||||||
id=response_id or "123",
|
id=response_id or "123",
|
||||||
|
|
@ -141,4 +147,11 @@ def get_response_obj(
|
||||||
tools=[],
|
tools=[],
|
||||||
top_p=None,
|
top_p=None,
|
||||||
parallel_tool_calls=False,
|
parallel_tool_calls=False,
|
||||||
|
usage=ResponseUsage(
|
||||||
|
input_tokens=usage.input_tokens if usage else 0,
|
||||||
|
output_tokens=usage.output_tokens if usage else 0,
|
||||||
|
total_tokens=usage.total_tokens if usage else 0,
|
||||||
|
input_tokens_details=InputTokensDetails(cached_tokens=0),
|
||||||
|
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Any
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agents import Agent, RunResult
|
from agents import Agent, RunContextWrapper, RunResult
|
||||||
|
|
||||||
|
|
||||||
def create_run_result(final_output: Any) -> RunResult:
|
def create_run_result(final_output: Any) -> RunResult:
|
||||||
|
|
@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult:
|
||||||
input_guardrail_results=[],
|
input_guardrail_results=[],
|
||||||
output_guardrail_results=[],
|
output_guardrail_results=[],
|
||||||
_last_agent=Agent(name="test"),
|
_last_agent=Agent(name="test"),
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue