add overwrite mechanism for stream_options (#465)
fix issue https://github.com/openai/openai-agents-python/issues/442 below is an example to overwrite include_usage ``` result = Runner.run_streamed( agent, "Write a haiku about recursion in programming.", run_config=RunConfig( model_provider=CUSTOM_MODEL_PROVIDER, model_settings=ModelSettings(include_usage=True) ), ) ```
This commit is contained in:
parent
84fb734ba0
commit
8ded8a9981
3 changed files with 24 additions and 3 deletions
|
|
@ -54,6 +54,10 @@ class ModelSettings:
|
|||
"""Whether to store the generated model response for later retrieval.
|
||||
Defaults to True if not provided."""
|
||||
|
||||
include_usage: bool | None = None
|
||||
"""Whether to include usage chunk.
|
||||
Defaults to True if not provided."""
|
||||
|
||||
def resolve(self, override: ModelSettings | None) -> ModelSettings:
|
||||
"""Produce a new ModelSettings by overlaying any non-None values from the
|
||||
override on top of this instance."""
|
||||
|
|
|
|||
|
|
@ -521,6 +521,8 @@ class OpenAIChatCompletionsModel(Model):
|
|||
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
||||
store = _Converter.get_store_param(self._get_client(), model_settings)
|
||||
|
||||
stream_options = _Converter.get_stream_options_param(self._get_client(), model_settings)
|
||||
|
||||
ret = await self._get_client().chat.completions.create(
|
||||
model=self.model,
|
||||
messages=converted_messages,
|
||||
|
|
@ -534,7 +536,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
response_format=response_format,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
stream=stream,
|
||||
stream_options={"include_usage": True} if stream else NOT_GIVEN,
|
||||
stream_options=self._non_null_or_not_given(stream_options),
|
||||
store=self._non_null_or_not_given(store),
|
||||
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
||||
extra_headers=_HEADERS,
|
||||
|
|
@ -568,12 +570,27 @@ class OpenAIChatCompletionsModel(Model):
|
|||
|
||||
|
||||
class _Converter:
|
||||
|
||||
@classmethod
|
||||
def is_openai(cls, client: AsyncOpenAI):
|
||||
return str(client.base_url).startswith("https://api.openai.com")
|
||||
|
||||
@classmethod
|
||||
def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None:
|
||||
# Match the behavior of Responses where store is True when not given
|
||||
default_store = True if str(client.base_url).startswith("https://api.openai.com") else None
|
||||
default_store = True if cls.is_openai(client) else None
|
||||
return model_settings.store if model_settings.store is not None else default_store
|
||||
|
||||
@classmethod
|
||||
def get_stream_options_param(
|
||||
cls, client: AsyncOpenAI, model_settings: ModelSettings
|
||||
) -> dict[str, bool] | None:
|
||||
default_include_usage = True if cls.is_openai(client) else None
|
||||
include_usage = model_settings.include_usage if model_settings.include_usage is not None \
|
||||
else default_include_usage
|
||||
stream_options = {"include_usage": include_usage} if include_usage is not None else None
|
||||
return stream_options
|
||||
|
||||
@classmethod
|
||||
def convert_tool_choice(
|
||||
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
||||
|
|
|
|||
|
|
@ -282,7 +282,7 @@ async def test_fetch_response_stream(monkeypatch) -> None:
|
|||
# Check OpenAI client was called for streaming
|
||||
assert completions.kwargs["stream"] is True
|
||||
assert completions.kwargs["store"] is NOT_GIVEN
|
||||
assert completions.kwargs["stream_options"] == {"include_usage": True}
|
||||
assert completions.kwargs["stream_options"] is NOT_GIVEN
|
||||
# Response is a proper openai Response
|
||||
assert isinstance(response, Response)
|
||||
assert response.id == FAKE_RESPONSES_ID
|
||||
|
|
|
|||
Loading…
Reference in a new issue