Litellm integration (#524)
litellm is a library that abstracts away details/differences for a lot of model providers. Adding an extension, so that any provider can easily be integrated. --- [//]: # (BEGIN SAPLING FOOTER) * #532 * __->__ #524
This commit is contained in:
parent
0faadf7f7b
commit
bd404e0f87
10 changed files with 1696 additions and 376 deletions
55
examples/model_providers/litellm_provider.py
Normal file
55
examples/model_providers/litellm_provider.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from agents import Agent, Runner, function_tool, set_tracing_disabled
|
||||
from agents.extensions.models.litellm_model import LitellmModel
|
||||
|
||||
"""This example uses the LitellmModel directly, to hit any model provider.
|
||||
You can run it like this:
|
||||
uv run examples/model_providers/litellm_provider.py --model anthropic/claude-3-5-sonnet-20240620
|
||||
or
|
||||
uv run examples/model_providers/litellm_provider.py --model gemini/gemini-2.0-flash
|
||||
|
||||
Find more providers here: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
set_tracing_disabled(disabled=True)
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str):
|
||||
print(f"[debug] getting weather for {city}")
|
||||
return f"The weather in {city} is sunny."
|
||||
|
||||
|
||||
async def main(model: str, api_key: str):
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="You only respond in haikus.",
|
||||
model=LitellmModel(model=model, api_key=api_key),
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, "What's the weather in Tokyo?")
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# First try to get model/api key from args
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, required=False)
|
||||
parser.add_argument("--api-key", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = args.model
|
||||
if not model:
|
||||
model = input("Enter a model name for Litellm: ")
|
||||
|
||||
api_key = args.api_key
|
||||
if not api_key:
|
||||
api_key = input("Enter an API key for Litellm: ")
|
||||
|
||||
asyncio.run(main(model, api_key))
|
||||
|
|
@ -36,6 +36,7 @@ Repository = "https://github.com/openai/openai-agents-python"
|
|||
[project.optional-dependencies]
|
||||
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
|
||||
viz = ["graphviz>=0.17"]
|
||||
litellm = ["litellm>=1.65.0, <2"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
|
@ -44,7 +45,7 @@ dev = [
|
|||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pytest-mock>=3.14.0",
|
||||
"rich",
|
||||
"rich>=13.1.0, <14",
|
||||
"mkdocs>=1.6.0",
|
||||
"mkdocs-material>=9.6.0",
|
||||
"mkdocstrings[python]>=0.28.0",
|
||||
|
|
|
|||
0
src/agents/extensions/models/__init__.py
Normal file
0
src/agents/extensions/models/__init__.py
Normal file
380
src/agents/extensions/models/litellm_model.py
Normal file
380
src/agents/extensions/models/litellm_model.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, Literal, cast, overload
|
||||
|
||||
import litellm.types
|
||||
|
||||
from agents.exceptions import ModelBehaviorError
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError as _e:
|
||||
raise ImportError(
|
||||
"`litellm` is required to use the LitellmModel. You can install it via the optional "
|
||||
"dependency group: `pip install 'openai-agents[litellm]'`."
|
||||
) from _e
|
||||
|
||||
from openai import NOT_GIVEN, AsyncStream, NotGiven
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import (
|
||||
Annotation,
|
||||
AnnotationURLCitation,
|
||||
ChatCompletionMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai.types.responses import Response
|
||||
|
||||
from ... import _debug
|
||||
from ...agent_output import AgentOutputSchema
|
||||
from ...handoffs import Handoff
|
||||
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
||||
from ...logger import logger
|
||||
from ...model_settings import ModelSettings
|
||||
from ...models.chatcmpl_converter import Converter
|
||||
from ...models.chatcmpl_helpers import HEADERS
|
||||
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
|
||||
from ...models.fake_id import FAKE_RESPONSES_ID
|
||||
from ...models.interface import Model, ModelTracing
|
||||
from ...tool import Tool
|
||||
from ...tracing import generation_span
|
||||
from ...tracing.span_data import GenerationSpanData
|
||||
from ...tracing.spans import Span
|
||||
from ...usage import Usage
|
||||
|
||||
|
||||
class LitellmModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
system_instructions: str | None,
|
||||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
) -> ModelResponse:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
model_config=dataclasses.asdict(model_settings)
|
||||
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
||||
disabled=tracing.is_disabled(),
|
||||
) as span_generation:
|
||||
response = await self._fetch_response(
|
||||
system_instructions,
|
||||
input,
|
||||
model_settings,
|
||||
tools,
|
||||
output_schema,
|
||||
handoffs,
|
||||
span_generation,
|
||||
tracing,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response.choices[0], litellm.types.utils.Choices)
|
||||
|
||||
if _debug.DONT_LOG_MODEL_DATA:
|
||||
logger.debug("Received model response")
|
||||
else:
|
||||
logger.debug(
|
||||
f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n"
|
||||
)
|
||||
|
||||
if hasattr(response, "usage"):
|
||||
response_usage = response.usage
|
||||
usage = (
|
||||
Usage(
|
||||
requests=1,
|
||||
input_tokens=response_usage.prompt_tokens,
|
||||
output_tokens=response_usage.completion_tokens,
|
||||
total_tokens=response_usage.total_tokens,
|
||||
)
|
||||
if response.usage
|
||||
else Usage()
|
||||
)
|
||||
else:
|
||||
usage = Usage()
|
||||
logger.warning("No usage information returned from Litellm")
|
||||
|
||||
if tracing.include_data():
|
||||
span_generation.span_data.output = [response.choices[0].message.model_dump()]
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
}
|
||||
|
||||
items = Converter.message_to_output_items(
|
||||
LitellmConverter.convert_message_to_openai(response.choices[0].message)
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
output=items,
|
||||
usage=usage,
|
||||
response_id=None,
|
||||
)
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
system_instructions: str | None,
|
||||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
"""
|
||||
Yields a partial message as it is generated, as well as the usage information.
|
||||
"""
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
model_config=dataclasses.asdict(model_settings)
|
||||
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
||||
disabled=tracing.is_disabled(),
|
||||
) as span_generation:
|
||||
response, stream = await self._fetch_response(
|
||||
system_instructions,
|
||||
input,
|
||||
model_settings,
|
||||
tools,
|
||||
output_schema,
|
||||
handoffs,
|
||||
span_generation,
|
||||
tracing,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
final_response: Response | None = None
|
||||
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
|
||||
yield chunk
|
||||
|
||||
if chunk.type == "response.completed":
|
||||
final_response = chunk.response
|
||||
|
||||
if tracing.include_data() and final_response:
|
||||
span_generation.span_data.output = [final_response.model_dump()]
|
||||
|
||||
if final_response and final_response.usage:
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": final_response.usage.input_tokens,
|
||||
"output_tokens": final_response.usage.output_tokens,
|
||||
}
|
||||
|
||||
@overload
|
||||
async def _fetch_response(
|
||||
self,
|
||||
system_instructions: str | None,
|
||||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: Literal[True],
|
||||
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
|
||||
|
||||
@overload
|
||||
async def _fetch_response(
|
||||
self,
|
||||
system_instructions: str | None,
|
||||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: Literal[False],
|
||||
) -> litellm.types.utils.ModelResponse: ...
|
||||
|
||||
async def _fetch_response(
|
||||
self,
|
||||
system_instructions: str | None,
|
||||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: bool = False,
|
||||
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
||||
converted_messages = Converter.items_to_messages(input)
|
||||
|
||||
if system_instructions:
|
||||
converted_messages.insert(
|
||||
0,
|
||||
{
|
||||
"content": system_instructions,
|
||||
"role": "system",
|
||||
},
|
||||
)
|
||||
if tracing.include_data():
|
||||
span.span_data.input = converted_messages
|
||||
|
||||
parallel_tool_calls = (
|
||||
True
|
||||
if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
||||
else False
|
||||
if model_settings.parallel_tool_calls is False
|
||||
else None
|
||||
)
|
||||
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
||||
response_format = Converter.convert_response_format(output_schema)
|
||||
|
||||
converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else []
|
||||
|
||||
for handoff in handoffs:
|
||||
converted_tools.append(Converter.convert_handoff_tool(handoff))
|
||||
|
||||
if _debug.DONT_LOG_MODEL_DATA:
|
||||
logger.debug("Calling LLM")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Calling Litellm model: {self.model}\n"
|
||||
f"{json.dumps(converted_messages, indent=2)}\n"
|
||||
f"Tools:\n{json.dumps(converted_tools, indent=2)}\n"
|
||||
f"Stream: {stream}\n"
|
||||
f"Tool choice: {tool_choice}\n"
|
||||
f"Response format: {response_format}\n"
|
||||
)
|
||||
|
||||
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
||||
|
||||
stream_options = None
|
||||
if stream and model_settings.include_usage is not None:
|
||||
stream_options = {"include_usage": model_settings.include_usage}
|
||||
|
||||
extra_kwargs = {}
|
||||
if model_settings.extra_query:
|
||||
extra_kwargs["extra_query"] = model_settings.extra_query
|
||||
if model_settings.metadata:
|
||||
extra_kwargs["metadata"] = model_settings.metadata
|
||||
|
||||
ret = await litellm.acompletion(
|
||||
model=self.model,
|
||||
messages=converted_messages,
|
||||
tools=converted_tools or None,
|
||||
temperature=model_settings.temperature,
|
||||
top_p=model_settings.top_p,
|
||||
frequency_penalty=model_settings.frequency_penalty,
|
||||
presence_penalty=model_settings.presence_penalty,
|
||||
max_tokens=model_settings.max_tokens,
|
||||
tool_choice=self._remove_not_given(tool_choice),
|
||||
response_format=self._remove_not_given(response_format),
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
reasoning_effort=reasoning_effort,
|
||||
extra_headers=HEADERS,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(ret, litellm.types.utils.ModelResponse):
|
||||
return ret
|
||||
|
||||
response = Response(
|
||||
id=FAKE_RESPONSES_ID,
|
||||
created_at=time.time(),
|
||||
model=self.model,
|
||||
object="response",
|
||||
output=[],
|
||||
tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
|
||||
if tool_choice != NOT_GIVEN
|
||||
else "auto",
|
||||
top_p=model_settings.top_p,
|
||||
temperature=model_settings.temperature,
|
||||
tools=[],
|
||||
parallel_tool_calls=parallel_tool_calls or False,
|
||||
reasoning=model_settings.reasoning,
|
||||
)
|
||||
return response, ret
|
||||
|
||||
def _remove_not_given(self, value: Any) -> Any:
|
||||
if isinstance(value, NotGiven):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class LitellmConverter:
|
||||
@classmethod
|
||||
def convert_message_to_openai(
|
||||
cls, message: litellm.types.utils.Message
|
||||
) -> ChatCompletionMessage:
|
||||
if message.role != "assistant":
|
||||
raise ModelBehaviorError(f"Unsupported role: {message.role}")
|
||||
|
||||
tool_calls = (
|
||||
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
|
||||
if message.tool_calls
|
||||
else None
|
||||
)
|
||||
|
||||
provider_specific_fields = message.get("provider_specific_fields", None)
|
||||
refusal = (
|
||||
provider_specific_fields.get("refusal", None) if provider_specific_fields else None
|
||||
)
|
||||
|
||||
return ChatCompletionMessage(
|
||||
content=message.content,
|
||||
refusal=refusal,
|
||||
role="assistant",
|
||||
annotations=cls.convert_annotations_to_openai(message),
|
||||
audio=message.get("audio", None), # litellm deletes audio if not present
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_annotations_to_openai(
|
||||
cls, message: litellm.types.utils.Message
|
||||
) -> list[Annotation] | None:
|
||||
annotations: list[litellm.types.llms.openai.ChatCompletionAnnotation] | None = message.get(
|
||||
"annotations", None
|
||||
)
|
||||
if not annotations:
|
||||
return None
|
||||
|
||||
return [
|
||||
Annotation(
|
||||
type="url_citation",
|
||||
url_citation=AnnotationURLCitation(
|
||||
start_index=annotation["url_citation"]["start_index"],
|
||||
end_index=annotation["url_citation"]["end_index"],
|
||||
url=annotation["url_citation"]["url"],
|
||||
title=annotation["url_citation"]["title"],
|
||||
),
|
||||
)
|
||||
for annotation in annotations
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def convert_tool_call_to_openai(
|
||||
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
|
||||
) -> ChatCompletionMessageToolCall:
|
||||
return ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tool_call.function.name or "", arguments=tool_call.function.arguments
|
||||
),
|
||||
)
|
||||
37
src/agents/models/chatcmpl_helpers.py
Normal file
37
src/agents/models/chatcmpl_helpers.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..model_settings import ModelSettings
|
||||
from ..version import __version__
|
||||
|
||||
_USER_AGENT = f"Agents/Python {__version__}"
|
||||
HEADERS = {"User-Agent": _USER_AGENT}
|
||||
|
||||
|
||||
class ChatCmplHelpers:
|
||||
@classmethod
|
||||
def is_openai(cls, client: AsyncOpenAI):
|
||||
return str(client.base_url).startswith("https://api.openai.com")
|
||||
|
||||
@classmethod
|
||||
def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None:
|
||||
# Match the behavior of Responses where store is True when not given
|
||||
default_store = True if cls.is_openai(client) else None
|
||||
return model_settings.store if model_settings.store is not None else default_store
|
||||
|
||||
@classmethod
|
||||
def get_stream_options_param(
|
||||
cls, client: AsyncOpenAI, model_settings: ModelSettings, stream: bool
|
||||
) -> dict[str, bool] | None:
|
||||
if not stream:
|
||||
return None
|
||||
|
||||
default_include_usage = True if cls.is_openai(client) else None
|
||||
include_usage = (
|
||||
model_settings.include_usage
|
||||
if model_settings.include_usage is not None
|
||||
else default_include_usage
|
||||
)
|
||||
stream_options = {"include_usage": include_usage} if include_usage is not None else None
|
||||
return stream_options
|
||||
|
|
@ -21,8 +21,8 @@ from ..tracing import generation_span
|
|||
from ..tracing.span_data import GenerationSpanData
|
||||
from ..tracing.spans import Span
|
||||
from ..usage import Usage
|
||||
from ..version import __version__
|
||||
from .chatcmpl_converter import Converter
|
||||
from .chatcmpl_helpers import HEADERS, ChatCmplHelpers
|
||||
from .chatcmpl_stream_handler import ChatCmplStreamHandler
|
||||
from .fake_id import FAKE_RESPONSES_ID
|
||||
from .interface import Model, ModelTracing
|
||||
|
|
@ -31,10 +31,6 @@ if TYPE_CHECKING:
|
|||
from ..model_settings import ModelSettings
|
||||
|
||||
|
||||
_USER_AGENT = f"Agents/Python {__version__}"
|
||||
_HEADERS = {"User-Agent": _USER_AGENT}
|
||||
|
||||
|
||||
class OpenAIChatCompletionsModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -237,9 +233,9 @@ class OpenAIChatCompletionsModel(Model):
|
|||
)
|
||||
|
||||
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
||||
store = _Helpers.get_store_param(self._get_client(), model_settings)
|
||||
store = ChatCmplHelpers.get_store_param(self._get_client(), model_settings)
|
||||
|
||||
stream_options = _Helpers.get_stream_options_param(
|
||||
stream_options = ChatCmplHelpers.get_stream_options_param(
|
||||
self._get_client(), model_settings, stream=stream
|
||||
)
|
||||
|
||||
|
|
@ -259,7 +255,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
stream_options=self._non_null_or_not_given(stream_options),
|
||||
store=self._non_null_or_not_given(store),
|
||||
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
||||
extra_headers=_HEADERS,
|
||||
extra_headers=HEADERS,
|
||||
extra_query=model_settings.extra_query,
|
||||
extra_body=model_settings.extra_body,
|
||||
metadata=self._non_null_or_not_given(model_settings.metadata),
|
||||
|
|
@ -289,31 +285,3 @@ class OpenAIChatCompletionsModel(Model):
|
|||
if self._client is None:
|
||||
self._client = AsyncOpenAI()
|
||||
return self._client
|
||||
|
||||
|
||||
class _Helpers:
|
||||
@classmethod
|
||||
def is_openai(cls, client: AsyncOpenAI):
|
||||
return str(client.base_url).startswith("https://api.openai.com")
|
||||
|
||||
@classmethod
|
||||
def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None:
|
||||
# Match the behavior of Responses where store is True when not given
|
||||
default_store = True if cls.is_openai(client) else None
|
||||
return model_settings.store if model_settings.store is not None else default_store
|
||||
|
||||
@classmethod
|
||||
def get_stream_options_param(
|
||||
cls, client: AsyncOpenAI, model_settings: ModelSettings, stream: bool
|
||||
) -> dict[str, bool] | None:
|
||||
if not stream:
|
||||
return None
|
||||
|
||||
default_include_usage = True if cls.is_openai(client) else None
|
||||
include_usage = (
|
||||
model_settings.include_usage
|
||||
if model_settings.include_usage is not None
|
||||
else default_include_usage
|
||||
)
|
||||
stream_options = {"include_usage": include_usage} if include_usage is not None else None
|
||||
return stream_options
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ def varargs_function(x: int, *numbers: float, flag: bool = False, **kwargs: Any)
|
|||
def test_varargs_function():
|
||||
"""Test a function that uses *args and **kwargs."""
|
||||
|
||||
func_schema = function_schema(varargs_function)
|
||||
func_schema = function_schema(varargs_function, strict_json_schema=False)
|
||||
# Check JSON schema structure
|
||||
assert isinstance(func_schema.params_json_schema, dict)
|
||||
assert func_schema.params_json_schema.get("title") == "varargs_function_args"
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ from agents import (
|
|||
OpenAIProvider,
|
||||
generation_span,
|
||||
)
|
||||
from agents.models.chatcmpl_helpers import ChatCmplHelpers
|
||||
from agents.models.fake_id import FAKE_RESPONSES_ID
|
||||
from agents.models.openai_chatcompletions import _Helpers
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
@ -301,32 +301,32 @@ def test_store_param():
|
|||
|
||||
model_settings = ModelSettings()
|
||||
client = AsyncOpenAI()
|
||||
assert _Helpers.get_store_param(client, model_settings) is True, (
|
||||
assert ChatCmplHelpers.get_store_param(client, model_settings) is True, (
|
||||
"Should default to True for OpenAI API calls"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=False)
|
||||
assert _Helpers.get_store_param(client, model_settings) is False, (
|
||||
assert ChatCmplHelpers.get_store_param(client, model_settings) is False, (
|
||||
"Should respect explicitly set store=False"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=True)
|
||||
assert _Helpers.get_store_param(client, model_settings) is True, (
|
||||
assert ChatCmplHelpers.get_store_param(client, model_settings) is True, (
|
||||
"Should respect explicitly set store=True"
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(base_url="http://www.notopenai.com")
|
||||
model_settings = ModelSettings()
|
||||
assert _Helpers.get_store_param(client, model_settings) is None, (
|
||||
assert ChatCmplHelpers.get_store_param(client, model_settings) is None, (
|
||||
"Should default to None for non-OpenAI API calls"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=False)
|
||||
assert _Helpers.get_store_param(client, model_settings) is False, (
|
||||
assert ChatCmplHelpers.get_store_param(client, model_settings) is False, (
|
||||
"Should respect explicitly set store=False"
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(store=True)
|
||||
assert _Helpers.get_store_param(client, model_settings) is True, (
|
||||
assert ChatCmplHelpers.get_store_param(client, model_settings) is True, (
|
||||
"Should respect explicitly set store=True"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Skip voice tests on Python 3.9
|
||||
def pytest_ignore_collect(collection_path, config):
|
||||
if sys.version_info[:2] == (3, 9):
|
||||
this_dir = os.path.dirname(__file__)
|
||||
skip_marker = pytest.mark.skip(reason="Skipped on Python 3.9")
|
||||
|
||||
for item in items:
|
||||
if item.fspath.dirname.startswith(this_dir):
|
||||
item.add_marker(skip_marker)
|
||||
if str(collection_path).startswith(this_dir):
|
||||
return True
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue