Add examples and documentation for using custom model providers
This commit is contained in:
parent
97a09067cf
commit
25a633139d
8 changed files with 247 additions and 34 deletions
|
|
@ -53,21 +53,14 @@ async def main():
|
||||||
|
|
||||||
## Using other LLM providers
|
## Using other LLM providers
|
||||||
|
|
||||||
Many providers also support the OpenAI API format, which means you can pass a `base_url` to the existing OpenAI model implementations and use them easily. `ModelSettings` is used to configure tuning parameters (e.g., temperature, top_p) for the model you select.
|
You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)):
|
||||||
|
|
||||||
```python
|
1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py).
|
||||||
external_client = AsyncOpenAI(
|
2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py).
|
||||||
api_key="EXTERNAL_API_KEY",
|
3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py).
|
||||||
base_url="https://api.external.com/v1/",
|
|
||||||
)
|
|
||||||
|
|
||||||
spanish_agent = Agent(
|
In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](tracing.md).
|
||||||
name="Spanish agent",
|
|
||||||
instructions="You only speak Spanish.",
|
!!! note
|
||||||
model=OpenAIChatCompletionsModel(
|
|
||||||
model="EXTERNAL_MODEL_NAME",
|
In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses.
|
||||||
openai_client=external_client,
|
|
||||||
),
|
|
||||||
model_settings=ModelSettings(temperature=0.5),
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
|
||||||
19
examples/model_providers/README.md
Normal file
19
examples/model_providers/README.md
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Custom LLM providers
|
||||||
|
|
||||||
|
The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export EXAMPLE_BASE_URL="..."
|
||||||
|
export EXAMPLE_API_KEY="..."
|
||||||
|
export EXAMPLE_MODEL_NAME"..."
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the examples, e.g.:
|
||||||
|
|
||||||
|
```
|
||||||
|
python examples/model_providers/custom_example_provider.py
|
||||||
|
|
||||||
|
Loops within themselves,
|
||||||
|
Function calls its own being,
|
||||||
|
Depth without ending.
|
||||||
|
```
|
||||||
51
examples/model_providers/custom_example_agent.py
Normal file
51
examples/model_providers/custom_example_agent.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from agents import Agent, OpenAIChatCompletionsModel, Runner, set_tracing_disabled
|
||||||
|
|
||||||
|
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
|
||||||
|
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
|
||||||
|
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
|
||||||
|
|
||||||
|
if not BASE_URL or not API_KEY or not MODEL_NAME:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
|
||||||
|
)
|
||||||
|
|
||||||
|
"""This example uses a custom provider for a specific agent. Steps:
|
||||||
|
1. Create a custom OpenAI client.
|
||||||
|
2. Create a `Model` that uses the custom client.
|
||||||
|
3. Set the `model` on the Agent.
|
||||||
|
|
||||||
|
Note that in this example, we disable tracing under the assumption that you don't have an API key
|
||||||
|
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
|
||||||
|
or call set_tracing_export_api_key() to set a tracing specific key.
|
||||||
|
"""
|
||||||
|
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
|
||||||
|
set_tracing_disabled(disabled=True)
|
||||||
|
|
||||||
|
# An alternate approach that would also work:
|
||||||
|
# PROVIDER = OpenAIProvider(openai_client=client)
|
||||||
|
# agent = Agent(..., model="some-custom-model")
|
||||||
|
# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER))
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# This agent will use the custom LLM provider
|
||||||
|
agent = Agent(
|
||||||
|
name="Assistant",
|
||||||
|
instructions="You only respond in haikus.",
|
||||||
|
model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Runner.run(
|
||||||
|
agent,
|
||||||
|
"Tell me about recursion in programming.",
|
||||||
|
)
|
||||||
|
print(result.final_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
55
examples/model_providers/custom_example_global.py
Normal file
55
examples/model_providers/custom_example_global.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from agents import (
|
||||||
|
Agent,
|
||||||
|
Runner,
|
||||||
|
set_default_openai_api,
|
||||||
|
set_default_openai_client,
|
||||||
|
set_tracing_disabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
|
||||||
|
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
|
||||||
|
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
|
||||||
|
|
||||||
|
if not BASE_URL or not API_KEY or not MODEL_NAME:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""This example uses a custom provider for all requests by default. We do three things:
|
||||||
|
1. Create a custom client.
|
||||||
|
2. Set it as the default OpenAI client, and don't use it for tracing.
|
||||||
|
3. Set the default API as Chat Completions, as most LLM providers don't yet support Responses API.
|
||||||
|
|
||||||
|
Note that in this example, we disable tracing under the assumption that you don't have an API key
|
||||||
|
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
|
||||||
|
or call set_tracing_export_api_key() to set a tracing specific key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
base_url=BASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
|
)
|
||||||
|
set_default_openai_client(client=client, use_for_tracing=False)
|
||||||
|
set_default_openai_api("chat_completions")
|
||||||
|
set_tracing_disabled(disabled=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
agent = Agent(
|
||||||
|
name="Assistant",
|
||||||
|
instructions="You only respond in haikus.",
|
||||||
|
model=MODEL_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Runner.run(agent, "Tell me about recursion in programming.")
|
||||||
|
print(result.final_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
73
examples/model_providers/custom_example_provider.py
Normal file
73
examples/model_providers/custom_example_provider.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from agents import (
|
||||||
|
Agent,
|
||||||
|
Model,
|
||||||
|
ModelProvider,
|
||||||
|
OpenAIChatCompletionsModel,
|
||||||
|
RunConfig,
|
||||||
|
Runner,
|
||||||
|
set_tracing_disabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
|
||||||
|
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
|
||||||
|
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
|
||||||
|
|
||||||
|
if not BASE_URL or not API_KEY or not MODEL_NAME:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for
|
||||||
|
others. Steps:
|
||||||
|
1. Create a custom OpenAI client.
|
||||||
|
2. Create a ModelProvider that uses the custom client.
|
||||||
|
3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider.
|
||||||
|
|
||||||
|
Note that in this example, we disable tracing under the assumption that you don't have an API key
|
||||||
|
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
|
||||||
|
or call set_tracing_export_api_key() to set a tracing specific key.
|
||||||
|
"""
|
||||||
|
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
|
||||||
|
set_tracing_disabled(disabled=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelProvider(ModelProvider):
|
||||||
|
def get_model(self, model_name: str | None) -> Model:
|
||||||
|
return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client)
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_MODEL_PROVIDER = CustomModelProvider()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
agent = Agent(
|
||||||
|
name="Assistant",
|
||||||
|
instructions="You only respond in haikus.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will use the custom model provider
|
||||||
|
result = await Runner.run(
|
||||||
|
agent,
|
||||||
|
"Tell me about recursion in programming.",
|
||||||
|
run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER),
|
||||||
|
)
|
||||||
|
print(result.final_output)
|
||||||
|
|
||||||
|
# If you uncomment this, it will use OpenAI directly, not the custom provider
|
||||||
|
# result = await Runner.run(
|
||||||
|
# agent,
|
||||||
|
# "Tell me about recursion in programming.",
|
||||||
|
# )
|
||||||
|
# print(result.final_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -92,13 +92,19 @@ from .tracing import (
|
||||||
from .usage import Usage
|
from .usage import Usage
|
||||||
|
|
||||||
|
|
||||||
def set_default_openai_key(key: str) -> None:
|
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
|
||||||
"""Set the default OpenAI API key to use for LLM requests and tracing. This is only necessary if
|
"""Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is
|
||||||
the OPENAI_API_KEY environment variable is not already set.
|
only necessary if the OPENAI_API_KEY environment variable is not already set.
|
||||||
|
|
||||||
If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
|
If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The OpenAI key to use.
|
||||||
|
use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True
|
||||||
|
If False, you'll either need to set the OPENAI_API_KEY environment variable or call
|
||||||
|
set_tracing_export_api_key() with the API key you want to use for tracing.
|
||||||
"""
|
"""
|
||||||
_config.set_default_openai_key(key)
|
_config.set_default_openai_key(key, use_for_tracing)
|
||||||
|
|
||||||
|
|
||||||
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:
|
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,18 @@ from .models import _openai_shared
|
||||||
from .tracing import set_tracing_export_api_key
|
from .tracing import set_tracing_export_api_key
|
||||||
|
|
||||||
|
|
||||||
def set_default_openai_key(key: str) -> None:
|
def set_default_openai_key(key: str, use_for_tracing: bool) -> None:
|
||||||
set_tracing_export_api_key(key)
|
|
||||||
_openai_shared.set_default_openai_key(key)
|
_openai_shared.set_default_openai_key(key)
|
||||||
|
|
||||||
|
if use_for_tracing:
|
||||||
|
set_tracing_export_api_key(key)
|
||||||
|
|
||||||
|
|
||||||
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
|
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
|
||||||
|
_openai_shared.set_default_openai_client(client)
|
||||||
|
|
||||||
if use_for_tracing:
|
if use_for_tracing:
|
||||||
set_tracing_export_api_key(client.api_key)
|
set_tracing_export_api_key(client.api_key)
|
||||||
_openai_shared.set_default_openai_client(client)
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
|
def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -38,28 +38,41 @@ class OpenAIProvider(ModelProvider):
|
||||||
assert api_key is None and base_url is None, (
|
assert api_key is None and base_url is None, (
|
||||||
"Don't provide api_key or base_url if you provide openai_client"
|
"Don't provide api_key or base_url if you provide openai_client"
|
||||||
)
|
)
|
||||||
self._client = openai_client
|
self._client: AsyncOpenAI | None = openai_client
|
||||||
else:
|
else:
|
||||||
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
|
self._client = None
|
||||||
api_key=api_key or _openai_shared.get_default_openai_key(),
|
self._stored_api_key = api_key
|
||||||
base_url=base_url,
|
self._stored_base_url = base_url
|
||||||
organization=organization,
|
self._stored_organization = organization
|
||||||
project=project,
|
self._stored_project = project
|
||||||
http_client=shared_http_client(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._is_openai_model = self._client.base_url.host.startswith("api.openai.com")
|
|
||||||
if use_responses is not None:
|
if use_responses is not None:
|
||||||
self._use_responses = use_responses
|
self._use_responses = use_responses
|
||||||
else:
|
else:
|
||||||
self._use_responses = _openai_shared.get_use_responses_by_default()
|
self._use_responses = _openai_shared.get_use_responses_by_default()
|
||||||
|
|
||||||
|
# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
|
||||||
|
# AsyncOpenAI() raises an error if you don't have an API key set.
|
||||||
|
def _get_client(self) -> AsyncOpenAI:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
|
||||||
|
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
|
||||||
|
base_url=self._stored_base_url,
|
||||||
|
organization=self._stored_organization,
|
||||||
|
project=self._stored_project,
|
||||||
|
http_client=shared_http_client(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._client
|
||||||
|
|
||||||
def get_model(self, model_name: str | None) -> Model:
|
def get_model(self, model_name: str | None) -> Model:
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
model_name = DEFAULT_MODEL
|
model_name = DEFAULT_MODEL
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
OpenAIResponsesModel(model=model_name, openai_client=self._client)
|
OpenAIResponsesModel(model=model_name, openai_client=client)
|
||||||
if self._use_responses
|
if self._use_responses
|
||||||
else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client)
|
else OpenAIChatCompletionsModel(model=model_name, openai_client=client)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue