Update custom models to use tools

This commit is contained in:
Rohan Mehta 2025-03-13 13:10:25 -04:00
parent 691be07339
commit 6ab8c91d23
4 changed files with 29 additions and 14 deletions

View file

@ -3,7 +3,7 @@ import os
from openai import AsyncOpenAI
from agents import Agent, OpenAIChatCompletionsModel, Runner, set_tracing_disabled
from agents import Agent, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
@ -32,18 +32,22 @@ set_tracing_disabled(disabled=True)
# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER))
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
# This agent will use the custom LLM provider
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),
tools=[get_weather],
)
result = await Runner.run(
agent,
"Tell me about recursion in programming.",
)
result = await Runner.run(agent, "What's the weather in Tokyo?")
print(result.final_output)

View file

@ -6,6 +6,7 @@ from openai import AsyncOpenAI
from agents import (
Agent,
Runner,
function_tool,
set_default_openai_api,
set_default_openai_client,
set_tracing_disabled,
@ -40,14 +41,21 @@ set_default_openai_api("chat_completions")
set_tracing_disabled(disabled=True)
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=MODEL_NAME,
tools=[get_weather],
)
result = await Runner.run(agent, "Tell me about recursion in programming.")
result = await Runner.run(agent, "What's the weather in Tokyo?")
print(result.final_output)

View file

@ -12,6 +12,7 @@ from agents import (
OpenAIChatCompletionsModel,
RunConfig,
Runner,
function_tool,
set_tracing_disabled,
)
@ -47,16 +48,19 @@ class CustomModelProvider(ModelProvider):
CUSTOM_MODEL_PROVIDER = CustomModelProvider()
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
)
agent = Agent(name="Assistant", instructions="You only respond in haikus.", tools=[get_weather])
# This will use the custom model provider
result = await Runner.run(
agent,
"Tell me about recursion in programming.",
"What's the weather in Tokyo?",
run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER),
)
print(result.final_output)
@ -64,7 +68,7 @@ async def main():
# If you uncomment this, it will use OpenAI directly, not the custom provider
# result = await Runner.run(
# agent,
# "Tell me about recursion in programming.",
# "What's the weather in Tokyo?",
# )
# print(result.final_output)

View file

@ -1,5 +1,4 @@
version = 1
revision = 1
requires-python = ">=3.9"
[[package]]
@ -783,7 +782,7 @@ wheels = [
[[package]]
name = "openai-agents"
version = "0.0.3"
version = "0.0.4"
source = { editable = "." }
dependencies = [
{ name = "griffe" },