RFC: automatically use litellm if possible (#534)

## Summary
This replaces the default model provider with a `MultiProvider`, which
has the logic:
- if the model name starts with `openai/` or doesn't contain "/", use
OpenAI
- if the model name starts with `litellm/`, use LiteLLM to use the
appropriate model provider.

It's also extensible, so users can create their own mappings. I also
imagine that if we natively supported Anthropic/Gemini etc, we can add
it to MultiProvider to make it work.

The goal is that it should be really easy to use any model provider.
Today if you pass `model="gpt-4.1"`, it works great. But
`model="claude-sonnet-3.7"` doesn't. If we can make it that easy, it's a
win for devx.

I'm not entirely sure if this is a good idea - is it too magical? Is the
API too reliant on litellm? Comments welcome.

## Test plan

For now, the example. Will add unit tests if we agree its worth mergin.

---------

Co-authored-by: Steven Heidel <steven@heidel.ca>
This commit is contained in:
Rohan Mehta 2025-04-21 15:03:06 -04:00 committed by GitHub
parent 0a3dfa071a
commit a0254b0b74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 208 additions and 2 deletions

View file

@ -0,0 +1,41 @@
from __future__ import annotations
import asyncio
from agents import Agent, Runner, function_tool, set_tracing_disabled
"""This example uses the built-in support for LiteLLM. To use this, ensure you have the
ANTHROPIC_API_KEY environment variable set.
"""
set_tracing_disabled(disabled=True)
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
# We prefix with litellm/ to tell the Runner to use the LitellmModel
model="litellm/anthropic/claude-3-5-sonnet-20240620",
tools=[get_weather],
)
result = await Runner.run(agent, "What's the weather in Tokyo?")
print(result.final_output)
if __name__ == "__main__":
import os
if os.getenv("ANTHROPIC_API_KEY") is None:
raise ValueError(
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again."
)
asyncio.run(main())

View file

@ -0,0 +1,21 @@
from ...models.interface import Model, ModelProvider
from .litellm_model import LitellmModel
DEFAULT_MODEL: str = "gpt-4.1"
class LitellmProvider(ModelProvider):
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
```python
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
```
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
NOTE: API keys must be set via environment variables. If you're using models that require
additional configuration (e.g. Azure API base or version), those must also be set via the
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
copy-pasting this class and making any modifications you need.
"""
def get_model(self, model_name: str | None) -> Model:
return LitellmModel(model_name or DEFAULT_MODEL)

View file

@ -0,0 +1,144 @@
from __future__ import annotations
from openai import AsyncOpenAI
from ..exceptions import UserError
from .interface import Model, ModelProvider
from .openai_provider import OpenAIProvider
class MultiProviderMap:
"""A map of model name prefixes to ModelProviders."""
def __init__(self):
self._mapping: dict[str, ModelProvider] = {}
def has_prefix(self, prefix: str) -> bool:
"""Returns True if the given prefix is in the mapping."""
return prefix in self._mapping
def get_mapping(self) -> dict[str, ModelProvider]:
"""Returns a copy of the current prefix -> ModelProvider mapping."""
return self._mapping.copy()
def set_mapping(self, mapping: dict[str, ModelProvider]):
"""Overwrites the current mapping with a new one."""
self._mapping = mapping
def get_provider(self, prefix: str) -> ModelProvider | None:
"""Returns the ModelProvider for the given prefix.
Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
"""
return self._mapping.get(prefix)
def add_provider(self, prefix: str, provider: ModelProvider):
"""Adds a new prefix -> ModelProvider mapping.
Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
provider: The ModelProvider to use for the given prefix.
"""
self._mapping[prefix] = provider
def remove_provider(self, prefix: str):
"""Removes the mapping for the given prefix.
Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
"""
del self._mapping[prefix]
class MultiProvider(ModelProvider):
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
mapping is:
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
You can override or customize this mapping.
"""
def __init__(
self,
*,
provider_map: MultiProviderMap | None = None,
openai_api_key: str | None = None,
openai_base_url: str | None = None,
openai_client: AsyncOpenAI | None = None,
openai_organization: str | None = None,
openai_project: str | None = None,
openai_use_responses: bool | None = None,
) -> None:
"""Create a new OpenAI provider.
Args:
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
we will use a default mapping. See the documentation for this class to see the
default mapping.
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
the default API key.
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
use the default base URL.
openai_client: An optional OpenAI client to use. If not provided, we will create a new
OpenAI client using the api_key and base_url.
openai_organization: The organization to use for the OpenAI provider.
openai_project: The project to use for the OpenAI provider.
openai_use_responses: Whether to use the OpenAI responses API.
"""
self.provider_map = provider_map
self.openai_provider = OpenAIProvider(
api_key=openai_api_key,
base_url=openai_base_url,
openai_client=openai_client,
organization=openai_organization,
project=openai_project,
use_responses=openai_use_responses,
)
self._fallback_providers: dict[str, ModelProvider] = {}
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
if model_name is None:
return None, None
elif "/" in model_name:
prefix, model_name = model_name.split("/", 1)
return prefix, model_name
else:
return None, model_name
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
if prefix == "litellm":
from ..extensions.models.litellm_provider import LitellmProvider
return LitellmProvider()
else:
raise UserError(f"Unknown prefix: {prefix}")
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
if prefix is None or prefix == "openai":
return self.openai_provider
elif prefix in self._fallback_providers:
return self._fallback_providers[prefix]
else:
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
return self._fallback_providers[prefix]
def get_model(self, model_name: str | None) -> Model:
"""Returns a Model based on the model name. The model name can have a prefix, ending with
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
the OpenAI provider.
Args:
model_name: The name of the model to get.
Returns:
A Model.
"""
prefix, model_name = self._get_prefix_and_model_name(model_name)
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
return provider.get_model(model_name)
else:
return self._get_fallback_provider(prefix).get_model(model_name)

View file

@ -34,7 +34,7 @@ from .lifecycle import RunHooks
from .logger import logger
from .model_settings import ModelSettings
from .models.interface import Model, ModelProvider
from .models.openai_provider import OpenAIProvider
from .models.multi_provider import MultiProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
@ -56,7 +56,7 @@ class RunConfig:
agent. The model_provider passed in below must be able to resolve this model name.
"""
model_provider: ModelProvider = field(default_factory=OpenAIProvider)
model_provider: ModelProvider = field(default_factory=MultiProvider)
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
model_settings: ModelSettings | None = None