RFC: automatically use litellm if possible (#534)
## Summary This replaces the default model provider with a `MultiProvider`, which has the logic: - if the model name starts with `openai/` or doesn't contain "/", use OpenAI - if the model name starts with `litellm/`, use LiteLLM to use the appropriate model provider. It's also extensible, so users can create their own mappings. I also imagine that if we natively supported Anthropic/Gemini etc, we can add it to MultiProvider to make it work. The goal is that it should be really easy to use any model provider. Today if you pass `model="gpt-4.1"`, it works great. But `model="claude-sonnet-3.7"` doesn't. If we can make it that easy, it's a win for devx. I'm not entirely sure if this is a good idea - is it too magical? Is the API too reliant on litellm? Comments welcome. ## Test plan For now, the example. Will add unit tests if we agree its worth mergin. --------- Co-authored-by: Steven Heidel <steven@heidel.ca>
This commit is contained in:
parent
0a3dfa071a
commit
a0254b0b74
4 changed files with 208 additions and 2 deletions
41
examples/model_providers/litellm_auto.py
Normal file
41
examples/model_providers/litellm_auto.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from agents import Agent, Runner, function_tool, set_tracing_disabled
|
||||||
|
|
||||||
|
"""This example uses the built-in support for LiteLLM. To use this, ensure you have the
|
||||||
|
ANTHROPIC_API_KEY environment variable set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_tracing_disabled(disabled=True)
|
||||||
|
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def get_weather(city: str):
|
||||||
|
print(f"[debug] getting weather for {city}")
|
||||||
|
return f"The weather in {city} is sunny."
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
agent = Agent(
|
||||||
|
name="Assistant",
|
||||||
|
instructions="You only respond in haikus.",
|
||||||
|
# We prefix with litellm/ to tell the Runner to use the LitellmModel
|
||||||
|
model="litellm/anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
tools=[get_weather],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Runner.run(agent, "What's the weather in Tokyo?")
|
||||||
|
print(result.final_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.getenv("ANTHROPIC_API_KEY") is None:
|
||||||
|
raise ValueError(
|
||||||
|
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
21
src/agents/extensions/models/litellm_provider.py
Normal file
21
src/agents/extensions/models/litellm_provider.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from ...models.interface import Model, ModelProvider
|
||||||
|
from .litellm_model import LitellmModel
|
||||||
|
|
||||||
|
DEFAULT_MODEL: str = "gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
|
class LitellmProvider(ModelProvider):
|
||||||
|
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
|
||||||
|
```python
|
||||||
|
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
|
||||||
|
```
|
||||||
|
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
|
||||||
|
|
||||||
|
NOTE: API keys must be set via environment variables. If you're using models that require
|
||||||
|
additional configuration (e.g. Azure API base or version), those must also be set via the
|
||||||
|
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
|
||||||
|
copy-pasting this class and making any modifications you need.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_model(self, model_name: str | None) -> Model:
|
||||||
|
return LitellmModel(model_name or DEFAULT_MODEL)
|
||||||
144
src/agents/models/multi_provider.py
Normal file
144
src/agents/models/multi_provider.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from ..exceptions import UserError
|
||||||
|
from .interface import Model, ModelProvider
|
||||||
|
from .openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
class MultiProviderMap:
|
||||||
|
"""A map of model name prefixes to ModelProviders."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._mapping: dict[str, ModelProvider] = {}
|
||||||
|
|
||||||
|
def has_prefix(self, prefix: str) -> bool:
|
||||||
|
"""Returns True if the given prefix is in the mapping."""
|
||||||
|
return prefix in self._mapping
|
||||||
|
|
||||||
|
def get_mapping(self) -> dict[str, ModelProvider]:
|
||||||
|
"""Returns a copy of the current prefix -> ModelProvider mapping."""
|
||||||
|
return self._mapping.copy()
|
||||||
|
|
||||||
|
def set_mapping(self, mapping: dict[str, ModelProvider]):
|
||||||
|
"""Overwrites the current mapping with a new one."""
|
||||||
|
self._mapping = mapping
|
||||||
|
|
||||||
|
def get_provider(self, prefix: str) -> ModelProvider | None:
|
||||||
|
"""Returns the ModelProvider for the given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
||||||
|
"""
|
||||||
|
return self._mapping.get(prefix)
|
||||||
|
|
||||||
|
def add_provider(self, prefix: str, provider: ModelProvider):
|
||||||
|
"""Adds a new prefix -> ModelProvider mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
||||||
|
provider: The ModelProvider to use for the given prefix.
|
||||||
|
"""
|
||||||
|
self._mapping[prefix] = provider
|
||||||
|
|
||||||
|
def remove_provider(self, prefix: str):
|
||||||
|
"""Removes the mapping for the given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
||||||
|
"""
|
||||||
|
del self._mapping[prefix]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiProvider(ModelProvider):
|
||||||
|
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
|
||||||
|
mapping is:
|
||||||
|
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
|
||||||
|
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
|
||||||
|
|
||||||
|
You can override or customize this mapping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
provider_map: MultiProviderMap | None = None,
|
||||||
|
openai_api_key: str | None = None,
|
||||||
|
openai_base_url: str | None = None,
|
||||||
|
openai_client: AsyncOpenAI | None = None,
|
||||||
|
openai_organization: str | None = None,
|
||||||
|
openai_project: str | None = None,
|
||||||
|
openai_use_responses: bool | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
|
||||||
|
we will use a default mapping. See the documentation for this class to see the
|
||||||
|
default mapping.
|
||||||
|
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
|
||||||
|
the default API key.
|
||||||
|
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
|
||||||
|
use the default base URL.
|
||||||
|
openai_client: An optional OpenAI client to use. If not provided, we will create a new
|
||||||
|
OpenAI client using the api_key and base_url.
|
||||||
|
openai_organization: The organization to use for the OpenAI provider.
|
||||||
|
openai_project: The project to use for the OpenAI provider.
|
||||||
|
openai_use_responses: Whether to use the OpenAI responses API.
|
||||||
|
"""
|
||||||
|
self.provider_map = provider_map
|
||||||
|
self.openai_provider = OpenAIProvider(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_base_url,
|
||||||
|
openai_client=openai_client,
|
||||||
|
organization=openai_organization,
|
||||||
|
project=openai_project,
|
||||||
|
use_responses=openai_use_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._fallback_providers: dict[str, ModelProvider] = {}
|
||||||
|
|
||||||
|
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
|
||||||
|
if model_name is None:
|
||||||
|
return None, None
|
||||||
|
elif "/" in model_name:
|
||||||
|
prefix, model_name = model_name.split("/", 1)
|
||||||
|
return prefix, model_name
|
||||||
|
else:
|
||||||
|
return None, model_name
|
||||||
|
|
||||||
|
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
|
||||||
|
if prefix == "litellm":
|
||||||
|
from ..extensions.models.litellm_provider import LitellmProvider
|
||||||
|
|
||||||
|
return LitellmProvider()
|
||||||
|
else:
|
||||||
|
raise UserError(f"Unknown prefix: {prefix}")
|
||||||
|
|
||||||
|
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
|
||||||
|
if prefix is None or prefix == "openai":
|
||||||
|
return self.openai_provider
|
||||||
|
elif prefix in self._fallback_providers:
|
||||||
|
return self._fallback_providers[prefix]
|
||||||
|
else:
|
||||||
|
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
|
||||||
|
return self._fallback_providers[prefix]
|
||||||
|
|
||||||
|
def get_model(self, model_name: str | None) -> Model:
|
||||||
|
"""Returns a Model based on the model name. The model name can have a prefix, ending with
|
||||||
|
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
|
||||||
|
the OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The name of the model to get.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Model.
|
||||||
|
"""
|
||||||
|
prefix, model_name = self._get_prefix_and_model_name(model_name)
|
||||||
|
|
||||||
|
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
|
||||||
|
return provider.get_model(model_name)
|
||||||
|
else:
|
||||||
|
return self._get_fallback_provider(prefix).get_model(model_name)
|
||||||
|
|
@ -34,7 +34,7 @@ from .lifecycle import RunHooks
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
from .model_settings import ModelSettings
|
from .model_settings import ModelSettings
|
||||||
from .models.interface import Model, ModelProvider
|
from .models.interface import Model, ModelProvider
|
||||||
from .models.openai_provider import OpenAIProvider
|
from .models.multi_provider import MultiProvider
|
||||||
from .result import RunResult, RunResultStreaming
|
from .result import RunResult, RunResultStreaming
|
||||||
from .run_context import RunContextWrapper, TContext
|
from .run_context import RunContextWrapper, TContext
|
||||||
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
|
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
|
||||||
|
|
@ -56,7 +56,7 @@ class RunConfig:
|
||||||
agent. The model_provider passed in below must be able to resolve this model name.
|
agent. The model_provider passed in below must be able to resolve this model name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_provider: ModelProvider = field(default_factory=OpenAIProvider)
|
model_provider: ModelProvider = field(default_factory=MultiProvider)
|
||||||
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
|
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
|
||||||
|
|
||||||
model_settings: ModelSettings | None = None
|
model_settings: ModelSettings | None = None
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue