Create to_json_dict for ModelSettings (#582)

Now that `ModelSettings` has `Reasoning`, a non-primitive object,
`dataclasses.as_dict()` wont work. It will raise an error when you try
to serialize (e.g. for tracing). This ensures the object is actually
serializable.
This commit is contained in:
Rohan Mehta 2025-04-23 20:39:07 -04:00 committed by GitHub
parent a113fea0ee
commit 3755ea8658
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 84 additions and 15 deletions

View file

@ -7,7 +7,7 @@ requires-python = ">=3.9"
license = "MIT"
authors = [{ name = "OpenAI", email = "support@openai.com" }]
dependencies = [
"openai>=1.66.5",
"openai>=1.76.0",
"pydantic>=2.10, <3",
"griffe>=1.5.6, <2",
"typing-extensions>=4.12.2, <5",

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import dataclasses
import json
import time
from collections.abc import AsyncIterator
@ -75,7 +74,7 @@ class LitellmModel(Model):
) -> ModelResponse:
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
model_config=model_settings.to_json_dict()
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
disabled=tracing.is_disabled(),
) as span_generation:
@ -147,7 +146,7 @@ class LitellmModel(Model):
) -> AsyncIterator[TResponseStreamEvent]:
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
model_config=model_settings.to_json_dict()
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
disabled=tracing.is_disabled(),
) as span_generation:

View file

@ -1,10 +1,12 @@
from __future__ import annotations
import dataclasses
from dataclasses import dataclass, fields, replace
from typing import Literal
from typing import Any, Literal
from openai._types import Body, Headers, Query
from openai.types.shared import Reasoning
from pydantic import BaseModel
@dataclass
@ -83,3 +85,16 @@ class ModelSettings:
if getattr(override, field.name) is not None
}
return replace(self, **changes)
def to_json_dict(self) -> dict[str, Any]:
dataclass_dict = dataclasses.asdict(self)
json_dict: dict[str, Any] = {}
for field_name, value in dataclass_dict.items():
if isinstance(value, BaseModel):
json_dict[field_name] = value.model_dump(mode="json")
else:
json_dict[field_name] = value
return json_dict

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import dataclasses
import json
import time
from collections.abc import AsyncIterator
@ -56,8 +55,7 @@ class OpenAIChatCompletionsModel(Model):
) -> ModelResponse:
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
| {"base_url": str(self._client.base_url)},
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
disabled=tracing.is_disabled(),
) as span_generation:
response = await self._fetch_response(
@ -121,8 +119,7 @@ class OpenAIChatCompletionsModel(Model):
"""
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
| {"base_url": str(self._client.base_url)},
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
disabled=tracing.is_disabled(),
) as span_generation:
response, stream = await self._fetch_response(

View file

@ -0,0 +1,59 @@
import json
from dataclasses import fields
from openai.types.shared import Reasoning
from agents.model_settings import ModelSettings
def verify_serialization(model_settings: ModelSettings) -> None:
"""Verify that ModelSettings can be serialized to a JSON string."""
json_dict = model_settings.to_json_dict()
json_string = json.dumps(json_dict)
assert json_string is not None
def test_basic_serialization() -> None:
"""Tests whether ModelSettings can be serialized to a JSON string."""
# First, lets create a ModelSettings instance
model_settings = ModelSettings(
temperature=0.5,
top_p=0.9,
max_tokens=100,
)
# Now, lets serialize the ModelSettings instance to a JSON string
verify_serialization(model_settings)
def test_all_fields_serialization() -> None:
"""Tests whether ModelSettings can be serialized to a JSON string."""
# First, lets create a ModelSettings instance
model_settings = ModelSettings(
temperature=0.5,
top_p=0.9,
frequency_penalty=0.0,
presence_penalty=0.0,
tool_choice="auto",
parallel_tool_calls=True,
truncation="auto",
max_tokens=100,
reasoning=Reasoning(),
metadata={"foo": "bar"},
store=False,
include_usage=False,
extra_query={"foo": "bar"},
extra_body={"foo": "bar"},
extra_headers={"foo": "bar"},
)
# Verify that every single field is set to a non-None value
for field in fields(model_settings):
assert getattr(model_settings, field.name) is not None, (
f"You must set the {field.name} field"
)
# Now, lets serialize the ModelSettings instance to a JSON string
verify_serialization(model_settings)

View file

@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config):
if str(collection_path).startswith(this_dir):
return True

View file

@ -1463,7 +1463,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.74.0"
version = "1.76.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -1475,9 +1475,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/75/86/c605a6e84da0248f2cebfcd864b5a6076ecf78849245af5e11d2a5ec7977/openai-1.74.0.tar.gz", hash = "sha256:592c25b8747a7cad33a841958f5eb859a785caea9ee22b9e4f4a2ec062236526", size = 427571 }
sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a9/91/8c150f16a96367e14bd7d20e86e0bbbec3080e3eb593e63f21a7f013f8e4/openai-1.74.0-py3-none-any.whl", hash = "sha256:aff3e0f9fb209836382ec112778667027f4fd6ae38bdb2334bc9e173598b092a", size = 644790 },
{ url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 },
]
[[package]]
@ -1538,7 +1538,7 @@ requires-dist = [
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.65.0,<2" },
{ name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" },
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
{ name = "openai", specifier = ">=1.66.5" },
{ name = "openai", specifier = ">=1.76.0" },
{ name = "pydantic", specifier = ">=2.10,<3" },
{ name = "requests", specifier = ">=2.0,<3" },
{ name = "types-requests", specifier = ">=2.0,<3" },