Adding extra_headers parameters to ModelSettings (#550)

This commit is contained in:
Jonny Kalambay 2025-04-22 19:26:47 -07:00 committed by GitHub
parent 83ce49ec7e
commit 111fc9ee66
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 100 additions and 4 deletions

View file

@ -286,7 +286,7 @@ class LitellmModel(Model):
stream=stream,
stream_options=stream_options,
reasoning_effort=reasoning_effort,
extra_headers=HEADERS,
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
api_key=self.api_key,
base_url=self.base_url,
**extra_kwargs,

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass, fields, replace
from typing import Literal
from openai._types import Body, Query
from openai._types import Body, Headers, Query
from openai.types.shared import Reasoning
@ -67,6 +67,10 @@ class ModelSettings:
"""Additional body fields to provide with the request.
Defaults to None if not provided."""
extra_headers: Headers | None = None
"""Additional headers to provide with the request.
Defaults to None if not provided."""
def resolve(self, override: ModelSettings | None) -> ModelSettings:
"""Produce a new ModelSettings by overlaying any non-None values from the
override on top of this instance."""

View file

@ -255,7 +255,7 @@ class OpenAIChatCompletionsModel(Model):
stream_options=self._non_null_or_not_given(stream_options),
store=self._non_null_or_not_given(store),
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
extra_headers=HEADERS,
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
metadata=self._non_null_or_not_given(model_settings.metadata),

View file

@ -253,7 +253,7 @@ class OpenAIResponsesModel(Model):
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
stream=stream,
extra_headers=_HEADERS,
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
text=response_format,

View file

@ -0,0 +1,92 @@
import pytest
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_headers_passed_to_openai_responses_model():
"""
Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client.
"""
called_kwargs = {}
class DummyResponses:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
class DummyResponse:
id = "dummy"
output = []
usage = type(
"Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
)()
return DummyResponse()
class DummyClient:
def __init__(self):
self.responses = DummyResponses()
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_headers=extra_headers),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)
assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_headers_passed_to_openai_client():
"""
Ensure extra_headers in ModelSettings is passed to the OpenAI client.
"""
called_kwargs = {}
class DummyCompletions:
async def create(self, **kwargs):
nonlocal called_kwargs
called_kwargs = kwargs
msg = ChatCompletionMessage(role="assistant", content="Hello")
choice = Choice(index=0, finish_reason="stop", message=msg)
return ChatCompletion(
id="resp-id",
created=0,
model="fake",
object="chat.completion",
choices=[choice],
usage=None,
)
class DummyClient:
def __init__(self):
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
self.base_url = "https://api.openai.com"
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
extra_headers = {"X-Test-Header": "test-value"}
await model.get_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_headers=extra_headers),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)
assert "extra_headers" in called_kwargs
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"