Add examples and documentation for using custom model providers

This commit is contained in:
Rohan Mehta 2025-03-12 16:06:17 -07:00
parent 97a09067cf
commit 25a633139d
8 changed files with 247 additions and 34 deletions

View file

@ -53,21 +53,14 @@ async def main():
## Using other LLM providers
Many providers also support the OpenAI API format, which means you can pass a `base_url` to the existing OpenAI model implementations and use them easily. `ModelSettings` is used to configure tuning parameters (e.g., temperature, top_p) for the model you select.
You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)):
```python
external_client = AsyncOpenAI(
api_key="EXTERNAL_API_KEY",
base_url="https://api.external.com/v1/",
)
1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py).
2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py).
3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py).
spanish_agent = Agent(
name="Spanish agent",
instructions="You only speak Spanish.",
model=OpenAIChatCompletionsModel(
model="EXTERNAL_MODEL_NAME",
openai_client=external_client,
),
model_settings=ModelSettings(temperature=0.5),
)
```
In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](tracing.md).
!!! note
In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses.

View file

@ -0,0 +1,19 @@
# Custom LLM providers
The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model.
```bash
export EXAMPLE_BASE_URL="..."
export EXAMPLE_API_KEY="..."
export EXAMPLE_MODEL_NAME"..."
```
Then run the examples, e.g.:
```
python examples/model_providers/custom_example_provider.py
Loops within themselves,
Function calls its own being,
Depth without ending.
```

View file

@ -0,0 +1,51 @@
import asyncio
import os
from openai import AsyncOpenAI
from agents import Agent, OpenAIChatCompletionsModel, Runner, set_tracing_disabled
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for a specific agent. Steps:
1. Create a custom OpenAI client.
2. Create a `Model` that uses the custom client.
3. Set the `model` on the Agent.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)
# An alternate approach that would also work:
# PROVIDER = OpenAIProvider(openai_client=client)
# agent = Agent(..., model="some-custom-model")
# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER))
async def main():
# This agent will use the custom LLM provider
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),
)
result = await Runner.run(
agent,
"Tell me about recursion in programming.",
)
print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,55 @@
import asyncio
import os
from openai import AsyncOpenAI
from agents import (
Agent,
Runner,
set_default_openai_api,
set_default_openai_client,
set_tracing_disabled,
)
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for all requests by default. We do three things:
1. Create a custom client.
2. Set it as the default OpenAI client, and don't use it for tracing.
3. Set the default API as Chat Completions, as most LLM providers don't yet support Responses API.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(
base_url=BASE_URL,
api_key=API_KEY,
)
set_default_openai_client(client=client, use_for_tracing=False)
set_default_openai_api("chat_completions")
set_tracing_disabled(disabled=True)
async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=MODEL_NAME,
)
result = await Runner.run(agent, "Tell me about recursion in programming.")
print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,73 @@
from __future__ import annotations
import asyncio
import os
from openai import AsyncOpenAI
from agents import (
Agent,
Model,
ModelProvider,
OpenAIChatCompletionsModel,
RunConfig,
Runner,
set_tracing_disabled,
)
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for
others. Steps:
1. Create a custom OpenAI client.
2. Create a ModelProvider that uses the custom client.
3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)
class CustomModelProvider(ModelProvider):
def get_model(self, model_name: str | None) -> Model:
return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client)
CUSTOM_MODEL_PROVIDER = CustomModelProvider()
async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
)
# This will use the custom model provider
result = await Runner.run(
agent,
"Tell me about recursion in programming.",
run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER),
)
print(result.final_output)
# If you uncomment this, it will use OpenAI directly, not the custom provider
# result = await Runner.run(
# agent,
# "Tell me about recursion in programming.",
# )
# print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -92,13 +92,19 @@ from .tracing import (
from .usage import Usage
def set_default_openai_key(key: str) -> None:
"""Set the default OpenAI API key to use for LLM requests and tracing. This is only necessary if
the OPENAI_API_KEY environment variable is not already set.
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
"""Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is
only necessary if the OPENAI_API_KEY environment variable is not already set.
If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
Args:
key: The OpenAI key to use.
use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True
If False, you'll either need to set the OPENAI_API_KEY environment variable or call
set_tracing_export_api_key() with the API key you want to use for tracing.
"""
_config.set_default_openai_key(key)
_config.set_default_openai_key(key, use_for_tracing)
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:

View file

@ -5,15 +5,18 @@ from .models import _openai_shared
from .tracing import set_tracing_export_api_key
def set_default_openai_key(key: str) -> None:
set_tracing_export_api_key(key)
def set_default_openai_key(key: str, use_for_tracing: bool) -> None:
_openai_shared.set_default_openai_key(key)
if use_for_tracing:
set_tracing_export_api_key(key)
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
_openai_shared.set_default_openai_client(client)
if use_for_tracing:
set_tracing_export_api_key(client.api_key)
_openai_shared.set_default_openai_client(client)
def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:

View file

@ -38,28 +38,41 @@ class OpenAIProvider(ModelProvider):
assert api_key is None and base_url is None, (
"Don't provide api_key or base_url if you provide openai_client"
)
self._client = openai_client
self._client: AsyncOpenAI | None = openai_client
else:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=api_key or _openai_shared.get_default_openai_key(),
base_url=base_url,
organization=organization,
project=project,
http_client=shared_http_client(),
)
self._client = None
self._stored_api_key = api_key
self._stored_base_url = base_url
self._stored_organization = organization
self._stored_project = project
self._is_openai_model = self._client.base_url.host.startswith("api.openai.com")
if use_responses is not None:
self._use_responses = use_responses
else:
self._use_responses = _openai_shared.get_use_responses_by_default()
# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
# AsyncOpenAI() raises an error if you don't have an API key set.
def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
base_url=self._stored_base_url,
organization=self._stored_organization,
project=self._stored_project,
http_client=shared_http_client(),
)
return self._client
def get_model(self, model_name: str | None) -> Model:
if model_name is None:
model_name = DEFAULT_MODEL
client = self._get_client()
return (
OpenAIResponsesModel(model=model_name, openai_client=self._client)
OpenAIResponsesModel(model=model_name, openai_client=client)
if self._use_responses
else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client)
else OpenAIChatCompletionsModel(model=model_name, openai_client=client)
)