RFC: automatically use litellm if possible (#534)
## Summary This replaces the default model provider with a `MultiProvider`, which has the logic: - if the model name starts with `openai/` or doesn't contain "/", use OpenAI - if the model name starts with `litellm/`, use LiteLLM to use the appropriate model provider. It's also extensible, so users can create their own mappings. I also imagine that if we natively supported Anthropic/Gemini etc, we can add it to MultiProvider to make it work. The goal is that it should be really easy to use any model provider. Today if you pass `model="gpt-4.1"`, it works great. But `model="claude-sonnet-3.7"` doesn't. If we can make it that easy, it's a win for devx. I'm not entirely sure if this is a good idea - is it too magical? Is the API too reliant on litellm? Comments welcome. ## Test plan For now, the example. Will add unit tests if we agree its worth mergin. --------- Co-authored-by: Steven Heidel <steven@heidel.ca>
This commit is contained in:
parent
0a3dfa071a
commit
a0254b0b74
4 changed files with 208 additions and 2 deletions
41
examples/model_providers/litellm_auto.py
Normal file
41
examples/model_providers/litellm_auto.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from agents import Agent, Runner, function_tool, set_tracing_disabled
|
||||
|
||||
"""This example uses the built-in support for LiteLLM. To use this, ensure you have the
|
||||
ANTHROPIC_API_KEY environment variable set.
|
||||
"""
|
||||
|
||||
set_tracing_disabled(disabled=True)
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str):
|
||||
print(f"[debug] getting weather for {city}")
|
||||
return f"The weather in {city} is sunny."
|
||||
|
||||
|
||||
async def main():
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="You only respond in haikus.",
|
||||
# We prefix with litellm/ to tell the Runner to use the LitellmModel
|
||||
model="litellm/anthropic/claude-3-5-sonnet-20240620",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, "What's the weather in Tokyo?")
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
if os.getenv("ANTHROPIC_API_KEY") is None:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again."
|
||||
)
|
||||
|
||||
asyncio.run(main())
|
||||
21
src/agents/extensions/models/litellm_provider.py
Normal file
21
src/agents/extensions/models/litellm_provider.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from ...models.interface import Model, ModelProvider
|
||||
from .litellm_model import LitellmModel
|
||||
|
||||
DEFAULT_MODEL: str = "gpt-4.1"
|
||||
|
||||
|
||||
class LitellmProvider(ModelProvider):
|
||||
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
|
||||
```python
|
||||
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
|
||||
```
|
||||
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
|
||||
|
||||
NOTE: API keys must be set via environment variables. If you're using models that require
|
||||
additional configuration (e.g. Azure API base or version), those must also be set via the
|
||||
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
|
||||
copy-pasting this class and making any modifications you need.
|
||||
"""
|
||||
|
||||
def get_model(self, model_name: str | None) -> Model:
|
||||
return LitellmModel(model_name or DEFAULT_MODEL)
|
||||
144
src/agents/models/multi_provider.py
Normal file
144
src/agents/models/multi_provider.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..exceptions import UserError
|
||||
from .interface import Model, ModelProvider
|
||||
from .openai_provider import OpenAIProvider
|
||||
|
||||
|
||||
class MultiProviderMap:
|
||||
"""A map of model name prefixes to ModelProviders."""
|
||||
|
||||
def __init__(self):
|
||||
self._mapping: dict[str, ModelProvider] = {}
|
||||
|
||||
def has_prefix(self, prefix: str) -> bool:
|
||||
"""Returns True if the given prefix is in the mapping."""
|
||||
return prefix in self._mapping
|
||||
|
||||
def get_mapping(self) -> dict[str, ModelProvider]:
|
||||
"""Returns a copy of the current prefix -> ModelProvider mapping."""
|
||||
return self._mapping.copy()
|
||||
|
||||
def set_mapping(self, mapping: dict[str, ModelProvider]):
|
||||
"""Overwrites the current mapping with a new one."""
|
||||
self._mapping = mapping
|
||||
|
||||
def get_provider(self, prefix: str) -> ModelProvider | None:
|
||||
"""Returns the ModelProvider for the given prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
||||
"""
|
||||
return self._mapping.get(prefix)
|
||||
|
||||
def add_provider(self, prefix: str, provider: ModelProvider):
|
||||
"""Adds a new prefix -> ModelProvider mapping.
|
||||
|
||||
Args:
|
||||
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
||||
provider: The ModelProvider to use for the given prefix.
|
||||
"""
|
||||
self._mapping[prefix] = provider
|
||||
|
||||
def remove_provider(self, prefix: str):
|
||||
"""Removes the mapping for the given prefix.
|
||||
|
||||
Args:
|
||||
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
||||
"""
|
||||
del self._mapping[prefix]
|
||||
|
||||
|
||||
class MultiProvider(ModelProvider):
|
||||
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
|
||||
mapping is:
|
||||
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
|
||||
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
|
||||
|
||||
You can override or customize this mapping.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider_map: MultiProviderMap | None = None,
|
||||
openai_api_key: str | None = None,
|
||||
openai_base_url: str | None = None,
|
||||
openai_client: AsyncOpenAI | None = None,
|
||||
openai_organization: str | None = None,
|
||||
openai_project: str | None = None,
|
||||
openai_use_responses: bool | None = None,
|
||||
) -> None:
|
||||
"""Create a new OpenAI provider.
|
||||
|
||||
Args:
|
||||
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
|
||||
we will use a default mapping. See the documentation for this class to see the
|
||||
default mapping.
|
||||
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
|
||||
the default API key.
|
||||
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
|
||||
use the default base URL.
|
||||
openai_client: An optional OpenAI client to use. If not provided, we will create a new
|
||||
OpenAI client using the api_key and base_url.
|
||||
openai_organization: The organization to use for the OpenAI provider.
|
||||
openai_project: The project to use for the OpenAI provider.
|
||||
openai_use_responses: Whether to use the OpenAI responses API.
|
||||
"""
|
||||
self.provider_map = provider_map
|
||||
self.openai_provider = OpenAIProvider(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
openai_client=openai_client,
|
||||
organization=openai_organization,
|
||||
project=openai_project,
|
||||
use_responses=openai_use_responses,
|
||||
)
|
||||
|
||||
self._fallback_providers: dict[str, ModelProvider] = {}
|
||||
|
||||
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
|
||||
if model_name is None:
|
||||
return None, None
|
||||
elif "/" in model_name:
|
||||
prefix, model_name = model_name.split("/", 1)
|
||||
return prefix, model_name
|
||||
else:
|
||||
return None, model_name
|
||||
|
||||
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
|
||||
if prefix == "litellm":
|
||||
from ..extensions.models.litellm_provider import LitellmProvider
|
||||
|
||||
return LitellmProvider()
|
||||
else:
|
||||
raise UserError(f"Unknown prefix: {prefix}")
|
||||
|
||||
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
|
||||
if prefix is None or prefix == "openai":
|
||||
return self.openai_provider
|
||||
elif prefix in self._fallback_providers:
|
||||
return self._fallback_providers[prefix]
|
||||
else:
|
||||
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
|
||||
return self._fallback_providers[prefix]
|
||||
|
||||
def get_model(self, model_name: str | None) -> Model:
|
||||
"""Returns a Model based on the model name. The model name can have a prefix, ending with
|
||||
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
|
||||
the OpenAI provider.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to get.
|
||||
|
||||
Returns:
|
||||
A Model.
|
||||
"""
|
||||
prefix, model_name = self._get_prefix_and_model_name(model_name)
|
||||
|
||||
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
|
||||
return provider.get_model(model_name)
|
||||
else:
|
||||
return self._get_fallback_provider(prefix).get_model(model_name)
|
||||
|
|
@ -34,7 +34,7 @@ from .lifecycle import RunHooks
|
|||
from .logger import logger
|
||||
from .model_settings import ModelSettings
|
||||
from .models.interface import Model, ModelProvider
|
||||
from .models.openai_provider import OpenAIProvider
|
||||
from .models.multi_provider import MultiProvider
|
||||
from .result import RunResult, RunResultStreaming
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
|
||||
|
|
@ -56,7 +56,7 @@ class RunConfig:
|
|||
agent. The model_provider passed in below must be able to resolve this model name.
|
||||
"""
|
||||
|
||||
model_provider: ModelProvider = field(default_factory=OpenAIProvider)
|
||||
model_provider: ModelProvider = field(default_factory=MultiProvider)
|
||||
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
|
||||
|
||||
model_settings: ModelSettings | None = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue