Adding extra_headers parameters to ModelSettings (#550)
This commit is contained in:
parent
83ce49ec7e
commit
111fc9ee66
5 changed files with 100 additions and 4 deletions
|
|
@ -286,7 +286,7 @@ class LitellmModel(Model):
|
|||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
reasoning_effort=reasoning_effort,
|
||||
extra_headers=HEADERS,
|
||||
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
**extra_kwargs,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Literal
|
||||
|
||||
from openai._types import Body, Query
|
||||
from openai._types import Body, Headers, Query
|
||||
from openai.types.shared import Reasoning
|
||||
|
||||
|
||||
|
|
@ -67,6 +67,10 @@ class ModelSettings:
|
|||
"""Additional body fields to provide with the request.
|
||||
Defaults to None if not provided."""
|
||||
|
||||
extra_headers: Headers | None = None
|
||||
"""Additional headers to provide with the request.
|
||||
Defaults to None if not provided."""
|
||||
|
||||
def resolve(self, override: ModelSettings | None) -> ModelSettings:
|
||||
"""Produce a new ModelSettings by overlaying any non-None values from the
|
||||
override on top of this instance."""
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
stream_options=self._non_null_or_not_given(stream_options),
|
||||
store=self._non_null_or_not_given(store),
|
||||
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
||||
extra_headers=HEADERS,
|
||||
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
|
||||
extra_query=model_settings.extra_query,
|
||||
extra_body=model_settings.extra_body,
|
||||
metadata=self._non_null_or_not_given(model_settings.metadata),
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ class OpenAIResponsesModel(Model):
|
|||
tool_choice=tool_choice,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
stream=stream,
|
||||
extra_headers=_HEADERS,
|
||||
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
|
||||
extra_query=model_settings.extra_query,
|
||||
extra_body=model_settings.extra_body,
|
||||
text=response_format,
|
||||
|
|
|
|||
92
tests/test_extra_headers.py
Normal file
92
tests/test_extra_headers.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_headers_passed_to_openai_responses_model():
|
||||
"""
|
||||
Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client.
|
||||
"""
|
||||
called_kwargs = {}
|
||||
|
||||
class DummyResponses:
|
||||
async def create(self, **kwargs):
|
||||
nonlocal called_kwargs
|
||||
called_kwargs = kwargs
|
||||
class DummyResponse:
|
||||
id = "dummy"
|
||||
output = []
|
||||
usage = type(
|
||||
"Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
)()
|
||||
return DummyResponse()
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self.responses = DummyResponses()
|
||||
|
||||
|
||||
|
||||
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
|
||||
extra_headers = {"X-Test-Header": "test-value"}
|
||||
await model.get_response(
|
||||
system_instructions=None,
|
||||
input="hi",
|
||||
model_settings=ModelSettings(extra_headers=extra_headers),
|
||||
tools=[],
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
)
|
||||
assert "extra_headers" in called_kwargs
|
||||
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_headers_passed_to_openai_client():
|
||||
"""
|
||||
Ensure extra_headers in ModelSettings is passed to the OpenAI client.
|
||||
"""
|
||||
called_kwargs = {}
|
||||
|
||||
class DummyCompletions:
|
||||
async def create(self, **kwargs):
|
||||
nonlocal called_kwargs
|
||||
called_kwargs = kwargs
|
||||
msg = ChatCompletionMessage(role="assistant", content="Hello")
|
||||
choice = Choice(index=0, finish_reason="stop", message=msg)
|
||||
return ChatCompletion(
|
||||
id="resp-id",
|
||||
created=0,
|
||||
model="fake",
|
||||
object="chat.completion",
|
||||
choices=[choice],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
|
||||
self.base_url = "https://api.openai.com"
|
||||
|
||||
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
|
||||
extra_headers = {"X-Test-Header": "test-value"}
|
||||
await model.get_response(
|
||||
system_instructions=None,
|
||||
input="hi",
|
||||
model_settings=ModelSettings(extra_headers=extra_headers),
|
||||
tools=[],
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
)
|
||||
assert "extra_headers" in called_kwargs
|
||||
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
|
||||
Loading…
Reference in a new issue