Merge branch 'main' of https://github.com/openai/openai-agents-python into feat/draw_graph

This commit is contained in:
MartinEBravo 2025-03-17 10:17:36 +01:00
commit e984274da1
20 changed files with 388 additions and 48 deletions

View file

@ -0,0 +1,26 @@
---
name: Custom model providers
about: Questions or bugs about using non-OpenAI models
title: ''
labels: bug
assignees: ''
---
### Please read this first
- **Have you read the custom model provider docs, including the 'Common issues' section?** [Model provider docs](https://openai.github.io/openai-agents-python/models/#using-other-llm-providers)
- **Have you searched for related issues?** Others may have faced similar issues.
### Describe the question
A clear and concise description of what the question or bug is.
### Debug information
- Agents SDK version: (e.g. `v0.0.3`)
- Python version (e.g. Python 3.10)
### Repro steps
Ideally provide a minimal python script that can be run to reproduce the issue.
### Expected behavior
A clear and concise description of what you expected to happen.

View file

@ -53,21 +53,41 @@ async def main():
## Using other LLM providers
Many providers also support the OpenAI API format, which means you can pass a `base_url` to the existing OpenAI model implementations and use them easily. `ModelSettings` is used to configure tuning parameters (e.g., temperature, top_p) for the model you select.
You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)):
```python
external_client = AsyncOpenAI(
api_key="EXTERNAL_API_KEY",
base_url="https://api.external.com/v1/",
)
1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py).
2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py).
3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py).
In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](tracing.md).
!!! note
In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses.
## Common issues with using other LLM providers
### Tracing client error 401
If you get errors related to tracing, this is because traces are uploaded to OpenAI servers, and you don't have an OpenAI API key. You have three options to resolve this:
1. Disable tracing entirely: [`set_tracing_disabled(True)`][agents.set_tracing_disabled].
2. Set an OpenAI key for tracing: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key]. This API key will only be used for uploading traces, and must be from [platform.openai.com](https://platform.openai.com/).
3. Use a non-OpenAI trace processor. See the [tracing docs](tracing.md#custom-tracing-processors).
### Responses API support
The SDK uses the Responses API by default, but most other LLM providers don't yet support it. You may see 404s or similar issues as a result. To resolve, you have two options:
1. Call [`set_default_openai_api("chat_completions")`][agents.set_default_openai_api]. This works if you are setting `OPENAI_API_KEY` and `OPENAI_BASE_URL` via environment vars.
2. Use [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel]. There are examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/).
### Structured outputs support
Some model providers don't have support for [structured outputs](https://platform.openai.com/docs/guides/structured-outputs). This sometimes results in an error that looks something like this:
spanish_agent = Agent(
name="Spanish agent",
instructions="You only speak Spanish.",
model=OpenAIChatCompletionsModel(
model="EXTERNAL_MODEL_NAME",
openai_client=external_client,
),
model_settings=ModelSettings(temperature=0.5),
)
```
BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}}
```
This is a shortcoming of some model providers - they support JSON outputs, but don't allow you to specify the `json_schema` to use for the output. We are working on a fix for this, but we suggest relying on providers that do have support for JSON schema output, because otherwise your app will often break because of malformed JSON.

View file

@ -0,0 +1,19 @@
# Custom LLM providers
The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model.
```bash
export EXAMPLE_BASE_URL="..."
export EXAMPLE_API_KEY="..."
export EXAMPLE_MODEL_NAME"..."
```
Then run the examples, e.g.:
```
python examples/model_providers/custom_example_provider.py
Loops within themselves,
Function calls its own being,
Depth without ending.
```

View file

@ -0,0 +1,55 @@
import asyncio
import os
from openai import AsyncOpenAI
from agents import Agent, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for a specific agent. Steps:
1. Create a custom OpenAI client.
2. Create a `Model` that uses the custom client.
3. Set the `model` on the Agent.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)
# An alternate approach that would also work:
# PROVIDER = OpenAIProvider(openai_client=client)
# agent = Agent(..., model="some-custom-model")
# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER))
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
# This agent will use the custom LLM provider
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),
tools=[get_weather],
)
result = await Runner.run(agent, "What's the weather in Tokyo?")
print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,63 @@
import asyncio
import os
from openai import AsyncOpenAI
from agents import (
Agent,
Runner,
function_tool,
set_default_openai_api,
set_default_openai_client,
set_tracing_disabled,
)
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for all requests by default. We do three things:
1. Create a custom client.
2. Set it as the default OpenAI client, and don't use it for tracing.
3. Set the default API as Chat Completions, as most LLM providers don't yet support Responses API.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(
base_url=BASE_URL,
api_key=API_KEY,
)
set_default_openai_client(client=client, use_for_tracing=False)
set_default_openai_api("chat_completions")
set_tracing_disabled(disabled=True)
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=MODEL_NAME,
tools=[get_weather],
)
result = await Runner.run(agent, "What's the weather in Tokyo?")
print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,77 @@
from __future__ import annotations
import asyncio
import os
from openai import AsyncOpenAI
from agents import (
Agent,
Model,
ModelProvider,
OpenAIChatCompletionsModel,
RunConfig,
Runner,
function_tool,
set_tracing_disabled,
)
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for
others. Steps:
1. Create a custom OpenAI client.
2. Create a ModelProvider that uses the custom client.
3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)
class CustomModelProvider(ModelProvider):
def get_model(self, model_name: str | None) -> Model:
return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client)
CUSTOM_MODEL_PROVIDER = CustomModelProvider()
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
agent = Agent(name="Assistant", instructions="You only respond in haikus.", tools=[get_weather])
# This will use the custom model provider
result = await Runner.run(
agent,
"What's the weather in Tokyo?",
run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER),
)
print(result.final_output)
# If you uncomment this, it will use OpenAI directly, not the custom provider
# result = await Runner.run(
# agent,
# "What's the weather in Tokyo?",
# )
# print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,6 +1,6 @@
[project]
name = "openai-agents"
version = "0.0.3"
version = "0.0.4"
description = "OpenAI Agents SDK"
readme = "README.md"
requires-python = ">=3.9"

View file

@ -92,13 +92,19 @@ from .tracing import (
from .usage import Usage
def set_default_openai_key(key: str) -> None:
"""Set the default OpenAI API key to use for LLM requests and tracing. This is only necessary if
the OPENAI_API_KEY environment variable is not already set.
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
"""Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is
only necessary if the OPENAI_API_KEY environment variable is not already set.
If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
Args:
key: The OpenAI key to use.
use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True
If False, you'll either need to set the OPENAI_API_KEY environment variable or call
set_tracing_export_api_key() with the API key you want to use for tracing.
"""
_config.set_default_openai_key(key)
_config.set_default_openai_key(key, use_for_tracing)
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:
@ -123,8 +129,7 @@ def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> Non
def enable_verbose_stdout_logging():
"""Enables verbose logging to stdout. This is useful for debugging."""
for name in ["openai.agents", "openai.agents.tracing"]:
logger = logging.getLogger(name)
logger = logging.getLogger("openai.agents")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

View file

@ -5,15 +5,18 @@ from .models import _openai_shared
from .tracing import set_tracing_export_api_key
def set_default_openai_key(key: str) -> None:
set_tracing_export_api_key(key)
def set_default_openai_key(key: str, use_for_tracing: bool) -> None:
_openai_shared.set_default_openai_key(key)
if use_for_tracing:
set_tracing_export_api_key(key)
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
_openai_shared.set_default_openai_client(client)
if use_for_tracing:
set_tracing_export_api_key(client.api_key)
_openai_shared.set_default_openai_client(client)
def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:

View file

@ -33,6 +33,9 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
@ -337,4 +340,5 @@ def function_schema(
params_json_schema=json_schema,
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
)

View file

@ -38,28 +38,41 @@ class OpenAIProvider(ModelProvider):
assert api_key is None and base_url is None, (
"Don't provide api_key or base_url if you provide openai_client"
)
self._client = openai_client
self._client: AsyncOpenAI | None = openai_client
else:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=api_key or _openai_shared.get_default_openai_key(),
base_url=base_url,
organization=organization,
project=project,
http_client=shared_http_client(),
)
self._client = None
self._stored_api_key = api_key
self._stored_base_url = base_url
self._stored_organization = organization
self._stored_project = project
self._is_openai_model = self._client.base_url.host.startswith("api.openai.com")
if use_responses is not None:
self._use_responses = use_responses
else:
self._use_responses = _openai_shared.get_use_responses_by_default()
# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
# AsyncOpenAI() raises an error if you don't have an API key set.
def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
base_url=self._stored_base_url,
organization=self._stored_organization,
project=self._stored_project,
http_client=shared_http_client(),
)
return self._client
def get_model(self, model_name: str | None) -> Model:
if model_name is None:
model_name = DEFAULT_MODEL
client = self._get_client()
return (
OpenAIResponsesModel(model=model_name, openai_client=self._client)
OpenAIResponsesModel(model=model_name, openai_client=client)
if self._use_responses
else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client)
else OpenAIChatCompletionsModel(model=model_name, openai_client=client)
)

View file

@ -137,6 +137,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
@ -150,6 +151,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
@ -163,6 +165,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
@ -186,6 +189,8 @@ def function_tool(
failure_error_function: If provided, use this function to generate an error message when
the tool call fails. The error message is sent to the LLM. If you pass None, then no
error message will be sent and instead an Exception will be raised.
strict_mode: If False, parameters with default values become optional in the
function schema.
"""
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -195,6 +200,7 @@ def function_tool(
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@ -273,6 +279,7 @@ def function_tool(
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
)
# If func is actually a callable, we were used as @function_tool with no parentheses

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from .logger import logger
from ..logger import logger
from .setup import GLOBAL_TRACE_PROVIDER
from .span_data import (
AgentSpanData,

View file

@ -9,7 +9,7 @@ from typing import Any
import httpx
from .logger import logger
from ..logger import logger
from .processor_interface import TracingExporter, TracingProcessor
from .spans import Span
from .traces import Trace
@ -40,7 +40,7 @@ class BackendSpanExporter(TracingExporter):
"""
Args:
api_key: The API key for the "Authorization" header. Defaults to
`os.environ["OPENAI_TRACE_API_KEY"]` if not provided.
`os.environ["OPENAI_API_KEY"]` if not provided.
organization: The OpenAI organization to use. Defaults to
`os.environ["OPENAI_ORG_ID"]` if not provided.
project: The OpenAI project to use. Defaults to

View file

@ -2,7 +2,7 @@
import contextvars
from typing import TYPE_CHECKING, Any
from .logger import logger
from ..logger import logger
if TYPE_CHECKING:
from .spans import Span

View file

@ -4,8 +4,8 @@ import os
import threading
from typing import Any
from ..logger import logger
from . import util
from .logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope
from .spans import NoOpSpan, Span, SpanImpl, TSpanData

View file

@ -6,8 +6,8 @@ from typing import Any, Generic, TypeVar
from typing_extensions import TypedDict
from ..logger import logger
from . import util
from .logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope
from .span_data import SpanData

View file

@ -4,8 +4,8 @@ import abc
import contextvars
from typing import Any
from ..logger import logger
from . import util
from .logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import Any
from typing import Any, Optional
import pytest
@ -142,3 +142,52 @@ async def test_no_error_on_invalid_json_async():
tool = will_not_fail_on_bad_json_async
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
assert result == "error_ModelBehaviorError"
@function_tool(strict_mode=False)
def optional_param_function(a: int, b: Optional[int] = None) -> str:
if b is None:
return f"{a}_no_b"
return f"{a}_{b}"
@pytest.mark.asyncio
async def test_optional_param_function():
tool = optional_param_function
input_data = {"a": 5}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_no_b"
input_data = {"a": 5, "b": 10}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_10"
@function_tool(strict_mode=False)
def multiple_optional_params_function(
x: int = 42,
y: str = "hello",
z: Optional[int] = None,
) -> str:
if z is None:
return f"{x}_{y}_no_z"
return f"{x}_{y}_{z}"
@pytest.mark.asyncio
async def test_multiple_optional_params_function():
tool = multiple_optional_params_function
input_data: dict[str,Any] = {}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "42_hello_no_z"
input_data = {"x": 10, "y": "world"}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_no_z"
input_data = {"x": 10, "y": "world", "z": 99}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_99"

View file

@ -1,5 +1,4 @@
version = 1
revision = 1
requires-python = ">=3.9"
[[package]]
@ -783,7 +782,7 @@ wheels = [
[[package]]
name = "openai-agents"
version = "0.0.3"
version = "0.0.4"
source = { editable = "." }
dependencies = [
{ name = "griffe" },