Don't send the "store" param unless its hitting OpenAI (#455)
Summary: See #443. Causes issues with Gemini. Test Plan: Tests. Also tested with Gemini to ensure it works.
This commit is contained in:
parent
50bbfdd8be
commit
2bcc864b81
3 changed files with 50 additions and 9 deletions
|
|
@ -518,10 +518,8 @@ class OpenAIChatCompletionsModel(Model):
|
|||
f"Response format: {response_format}\n"
|
||||
)
|
||||
|
||||
# Match the behavior of Responses where store is True when not given
|
||||
store = model_settings.store if model_settings.store is not None else True
|
||||
|
||||
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
||||
store = _Converter.get_store_param(self._get_client(), model_settings)
|
||||
|
||||
ret = await self._get_client().chat.completions.create(
|
||||
model=self.model,
|
||||
|
|
@ -537,10 +535,10 @@ class OpenAIChatCompletionsModel(Model):
|
|||
parallel_tool_calls=parallel_tool_calls,
|
||||
stream=stream,
|
||||
stream_options={"include_usage": True} if stream else NOT_GIVEN,
|
||||
store=store,
|
||||
store=self._non_null_or_not_given(store),
|
||||
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
||||
extra_headers=_HEADERS,
|
||||
metadata=model_settings.metadata,
|
||||
metadata=self._non_null_or_not_given(model_settings.metadata),
|
||||
)
|
||||
|
||||
if isinstance(ret, ChatCompletion):
|
||||
|
|
@ -570,6 +568,12 @@ class OpenAIChatCompletionsModel(Model):
|
|||
|
||||
|
||||
class _Converter:
|
||||
@classmethod
|
||||
def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None:
|
||||
# Match the behavior of Responses where store is True when not given
|
||||
default_store = True if str(client.base_url).startswith("https://api.openai.com") else None
|
||||
return model_settings.store if model_settings.store is not None else default_store
|
||||
|
||||
@classmethod
|
||||
def convert_tool_choice(
|
||||
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any
|
|||
|
||||
import httpx
|
||||
import pytest
|
||||
from openai import NOT_GIVEN
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
|
@ -31,6 +31,7 @@ from agents import (
|
|||
generation_span,
|
||||
)
|
||||
from agents.models.fake_id import FAKE_RESPONSES_ID
|
||||
from agents.models.openai_chatcompletions import _Converter
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
@ -226,7 +227,7 @@ async def test_fetch_response_non_stream(monkeypatch) -> None:
|
|||
# Ensure expected args were passed through to OpenAI client.
|
||||
kwargs = completions.kwargs
|
||||
assert kwargs["stream"] is False
|
||||
assert kwargs["store"] is True
|
||||
assert kwargs["store"] is NOT_GIVEN
|
||||
assert kwargs["model"] == "gpt-4"
|
||||
assert kwargs["messages"][0]["role"] == "system"
|
||||
assert kwargs["messages"][0]["content"] == "sys"
|
||||
|
|
@ -280,7 +281,7 @@ async def test_fetch_response_stream(monkeypatch) -> None:
|
|||
)
|
||||
# Check OpenAI client was called for streaming
|
||||
assert completions.kwargs["stream"] is True
|
||||
assert completions.kwargs["store"] is True
|
||||
assert completions.kwargs["store"] is NOT_GIVEN
|
||||
assert completions.kwargs["stream_options"] == {"include_usage": True}
|
||||
# Response is a proper openai Response
|
||||
assert isinstance(response, Response)
|
||||
|
|
@ -290,3 +291,39 @@ async def test_fetch_response_stream(monkeypatch) -> None:
|
|||
assert response.output == []
|
||||
# We returned the async iterator produced by our dummy.
|
||||
assert hasattr(stream, "__aiter__")
|
||||
|
||||
|
||||
def test_store_param():
|
||||
"""Should default to True for OpenAI API calls, and False otherwise."""
|
||||
|
||||
model_settings = ModelSettings()
|
||||
client = AsyncOpenAI()
|
||||
assert _Converter.get_store_param(client, model_settings) is True, (
|
||||
"Should default to True for OpenAI API calls"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=False)
|
||||
assert _Converter.get_store_param(client, model_settings) is False, (
|
||||
"Should respect explicitly set store=False"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=True)
|
||||
assert _Converter.get_store_param(client, model_settings) is True, (
|
||||
"Should respect explicitly set store=True"
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(base_url="http://www.notopenai.com")
|
||||
model_settings = ModelSettings()
|
||||
assert _Converter.get_store_param(client, model_settings) is None, (
|
||||
"Should default to None for non-OpenAI API calls"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=False)
|
||||
assert _Converter.get_store_param(client, model_settings) is False, (
|
||||
"Should respect explicitly set store=False"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=True)
|
||||
assert _Converter.get_store_param(client, model_settings) is True, (
|
||||
"Should respect explicitly set store=True"
|
||||
)
|
||||
|
|
|
|||
2
uv.lock
2
uv.lock
|
|
@ -1087,7 +1087,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "openai-agents"
|
||||
version = "0.0.7"
|
||||
version = "0.0.8"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "griffe" },
|
||||
|
|
|
|||
Loading…
Reference in a new issue