Dev/add usage details to Usage class (#726)
PR to enhance the `Usage` object and related logic, to support more granular token accounting, matching the details available in the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) . Specifically, it: - Adds `input_tokens_details` and `output_tokens_details` fields to the `Usage` dataclass, storing detailed token breakdowns (e.g., `cached_tokens`, `reasoning_tokens`). - Flows this change through - Updates and extends tests to match - Adds a test for the Usage.add method ### Motivation - Aligns the SDK’s usage with the latest OpenAI responses API Usage object - Supports downstream use cases that require fine-grained token usage data (e.g., billing, analytics, optimization) requested by startups --------- Co-authored-by: Wulfie Bain <wulfie@openai.com>
This commit is contained in:
parent
428c9a65bf
commit
466b44df18
11 changed files with 178 additions and 15 deletions
|
|
@ -6,6 +6,7 @@ from collections.abc import AsyncIterator
|
|||
from typing import Any, Literal, cast, overload
|
||||
|
||||
import litellm.types
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from agents.exceptions import ModelBehaviorError
|
||||
|
||||
|
|
@ -107,6 +108,16 @@ class LitellmModel(Model):
|
|||
input_tokens=response_usage.prompt_tokens,
|
||||
output_tokens=response_usage.completion_tokens,
|
||||
total_tokens=response_usage.total_tokens,
|
||||
input_tokens_details=InputTokensDetails(
|
||||
cached_tokens=getattr(
|
||||
response_usage.prompt_tokens_details, "cached_tokens", 0
|
||||
)
|
||||
),
|
||||
output_tokens_details=OutputTokensDetails(
|
||||
reasoning_tokens=getattr(
|
||||
response_usage.completion_tokens_details, "reasoning_tokens", 0
|
||||
)
|
||||
),
|
||||
)
|
||||
if response.usage
|
||||
else Usage()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|||
from openai.types import ChatModel
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.responses import Response
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from .. import _debug
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
|
|
@ -83,6 +84,18 @@ class OpenAIChatCompletionsModel(Model):
|
|||
input_tokens=response.usage.prompt_tokens,
|
||||
output_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
input_tokens_details=InputTokensDetails(
|
||||
cached_tokens=getattr(
|
||||
response.usage.prompt_tokens_details, "cached_tokens", 0
|
||||
)
|
||||
or 0,
|
||||
),
|
||||
output_tokens_details=OutputTokensDetails(
|
||||
reasoning_tokens=getattr(
|
||||
response.usage.completion_tokens_details, "reasoning_tokens", 0
|
||||
)
|
||||
or 0,
|
||||
),
|
||||
)
|
||||
if response.usage
|
||||
else Usage()
|
||||
|
|
@ -252,7 +265,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
stream_options=self._non_null_or_not_given(stream_options),
|
||||
store=self._non_null_or_not_given(store),
|
||||
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
||||
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
|
||||
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
|
||||
extra_query=model_settings.extra_query,
|
||||
extra_body=model_settings.extra_body,
|
||||
metadata=self._non_null_or_not_given(model_settings.metadata),
|
||||
|
|
|
|||
|
|
@ -98,6 +98,8 @@ class OpenAIResponsesModel(Model):
|
|||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
input_tokens_details=response.usage.input_tokens_details,
|
||||
output_tokens_details=response.usage.output_tokens_details,
|
||||
)
|
||||
if response.usage
|
||||
else Usage()
|
||||
|
|
|
|||
|
|
@ -689,6 +689,8 @@ class Runner:
|
|||
input_tokens=event.response.usage.input_tokens,
|
||||
output_tokens=event.response.usage.output_tokens,
|
||||
total_tokens=event.response.usage.total_tokens,
|
||||
input_tokens_details=event.response.usage.input_tokens_details,
|
||||
output_tokens_details=event.response.usage.output_tokens_details,
|
||||
)
|
||||
if event.response.usage
|
||||
else Usage()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -9,9 +11,18 @@ class Usage:
|
|||
input_tokens: int = 0
|
||||
"""Total input tokens sent, across all requests."""
|
||||
|
||||
input_tokens_details: InputTokensDetails = field(
|
||||
default_factory=lambda: InputTokensDetails(cached_tokens=0)
|
||||
)
|
||||
"""Details about the input tokens, matching responses API usage details."""
|
||||
output_tokens: int = 0
|
||||
"""Total output tokens received, across all requests."""
|
||||
|
||||
output_tokens_details: OutputTokensDetails = field(
|
||||
default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)
|
||||
)
|
||||
"""Details about the output tokens, matching responses API usage details."""
|
||||
|
||||
total_tokens: int = 0
|
||||
"""Total tokens sent and received, across all requests."""
|
||||
|
||||
|
|
@ -20,3 +31,12 @@ class Usage:
|
|||
self.input_tokens += other.input_tokens if other.input_tokens else 0
|
||||
self.output_tokens += other.output_tokens if other.output_tokens else 0
|
||||
self.total_tokens += other.total_tokens if other.total_tokens else 0
|
||||
self.input_tokens_details = InputTokensDetails(
|
||||
cached_tokens=self.input_tokens_details.cached_tokens
|
||||
+ other.input_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
self.output_tokens_details = OutputTokensDetails(
|
||||
reasoning_tokens=self.output_tokens_details.reasoning_tokens
|
||||
+ other.output_tokens_details.reasoning_tokens
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.types.completion_usage import (
|
||||
CompletionTokensDetails,
|
||||
CompletionUsage,
|
||||
PromptTokensDetails,
|
||||
)
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseFunctionToolCall,
|
||||
|
|
@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
|
|||
model="fake",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))],
|
||||
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=5,
|
||||
prompt_tokens=7,
|
||||
total_tokens=12,
|
||||
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2),
|
||||
prompt_tokens_details=PromptTokensDetails(cached_tokens=6),
|
||||
),
|
||||
)
|
||||
|
||||
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
|
||||
|
|
@ -112,6 +122,8 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
|
|||
assert completed_resp.usage.input_tokens == 7
|
||||
assert completed_resp.usage.output_tokens == 5
|
||||
assert completed_resp.usage.total_tokens == 12
|
||||
assert completed_resp.usage.input_tokens_details.cached_tokens == 6
|
||||
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel
|
||||
|
||||
|
|
@ -17,21 +18,29 @@ async def test_extra_headers_passed_to_openai_responses_model():
|
|||
async def create(self, **kwargs):
|
||||
nonlocal called_kwargs
|
||||
called_kwargs = kwargs
|
||||
|
||||
class DummyResponse:
|
||||
id = "dummy"
|
||||
output = []
|
||||
usage = type(
|
||||
"Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
"Usage",
|
||||
(),
|
||||
{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"input_tokens_details": InputTokensDetails(cached_tokens=0),
|
||||
"output_tokens_details": OutputTokensDetails(reasoning_tokens=0),
|
||||
},
|
||||
)()
|
||||
|
||||
return DummyResponse()
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self.responses = DummyResponses()
|
||||
|
||||
|
||||
|
||||
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
|
||||
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
|
||||
extra_headers = {"X-Test-Header": "test-value"}
|
||||
await model.get_response(
|
||||
system_instructions=None,
|
||||
|
|
@ -47,7 +56,6 @@ async def test_extra_headers_passed_to_openai_responses_model():
|
|||
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_headers_passed_to_openai_client():
|
||||
|
|
@ -76,7 +84,7 @@ async def test_extra_headers_passed_to_openai_client():
|
|||
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
|
||||
self.base_url = "https://api.openai.com"
|
||||
|
||||
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
|
||||
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
|
||||
extra_headers = {"X-Test-Header": "test-value"}
|
||||
await model.get_response(
|
||||
system_instructions=None,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@ from openai.types.chat.chat_completion_message_tool_call import (
|
|||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.types.completion_usage import (
|
||||
CompletionUsage,
|
||||
PromptTokensDetails,
|
||||
)
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseFunctionToolCall,
|
||||
|
|
@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None:
|
|||
model="fake",
|
||||
object="chat.completion",
|
||||
choices=[choice],
|
||||
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=5,
|
||||
prompt_tokens=7,
|
||||
total_tokens=12,
|
||||
# completion_tokens_details left blank to test default
|
||||
prompt_tokens_details=PromptTokensDetails(cached_tokens=3),
|
||||
),
|
||||
)
|
||||
|
||||
async def patched_fetch_response(self, *args, **kwargs):
|
||||
|
|
@ -81,6 +90,8 @@ async def test_get_response_with_text_message(monkeypatch) -> None:
|
|||
assert resp.usage.input_tokens == 7
|
||||
assert resp.usage.output_tokens == 5
|
||||
assert resp.usage.total_tokens == 12
|
||||
assert resp.usage.input_tokens_details.cached_tokens == 3
|
||||
assert resp.usage.output_tokens_details.reasoning_tokens == 0
|
||||
assert resp.response_id is None
|
||||
|
||||
|
||||
|
|
@ -127,6 +138,8 @@ async def test_get_response_with_refusal(monkeypatch) -> None:
|
|||
assert resp.usage.requests == 0
|
||||
assert resp.usage.input_tokens == 0
|
||||
assert resp.usage.output_tokens == 0
|
||||
assert resp.usage.input_tokens_details.cached_tokens == 0
|
||||
assert resp.usage.output_tokens_details.reasoning_tokens == 0
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.types.completion_usage import (
|
||||
CompletionTokensDetails,
|
||||
CompletionUsage,
|
||||
PromptTokensDetails,
|
||||
)
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseFunctionToolCall,
|
||||
|
|
@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
|
|||
model="fake",
|
||||
object="chat.completion.chunk",
|
||||
choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))],
|
||||
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=5,
|
||||
prompt_tokens=7,
|
||||
total_tokens=12,
|
||||
prompt_tokens_details=PromptTokensDetails(cached_tokens=2),
|
||||
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3),
|
||||
),
|
||||
)
|
||||
|
||||
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
|
||||
|
|
@ -112,6 +122,8 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
|
|||
assert completed_resp.usage.input_tokens == 7
|
||||
assert completed_resp.usage.output_tokens == 5
|
||||
assert completed_resp.usage.total_tokens == 12
|
||||
assert completed_resp.usage.input_tokens_details.cached_tokens == 2
|
||||
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.responses import ResponseCompletedEvent
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
|
||||
from agents.tracing.span_data import ResponseSpanData
|
||||
|
|
@ -16,10 +19,25 @@ class DummyTracing:
|
|||
|
||||
|
||||
class DummyUsage:
|
||||
def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2):
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 1,
|
||||
input_tokens_details: Optional[InputTokensDetails] = None,
|
||||
output_tokens: int = 1,
|
||||
output_tokens_details: Optional[OutputTokensDetails] = None,
|
||||
total_tokens: int = 2,
|
||||
):
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.total_tokens = total_tokens
|
||||
self.input_tokens_details = (
|
||||
input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0)
|
||||
)
|
||||
self.output_tokens_details = (
|
||||
output_tokens_details
|
||||
if output_tokens_details
|
||||
else OutputTokensDetails(reasoning_tokens=0)
|
||||
)
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
|
|
|
|||
52
tests/test_usage.py
Normal file
52
tests/test_usage.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from agents.usage import Usage
|
||||
|
||||
|
||||
def test_usage_add_aggregates_all_fields():
|
||||
u1 = Usage(
|
||||
requests=1,
|
||||
input_tokens=10,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=3),
|
||||
output_tokens=20,
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=5),
|
||||
total_tokens=30,
|
||||
)
|
||||
u2 = Usage(
|
||||
requests=2,
|
||||
input_tokens=7,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=4),
|
||||
output_tokens=8,
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=6),
|
||||
total_tokens=15,
|
||||
)
|
||||
|
||||
u1.add(u2)
|
||||
|
||||
assert u1.requests == 3
|
||||
assert u1.input_tokens == 17
|
||||
assert u1.output_tokens == 28
|
||||
assert u1.total_tokens == 45
|
||||
assert u1.input_tokens_details.cached_tokens == 7
|
||||
assert u1.output_tokens_details.reasoning_tokens == 11
|
||||
|
||||
|
||||
def test_usage_add_aggregates_with_none_values():
|
||||
u1 = Usage()
|
||||
u2 = Usage(
|
||||
requests=2,
|
||||
input_tokens=7,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=4),
|
||||
output_tokens=8,
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=6),
|
||||
total_tokens=15,
|
||||
)
|
||||
|
||||
u1.add(u2)
|
||||
|
||||
assert u1.requests == 2
|
||||
assert u1.input_tokens == 7
|
||||
assert u1.output_tokens == 8
|
||||
assert u1.total_tokens == 15
|
||||
assert u1.input_tokens_details.cached_tokens == 4
|
||||
assert u1.output_tokens_details.reasoning_tokens == 6
|
||||
Loading…
Reference in a new issue