Don't send the "store" param unless its hitting OpenAI (#455)
Summary: See #443. Causes issues with Gemini. Test Plan: Tests. Also tested with Gemini to ensure it works.
This commit is contained in:
parent
50bbfdd8be
commit
2bcc864b81
3 changed files with 50 additions and 9 deletions
|
|
@ -518,10 +518,8 @@ class OpenAIChatCompletionsModel(Model):
|
||||||
f"Response format: {response_format}\n"
|
f"Response format: {response_format}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Match the behavior of Responses where store is True when not given
|
|
||||||
store = model_settings.store if model_settings.store is not None else True
|
|
||||||
|
|
||||||
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
||||||
|
store = _Converter.get_store_param(self._get_client(), model_settings)
|
||||||
|
|
||||||
ret = await self._get_client().chat.completions.create(
|
ret = await self._get_client().chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
|
@ -537,10 +535,10 @@ class OpenAIChatCompletionsModel(Model):
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stream_options={"include_usage": True} if stream else NOT_GIVEN,
|
stream_options={"include_usage": True} if stream else NOT_GIVEN,
|
||||||
store=store,
|
store=self._non_null_or_not_given(store),
|
||||||
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
||||||
extra_headers=_HEADERS,
|
extra_headers=_HEADERS,
|
||||||
metadata=model_settings.metadata,
|
metadata=self._non_null_or_not_given(model_settings.metadata),
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(ret, ChatCompletion):
|
if isinstance(ret, ChatCompletion):
|
||||||
|
|
@ -570,6 +568,12 @@ class OpenAIChatCompletionsModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class _Converter:
|
class _Converter:
|
||||||
|
@classmethod
|
||||||
|
def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None:
|
||||||
|
# Match the behavior of Responses where store is True when not given
|
||||||
|
default_store = True if str(client.base_url).startswith("https://api.openai.com") else None
|
||||||
|
return model_settings.store if model_settings.store is not None else default_store
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_tool_choice(
|
def convert_tool_choice(
|
||||||
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from openai import NOT_GIVEN
|
from openai import NOT_GIVEN, AsyncOpenAI
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||||
|
|
@ -31,6 +31,7 @@ from agents import (
|
||||||
generation_span,
|
generation_span,
|
||||||
)
|
)
|
||||||
from agents.models.fake_id import FAKE_RESPONSES_ID
|
from agents.models.fake_id import FAKE_RESPONSES_ID
|
||||||
|
from agents.models.openai_chatcompletions import _Converter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.allow_call_model_methods
|
@pytest.mark.allow_call_model_methods
|
||||||
|
|
@ -226,7 +227,7 @@ async def test_fetch_response_non_stream(monkeypatch) -> None:
|
||||||
# Ensure expected args were passed through to OpenAI client.
|
# Ensure expected args were passed through to OpenAI client.
|
||||||
kwargs = completions.kwargs
|
kwargs = completions.kwargs
|
||||||
assert kwargs["stream"] is False
|
assert kwargs["stream"] is False
|
||||||
assert kwargs["store"] is True
|
assert kwargs["store"] is NOT_GIVEN
|
||||||
assert kwargs["model"] == "gpt-4"
|
assert kwargs["model"] == "gpt-4"
|
||||||
assert kwargs["messages"][0]["role"] == "system"
|
assert kwargs["messages"][0]["role"] == "system"
|
||||||
assert kwargs["messages"][0]["content"] == "sys"
|
assert kwargs["messages"][0]["content"] == "sys"
|
||||||
|
|
@ -280,7 +281,7 @@ async def test_fetch_response_stream(monkeypatch) -> None:
|
||||||
)
|
)
|
||||||
# Check OpenAI client was called for streaming
|
# Check OpenAI client was called for streaming
|
||||||
assert completions.kwargs["stream"] is True
|
assert completions.kwargs["stream"] is True
|
||||||
assert completions.kwargs["store"] is True
|
assert completions.kwargs["store"] is NOT_GIVEN
|
||||||
assert completions.kwargs["stream_options"] == {"include_usage": True}
|
assert completions.kwargs["stream_options"] == {"include_usage": True}
|
||||||
# Response is a proper openai Response
|
# Response is a proper openai Response
|
||||||
assert isinstance(response, Response)
|
assert isinstance(response, Response)
|
||||||
|
|
@ -290,3 +291,39 @@ async def test_fetch_response_stream(monkeypatch) -> None:
|
||||||
assert response.output == []
|
assert response.output == []
|
||||||
# We returned the async iterator produced by our dummy.
|
# We returned the async iterator produced by our dummy.
|
||||||
assert hasattr(stream, "__aiter__")
|
assert hasattr(stream, "__aiter__")
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_param():
|
||||||
|
"""Should default to True for OpenAI API calls, and False otherwise."""
|
||||||
|
|
||||||
|
model_settings = ModelSettings()
|
||||||
|
client = AsyncOpenAI()
|
||||||
|
assert _Converter.get_store_param(client, model_settings) is True, (
|
||||||
|
"Should default to True for OpenAI API calls"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(store=False)
|
||||||
|
assert _Converter.get_store_param(client, model_settings) is False, (
|
||||||
|
"Should respect explicitly set store=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(store=True)
|
||||||
|
assert _Converter.get_store_param(client, model_settings) is True, (
|
||||||
|
"Should respect explicitly set store=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncOpenAI(base_url="http://www.notopenai.com")
|
||||||
|
model_settings = ModelSettings()
|
||||||
|
assert _Converter.get_store_param(client, model_settings) is None, (
|
||||||
|
"Should default to None for non-OpenAI API calls"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(store=False)
|
||||||
|
assert _Converter.get_store_param(client, model_settings) is False, (
|
||||||
|
"Should respect explicitly set store=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(store=True)
|
||||||
|
assert _Converter.get_store_param(client, model_settings) is True, (
|
||||||
|
"Should respect explicitly set store=True"
|
||||||
|
)
|
||||||
|
|
|
||||||
2
uv.lock
2
uv.lock
|
|
@ -1087,7 +1087,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai-agents"
|
name = "openai-agents"
|
||||||
version = "0.0.7"
|
version = "0.0.8"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "griffe" },
|
{ name = "griffe" },
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue