Dev/add usage details to Usage class (#726)

PR to enhance the `Usage` object and related logic, to support more
granular token accounting, matching the details available in the [OpenAI
Responses API](https://platform.openai.com/docs/api-reference/responses)
. Specifically, it:

- Adds `input_tokens_details` and `output_tokens_details` fields to the
`Usage` dataclass, storing detailed token breakdowns (e.g.,
`cached_tokens`, `reasoning_tokens`).
- Flows this change through
- Updates and extends tests to match
- Adds a test for the Usage.add method

### Motivation
- Aligns the SDK’s usage with the latest OpenAI responses API Usage
object
- Supports downstream use cases that require fine-grained token usage
data (e.g., billing, analytics, optimization) requested by startups

---------

Co-authored-by: Wulfie Bain <wulfie@openai.com>
This commit is contained in:
WJPBProjects 2025-05-20 18:23:56 +01:00 committed by GitHub
parent 428c9a65bf
commit 466b44df18
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 178 additions and 15 deletions

View file

@ -6,6 +6,7 @@ from collections.abc import AsyncIterator
from typing import Any, Literal, cast, overload
import litellm.types
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from agents.exceptions import ModelBehaviorError
@ -107,6 +108,16 @@ class LitellmModel(Model):
input_tokens=response_usage.prompt_tokens,
output_tokens=response_usage.completion_tokens,
total_tokens=response_usage.total_tokens,
input_tokens_details=InputTokensDetails(
cached_tokens=getattr(
response_usage.prompt_tokens_details, "cached_tokens", 0
)
),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr(
response_usage.completion_tokens_details, "reasoning_tokens", 0
)
),
)
if response.usage
else Usage()

View file

@ -9,6 +9,7 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
from openai.types import ChatModel
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.responses import Response
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from .. import _debug
from ..agent_output import AgentOutputSchemaBase
@ -83,6 +84,18 @@ class OpenAIChatCompletionsModel(Model):
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
input_tokens_details=InputTokensDetails(
cached_tokens=getattr(
response.usage.prompt_tokens_details, "cached_tokens", 0
)
or 0,
),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr(
response.usage.completion_tokens_details, "reasoning_tokens", 0
)
or 0,
),
)
if response.usage
else Usage()
@ -252,7 +265,7 @@ class OpenAIChatCompletionsModel(Model):
stream_options=self._non_null_or_not_given(stream_options),
store=self._non_null_or_not_given(store),
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
metadata=self._non_null_or_not_given(model_settings.metadata),

View file

@ -98,6 +98,8 @@ class OpenAIResponsesModel(Model):
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
total_tokens=response.usage.total_tokens,
input_tokens_details=response.usage.input_tokens_details,
output_tokens_details=response.usage.output_tokens_details,
)
if response.usage
else Usage()

View file

@ -689,6 +689,8 @@ class Runner:
input_tokens=event.response.usage.input_tokens,
output_tokens=event.response.usage.output_tokens,
total_tokens=event.response.usage.total_tokens,
input_tokens_details=event.response.usage.input_tokens_details,
output_tokens_details=event.response.usage.output_tokens_details,
)
if event.response.usage
else Usage()

View file

@ -1,4 +1,6 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
@dataclass
@ -9,9 +11,18 @@ class Usage:
input_tokens: int = 0
"""Total input tokens sent, across all requests."""
input_tokens_details: InputTokensDetails = field(
default_factory=lambda: InputTokensDetails(cached_tokens=0)
)
"""Details about the input tokens, matching responses API usage details."""
output_tokens: int = 0
"""Total output tokens received, across all requests."""
output_tokens_details: OutputTokensDetails = field(
default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)
)
"""Details about the output tokens, matching responses API usage details."""
total_tokens: int = 0
"""Total tokens sent and received, across all requests."""
@ -20,3 +31,12 @@ class Usage:
self.input_tokens += other.input_tokens if other.input_tokens else 0
self.output_tokens += other.output_tokens if other.output_tokens else 0
self.total_tokens += other.total_tokens if other.total_tokens else 0
self.input_tokens_details = InputTokensDetails(
cached_tokens=self.input_tokens_details.cached_tokens
+ other.input_tokens_details.cached_tokens
)
self.output_tokens_details = OutputTokensDetails(
reasoning_tokens=self.output_tokens_details.reasoning_tokens
+ other.output_tokens_details.reasoning_tokens
)

View file

@ -8,7 +8,11 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.completion_usage import CompletionUsage
from openai.types.completion_usage import (
CompletionTokensDetails,
CompletionUsage,
PromptTokensDetails,
)
from openai.types.responses import (
Response,
ResponseFunctionToolCall,
@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))],
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
usage=CompletionUsage(
completion_tokens=5,
prompt_tokens=7,
total_tokens=12,
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2),
prompt_tokens_details=PromptTokensDetails(cached_tokens=6),
),
)
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
@ -112,6 +122,8 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
assert completed_resp.usage.input_tokens == 7
assert completed_resp.usage.output_tokens == 5
assert completed_resp.usage.total_tokens == 12
assert completed_resp.usage.input_tokens_details.cached_tokens == 6
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2
@pytest.mark.allow_call_model_methods

View file

@ -1,6 +1,7 @@
import pytest
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel
@ -17,21 +18,29 @@ async def test_extra_headers_passed_to_openai_responses_model():
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
class DummyResponse:
id = "dummy"
output = []
usage = type(
"Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
"Usage",
(),
{
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"input_tokens_details": InputTokensDetails(cached_tokens=0),
"output_tokens_details": OutputTokensDetails(reasoning_tokens=0),
},
)()
return DummyResponse()
class DummyClient:
def __init__(self):
self.responses = DummyResponses()
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
@ -47,7 +56,6 @@ async def test_extra_headers_passed_to_openai_responses_model():
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_headers_passed_to_openai_client():
@ -76,7 +84,7 @@ async def test_extra_headers_passed_to_openai_client():
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
self.base_url = "https://api.openai.com"
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,

View file

@ -13,7 +13,10 @@ from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from openai.types.completion_usage import CompletionUsage
from openai.types.completion_usage import (
CompletionUsage,
PromptTokensDetails,
)
from openai.types.responses import (
Response,
ResponseFunctionToolCall,
@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None:
model="fake",
object="chat.completion",
choices=[choice],
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
usage=CompletionUsage(
completion_tokens=5,
prompt_tokens=7,
total_tokens=12,
# completion_tokens_details left blank to test default
prompt_tokens_details=PromptTokensDetails(cached_tokens=3),
),
)
async def patched_fetch_response(self, *args, **kwargs):
@ -81,6 +90,8 @@ async def test_get_response_with_text_message(monkeypatch) -> None:
assert resp.usage.input_tokens == 7
assert resp.usage.output_tokens == 5
assert resp.usage.total_tokens == 12
assert resp.usage.input_tokens_details.cached_tokens == 3
assert resp.usage.output_tokens_details.reasoning_tokens == 0
assert resp.response_id is None
@ -127,6 +138,8 @@ async def test_get_response_with_refusal(monkeypatch) -> None:
assert resp.usage.requests == 0
assert resp.usage.input_tokens == 0
assert resp.usage.output_tokens == 0
assert resp.usage.input_tokens_details.cached_tokens == 0
assert resp.usage.output_tokens_details.reasoning_tokens == 0
@pytest.mark.allow_call_model_methods

View file

@ -8,7 +8,11 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.completion_usage import CompletionUsage
from openai.types.completion_usage import (
CompletionTokensDetails,
CompletionUsage,
PromptTokensDetails,
)
from openai.types.responses import (
Response,
ResponseFunctionToolCall,
@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))],
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
usage=CompletionUsage(
completion_tokens=5,
prompt_tokens=7,
total_tokens=12,
prompt_tokens_details=PromptTokensDetails(cached_tokens=2),
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3),
),
)
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
@ -112,6 +122,8 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
assert completed_resp.usage.input_tokens == 7
assert completed_resp.usage.output_tokens == 5
assert completed_resp.usage.total_tokens == 12
assert completed_resp.usage.input_tokens_details.cached_tokens == 2
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3
@pytest.mark.allow_call_model_methods

View file

@ -1,7 +1,10 @@
from typing import Optional
import pytest
from inline_snapshot import snapshot
from openai import AsyncOpenAI
from openai.types.responses import ResponseCompletedEvent
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
from agents.tracing.span_data import ResponseSpanData
@ -16,10 +19,25 @@ class DummyTracing:
class DummyUsage:
def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2):
def __init__(
self,
input_tokens: int = 1,
input_tokens_details: Optional[InputTokensDetails] = None,
output_tokens: int = 1,
output_tokens_details: Optional[OutputTokensDetails] = None,
total_tokens: int = 2,
):
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.total_tokens = total_tokens
self.input_tokens_details = (
input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0)
)
self.output_tokens_details = (
output_tokens_details
if output_tokens_details
else OutputTokensDetails(reasoning_tokens=0)
)
class DummyResponse:

52
tests/test_usage.py Normal file
View file

@ -0,0 +1,52 @@
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from agents.usage import Usage
def test_usage_add_aggregates_all_fields():
u1 = Usage(
requests=1,
input_tokens=10,
input_tokens_details=InputTokensDetails(cached_tokens=3),
output_tokens=20,
output_tokens_details=OutputTokensDetails(reasoning_tokens=5),
total_tokens=30,
)
u2 = Usage(
requests=2,
input_tokens=7,
input_tokens_details=InputTokensDetails(cached_tokens=4),
output_tokens=8,
output_tokens_details=OutputTokensDetails(reasoning_tokens=6),
total_tokens=15,
)
u1.add(u2)
assert u1.requests == 3
assert u1.input_tokens == 17
assert u1.output_tokens == 28
assert u1.total_tokens == 45
assert u1.input_tokens_details.cached_tokens == 7
assert u1.output_tokens_details.reasoning_tokens == 11
def test_usage_add_aggregates_with_none_values():
u1 = Usage()
u2 = Usage(
requests=2,
input_tokens=7,
input_tokens_details=InputTokensDetails(cached_tokens=4),
output_tokens=8,
output_tokens_details=OutputTokensDetails(reasoning_tokens=6),
total_tokens=15,
)
u1.add(u2)
assert u1.requests == 2
assert u1.input_tokens == 7
assert u1.output_tokens == 8
assert u1.total_tokens == 15
assert u1.input_tokens_details.cached_tokens == 4
assert u1.output_tokens_details.reasoning_tokens == 6