Create to_json_dict for ModelSettings (#582)
Now that `ModelSettings` has `Reasoning`, a non-primitive object, `dataclasses.as_dict()` wont work. It will raise an error when you try to serialize (e.g. for tracing). This ensures the object is actually serializable.
This commit is contained in:
parent
a113fea0ee
commit
3755ea8658
7 changed files with 84 additions and 15 deletions
|
|
@ -7,7 +7,7 @@ requires-python = ">=3.9"
|
|||
license = "MIT"
|
||||
authors = [{ name = "OpenAI", email = "support@openai.com" }]
|
||||
dependencies = [
|
||||
"openai>=1.66.5",
|
||||
"openai>=1.76.0",
|
||||
"pydantic>=2.10, <3",
|
||||
"griffe>=1.5.6, <2",
|
||||
"typing-extensions>=4.12.2, <5",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
|
|
@ -75,7 +74,7 @@ class LitellmModel(Model):
|
|||
) -> ModelResponse:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
model_config=dataclasses.asdict(model_settings)
|
||||
model_config=model_settings.to_json_dict()
|
||||
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
||||
disabled=tracing.is_disabled(),
|
||||
) as span_generation:
|
||||
|
|
@ -147,7 +146,7 @@ class LitellmModel(Model):
|
|||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
model_config=dataclasses.asdict(model_settings)
|
||||
model_config=model_settings.to_json_dict()
|
||||
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
||||
disabled=tracing.is_disabled(),
|
||||
) as span_generation:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from openai._types import Body, Headers, Query
|
||||
from openai.types.shared import Reasoning
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -83,3 +85,16 @@ class ModelSettings:
|
|||
if getattr(override, field.name) is not None
|
||||
}
|
||||
return replace(self, **changes)
|
||||
|
||||
def to_json_dict(self) -> dict[str, Any]:
|
||||
dataclass_dict = dataclasses.asdict(self)
|
||||
|
||||
json_dict: dict[str, Any] = {}
|
||||
|
||||
for field_name, value in dataclass_dict.items():
|
||||
if isinstance(value, BaseModel):
|
||||
json_dict[field_name] = value.model_dump(mode="json")
|
||||
else:
|
||||
json_dict[field_name] = value
|
||||
|
||||
return json_dict
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
|
|
@ -56,8 +55,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
) -> ModelResponse:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
model_config=dataclasses.asdict(model_settings)
|
||||
| {"base_url": str(self._client.base_url)},
|
||||
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
|
||||
disabled=tracing.is_disabled(),
|
||||
) as span_generation:
|
||||
response = await self._fetch_response(
|
||||
|
|
@ -121,8 +119,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
"""
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
model_config=dataclasses.asdict(model_settings)
|
||||
| {"base_url": str(self._client.base_url)},
|
||||
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
|
||||
disabled=tracing.is_disabled(),
|
||||
) as span_generation:
|
||||
response, stream = await self._fetch_response(
|
||||
|
|
|
|||
59
tests/model_settings/test_serialization.py
Normal file
59
tests/model_settings/test_serialization.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
from dataclasses import fields
|
||||
|
||||
from openai.types.shared import Reasoning
|
||||
|
||||
from agents.model_settings import ModelSettings
|
||||
|
||||
|
||||
def verify_serialization(model_settings: ModelSettings) -> None:
|
||||
"""Verify that ModelSettings can be serialized to a JSON string."""
|
||||
json_dict = model_settings.to_json_dict()
|
||||
json_string = json.dumps(json_dict)
|
||||
assert json_string is not None
|
||||
|
||||
|
||||
def test_basic_serialization() -> None:
|
||||
"""Tests whether ModelSettings can be serialized to a JSON string."""
|
||||
|
||||
# First, lets create a ModelSettings instance
|
||||
model_settings = ModelSettings(
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Now, lets serialize the ModelSettings instance to a JSON string
|
||||
verify_serialization(model_settings)
|
||||
|
||||
|
||||
def test_all_fields_serialization() -> None:
|
||||
"""Tests whether ModelSettings can be serialized to a JSON string."""
|
||||
|
||||
# First, lets create a ModelSettings instance
|
||||
model_settings = ModelSettings(
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
tool_choice="auto",
|
||||
parallel_tool_calls=True,
|
||||
truncation="auto",
|
||||
max_tokens=100,
|
||||
reasoning=Reasoning(),
|
||||
metadata={"foo": "bar"},
|
||||
store=False,
|
||||
include_usage=False,
|
||||
extra_query={"foo": "bar"},
|
||||
extra_body={"foo": "bar"},
|
||||
extra_headers={"foo": "bar"},
|
||||
)
|
||||
|
||||
# Verify that every single field is set to a non-None value
|
||||
for field in fields(model_settings):
|
||||
assert getattr(model_settings, field.name) is not None, (
|
||||
f"You must set the {field.name} field"
|
||||
)
|
||||
|
||||
# Now, lets serialize the ModelSettings instance to a JSON string
|
||||
verify_serialization(model_settings)
|
||||
|
|
@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config):
|
|||
|
||||
if str(collection_path).startswith(this_dir):
|
||||
return True
|
||||
|
||||
|
|
|
|||
8
uv.lock
8
uv.lock
|
|
@ -1463,7 +1463,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.74.0"
|
||||
version = "1.76.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
|
|
@ -1475,9 +1475,9 @@ dependencies = [
|
|||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/75/86/c605a6e84da0248f2cebfcd864b5a6076ecf78849245af5e11d2a5ec7977/openai-1.74.0.tar.gz", hash = "sha256:592c25b8747a7cad33a841958f5eb859a785caea9ee22b9e4f4a2ec062236526", size = 427571 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/91/8c150f16a96367e14bd7d20e86e0bbbec3080e3eb593e63f21a7f013f8e4/openai-1.74.0-py3-none-any.whl", hash = "sha256:aff3e0f9fb209836382ec112778667027f4fd6ae38bdb2334bc9e173598b092a", size = 644790 },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1538,7 +1538,7 @@ requires-dist = [
|
|||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.65.0,<2" },
|
||||
{ name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
|
||||
{ name = "openai", specifier = ">=1.66.5" },
|
||||
{ name = "openai", specifier = ">=1.76.0" },
|
||||
{ name = "pydantic", specifier = ">=2.10,<3" },
|
||||
{ name = "requests", specifier = ">=2.0,<3" },
|
||||
{ name = "types-requests", specifier = ">=2.0,<3" },
|
||||
|
|
|
|||
Loading…
Reference in a new issue