Merge branch 'main' of https://github.com/openai/openai-agents-python into feat/draw_graph

This commit is contained in:
Martín Bravo 2025-03-25 16:58:01 +01:00
commit 900a97fa55
105 changed files with 6252 additions and 764 deletions

View file

@ -17,7 +17,10 @@ jobs:
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 7 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 3 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
any-of-labels: 'question,needs-more-info'
any-of-issue-labels: 'question,needs-more-info'
days-before-pr-stale: 10
days-before-pr-close: 7
stale-pr-label: "stale"
stale-pr-message: "This PR is stale because it has been open for 10 days with no activity."
close-pr-message: "This PR was closed because it has been inactive for 7 days since being marked as stale."
repo-token: ${{ secrets.GITHUB_TOKEN }}

View file

@ -8,6 +8,9 @@ on:
branches:
- main
env:
UV_FROZEN: "1"
jobs:
lint:
runs-on: ubuntu-latest
@ -50,8 +53,8 @@ jobs:
enable-cache: true
- name: Install dependencies
run: make sync
- name: Run tests
run: make tests
- name: Run tests with coverage
run: make coverage
build-docs:
runs-on: ubuntu-latest

2
.gitignore vendored
View file

@ -135,7 +135,7 @@ dmypy.json
cython_debug/
# PyCharm
#.idea/
.idea/
# Ruff stuff:
.ruff_cache/

View file

@ -5,6 +5,7 @@ sync:
.PHONY: format
format:
uv run ruff format
uv run ruff check --fix
.PHONY: lint
lint:
@ -18,6 +19,13 @@ mypy:
tests:
uv run pytest
.PHONY: coverage
coverage:
uv run coverage run -m pytest
uv run coverage xml -o coverage.xml
uv run coverage report -m --fail-under=95
.PHONY: snapshots-fix
snapshots-fix:
uv run pytest --inline-snapshot=fix
@ -29,7 +37,6 @@ snapshots-create:
.PHONY: old_version_tests
old_version_tests:
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m mypy .
.PHONY: build-docs
build-docs:
@ -42,4 +49,6 @@ serve-docs:
.PHONY: deploy-docs
deploy-docs:
uv run mkdocs gh-deploy --force --verbose

View file

@ -30,6 +30,8 @@ source env/bin/activate
pip install openai-agents
```
For voice support, install with the optional `voice` group: `pip install 'openai-agents[voice]'`.
## Hello world example
```python

View file

@ -130,3 +130,23 @@ robot_agent = pirate_agent.clone(
instructions="Write like a robot",
)
```
## Forcing tool use
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
1. `auto`, which allows the LLM to decide whether or not to use a tool.
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
3. `none`, which requires the LLM to _not_ use a tool.
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
!!! note
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:
1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
2. When `tool_choice` is set to "required" AND there is only one tool available
This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.
If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.

View file

@ -41,14 +41,14 @@ async def fetch_user_age(wrapper: RunContextWrapper[UserInfo]) -> str: # (2)!
return f"User {wrapper.context.name} is 47 years old"
async def main():
user_info = UserInfo(name="John", uid=123) # (3)!
user_info = UserInfo(name="John", uid=123)
agent = Agent[UserInfo]( # (4)!
agent = Agent[UserInfo]( # (3)!
name="Assistant",
tools=[fetch_user_age],
)
result = await Runner.run(
result = await Runner.run( # (4)!
starting_agent=agent,
input="What is the age of the user?",
context=user_info,

36
docs/examples.md Normal file
View file

@ -0,0 +1,36 @@
# Examples
Check out a variety of sample implementations of the SDK in the examples section of the [repo](https://github.com/openai/openai-agents-python/tree/main/examples). The examples are organized into several categories that demonstrate different patterns and capabilities.
## Categories
- **agent_patterns:**
Examples in this category illustrate common agent design patterns, such as
- Deterministic workflows
- Agents as tools
- Parallel agent execution
- **basic:**
These examples showcase foundational capabilities of the SDK, such as
- Dynamic system prompts
- Streaming outputs
- Lifecycle events
- **tool examples:**
Learn how to implement OAI hosted tools such as web search and file search,
and integrate them into your agents.
- **model providers:**
Explore how to use non-OpenAI models with the SDK.
- **handoffs:**
See practical examples of agent handoffs.
- **customer_service** and **research_bot:**
Two more built-out examples that illustrate real-world applications
- **customer_service**: Example customer service system for an airline.
- **research_bot**: Simple deep research clone.

View file

@ -29,7 +29,7 @@ Output guardrails run in 3 steps:
!!! Note
Output guardrails are intended to run on the final agent input, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability.
Output guardrails are intended to run on the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability.
## Tripwires
@ -111,8 +111,8 @@ class MessageOutput(BaseModel): # (1)!
response: str
class MathOutput(BaseModel): # (2)!
is_math: bool
reasoning: str
is_math: bool
guardrail_agent = Agent(
name="Guardrail check",

3
docs/ref/voice/events.md Normal file
View file

@ -0,0 +1,3 @@
# `Events`
::: agents.voice.events

View file

@ -0,0 +1,3 @@
# `Exceptions`
::: agents.voice.exceptions

3
docs/ref/voice/input.md Normal file
View file

@ -0,0 +1,3 @@
# `Input`
::: agents.voice.input

3
docs/ref/voice/model.md Normal file
View file

@ -0,0 +1,3 @@
# `Model`
::: agents.voice.model

View file

@ -0,0 +1,3 @@
# `OpenAIVoiceModelProvider`
::: agents.voice.models.openai_model_provider

View file

@ -0,0 +1,3 @@
# `OpenAI STT`
::: agents.voice.models.openai_stt

View file

@ -0,0 +1,3 @@
# `OpenAI TTS`
::: agents.voice.models.openai_tts

View file

@ -0,0 +1,3 @@
# `Pipeline`
::: agents.voice.pipeline

View file

@ -0,0 +1,3 @@
# `Pipeline Config`
::: agents.voice.pipeline_config

3
docs/ref/voice/result.md Normal file
View file

@ -0,0 +1,3 @@
# `Result`
::: agents.voice.result

3
docs/ref/voice/utils.md Normal file
View file

@ -0,0 +1,3 @@
# `Utils`
::: agents.voice.utils

View file

@ -0,0 +1,3 @@
# `Workflow`
::: agents.voice.workflow

View file

@ -35,6 +35,9 @@ By default, the SDK traces the following:
- Function tool calls are each wrapped in `function_span()`
- Guardrails are wrapped in `guardrail_span()`
- Handoffs are wrapped in `handoff_span()`
- Audio inputs (speech-to-text) are wrapped in a `transcription_span()`
- Audio outputs (text-to-speech) are wrapped in a `speech_span()`
- Related audio spans may be parented under a `speech_group_span()`
By default, the trace is named "Agent trace". You can set this name if you use `trace`, or you can can configure the name and other properties with the [`RunConfig`][agents.run.RunConfig].
@ -76,7 +79,11 @@ Spans are automatically part of the current trace, and are nested under the near
## Sensitive data
Some spans track potentially sensitive data. For example, the `generation_span()` stores the inputs/outputs of the LLM generation, and `function_span()` stores the inputs/outputs of function calls. These may contain sensitive data, so you can disable capturing that data via [`RunConfig.trace_include_sensitive_data`][agents.run.RunConfig.trace_include_sensitive_data].
Certain spans may capture potentially sensitive data.
The `generation_span()` stores the inputs/outputs of the LLM generation, and `function_span()` stores the inputs/outputs of function calls. These may contain sensitive data, so you can disable capturing that data via [`RunConfig.trace_include_sensitive_data`][agents.run.RunConfig.trace_include_sensitive_data].
Similarly, Audio spans include base64-encoded PCM data for input and output audio by default. You can disable capturing this audio data by configuring [`VoicePipelineConfig.trace_include_sensitive_audio_data`][agents.voice.pipeline_config.VoicePipelineConfig.trace_include_sensitive_audio_data].
## Custom tracing processors
@ -92,10 +99,15 @@ To customize this default setup, to send traces to alternative or additional bac
## External tracing processors list
- [Weights & Biases](https://weave-docs.wandb.ai/guides/integrations/openai_agents)
- [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk)
- [MLflow](https://mlflow.org/docs/latest/tracing/integrations/openai-agent)
- [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk)
- [Pydantic Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents)
- [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk)
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration)
- [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent)
- [LangSmith](https://docs.smith.langchain.com/observability/how_to_guides/trace_with_openai_agents_sdk)
- [Maxim AI](https://www.getmaxim.ai/docs/observe/integrations/openai-agents-sdk)
- [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents)
- [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents)

75
docs/voice/pipeline.md Normal file
View file

@ -0,0 +1,75 @@
# Pipelines and workflows
[`VoicePipeline`][agents.voice.pipeline.VoicePipeline] is a class that makes it easy to turn your agentic workflows into a voice app. You pass in a workflow to run, and the pipeline takes care of transcribing input audio, detecting when the audio ends, calling your workflow at the right time, and turning the workflow output back into audio.
```mermaid
graph LR
%% Input
A["🎤 Audio Input"]
%% Voice Pipeline
subgraph Voice_Pipeline [Voice Pipeline]
direction TB
B["Transcribe (speech-to-text)"]
C["Your Code"]:::highlight
D["Text-to-speech"]
B --> C --> D
end
%% Output
E["🎧 Audio Output"]
%% Flow
A --> Voice_Pipeline
Voice_Pipeline --> E
%% Custom styling
classDef highlight fill:#ffcc66,stroke:#333,stroke-width:1px,font-weight:700;
```
## Configuring a pipeline
When you create a pipeline, you can set a few things:
1. The [`workflow`][agents.voice.workflow.VoiceWorkflowBase], which is the code that runs each time new audio is transcribed.
2. The [`speech-to-text`][agents.voice.model.STTModel] and [`text-to-speech`][agents.voice.model.TTSModel] models used
3. The [`config`][agents.voice.pipeline_config.VoicePipelineConfig], which lets you configure things like:
- A model provider, which can map model names to models
- Tracing, including whether to disable tracing, whether audio files are uploaded, the workflow name, trace IDs etc.
- Settings on the TTS and STT models, like the prompt, language and data types used.
## Running a pipeline
You can run a pipeline via the [`run()`][agents.voice.pipeline.VoicePipeline.run] method, which lets you pass in audio input in two forms:
1. [`AudioInput`][agents.voice.input.AudioInput] is used when you have a full audio transcript, and just want to produce a result for it. This is useful in cases where you don't need to detect when a speaker is done speaking; for example, when you have pre-recorded audio or in push-to-talk apps where it's clear when the user is done speaking.
2. [`StreamedAudioInput`][agents.voice.input.StreamedAudioInput] is used when you might need to detect when a user is done speaking. It allows you to push audio chunks as they are detected, and the voice pipeline will automatically run the agent workflow at the right time, via a process called "activity detection".
## Results
The result of a voice pipeline run is a [`StreamedAudioResult`][agents.voice.result.StreamedAudioResult]. This is an object that lets you stream events as they occur. There are a few kinds of [`VoiceStreamEvent`][agents.voice.events.VoiceStreamEvent], including:
1. [`VoiceStreamEventAudio`][agents.voice.events.VoiceStreamEventAudio], which contains a chunk of audio.
2. [`VoiceStreamEventLifecycle`][agents.voice.events.VoiceStreamEventLifecycle], which informs you of lifecycle events like a turn starting or ending.
3. [`VoiceStreamEventError`][agents.voice.events.VoiceStreamEventError], is an error event.
```python
result = await pipeline.run(input)
async for event in result.stream():
if event.type == "voice_stream_event_audio":
# play audio
elif event.type == "voice_stream_event_lifecycle":
# lifecycle
elif event.type == "voice_stream_event_error"
# error
...
```
## Best practices
### Interruptions
The Agents SDK currently does not support any built-in interruptions support for [`StreamedAudioInput`][agents.voice.input.StreamedAudioInput]. Instead for every detected turn it will trigger a separate run of your workflow. If you want to handle interruptions inside your application you can listen to the [`VoiceStreamEventLifecycle`][agents.voice.events.VoiceStreamEventLifecycle] events. `turn_started` will indicate that a new turn was transcribed and processing is beginning. `turn_ended` will trigger after all the audio was dispatched for a respective turn. You could use these events to mute the microphone of the speaker when the model starts a turn and unmute it after you flushed all the related audio for a turn.

194
docs/voice/quickstart.md Normal file
View file

@ -0,0 +1,194 @@
# Quickstart
## Prerequisites
Make sure you've followed the base [quickstart instructions](../quickstart.md) for the Agents SDK, and set up a virtual environment. Then, install the optional voice dependencies from the SDK:
```bash
pip install 'openai-agents[voice]'
```
## Concepts
The main concept to know about is a [`VoicePipeline`][agents.voice.pipeline.VoicePipeline], which is a 3 step process:
1. Run a speech-to-text model to turn audio into text.
2. Run your code, which is usually an agentic workflow, to produce a result.
3. Run a text-to-speech model to turn the result text back into audio.
```mermaid
graph LR
%% Input
A["🎤 Audio Input"]
%% Voice Pipeline
subgraph Voice_Pipeline [Voice Pipeline]
direction TB
B["Transcribe (speech-to-text)"]
C["Your Code"]:::highlight
D["Text-to-speech"]
B --> C --> D
end
%% Output
E["🎧 Audio Output"]
%% Flow
A --> Voice_Pipeline
Voice_Pipeline --> E
%% Custom styling
classDef highlight fill:#ffcc66,stroke:#333,stroke-width:1px,font-weight:700;
```
## Agents
First, let's set up some Agents. This should feel familiar to you if you've built any agents with this SDK. We'll have a couple of Agents, a handoff, and a tool.
```python
import asyncio
import random
from agents import (
Agent,
function_tool,
)
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
@function_tool
def get_weather(city: str) -> str:
"""Get the weather for a given city."""
print(f"[debug] get_weather called with city: {city}")
choices = ["sunny", "cloudy", "rainy", "snowy"]
return f"The weather in {city} is {random.choice(choices)}."
spanish_agent = Agent(
name="Spanish",
handoff_description="A spanish speaking agent.",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
),
model="gpt-4o-mini",
)
agent = Agent(
name="Assistant",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
),
model="gpt-4o-mini",
handoffs=[spanish_agent],
tools=[get_weather],
)
```
## Voice pipeline
We'll set up a simple voice pipeline, using [`SingleAgentVoiceWorkflow`][agents.voice.workflow.SingleAgentVoiceWorkflow] as the workflow.
```python
from agents.voice import SingleAgentVoiceWorkflow, VoicePipeline
pipeline = VoicePipeline(workflow=SingleAgentVoiceWorkflow(agent))
```
## Run the pipeline
```python
import numpy as np
import sounddevice as sd
from agents.voice import AudioInput
# For simplicity, we'll just create 3 seconds of silence
# In reality, you'd get microphone data
buffer = np.zeros(24000 * 3, dtype=np.int16)
audio_input = AudioInput(buffer=buffer)
result = await pipeline.run(audio_input)
# Create an audio player using `sounddevice`
player = sd.OutputStream(samplerate=24000, channels=1, dtype=np.int16)
player.start()
# Play the audio stream as it comes in
async for event in result.stream():
if event.type == "voice_stream_event_audio":
player.write(event.data)
```
## Put it all together
```python
import asyncio
import random
import numpy as np
import sounddevice as sd
from agents import (
Agent,
function_tool,
set_tracing_disabled,
)
from agents.voice import (
AudioInput,
SingleAgentVoiceWorkflow,
VoicePipeline,
)
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
@function_tool
def get_weather(city: str) -> str:
"""Get the weather for a given city."""
print(f"[debug] get_weather called with city: {city}")
choices = ["sunny", "cloudy", "rainy", "snowy"]
return f"The weather in {city} is {random.choice(choices)}."
spanish_agent = Agent(
name="Spanish",
handoff_description="A spanish speaking agent.",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
),
model="gpt-4o-mini",
)
agent = Agent(
name="Assistant",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
),
model="gpt-4o-mini",
handoffs=[spanish_agent],
tools=[get_weather],
)
async def main():
pipeline = VoicePipeline(workflow=SingleAgentVoiceWorkflow(agent))
buffer = np.zeros(24000 * 3, dtype=np.int16)
audio_input = AudioInput(buffer=buffer)
result = await pipeline.run(audio_input)
# Create an audio player using `sounddevice`
player = sd.OutputStream(samplerate=24000, channels=1, dtype=np.int16)
player.start()
# Play the audio stream as it comes in
async for event in result.stream():
if event.type == "voice_stream_event_audio":
player.write(event.data)
if __name__ == "__main__":
asyncio.run(main())
```
If you run this example, the agent will speak to you! Check out the example in [examples/voice/static](https://github.com/openai/openai-agents-python/tree/main/examples/voice/static) to see a demo where you can speak to the agent yourself.

14
docs/voice/tracing.md Normal file
View file

@ -0,0 +1,14 @@
# Tracing
Just like the way [agents are traced](../tracing.md), voice pipelines are also automatically traced.
You can read the tracing doc above for basic tracing information, but you can additionally configure tracing of a pipeline via [`VoicePipelineConfig`][agents.voice.pipeline_config.VoicePipelineConfig].
Key tracing related fields are:
- [`tracing_disabled`][agents.voice.pipeline_config.VoicePipelineConfig.tracing_disabled]: controls whether tracing is disabled. By default, tracing is enabled.
- [`trace_include_sensitive_data`][agents.voice.pipeline_config.VoicePipelineConfig.trace_include_sensitive_data]: controls whether traces include potentially sensitive data, like audio transcripts. This is specifically for the voice pipeline, and not for anything that goes on inside your Workflow.
- [`trace_include_sensitive_audio_data`][agents.voice.pipeline_config.VoicePipelineConfig.trace_include_sensitive_audio_data]: controls whether traces include audio data.
- [`workflow_name`][agents.voice.pipeline_config.VoicePipelineConfig.workflow_name]: The name of the trace workflow.
- [`group_id`][agents.voice.pipeline_config.VoicePipelineConfig.group_id]: The `group_id` of the trace, which lets you link multiple traces.
- [`trace_metadata`][agents.voice.pipeline_config.VoicePipelineConfig.tracing_disabled]: Additional metadata to include with the trace.

View file

@ -0,0 +1,99 @@
from __future__ import annotations
import asyncio
from typing import Any, Literal
from pydantic import BaseModel
from agents import (
Agent,
FunctionToolResult,
ModelSettings,
RunContextWrapper,
Runner,
ToolsToFinalOutputFunction,
ToolsToFinalOutputResult,
function_tool,
)
"""
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
to force the agent to use any tool.
You can run it with 3 options:
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
call the tool, the tool would run and send the results to the LLM, and that would repeat
(because the model is forced to use a tool every time.)
2. `first_tool_result`: The first tool result is used as the final output.
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
results, and chooses to use the first tool result to generate the final output.
Usage:
python examples/agent_patterns/forcing_tool_use.py -t default
python examples/agent_patterns/forcing_tool_use.py -t first_tool
python examples/agent_patterns/forcing_tool_use.py -t custom
"""
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
def get_weather(city: str) -> Weather:
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
async def custom_tool_use_behavior(
context: RunContextWrapper[Any], results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
weather: Weather = results[0].output
return ToolsToFinalOutputResult(
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
)
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
if tool_use_behavior == "default":
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
"run_llm_again"
)
elif tool_use_behavior == "first_tool":
behavior = "stop_on_first_tool"
elif tool_use_behavior == "custom":
behavior = custom_tool_use_behavior
agent = Agent(
name="Weather agent",
instructions="You are a helpful agent.",
tools=[get_weather],
tool_use_behavior=behavior,
model_settings=ModelSettings(
tool_choice="required" if tool_use_behavior != "default" else None
),
)
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--tool-use-behavior",
type=str,
required=True,
choices=["default", "first_tool", "custom"],
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
"first_tool_result will cause the first tool result to be used as the final output. "
"custom will use a custom tool use behavior function.",
)
args = parser.parse_args()
asyncio.run(main(args.tool_use_behavior))

View file

@ -79,7 +79,7 @@ multiply_agent = Agent(
start_agent = Agent(
name="Start Agent",
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multipler agent.",
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multiplier agent.",
tools=[random_number],
output_type=FinalResult,
handoffs=[multiply_agent],

34
examples/basic/tools.py Normal file
View file

@ -0,0 +1,34 @@
import asyncio
from pydantic import BaseModel
from agents import Agent, Runner, function_tool
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
def get_weather(city: str) -> Weather:
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
async def main():
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
# The weather in Tokyo is sunny.
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,38 @@
# Financial Research Agent Example
This example shows how you might compose a richer financial research agent using the Agents SDK. The pattern is similar to the `research_bot` example, but with more specialized subagents and a verification step.
The flow is:
1. **Planning**: A planner agent turns the end users request into a list of search terms relevant to financial analysis recent news, earnings calls, corporate filings, industry commentary, etc.
2. **Search**: A search agent uses the builtin `WebSearchTool` to retrieve terse summaries for each search term. (You could also add `FileSearchTool` if you have indexed PDFs or 10Ks.)
3. **Subanalysts**: Additional agents (e.g. a fundamentals analyst and a risk analyst) are exposed as tools so the writer can call them inline and incorporate their outputs.
4. **Writing**: A senior writer agent brings together the search snippets and any subanalyst summaries into a longform markdown report plus a short executive summary.
5. **Verification**: A final verifier agent audits the report for obvious inconsistencies or missing sourcing.
You can run the example with:
```bash
python -m examples.financial_research_agent.main
```
and enter a query like:
```
Write up an analysis of Apple Inc.'s most recent quarter.
```
### Starter prompt
The writer agent is seeded with instructions similar to:
```
You are a senior financial analyst. You will be provided with the original query
and a set of raw search summaries. Your job is to synthesize these into a
longform markdown report (at least several paragraphs) with a short executive
summary. You also have access to tools like `fundamentals_analysis` and
`risk_analysis` to get short specialist writeups if you want to incorporate them.
Add a few followup questions for further research.
```
You can tweak these prompts and subagents to suit your own data sources and preferred report structure.

View file

@ -0,0 +1,23 @@
from pydantic import BaseModel
from agents import Agent
# A subagent focused on analyzing a company's fundamentals.
FINANCIALS_PROMPT = (
"You are a financial analyst focused on company fundamentals such as revenue, "
"profit, margins and growth trajectory. Given a collection of web (and optional file) "
"search results about a company, write a concise analysis of its recent financial "
"performance. Pull out key metrics or quotes. Keep it under 2 paragraphs."
)
class AnalysisSummary(BaseModel):
summary: str
"""Short text summary for this aspect of the analysis."""
financials_agent = Agent(
name="FundamentalsAnalystAgent",
instructions=FINANCIALS_PROMPT,
output_type=AnalysisSummary,
)

View file

@ -0,0 +1,35 @@
from pydantic import BaseModel
from agents import Agent
# Generate a plan of searches to ground the financial analysis.
# For a given financial question or company, we want to search for
# recent news, official filings, analyst commentary, and other
# relevant background.
PROMPT = (
"You are a financial research planner. Given a request for financial analysis, "
"produce a set of web searches to gather the context needed. Aim for recent "
"headlines, earnings calls or 10K snippets, analyst commentary, and industry background. "
"Output between 5 and 15 search terms to query for."
)
class FinancialSearchItem(BaseModel):
reason: str
"""Your reasoning for why this search is relevant."""
query: str
"""The search term to feed into a web (or file) search."""
class FinancialSearchPlan(BaseModel):
searches: list[FinancialSearchItem]
"""A list of searches to perform."""
planner_agent = Agent(
name="FinancialPlannerAgent",
instructions=PROMPT,
model="o3-mini",
output_type=FinancialSearchPlan,
)

View file

@ -0,0 +1,22 @@
from pydantic import BaseModel
from agents import Agent
# A subagent specializing in identifying risk factors or concerns.
RISK_PROMPT = (
"You are a risk analyst looking for potential red flags in a company's outlook. "
"Given background research, produce a short analysis of risks such as competitive threats, "
"regulatory issues, supply chain problems, or slowing growth. Keep it under 2 paragraphs."
)
class AnalysisSummary(BaseModel):
summary: str
"""Short text summary for this aspect of the analysis."""
risk_agent = Agent(
name="RiskAnalystAgent",
instructions=RISK_PROMPT,
output_type=AnalysisSummary,
)

View file

@ -0,0 +1,18 @@
from agents import Agent, WebSearchTool
from agents.model_settings import ModelSettings
# Given a search term, use web search to pull back a brief summary.
# Summaries should be concise but capture the main financial points.
INSTRUCTIONS = (
"You are a research assistant specializing in financial topics. "
"Given a search term, use web search to retrieve uptodate context and "
"produce a short summary of at most 300 words. Focus on key numbers, events, "
"or quotes that will be useful to a financial analyst."
)
search_agent = Agent(
name="FinancialSearchAgent",
instructions=INSTRUCTIONS,
tools=[WebSearchTool()],
model_settings=ModelSettings(tool_choice="required"),
)

View file

@ -0,0 +1,27 @@
from pydantic import BaseModel
from agents import Agent
# Agent to sanitycheck a synthesized report for consistency and recall.
# This can be used to flag potential gaps or obvious mistakes.
VERIFIER_PROMPT = (
"You are a meticulous auditor. You have been handed a financial analysis report. "
"Your job is to verify the report is internally consistent, clearly sourced, and makes "
"no unsupported claims. Point out any issues or uncertainties."
)
class VerificationResult(BaseModel):
verified: bool
"""Whether the report seems coherent and plausible."""
issues: str
"""If not verified, describe the main issues or concerns."""
verifier_agent = Agent(
name="VerificationAgent",
instructions=VERIFIER_PROMPT,
model="gpt-4o",
output_type=VerificationResult,
)

View file

@ -0,0 +1,34 @@
from pydantic import BaseModel
from agents import Agent
# Writer agent brings together the raw search results and optionally calls out
# to subanalyst tools for specialized commentary, then returns a cohesive markdown report.
WRITER_PROMPT = (
"You are a senior financial analyst. You will be provided with the original query and "
"a set of raw search summaries. Your task is to synthesize these into a longform markdown "
"report (at least several paragraphs) including a short executive summary and followup "
"questions. If needed, you can call the available analysis tools (e.g. fundamentals_analysis, "
"risk_analysis) to get short specialist writeups to incorporate."
)
class FinancialReportData(BaseModel):
short_summary: str
"""A short 23 sentence executive summary."""
markdown_report: str
"""The full markdown report."""
follow_up_questions: list[str]
"""Suggested followup questions for further research."""
# Note: We will attach handoffs to specialist analyst agents at runtime in the manager.
# This shows how an agent can use handoffs to delegate to specialized subagents.
writer_agent = Agent(
name="FinancialWriterAgent",
instructions=WRITER_PROMPT,
model="gpt-4.5-preview-2025-02-27",
output_type=FinancialReportData,
)

View file

@ -0,0 +1,17 @@
import asyncio
from .manager import FinancialResearchManager
# Entrypoint for the financial bot example.
# Run this as `python -m examples.financial_bot.main` and enter a
# financial research query, for example:
# "Write up an analysis of Apple Inc.'s most recent quarter."
async def main() -> None:
query = input("Enter a financial research query: ")
mgr = FinancialResearchManager()
await mgr.run(query)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,135 @@
from __future__ import annotations
import asyncio
import time
from collections.abc import Sequence
from rich.console import Console
from agents import Runner, RunResult, custom_span, gen_trace_id, trace
from .agents.financials_agent import financials_agent
from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent
from .agents.risk_agent import risk_agent
from .agents.search_agent import search_agent
from .agents.verifier_agent import VerificationResult, verifier_agent
from .agents.writer_agent import FinancialReportData, writer_agent
from .printer import Printer
async def _summary_extractor(run_result: RunResult) -> str:
"""Custom output extractor for subagents that return an AnalysisSummary."""
# The financial/risk analyst agents emit an AnalysisSummary with a `summary` field.
# We want the tool call to return just that summary text so the writer can drop it inline.
return str(run_result.final_output.summary)
class FinancialResearchManager:
"""
Orchestrates the full flow: planning, searching, subanalysis, writing, and verification.
"""
def __init__(self) -> None:
self.console = Console()
self.printer = Printer(self.console)
async def run(self, query: str) -> None:
trace_id = gen_trace_id()
with trace("Financial research trace", trace_id=trace_id):
self.printer.update_item(
"trace_id",
f"View trace: https://platform.openai.com/traces/{trace_id}",
is_done=True,
hide_checkmark=True,
)
self.printer.update_item("start", "Starting financial research...", is_done=True)
search_plan = await self._plan_searches(query)
search_results = await self._perform_searches(search_plan)
report = await self._write_report(query, search_results)
verification = await self._verify_report(report)
final_report = f"Report summary\n\n{report.short_summary}"
self.printer.update_item("final_report", final_report, is_done=True)
self.printer.end()
# Print to stdout
print("\n\n=====REPORT=====\n\n")
print(f"Report:\n{report.markdown_report}")
print("\n\n=====FOLLOW UP QUESTIONS=====\n\n")
print("\n".join(report.follow_up_questions))
print("\n\n=====VERIFICATION=====\n\n")
print(verification)
async def _plan_searches(self, query: str) -> FinancialSearchPlan:
self.printer.update_item("planning", "Planning searches...")
result = await Runner.run(planner_agent, f"Query: {query}")
self.printer.update_item(
"planning",
f"Will perform {len(result.final_output.searches)} searches",
is_done=True,
)
return result.final_output_as(FinancialSearchPlan)
async def _perform_searches(self, search_plan: FinancialSearchPlan) -> Sequence[str]:
with custom_span("Search the web"):
self.printer.update_item("searching", "Searching...")
tasks = [asyncio.create_task(self._search(item)) for item in search_plan.searches]
results: list[str] = []
num_completed = 0
for task in asyncio.as_completed(tasks):
result = await task
if result is not None:
results.append(result)
num_completed += 1
self.printer.update_item(
"searching", f"Searching... {num_completed}/{len(tasks)} completed"
)
self.printer.mark_item_done("searching")
return results
async def _search(self, item: FinancialSearchItem) -> str | None:
input_data = f"Search term: {item.query}\nReason: {item.reason}"
try:
result = await Runner.run(search_agent, input_data)
return str(result.final_output)
except Exception:
return None
async def _write_report(self, query: str, search_results: Sequence[str]) -> FinancialReportData:
# Expose the specialist analysts as tools so the writer can invoke them inline
# and still produce the final FinancialReportData output.
fundamentals_tool = financials_agent.as_tool(
tool_name="fundamentals_analysis",
tool_description="Use to get a short writeup of key financial metrics",
custom_output_extractor=_summary_extractor,
)
risk_tool = risk_agent.as_tool(
tool_name="risk_analysis",
tool_description="Use to get a short writeup of potential red flags",
custom_output_extractor=_summary_extractor,
)
writer_with_tools = writer_agent.clone(tools=[fundamentals_tool, risk_tool])
self.printer.update_item("writing", "Thinking about report...")
input_data = f"Original query: {query}\nSummarized search results: {search_results}"
result = Runner.run_streamed(writer_with_tools, input_data)
update_messages = [
"Planning report structure...",
"Writing sections...",
"Finalizing report...",
]
last_update = time.time()
next_message = 0
async for _ in result.stream_events():
if time.time() - last_update > 5 and next_message < len(update_messages):
self.printer.update_item("writing", update_messages[next_message])
next_message += 1
last_update = time.time()
self.printer.mark_item_done("writing")
return result.final_output_as(FinancialReportData)
async def _verify_report(self, report: FinancialReportData) -> VerificationResult:
self.printer.update_item("verifying", "Verifying report...")
result = await Runner.run(verifier_agent, report.markdown_report)
self.printer.mark_item_done("verifying")
return result.final_output_as(VerificationResult)

View file

@ -0,0 +1,46 @@
from typing import Any
from rich.console import Console, Group
from rich.live import Live
from rich.spinner import Spinner
class Printer:
"""
Simple wrapper to stream status updates. Used by the financial bot
manager as it orchestrates planning, search and writing.
"""
def __init__(self, console: Console) -> None:
self.live = Live(console=console)
self.items: dict[str, tuple[str, bool]] = {}
self.hide_done_ids: set[str] = set()
self.live.start()
def end(self) -> None:
self.live.stop()
def hide_done_checkmark(self, item_id: str) -> None:
self.hide_done_ids.add(item_id)
def update_item(
self, item_id: str, content: str, is_done: bool = False, hide_checkmark: bool = False
) -> None:
self.items[item_id] = (content, is_done)
if hide_checkmark:
self.hide_done_ids.add(item_id)
self.flush()
def mark_item_done(self, item_id: str) -> None:
self.items[item_id] = (self.items[item_id][0], True)
self.flush()
def flush(self) -> None:
renderables: list[Any] = []
for item_id, (content, is_done) in self.items.items():
if is_done:
prefix = "" if item_id not in self.hide_done_ids else ""
renderables.append(prefix + content)
else:
renderables.append(Spinner("dots", text=content))
self.live.update(Group(*renderables))

View file

@ -4,7 +4,7 @@ from agents.model_settings import ModelSettings
INSTRUCTIONS = (
"You are a research assistant. Given a search term, you search the web for that term and"
"produce a concise summary of the results. The summary must 2-3 paragraphs and less than 300"
"words. Capture the main points. Write succintly, no need to have complete sentences or good"
"words. Capture the main points. Write succinctly, no need to have complete sentences or good"
"grammar. This will be consumed by someone synthesizing a report, so its vital you capture the"
"essence and ignore any fluff. Do not include any additional commentary other than the summary"
"itself."

View file

View file

@ -0,0 +1,26 @@
# Static voice demo
This demo operates by capturing a recording, then running a voice pipeline on it.
Run via:
```
python -m examples.voice.static.main
```
## How it works
1. We create a `VoicePipeline`, setup with a custom workflow. The workflow runs an Agent, but it also has some custom responses if you say the secret word.
2. When you speak, audio is forwarded to the voice pipeline. When you stop speaking, the agent runs.
3. The pipeline is run with the audio, which causes it to:
1. Transcribe the audio
2. Feed the transcription to the workflow, which runs the agent.
3. Stream the output of the agent to a text-to-speech model.
4. Play the audio.
Some suggested examples to try:
- Tell me a joke (_the assistant tells you a joke_)
- What's the weather in Tokyo? (_will call the `get_weather` tool and then speak_)
- Hola, como estas? (_will handoff to the spanish agent_)
- Tell me about dogs. (_will respond with the hardcoded "you guessed the secret word" message_)

View file

View file

@ -0,0 +1,83 @@
import asyncio
import random
from agents import Agent, function_tool
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
from agents.voice import (
AudioInput,
SingleAgentVoiceWorkflow,
SingleAgentWorkflowCallbacks,
VoicePipeline,
)
from .util import AudioPlayer, record_audio
"""
This is a simple example that uses a recorded audio buffer. Run it via:
`python -m examples.voice.static.main`
1. You can record an audio clip in the terminal.
2. The pipeline automatically transcribes the audio.
3. The agent workflow is a simple one that starts at the Assistant agent.
4. The output of the agent is streamed to the audio player.
Try examples like:
- Tell me a joke (will respond with a joke)
- What's the weather in Tokyo? (will call the `get_weather` tool and then speak)
- Hola, como estas? (will handoff to the spanish agent)
"""
@function_tool
def get_weather(city: str) -> str:
"""Get the weather for a given city."""
print(f"[debug] get_weather called with city: {city}")
choices = ["sunny", "cloudy", "rainy", "snowy"]
return f"The weather in {city} is {random.choice(choices)}."
spanish_agent = Agent(
name="Spanish",
handoff_description="A spanish speaking agent.",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
),
model="gpt-4o-mini",
)
agent = Agent(
name="Assistant",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
),
model="gpt-4o-mini",
handoffs=[spanish_agent],
tools=[get_weather],
)
class WorkflowCallbacks(SingleAgentWorkflowCallbacks):
def on_run(self, workflow: SingleAgentVoiceWorkflow, transcription: str) -> None:
print(f"[debug] on_run called with transcription: {transcription}")
async def main():
pipeline = VoicePipeline(
workflow=SingleAgentVoiceWorkflow(agent, callbacks=WorkflowCallbacks())
)
audio_input = AudioInput(buffer=record_audio())
result = await pipeline.run(audio_input)
with AudioPlayer() as player:
async for event in result.stream():
if event.type == "voice_stream_event_audio":
player.add_audio(event.data)
print("Received audio")
elif event.type == "voice_stream_event_lifecycle":
print(f"Received lifecycle event: {event.event}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,68 @@
import curses
import time
import numpy as np
import numpy.typing as npt
import sounddevice as sd
def _record_audio(screen: curses.window) -> npt.NDArray[np.float32]:
screen.nodelay(True) # Non-blocking input
screen.clear()
screen.addstr(
"Press <spacebar> to start recording. Press <spacebar> again to stop recording.\n"
)
screen.refresh()
recording = False
audio_buffer: list[npt.NDArray[np.float32]] = []
def _audio_callback(indata, frames, time_info, status):
if status:
screen.addstr(f"Status: {status}\n")
screen.refresh()
if recording:
audio_buffer.append(indata.copy())
# Open the audio stream with the callback.
with sd.InputStream(samplerate=24000, channels=1, dtype=np.float32, callback=_audio_callback):
while True:
key = screen.getch()
if key == ord(" "):
recording = not recording
if recording:
screen.addstr("Recording started...\n")
else:
screen.addstr("Recording stopped.\n")
break
screen.refresh()
time.sleep(0.01)
# Combine recorded audio chunks.
if audio_buffer:
audio_data = np.concatenate(audio_buffer, axis=0)
else:
audio_data = np.empty((0,), dtype=np.float32)
return audio_data
def record_audio():
# Using curses to record audio in a way that:
# - doesn't require accessibility permissions on macos
# - doesn't block the terminal
audio_data = curses.wrapper(_record_audio)
return audio_data
class AudioPlayer:
def __enter__(self):
self.stream = sd.OutputStream(samplerate=24000, channels=1, dtype=np.int16)
self.stream.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stream.close()
def add_audio(self, audio_data: npt.NDArray[np.int16]):
self.stream.write(audio_data)

View file

@ -0,0 +1,25 @@
# Streamed voice demo
This is an interactive demo, where you can talk to an Agent conversationally. It uses the voice pipeline's built in turn detection feature, so if you stop speaking the Agent responds.
Run via:
```
python -m examples.voice.streamed.main
```
## How it works
1. We create a `VoicePipeline`, setup with a `SingleAgentVoiceWorkflow`. This is a workflow that starts at an Assistant agent, has tools and handoffs.
2. Audio input is captured from the terminal.
3. The pipeline is run with the recorded audio, which causes it to:
1. Transcribe the audio
2. Feed the transcription to the workflow, which runs the agent.
3. Stream the output of the agent to a text-to-speech model.
4. Play the audio.
Some suggested examples to try:
- Tell me a joke (_the assistant tells you a joke_)
- What's the weather in Tokyo? (_will call the `get_weather` tool and then speak_)
- Hola, como estas? (_will handoff to the spanish agent_)

View file

View file

@ -0,0 +1,233 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
import numpy as np
import sounddevice as sd
from textual import events
from textual.app import App, ComposeResult
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import Button, RichLog, Static
from typing_extensions import override
from agents.voice import StreamedAudioInput, VoicePipeline
# Import MyWorkflow class - handle both module and package use cases
if TYPE_CHECKING:
# For type checking, use the relative import
from .my_workflow import MyWorkflow
else:
# At runtime, try both import styles
try:
# Try relative import first (when used as a package)
from .my_workflow import MyWorkflow
except ImportError:
# Fall back to direct import (when run as a script)
from my_workflow import MyWorkflow
CHUNK_LENGTH_S = 0.05 # 100ms
SAMPLE_RATE = 24000
FORMAT = np.int16
CHANNELS = 1
class Header(Static):
"""A header widget."""
session_id = reactive("")
@override
def render(self) -> str:
return "Speak to the agent. When you stop speaking, it will respond."
class AudioStatusIndicator(Static):
"""A widget that shows the current audio recording status."""
is_recording = reactive(False)
@override
def render(self) -> str:
status = (
"🔴 Recording... (Press K to stop)"
if self.is_recording
else "⚪ Press K to start recording (Q to quit)"
)
return status
class RealtimeApp(App[None]):
CSS = """
Screen {
background: #1a1b26; /* Dark blue-grey background */
}
Container {
border: double rgb(91, 164, 91);
}
Horizontal {
width: 100%;
}
#input-container {
height: 5; /* Explicit height for input container */
margin: 1 1;
padding: 1 2;
}
Input {
width: 80%;
height: 3; /* Explicit height for input */
}
Button {
width: 20%;
height: 3; /* Explicit height for button */
}
#bottom-pane {
width: 100%;
height: 82%; /* Reduced to make room for session display */
border: round rgb(205, 133, 63);
content-align: center middle;
}
#status-indicator {
height: 3;
content-align: center middle;
background: #2a2b36;
border: solid rgb(91, 164, 91);
margin: 1 1;
}
#session-display {
height: 3;
content-align: center middle;
background: #2a2b36;
border: solid rgb(91, 164, 91);
margin: 1 1;
}
Static {
color: white;
}
"""
should_send_audio: asyncio.Event
audio_player: sd.OutputStream
last_audio_item_id: str | None
connected: asyncio.Event
def __init__(self) -> None:
super().__init__()
self.last_audio_item_id = None
self.should_send_audio = asyncio.Event()
self.connected = asyncio.Event()
self.pipeline = VoicePipeline(
workflow=MyWorkflow(secret_word="dog", on_start=self._on_transcription)
)
self._audio_input = StreamedAudioInput()
self.audio_player = sd.OutputStream(
samplerate=SAMPLE_RATE,
channels=CHANNELS,
dtype=FORMAT,
)
def _on_transcription(self, transcription: str) -> None:
try:
self.query_one("#bottom-pane", RichLog).write(f"Transcription: {transcription}")
except Exception:
pass
@override
def compose(self) -> ComposeResult:
"""Create child widgets for the app."""
with Container():
yield Header(id="session-display")
yield AudioStatusIndicator(id="status-indicator")
yield RichLog(id="bottom-pane", wrap=True, highlight=True, markup=True)
async def on_mount(self) -> None:
self.run_worker(self.start_voice_pipeline())
self.run_worker(self.send_mic_audio())
async def start_voice_pipeline(self) -> None:
try:
self.audio_player.start()
self.result = await self.pipeline.run(self._audio_input)
async for event in self.result.stream():
bottom_pane = self.query_one("#bottom-pane", RichLog)
if event.type == "voice_stream_event_audio":
self.audio_player.write(event.data)
bottom_pane.write(
f"Received audio: {len(event.data) if event.data is not None else '0'} bytes"
)
elif event.type == "voice_stream_event_lifecycle":
bottom_pane.write(f"Lifecycle event: {event.event}")
except Exception as e:
bottom_pane = self.query_one("#bottom-pane", RichLog)
bottom_pane.write(f"Error: {e}")
finally:
self.audio_player.close()
async def send_mic_audio(self) -> None:
device_info = sd.query_devices()
print(device_info)
read_size = int(SAMPLE_RATE * 0.02)
stream = sd.InputStream(
channels=CHANNELS,
samplerate=SAMPLE_RATE,
dtype="int16",
)
stream.start()
status_indicator = self.query_one(AudioStatusIndicator)
try:
while True:
if stream.read_available < read_size:
await asyncio.sleep(0)
continue
await self.should_send_audio.wait()
status_indicator.is_recording = True
data, _ = stream.read(read_size)
await self._audio_input.add_audio(data)
await asyncio.sleep(0)
except KeyboardInterrupt:
pass
finally:
stream.stop()
stream.close()
async def on_key(self, event: events.Key) -> None:
"""Handle key press events."""
if event.key == "enter":
self.query_one(Button).press()
return
if event.key == "q":
self.exit()
return
if event.key == "k":
status_indicator = self.query_one(AudioStatusIndicator)
if status_indicator.is_recording:
self.should_send_audio.clear()
status_indicator.is_recording = False
else:
self.should_send_audio.set()
status_indicator.is_recording = True
if __name__ == "__main__":
app = RealtimeApp()
app.run()

View file

@ -0,0 +1,81 @@
import random
from collections.abc import AsyncIterator
from typing import Callable
from agents import Agent, Runner, TResponseInputItem, function_tool
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
from agents.voice import VoiceWorkflowBase, VoiceWorkflowHelper
@function_tool
def get_weather(city: str) -> str:
"""Get the weather for a given city."""
print(f"[debug] get_weather called with city: {city}")
choices = ["sunny", "cloudy", "rainy", "snowy"]
return f"The weather in {city} is {random.choice(choices)}."
spanish_agent = Agent(
name="Spanish",
handoff_description="A spanish speaking agent.",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
),
model="gpt-4o-mini",
)
agent = Agent(
name="Assistant",
instructions=prompt_with_handoff_instructions(
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
),
model="gpt-4o-mini",
handoffs=[spanish_agent],
tools=[get_weather],
)
class MyWorkflow(VoiceWorkflowBase):
def __init__(self, secret_word: str, on_start: Callable[[str], None]):
"""
Args:
secret_word: The secret word to guess.
on_start: A callback that is called when the workflow starts. The transcription
is passed in as an argument.
"""
self._input_history: list[TResponseInputItem] = []
self._current_agent = agent
self._secret_word = secret_word.lower()
self._on_start = on_start
async def run(self, transcription: str) -> AsyncIterator[str]:
self._on_start(transcription)
# Add the transcription to the input history
self._input_history.append(
{
"role": "user",
"content": transcription,
}
)
# If the user guessed the secret word, do alternate logic
if self._secret_word in transcription.lower():
yield "You guessed the secret word!"
self._input_history.append(
{
"role": "assistant",
"content": "You guessed the secret word!",
}
)
return
# Otherwise, run the agent
result = Runner.run_streamed(self._current_agent, self._input_history)
async for chunk in VoiceWorkflowHelper.stream_text_from(result):
yield chunk
# Update the input history and current agent
self._input_history = result.to_input_list()
self._current_agent = result.last_agent

View file

@ -1,121 +1,143 @@
site_name: OpenAI Agents SDK
theme:
name: material
features:
# Allows copying code blocks
- content.code.copy
# Allows selecting code blocks
- content.code.select
# Shows the current path in the sidebar
- navigation.path
# Shows sections in the sidebar
- navigation.sections
# Shows sections expanded by default
- navigation.expand
# Enables annotations in code blocks
- content.code.annotate
palette:
primary: black
logo: assets/logo.svg
favicon: images/favicon-platform.svg
name: material
features:
# Allows copying code blocks
- content.code.copy
# Allows selecting code blocks
- content.code.select
# Shows the current path in the sidebar
- navigation.path
# Shows sections in the sidebar
- navigation.sections
# Shows sections expanded by default
- navigation.expand
# Enables annotations in code blocks
- content.code.annotate
palette:
primary: black
logo: assets/logo.svg
favicon: images/favicon-platform.svg
nav:
- Intro: index.md
- Quickstart: quickstart.md
- Documentation:
- agents.md
- running_agents.md
- results.md
- streaming.md
- tools.md
- handoffs.md
- tracing.md
- context.md
- guardrails.md
- multi_agent.md
- models.md
- config.md
- API Reference:
- Agents:
- ref/index.md
- ref/agent.md
- ref/run.md
- ref/tool.md
- ref/result.md
- ref/stream_events.md
- ref/handoffs.md
- ref/lifecycle.md
- ref/items.md
- ref/run_context.md
- ref/usage.md
- ref/exceptions.md
- ref/guardrail.md
- ref/model_settings.md
- ref/agent_output.md
- ref/function_schema.md
- ref/models/interface.md
- ref/models/openai_chatcompletions.md
- ref/models/openai_responses.md
- Tracing:
- ref/tracing/index.md
- ref/tracing/create.md
- ref/tracing/traces.md
- ref/tracing/spans.md
- ref/tracing/processor_interface.md
- ref/tracing/processors.md
- ref/tracing/scope.md
- ref/tracing/setup.md
- ref/tracing/span_data.md
- ref/tracing/util.md
- Extensions:
- ref/extensions/handoff_filters.md
- ref/extensions/handoff_prompt.md
- Intro: index.md
- Quickstart: quickstart.md
- Examples: examples.md
- Documentation:
- agents.md
- running_agents.md
- results.md
- streaming.md
- tools.md
- handoffs.md
- tracing.md
- context.md
- guardrails.md
- multi_agent.md
- models.md
- config.md
- Voice agents:
- voice/quickstart.md
- voice/pipeline.md
- voice/tracing.md
- API Reference:
- Agents:
- ref/index.md
- ref/agent.md
- ref/run.md
- ref/tool.md
- ref/result.md
- ref/stream_events.md
- ref/handoffs.md
- ref/lifecycle.md
- ref/items.md
- ref/run_context.md
- ref/usage.md
- ref/exceptions.md
- ref/guardrail.md
- ref/model_settings.md
- ref/agent_output.md
- ref/function_schema.md
- ref/models/interface.md
- ref/models/openai_chatcompletions.md
- ref/models/openai_responses.md
- Tracing:
- ref/tracing/index.md
- ref/tracing/create.md
- ref/tracing/traces.md
- ref/tracing/spans.md
- ref/tracing/processor_interface.md
- ref/tracing/processors.md
- ref/tracing/scope.md
- ref/tracing/setup.md
- ref/tracing/span_data.md
- ref/tracing/util.md
- Voice:
- ref/voice/pipeline.md
- ref/voice/workflow.md
- ref/voice/input.md
- ref/voice/result.md
- ref/voice/pipeline_config.md
- ref/voice/events.md
- ref/voice/exceptions.md
- ref/voice/model.md
- ref/voice/utils.md
- ref/voice/models/openai_provider.md
- ref/voice/models/openai_stt.md
- ref/voice/models/openai_tts.md
- Extensions:
- ref/extensions/handoff_filters.md
- ref/extensions/handoff_prompt.md
plugins:
- search
- mkdocstrings:
handlers:
python:
paths: ["src/agents"]
selection:
docstring_style: google
options:
# Shows links to other members in signatures
signature_crossrefs: true
# Orders members by source order, rather than alphabetical
members_order: source
# Puts the signature on a separate line from the member name
separate_signature: true
# Shows type annotations in signatures
show_signature_annotations: true
# Makes the font sizes nicer
heading_level: 3
- search
- mkdocstrings:
handlers:
python:
paths: ["src/agents"]
selection:
docstring_style: google
options:
# Shows links to other members in signatures
signature_crossrefs: true
# Orders members by source order, rather than alphabetical
members_order: source
# Puts the signature on a separate line from the member name
separate_signature: true
# Shows type annotations in signatures
show_signature_annotations: true
# Makes the font sizes nicer
heading_level: 3
extra:
# Remove material generation message in footer
generator: false
# Remove material generation message in footer
generator: false
markdown_extensions:
- admonition
- pymdownx.details
- pymdownx.superfences
- attr_list
- md_in_html
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- pymdownx.superfences:
custom_fences:
- name: mermaid
class: mermaid
format: !!python/name:pymdownx.superfences.fence_code_format
- admonition
- pymdownx.details
- attr_list
- md_in_html
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
validation:
omitted_files: warn
absolute_links: warn
unrecognized_links: warn
anchors: warn
omitted_files: warn
absolute_links: warn
unrecognized_links: warn
anchors: warn
extra_css:
- stylesheets/extra.css
- stylesheets/extra.css
watch:
- "src/agents"
- "src/agents"

View file

@ -1,15 +1,13 @@
[project]
name = "openai-agents"
version = "0.0.4"
version = "0.0.6"
description = "OpenAI Agents SDK"
readme = "README.md"
requires-python = ">=3.9"
license = "MIT"
authors = [
{ name = "OpenAI", email = "support@openai.com" },
]
authors = [{ name = "OpenAI", email = "support@openai.com" }]
dependencies = [
"openai>=1.66.2",
"openai>=1.66.5",
"pydantic>=2.10, <3",
"griffe>=1.5.6, <2",
"typing-extensions>=4.12.2, <5",
@ -27,13 +25,16 @@ classifiers = [
"Intended Audience :: Developers",
"Operating System :: OS Independent",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: MIT License"
"License :: OSI Approved :: MIT License",
]
[project.urls]
Homepage = "https://github.com/openai/openai-agents-python"
Repository = "https://github.com/openai/openai-agents-python"
[project.optional-dependencies]
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
[dependency-groups]
dev = [
"mypy",
@ -48,6 +49,12 @@ dev = [
"coverage>=7.6.12",
"playwright==1.50.0",
"inline-snapshot>=0.20.7",
"pynput",
"types-pynput",
"sounddevice",
"pynput",
"textual",
"websockets",
]
[project.optional-dependencies]
@ -80,8 +87,8 @@ select = [
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
isort = { combine-as-imports = true, known-first-party = ["agents"] }
@ -97,11 +104,12 @@ disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_untyped_calls = false
[[tool.mypy.overrides]]
module = "sounddevice.*"
ignore_missing_imports = true
[tool.coverage.run]
source = [
"tests",
"src/agents",
]
source = ["tests", "src/agents"]
[tool.coverage.report]
show_missing = true
@ -115,7 +123,7 @@ exclude_also = [
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
filterwarnings = [
# This is a warning that is expected to happen: we have an async filter that raises an exception
@ -126,4 +134,4 @@ markers = [
]
[tool.inline-snapshot]
format-command="ruff format --stdin-filename {filename}"
format-command = "ruff format --stdin-filename {filename}"

View file

@ -5,7 +5,7 @@ from typing import Literal
from openai import AsyncOpenAI
from . import _config
from .agent import Agent
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
@ -57,6 +57,7 @@ from .tool import (
ComputerTool,
FileSearchTool,
FunctionTool,
FunctionToolResult,
Tool,
WebSearchTool,
default_tool_error_function,
@ -72,8 +73,11 @@ from .tracing import (
Span,
SpanData,
SpanError,
SpeechGroupSpanData,
SpeechSpanData,
Trace,
TracingProcessor,
TranscriptionSpanData,
add_trace_processor,
agent_span,
custom_span,
@ -88,7 +92,10 @@ from .tracing import (
set_trace_processors,
set_tracing_disabled,
set_tracing_export_api_key,
speech_group_span,
speech_span,
trace,
transcription_span,
)
from .usage import Usage
@ -137,6 +144,8 @@ def enable_verbose_stdout_logging():
__all__ = [
"Agent",
"ToolsToFinalOutputFunction",
"ToolsToFinalOutputResult",
"Runner",
"Model",
"ModelProvider",
@ -190,6 +199,7 @@ __all__ = [
"AgentUpdatedStreamEvent",
"StreamEvent",
"FunctionTool",
"FunctionToolResult",
"ComputerTool",
"FileSearchTool",
"Tool",
@ -207,6 +217,9 @@ __all__ = [
"handoff_span",
"set_trace_processors",
"set_tracing_disabled",
"speech_group_span",
"transcription_span",
"speech_span",
"trace",
"Trace",
"TracingProcessor",
@ -219,6 +232,9 @@ __all__ = [
"GenerationSpanData",
"GuardrailSpanData",
"HandoffSpanData",
"SpeechGroupSpanData",
"SpeechSpanData",
"TranscriptionSpanData",
"set_default_openai_key",
"set_default_openai_client",
"set_default_openai_api",

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import asyncio
import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from openai.types.responses import (
ResponseComputerToolCall,
@ -25,7 +28,7 @@ from openai.types.responses.response_computer_tool_call import (
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from .agent import Agent
from .agent import Agent, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
@ -45,10 +48,11 @@ from .items import (
)
from .lifecycle import RunHooks
from .logger import logger
from .model_settings import ModelSettings
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
from .tracing import (
SpanError,
Trace,
@ -70,6 +74,8 @@ class QueueCompleteSentinel:
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
@dataclass
class ToolRunHandoff:
@ -199,9 +205,32 @@ class RunImpl:
config=run_config,
),
)
new_step_items.extend(function_results)
new_step_items.extend([result.run_item for result in function_results])
new_step_items.extend(computer_results)
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
if processed_response.functions or processed_response.computer_actions:
tools = agent.tools
if (
run_config.model_settings and
cls._should_reset_tool_choice(run_config.model_settings, tools)
):
# update the run_config model settings with a copy
new_run_config_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto"
)
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)
if cls._should_reset_tool_choice(agent.model_settings, tools):
# Create a modified copy instead of modifying the original agent
new_model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto"
)
agent = dataclasses.replace(agent, model_settings=new_model_settings)
# Second, check if there are any handoffs
if run_handoffs := processed_response.handoffs:
return await cls.execute_handoffs(
@ -216,6 +245,36 @@ class RunImpl:
run_config=run_config,
)
# Third, we'll check if the tool use should result in a final output
check_tool_use = await cls._check_for_final_output_from_tools(
agent=agent,
tool_results=function_results,
context_wrapper=context_wrapper,
config=run_config,
)
if check_tool_use.is_final_output:
# If the output type is str, then let's just stringify it
if not agent.output_type or agent.output_type is str:
check_tool_use.final_output = str(check_tool_use.final_output)
if check_tool_use.final_output is None:
logger.error(
"Model returned a final output of None. Not raising an error because we assume"
"you know what you're doing."
)
return await cls.execute_final_output(
agent=agent,
original_input=original_input,
new_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
final_output=check_tool_use.final_output,
hooks=hooks,
context_wrapper=context_wrapper,
)
# Now we can check if the model also produced a final output
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
@ -262,6 +321,24 @@ class RunImpl:
next_step=NextStepRunAgain(),
)
@classmethod
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
if model_settings is None or model_settings.tool_choice is None:
return False
# for specific tool choices
if (
isinstance(model_settings.tool_choice, str) and
model_settings.tool_choice not in ["auto", "required", "none"]
):
return True
# for one tool and required tool choice
if model_settings.tool_choice == "required":
return len(tools) == 1
return False
@classmethod
def process_model_response(
cls,
@ -355,10 +432,10 @@ class RunImpl:
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[RunItem]:
) -> list[FunctionToolResult]:
async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> str:
) -> Any:
with function_span(func_tool.name) as span_fn:
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
@ -404,10 +481,14 @@ class RunImpl:
results = await asyncio.gather(*tasks)
return [
ToolCallOutputItem(
output=str(result),
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
agent=agent,
FunctionToolResult(
tool=tool_run.function_tool,
output=result,
run_item=ToolCallOutputItem(
output=result,
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
agent=agent,
),
)
for tool_run, result in zip(tool_runs, results)
]
@ -646,6 +727,47 @@ class RunImpl:
if event:
queue.put_nowait(event)
@classmethod
async def _check_for_final_output_from_tools(
cls,
*,
agent: Agent[TContext],
tool_results: list[FunctionToolResult],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> ToolsToFinalOutputResult:
"""Returns (i, final_output)."""
if not tool_results:
return _NOT_FINAL_OUTPUT
if agent.tool_use_behavior == "run_llm_again":
return _NOT_FINAL_OUTPUT
elif agent.tool_use_behavior == "stop_on_first_tool":
return ToolsToFinalOutputResult(
is_final_output=True, final_output=tool_results[0].output
)
elif isinstance(agent.tool_use_behavior, dict):
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
for tool_result in tool_results:
if tool_result.tool.name in names:
return ToolsToFinalOutputResult(
is_final_output=True, final_output=tool_result.output
)
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
elif callable(agent.tool_use_behavior):
if inspect.iscoroutinefunction(agent.tool_use_behavior):
return await cast(
Awaitable[ToolsToFinalOutputResult],
agent.tool_use_behavior(context_wrapper, tool_results),
)
else:
return cast(
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
)
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
class TraceCtxManager:
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""

View file

@ -4,7 +4,9 @@ import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
from typing_extensions import TypeAlias, TypedDict
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
@ -13,7 +15,7 @@ from .logger import logger
from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
from .tool import Tool, function_tool
from .tool import FunctionToolResult, Tool, function_tool
from .util import _transforms
from .util._types import MaybeAwaitable
@ -22,6 +24,33 @@ if TYPE_CHECKING:
from .result import RunResult
@dataclass
class ToolsToFinalOutputResult:
is_final_output: bool
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
output.
"""
final_output: Any | None = None
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
`output_type` of the agent.
"""
ToolsToFinalOutputFunction: TypeAlias = Callable[
[RunContextWrapper[TContext], list[FunctionToolResult]],
MaybeAwaitable[ToolsToFinalOutputResult],
]
"""A function that takes a run context and a list of tool results, and returns a
`ToolToFinalOutputResult`.
"""
class StopAtTools(TypedDict):
stop_at_tool_names: list[str]
"""A list of tool names, any of which will stop the agent from running further."""
@dataclass
class Agent(Generic[TContext]):
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
@ -95,6 +124,25 @@ class Agent(Generic[TContext]):
"""A class that receives callbacks on various lifecycle events for this agent.
"""
tool_use_behavior: (
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
) = "run_llm_again"
"""This lets you configure how tool use is handled.
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
and gets to respond.
- "stop_on_first_tool": The output of the first tool call is used as the final output. This
means that the LLM does not process the result of the tool call.
- A list of tool names: The agent will stop running if any of the tools in the list are called.
The final output will be the output of the first matching tool call. The LLM does not
process the result of the tool call.
- A function: If you pass a function, it will be called with the run context and the list of
tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
calls result in a final output.
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
web search, etc are always processed by the LLM.
"""
def clone(self, **kwargs: Any) -> Agent[TContext]:
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
```

View file

@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
raw_item: FunctionCallOutput | ComputerCallOutput
"""The raw item from the model."""
output: str
"""The output of the tool call."""
output: Any
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
contains a string representation of the output.
"""
type: Literal["tool_call_output_item"] = "tool_call_output_item"

View file

@ -54,7 +54,7 @@ from openai.types.responses import (
ResponseUsage,
)
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
from openai.types.responses.response_usage import OutputTokensDetails
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from .. import _debug
from ..agent_output import AgentOutputSchema
@ -420,6 +420,11 @@ class OpenAIChatCompletionsModel(Model):
and usage.completion_tokens_details.reasoning_tokens
else 0
),
input_tokens_details=InputTokensDetails(
cached_tokens=usage.prompt_tokens_details.cached_tokens
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
else 0
),
)
if usage
else None
@ -752,7 +757,7 @@ class _Converter:
elif isinstance(c, dict) and c.get("type") == "input_file":
raise UserError(f"File uploads are not supported for chat completions {c}")
else:
raise UserError(f"Unknonw content: {c}")
raise UserError(f"Unknown content: {c}")
return out
@classmethod

View file

@ -34,6 +34,19 @@ class OpenAIProvider(ModelProvider):
project: str | None = None,
use_responses: bool | None = None,
) -> None:
"""Create a new OpenAI provider.
Args:
api_key: The API key to use for the OpenAI client. If not provided, we will use the
default API key.
base_url: The base URL to use for the OpenAI client. If not provided, we will use the
default base URL.
openai_client: An optional OpenAI client to use. If not provided, we will create a new
OpenAI client using the api_key and base_url.
organization: The organization to use for the OpenAI client.
project: The project to use for the OpenAI client.
use_responses: Whether to use the OpenAI responses API.
"""
if openai_client is not None:
assert api_key is None and base_url is None, (
"Don't provide api_key or base_url if you provide openai_client"

View file

@ -83,7 +83,7 @@ class OpenAIResponsesModel(Model):
)
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("LLM responsed")
logger.debug("LLM responded")
else:
logger.debug(
"LLM resp:\n"
@ -208,7 +208,9 @@ class OpenAIResponsesModel(Model):
list_input = ItemHelpers.input_to_new_input_list(input)
parallel_tool_calls = (
True if model_settings.parallel_tool_calls and tools and len(tools) > 0 else NOT_GIVEN
True if model_settings.parallel_tool_calls and tools and len(tools) > 0
else False if model_settings.parallel_tool_calls is False
else NOT_GIVEN
)
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)

1
src/agents/py.typed Normal file
View file

@ -0,0 +1 @@

View file

@ -15,6 +15,7 @@ from . import _debug
from .computer import AsyncComputer, Computer
from .exceptions import ModelBehaviorError
from .function_schema import DocstringStyle, function_schema
from .items import RunItem
from .logger import logger
from .run_context import RunContextWrapper
from .tracing import SpanError
@ -29,6 +30,18 @@ ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParam
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
@dataclass
class FunctionToolResult:
tool: FunctionTool
"""The tool that was run."""
output: Any
"""The output of the tool."""
run_item: RunItem
"""The run item that was produced as a result of the tool call."""
@dataclass
class FunctionTool:
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
@ -44,15 +57,15 @@ class FunctionTool:
params_json_schema: dict[str, Any]
"""The JSON schema for the tool's parameters."""
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
"""A function that invokes the tool with the given context and parameters. The params passed
are:
1. The tool run context.
2. The arguments from the LLM, as a JSON string.
You must return a string representation of the tool output. In case of errors, you can either
raise an Exception (which will cause the run to fail) or return a string error message (which
will be sent back to the LLM).
You must return a string representation of the tool output, or something we can call `str()` on.
In case of errors, you can either raise an Exception (which will cause the run to fail) or
return a string error message (which will be sent back to the LLM).
"""
strict_json_schema: bool = True
@ -190,8 +203,11 @@ def function_tool(
failure_error_function: If provided, use this function to generate an error message when
the tool call fails. The error message is sent to the LLM. If you pass None, then no
error message will be sent and instead an Exception will be raised.
strict_mode: If False, parameters with default values become optional in the
function schema.
strict_mode: Whether to enable strict mode for the tool's JSON schema. We *strongly*
recommend setting this to True, as it increases the likelihood of correct JSON input.
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
value, it will be optional, additional properties are allowed, etc. See here for more:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
"""
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -204,7 +220,7 @@ def function_tool(
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
json_data: dict[str, Any] = json.loads(input) if input else {}
except Exception as e:
@ -251,9 +267,9 @@ def function_tool(
else:
logger.debug(f"Tool {schema.name} returned {result}")
return str(result)
return result
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
return await _on_invoke_tool_impl(ctx, input)
except Exception as e:

View file

@ -10,7 +10,10 @@ from .create import (
guardrail_span,
handoff_span,
response_span,
speech_group_span,
speech_span,
trace,
transcription_span,
)
from .processor_interface import TracingProcessor
from .processors import default_exporter, default_processor
@ -24,6 +27,9 @@ from .span_data import (
HandoffSpanData,
ResponseSpanData,
SpanData,
SpeechGroupSpanData,
SpeechSpanData,
TranscriptionSpanData,
)
from .spans import Span, SpanError
from .traces import Trace
@ -54,9 +60,15 @@ __all__ = [
"GuardrailSpanData",
"HandoffSpanData",
"ResponseSpanData",
"SpeechGroupSpanData",
"SpeechSpanData",
"TranscriptionSpanData",
"TracingProcessor",
"gen_trace_id",
"gen_span_id",
"speech_group_span",
"speech_span",
"transcription_span",
]

View file

@ -13,6 +13,9 @@ from .span_data import (
GuardrailSpanData,
HandoffSpanData,
ResponseSpanData,
SpeechGroupSpanData,
SpeechSpanData,
TranscriptionSpanData,
)
from .spans import Span
from .traces import Trace
@ -181,7 +184,11 @@ def generation_span(
"""
return GLOBAL_TRACE_PROVIDER.create_span(
span_data=GenerationSpanData(
input=input, output=output, model=model, model_config=model_config, usage=usage
input=input,
output=output,
model=model,
model_config=model_config,
usage=usage,
),
span_id=span_id,
parent=parent,
@ -304,3 +311,116 @@ def guardrail_span(
parent=parent,
disabled=disabled,
)
def transcription_span(
model: str | None = None,
input: str | None = None,
input_format: str | None = "pcm",
output: str | None = None,
model_config: Mapping[str, Any] | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TranscriptionSpanData]:
"""Create a new transcription span. The span will not be started automatically, you should
either do `with transcription_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
model: The name of the model used for the speech-to-text.
input: The audio input of the speech-to-text transcription, as a base64 encoded string of
audio bytes.
input_format: The format of the audio input (defaults to "pcm").
output: The output of the speech-to-text transcription.
model_config: The model configuration (hyperparameters) used.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created speech-to-text span.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
span_data=TranscriptionSpanData(
input=input,
input_format=input_format,
output=output,
model=model,
model_config=model_config,
),
span_id=span_id,
parent=parent,
disabled=disabled,
)
def speech_span(
model: str | None = None,
input: str | None = None,
output: str | None = None,
output_format: str | None = "pcm",
model_config: Mapping[str, Any] | None = None,
first_content_at: str | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[SpeechSpanData]:
"""Create a new speech span. The span will not be started automatically, you should either do
`with speech_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
model: The name of the model used for the text-to-speech.
input: The text input of the text-to-speech.
output: The audio output of the text-to-speech as base64 encoded string of PCM audio bytes.
output_format: The format of the audio output (defaults to "pcm").
model_config: The model configuration (hyperparameters) used.
first_content_at: The time of the first byte of the audio output.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
span_data=SpeechSpanData(
model=model,
input=input,
output=output,
output_format=output_format,
model_config=model_config,
first_content_at=first_content_at,
),
span_id=span_id,
parent=parent,
disabled=disabled,
)
def speech_group_span(
input: str | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[SpeechGroupSpanData]:
"""Create a new speech group span. The span will not be started automatically, you should
either do `with speech_group_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
input: The input text used for the speech request.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
"""
return GLOBAL_TRACE_PROVIDER.create_span(
span_data=SpeechGroupSpanData(input=input),
span_id=span_id,
parent=parent,
disabled=disabled,
)

View file

@ -5,6 +5,7 @@ import queue
import random
import threading
import time
from functools import cached_property
from typing import Any
import httpx
@ -50,9 +51,9 @@ class BackendSpanExporter(TracingExporter):
base_delay: Base delay (in seconds) for the first backoff.
max_delay: Maximum delay (in seconds) for backoff growth.
"""
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.organization = organization or os.environ.get("OPENAI_ORG_ID")
self.project = project or os.environ.get("OPENAI_PROJECT_ID")
self._api_key = api_key
self._organization = organization
self._project = project
self.endpoint = endpoint
self.max_retries = max_retries
self.base_delay = base_delay
@ -68,8 +69,22 @@ class BackendSpanExporter(TracingExporter):
api_key: The OpenAI API key to use. This is the same key used by the OpenAI Python
client.
"""
# We're specifically setting the underlying cached property as well
self._api_key = api_key
self.api_key = api_key
@cached_property
def api_key(self):
return self._api_key or os.environ.get("OPENAI_API_KEY")
@cached_property
def organization(self):
return self._organization or os.environ.get("OPENAI_ORG_ID")
@cached_property
def project(self):
return self._project or os.environ.get("OPENAI_PROJECT_ID")
def export(self, items: list[Trace | Span[Any]]) -> None:
if not items:
return
@ -102,18 +117,22 @@ class BackendSpanExporter(TracingExporter):
# If the response is a client error (4xx), we wont retry
if 400 <= response.status_code < 500:
logger.error(f"Tracing client error {response.status_code}: {response.text}")
logger.error(
f"[non-fatal] Tracing client error {response.status_code}: {response.text}"
)
return
# For 5xx or other unexpected codes, treat it as transient and retry
logger.warning(f"Server error {response.status_code}, retrying.")
logger.warning(
f"[non-fatal] Tracing: server error {response.status_code}, retrying."
)
except httpx.RequestError as exc:
# Network or other I/O error, we'll retry
logger.warning(f"Request failed: {exc}")
logger.warning(f"[non-fatal] Tracing: request failed: {exc}")
# If we reach here, we need to retry or give up
if attempt >= self.max_retries:
logger.error("Max retries reached, giving up on this batch.")
logger.error("[non-fatal] Tracing: max retries reached, giving up on this batch.")
return
# Exponential backoff + jitter

View file

@ -51,7 +51,7 @@ class AgentSpanData(SpanData):
class FunctionSpanData(SpanData):
__slots__ = ("name", "input", "output")
def __init__(self, name: str, input: str | None, output: str | None):
def __init__(self, name: str, input: str | None, output: Any | None):
self.name = name
self.input = input
self.output = output
@ -65,7 +65,7 @@ class FunctionSpanData(SpanData):
"type": self.type,
"name": self.name,
"input": self.input,
"output": self.output,
"output": str(self.output) if self.output else None,
}
@ -186,3 +186,99 @@ class GuardrailSpanData(SpanData):
"name": self.name,
"triggered": self.triggered,
}
class TranscriptionSpanData(SpanData):
__slots__ = (
"input",
"output",
"model",
"model_config",
)
def __init__(
self,
input: str | None = None,
input_format: str | None = "pcm",
output: str | None = None,
model: str | None = None,
model_config: Mapping[str, Any] | None = None,
):
self.input = input
self.input_format = input_format
self.output = output
self.model = model
self.model_config = model_config
@property
def type(self) -> str:
return "transcription"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"input": {
"data": self.input or "",
"format": self.input_format,
},
"output": self.output,
"model": self.model,
"model_config": self.model_config,
}
class SpeechSpanData(SpanData):
__slots__ = ("input", "output", "model", "model_config", "first_byte_at")
def __init__(
self,
input: str | None = None,
output: str | None = None,
output_format: str | None = "pcm",
model: str | None = None,
model_config: Mapping[str, Any] | None = None,
first_content_at: str | None = None,
):
self.input = input
self.output = output
self.output_format = output_format
self.model = model
self.model_config = model_config
self.first_content_at = first_content_at
@property
def type(self) -> str:
return "speech"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"input": self.input,
"output": {
"data": self.output or "",
"format": self.output_format,
},
"model": self.model,
"model_config": self.model_config,
"first_content_at": self.first_content_at,
}
class SpeechGroupSpanData(SpanData):
__slots__ = "input"
def __init__(
self,
input: str | None = None,
):
self.input = input
@property
def type(self) -> str:
return "speech-group"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"input": self.input,
}

View file

@ -15,3 +15,8 @@ def gen_trace_id() -> str:
def gen_span_id() -> str:
"""Generates a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}"
def gen_group_id() -> str:
"""Generates a new group ID."""
return f"group_{uuid.uuid4().hex[:24]}"

View file

@ -0,0 +1,51 @@
from .events import VoiceStreamEvent, VoiceStreamEventAudio, VoiceStreamEventLifecycle
from .exceptions import STTWebsocketConnectionError
from .input import AudioInput, StreamedAudioInput
from .model import (
StreamedTranscriptionSession,
STTModel,
STTModelSettings,
TTSModel,
TTSModelSettings,
VoiceModelProvider,
)
from .models.openai_model_provider import OpenAIVoiceModelProvider
from .models.openai_stt import OpenAISTTModel, OpenAISTTTranscriptionSession
from .models.openai_tts import OpenAITTSModel
from .pipeline import VoicePipeline
from .pipeline_config import VoicePipelineConfig
from .result import StreamedAudioResult
from .utils import get_sentence_based_splitter
from .workflow import (
SingleAgentVoiceWorkflow,
SingleAgentWorkflowCallbacks,
VoiceWorkflowBase,
VoiceWorkflowHelper,
)
__all__ = [
"AudioInput",
"StreamedAudioInput",
"STTModel",
"STTModelSettings",
"TTSModel",
"TTSModelSettings",
"VoiceModelProvider",
"StreamedAudioResult",
"SingleAgentVoiceWorkflow",
"OpenAIVoiceModelProvider",
"OpenAISTTModel",
"OpenAITTSModel",
"VoiceStreamEventAudio",
"VoiceStreamEventLifecycle",
"VoiceStreamEvent",
"VoicePipeline",
"VoicePipelineConfig",
"get_sentence_based_splitter",
"VoiceWorkflowHelper",
"VoiceWorkflowBase",
"SingleAgentWorkflowCallbacks",
"StreamedTranscriptionSession",
"OpenAISTTTranscriptionSession",
"STTWebsocketConnectionError",
]

View file

@ -0,0 +1,47 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Union
from typing_extensions import TypeAlias
from .imports import np, npt
@dataclass
class VoiceStreamEventAudio:
"""Streaming event from the VoicePipeline"""
data: npt.NDArray[np.int16 | np.float32] | None
"""The audio data."""
type: Literal["voice_stream_event_audio"] = "voice_stream_event_audio"
"""The type of event."""
@dataclass
class VoiceStreamEventLifecycle:
"""Streaming event from the VoicePipeline"""
event: Literal["turn_started", "turn_ended", "session_ended"]
"""The event that occurred."""
type: Literal["voice_stream_event_lifecycle"] = "voice_stream_event_lifecycle"
"""The type of event."""
@dataclass
class VoiceStreamEventError:
"""Streaming event from the VoicePipeline"""
error: Exception
"""The error that occurred."""
type: Literal["voice_stream_event_error"] = "voice_stream_event_error"
"""The type of event."""
VoiceStreamEvent: TypeAlias = Union[
VoiceStreamEventAudio, VoiceStreamEventLifecycle, VoiceStreamEventError
]
"""An event from the `VoicePipeline`, streamed via `StreamedAudioResult.stream()`."""

View file

@ -0,0 +1,8 @@
from ..exceptions import AgentsException
class STTWebsocketConnectionError(AgentsException):
"""Exception raised when the STT websocket connection fails."""
def __init__(self, message: str):
self.message = message

View file

@ -0,0 +1,11 @@
try:
import numpy as np
import numpy.typing as npt
import websockets
except ImportError as _e:
raise ImportError(
"`numpy` + `websockets` are required to use voice. You can install them via the optional "
"dependency group: `pip install 'openai-agents[voice]'`."
) from _e
__all__ = ["np", "npt", "websockets"]

88
src/agents/voice/input.py Normal file
View file

@ -0,0 +1,88 @@
from __future__ import annotations
import asyncio
import base64
import io
import wave
from dataclasses import dataclass
from ..exceptions import UserError
from .imports import np, npt
DEFAULT_SAMPLE_RATE = 24000
def _buffer_to_audio_file(
buffer: npt.NDArray[np.int16 | np.float32],
frame_rate: int = DEFAULT_SAMPLE_RATE,
sample_width: int = 2,
channels: int = 1,
) -> tuple[str, io.BytesIO, str]:
if buffer.dtype == np.float32:
# convert to int16
buffer = np.clip(buffer, -1.0, 1.0)
buffer = (buffer * 32767).astype(np.int16)
elif buffer.dtype != np.int16:
raise UserError("Buffer must be a numpy array of int16 or float32")
audio_file = io.BytesIO()
with wave.open(audio_file, "w") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(frame_rate)
wav_file.writeframes(buffer.tobytes())
audio_file.seek(0)
# (filename, bytes, content_type)
return ("audio.wav", audio_file, "audio/wav")
@dataclass
class AudioInput:
"""Static audio to be used as input for the VoicePipeline."""
buffer: npt.NDArray[np.int16 | np.float32]
"""
A buffer containing the audio data for the agent. Must be a numpy array of int16 or float32.
"""
frame_rate: int = DEFAULT_SAMPLE_RATE
"""The sample rate of the audio data. Defaults to 24000."""
sample_width: int = 2
"""The sample width of the audio data. Defaults to 2."""
channels: int = 1
"""The number of channels in the audio data. Defaults to 1."""
def to_audio_file(self) -> tuple[str, io.BytesIO, str]:
"""Returns a tuple of (filename, bytes, content_type)"""
return _buffer_to_audio_file(self.buffer, self.frame_rate, self.sample_width, self.channels)
def to_base64(self) -> str:
"""Returns the audio data as a base64 encoded string."""
if self.buffer.dtype == np.float32:
# convert to int16
self.buffer = np.clip(self.buffer, -1.0, 1.0)
self.buffer = (self.buffer * 32767).astype(np.int16)
elif self.buffer.dtype != np.int16:
raise UserError("Buffer must be a numpy array of int16 or float32")
return base64.b64encode(self.buffer.tobytes()).decode("utf-8")
class StreamedAudioInput:
"""Audio input represented as a stream of audio data. You can pass this to the `VoicePipeline`
and then push audio data into the queue using the `add_audio` method.
"""
def __init__(self):
self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = asyncio.Queue()
async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32]):
"""Adds more audio data to the stream.
Args:
audio: The audio data to add. Must be a numpy array of int16 or float32.
"""
await self.queue.put(audio)

193
src/agents/voice/model.py Normal file
View file

@ -0,0 +1,193 @@
from __future__ import annotations
import abc
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Callable, Literal
from .imports import np, npt
from .input import AudioInput, StreamedAudioInput
from .utils import get_sentence_based_splitter
DEFAULT_TTS_INSTRUCTIONS = (
"You will receive partial sentences. Do not complete the sentence, just read out the text."
)
DEFAULT_TTS_BUFFER_SIZE = 120
@dataclass
class TTSModelSettings:
"""Settings for a TTS model."""
voice: (
Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] | None
) = None
"""
The voice to use for the TTS model. If not provided, the default voice for the respective model
will be used.
"""
buffer_size: int = 120
"""The minimal size of the chunks of audio data that are being streamed out."""
dtype: npt.DTypeLike = np.int16
"""The data type for the audio data to be returned in."""
transform_data: (
Callable[[npt.NDArray[np.int16 | np.float32]], npt.NDArray[np.int16 | np.float32]] | None
) = None
"""
A function to transform the data from the TTS model. This is useful if you want the resulting
audio stream to have the data in a specific shape already.
"""
instructions: str = (
"You will receive partial sentences. Do not complete the sentence just read out the text."
)
"""
The instructions to use for the TTS model. This is useful if you want to control the tone of the
audio output.
"""
text_splitter: Callable[[str], tuple[str, str]] = get_sentence_based_splitter()
"""
A function to split the text into chunks. This is useful if you want to split the text into
chunks before sending it to the TTS model rather than waiting for the whole text to be
processed.
"""
speed: float | None = None
"""The speed with which the TTS model will read the text. Between 0.25 and 4.0."""
class TTSModel(abc.ABC):
"""A text-to-speech model that can convert text into audio output."""
@property
@abc.abstractmethod
def model_name(self) -> str:
"""The name of the TTS model."""
pass
@abc.abstractmethod
def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
"""Given a text string, produces a stream of audio bytes, in PCM format.
Args:
text: The text to convert to audio.
Returns:
An async iterator of audio bytes, in PCM format.
"""
pass
class StreamedTranscriptionSession(abc.ABC):
"""A streamed transcription of audio input."""
@abc.abstractmethod
def transcribe_turns(self) -> AsyncIterator[str]:
"""Yields a stream of text transcriptions. Each transcription is a turn in the conversation.
This method is expected to return only after `close()` is called.
"""
pass
@abc.abstractmethod
async def close(self) -> None:
"""Closes the session."""
pass
@dataclass
class STTModelSettings:
"""Settings for a speech-to-text model."""
prompt: str | None = None
"""Instructions for the model to follow."""
language: str | None = None
"""The language of the audio input."""
temperature: float | None = None
"""The temperature of the model."""
turn_detection: dict[str, Any] | None = None
"""The turn detection settings for the model when using streamed audio input."""
class STTModel(abc.ABC):
"""A speech-to-text model that can convert audio input into text."""
@property
@abc.abstractmethod
def model_name(self) -> str:
"""The name of the STT model."""
pass
@abc.abstractmethod
async def transcribe(
self,
input: AudioInput,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
) -> str:
"""Given an audio input, produces a text transcription.
Args:
input: The audio input to transcribe.
settings: The settings to use for the transcription.
trace_include_sensitive_data: Whether to include sensitive data in traces.
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
Returns:
The text transcription of the audio input.
"""
pass
@abc.abstractmethod
async def create_session(
self,
input: StreamedAudioInput,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
) -> StreamedTranscriptionSession:
"""Creates a new transcription session, which you can push audio to, and receive a stream
of text transcriptions.
Args:
input: The audio input to transcribe.
settings: The settings to use for the transcription.
trace_include_sensitive_data: Whether to include sensitive data in traces.
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
Returns:
A new transcription session.
"""
pass
class VoiceModelProvider(abc.ABC):
"""The base interface for a voice model provider.
A model provider is responsible for creating speech-to-text and text-to-speech models, given a
name.
"""
@abc.abstractmethod
def get_stt_model(self, model_name: str | None) -> STTModel:
"""Get a speech-to-text model by name.
Args:
model_name: The name of the model to get.
Returns:
The speech-to-text model.
"""
pass
@abc.abstractmethod
def get_tts_model(self, model_name: str | None) -> TTSModel:
"""Get a text-to-speech model by name."""

View file

View file

@ -0,0 +1,97 @@
from __future__ import annotations
import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from ...models import _openai_shared
from ..model import STTModel, TTSModel, VoiceModelProvider
from .openai_stt import OpenAISTTModel
from .openai_tts import OpenAITTSModel
_http_client: httpx.AsyncClient | None = None
# If we create a new httpx client for each request, that would mean no sharing of connection pools,
# which would mean worse latency and resource usage. So, we share the client across requests.
def shared_http_client() -> httpx.AsyncClient:
global _http_client
if _http_client is None:
_http_client = DefaultAsyncHttpxClient()
return _http_client
DEFAULT_STT_MODEL = "gpt-4o-transcribe"
DEFAULT_TTS_MODEL = "gpt-4o-mini-tts"
class OpenAIVoiceModelProvider(VoiceModelProvider):
"""A voice model provider that uses OpenAI models."""
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | None = None,
openai_client: AsyncOpenAI | None = None,
organization: str | None = None,
project: str | None = None,
) -> None:
"""Create a new OpenAI voice model provider.
Args:
api_key: The API key to use for the OpenAI client. If not provided, we will use the
default API key.
base_url: The base URL to use for the OpenAI client. If not provided, we will use the
default base URL.
openai_client: An optional OpenAI client to use. If not provided, we will create a new
OpenAI client using the api_key and base_url.
organization: The organization to use for the OpenAI client.
project: The project to use for the OpenAI client.
"""
if openai_client is not None:
assert api_key is None and base_url is None, (
"Don't provide api_key or base_url if you provide openai_client"
)
self._client: AsyncOpenAI | None = openai_client
else:
self._client = None
self._stored_api_key = api_key
self._stored_base_url = base_url
self._stored_organization = organization
self._stored_project = project
# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
# AsyncOpenAI() raises an error if you don't have an API key set.
def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
base_url=self._stored_base_url,
organization=self._stored_organization,
project=self._stored_project,
http_client=shared_http_client(),
)
return self._client
def get_stt_model(self, model_name: str | None) -> STTModel:
"""Get a speech-to-text model by name.
Args:
model_name: The name of the model to get.
Returns:
The speech-to-text model.
"""
return OpenAISTTModel(model_name or DEFAULT_STT_MODEL, self._get_client())
def get_tts_model(self, model_name: str | None) -> TTSModel:
"""Get a text-to-speech model by name.
Args:
model_name: The name of the model to get.
Returns:
The text-to-speech model.
"""
return OpenAITTSModel(model_name or DEFAULT_TTS_MODEL, self._get_client())

View file

@ -0,0 +1,457 @@
from __future__ import annotations
import asyncio
import base64
import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, cast
from openai import AsyncOpenAI
from agents.exceptions import AgentsException
from ... import _debug
from ...logger import logger
from ...tracing import Span, SpanError, TranscriptionSpanData, transcription_span
from ..exceptions import STTWebsocketConnectionError
from ..imports import np, npt, websockets
from ..input import AudioInput, StreamedAudioInput
from ..model import StreamedTranscriptionSession, STTModel, STTModelSettings
EVENT_INACTIVITY_TIMEOUT = 1000 # Timeout for inactivity in event processing
SESSION_CREATION_TIMEOUT = 10 # Timeout waiting for session.created event
SESSION_UPDATE_TIMEOUT = 10 # Timeout waiting for session.updated event
DEFAULT_TURN_DETECTION = {"type": "semantic_vad"}
@dataclass
class ErrorSentinel:
error: Exception
class SessionCompleteSentinel:
pass
class WebsocketDoneSentinel:
pass
def _audio_to_base64(audio_data: list[npt.NDArray[np.int16 | np.float32]]) -> str:
concatenated_audio = np.concatenate(audio_data)
if concatenated_audio.dtype == np.float32:
# convert to int16
concatenated_audio = np.clip(concatenated_audio, -1.0, 1.0)
concatenated_audio = (concatenated_audio * 32767).astype(np.int16)
audio_bytes = concatenated_audio.tobytes()
return base64.b64encode(audio_bytes).decode("utf-8")
async def _wait_for_event(
event_queue: asyncio.Queue[dict[str, Any]], expected_types: list[str], timeout: float
):
"""
Wait for an event from event_queue whose type is in expected_types within the specified timeout.
"""
start_time = time.time()
while True:
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
raise TimeoutError(f"Timeout waiting for event(s): {expected_types}")
evt = await asyncio.wait_for(event_queue.get(), timeout=remaining)
evt_type = evt.get("type", "")
if evt_type in expected_types:
return evt
elif evt_type == "error":
raise Exception(f"Error event: {evt.get('error')}")
class OpenAISTTTranscriptionSession(StreamedTranscriptionSession):
"""A transcription session for OpenAI's STT model."""
def __init__(
self,
input: StreamedAudioInput,
client: AsyncOpenAI,
model: str,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
):
self.connected: bool = False
self._client = client
self._model = model
self._settings = settings
self._turn_detection = settings.turn_detection or DEFAULT_TURN_DETECTION
self._trace_include_sensitive_data = trace_include_sensitive_data
self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data
self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
asyncio.Queue()
)
self._websocket: websockets.ClientConnection | None = None
self._event_queue: asyncio.Queue[dict[str, Any] | WebsocketDoneSentinel] = asyncio.Queue()
self._state_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._turn_audio_buffer: list[npt.NDArray[np.int16 | np.float32]] = []
self._tracing_span: Span[TranscriptionSpanData] | None = None
# tasks
self._listener_task: asyncio.Task[Any] | None = None
self._process_events_task: asyncio.Task[Any] | None = None
self._stream_audio_task: asyncio.Task[Any] | None = None
self._connection_task: asyncio.Task[Any] | None = None
self._stored_exception: Exception | None = None
def _start_turn(self) -> None:
self._tracing_span = transcription_span(
model=self._model,
model_config={
"temperature": self._settings.temperature,
"language": self._settings.language,
"prompt": self._settings.prompt,
"turn_detection": self._turn_detection,
},
)
self._tracing_span.start()
def _end_turn(self, _transcript: str) -> None:
if len(_transcript) < 1:
return
if self._tracing_span:
if self._trace_include_sensitive_audio_data:
self._tracing_span.span_data.input = _audio_to_base64(self._turn_audio_buffer)
self._tracing_span.span_data.input_format = "pcm"
if self._trace_include_sensitive_data:
self._tracing_span.span_data.output = _transcript
self._tracing_span.finish()
self._turn_audio_buffer = []
self._tracing_span = None
async def _event_listener(self) -> None:
assert self._websocket is not None, "Websocket not initialized"
async for message in self._websocket:
try:
event = json.loads(message)
if event.get("type") == "error":
raise STTWebsocketConnectionError(f"Error event: {event.get('error')}")
if event.get("type") in [
"session.updated",
"transcription_session.updated",
"session.created",
"transcription_session.created",
]:
await self._state_queue.put(event)
await self._event_queue.put(event)
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise STTWebsocketConnectionError("Error parsing events") from e
await self._event_queue.put(WebsocketDoneSentinel())
async def _configure_session(self) -> None:
assert self._websocket is not None, "Websocket not initialized"
await self._websocket.send(
json.dumps(
{
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16",
"input_audio_transcription": {"model": self._model},
"turn_detection": self._turn_detection,
},
}
)
)
async def _setup_connection(self, ws: websockets.ClientConnection) -> None:
self._websocket = ws
self._listener_task = asyncio.create_task(self._event_listener())
try:
event = await _wait_for_event(
self._state_queue,
["session.created", "transcription_session.created"],
SESSION_CREATION_TIMEOUT,
)
except TimeoutError as e:
wrapped_err = STTWebsocketConnectionError(
"Timeout waiting for transcription_session.created event"
)
await self._output_queue.put(ErrorSentinel(wrapped_err))
raise wrapped_err from e
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await self._configure_session()
try:
event = await _wait_for_event(
self._state_queue,
["session.updated", "transcription_session.updated"],
SESSION_UPDATE_TIMEOUT,
)
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Session updated")
else:
logger.debug(f"Session updated: {event}")
except TimeoutError as e:
wrapped_err = STTWebsocketConnectionError(
"Timeout waiting for transcription_session.updated event"
)
await self._output_queue.put(ErrorSentinel(wrapped_err))
raise wrapped_err from e
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise
async def _handle_events(self) -> None:
while True:
try:
event = await asyncio.wait_for(
self._event_queue.get(), timeout=EVENT_INACTIVITY_TIMEOUT
)
if isinstance(event, WebsocketDoneSentinel):
# processed all events and websocket is done
break
event_type = event.get("type", "unknown")
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = cast(str, event.get("transcript", ""))
if len(transcript) > 0:
self._end_turn(transcript)
self._start_turn()
await self._output_queue.put(transcript)
await asyncio.sleep(0) # yield control
except asyncio.TimeoutError:
# No new events for a while. Assume the session is done.
break
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await self._output_queue.put(SessionCompleteSentinel())
async def _stream_audio(
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
) -> None:
assert self._websocket is not None, "Websocket not initialized"
self._start_turn()
while True:
buffer = await audio_queue.get()
if buffer is None:
break
self._turn_audio_buffer.append(buffer)
try:
await self._websocket.send(
json.dumps(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(buffer.tobytes()).decode("utf-8"),
}
)
)
except websockets.ConnectionClosed:
break
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await asyncio.sleep(0) # yield control
async def _process_websocket_connection(self) -> None:
try:
async with websockets.connect(
"wss://api.openai.com/v1/realtime?intent=transcription",
additional_headers={
"Authorization": f"Bearer {self._client.api_key}",
"OpenAI-Beta": "realtime=v1",
"OpenAI-Log-Session": "1",
},
) as ws:
await self._setup_connection(ws)
self._process_events_task = asyncio.create_task(self._handle_events())
self._stream_audio_task = asyncio.create_task(self._stream_audio(self._input_queue))
self.connected = True
if self._listener_task:
await self._listener_task
else:
logger.error("Listener task not initialized")
raise AgentsException("Listener task not initialized")
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
def _check_errors(self) -> None:
if self._connection_task and self._connection_task.done():
exc = self._connection_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._process_events_task and self._process_events_task.done():
exc = self._process_events_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._stream_audio_task and self._stream_audio_task.done():
exc = self._stream_audio_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._listener_task and self._listener_task.done():
exc = self._listener_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
def _cleanup_tasks(self) -> None:
if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
if self._process_events_task and not self._process_events_task.done():
self._process_events_task.cancel()
if self._stream_audio_task and not self._stream_audio_task.done():
self._stream_audio_task.cancel()
if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
async def transcribe_turns(self) -> AsyncIterator[str]:
self._connection_task = asyncio.create_task(self._process_websocket_connection())
while True:
try:
turn = await self._output_queue.get()
except asyncio.CancelledError:
break
if (
turn is None
or isinstance(turn, ErrorSentinel)
or isinstance(turn, SessionCompleteSentinel)
):
self._output_queue.task_done()
break
yield turn
self._output_queue.task_done()
if self._tracing_span:
self._end_turn("")
if self._websocket:
await self._websocket.close()
self._check_errors()
if self._stored_exception:
raise self._stored_exception
async def close(self) -> None:
if self._websocket:
await self._websocket.close()
self._cleanup_tasks()
class OpenAISTTModel(STTModel):
"""A speech-to-text model for OpenAI."""
def __init__(
self,
model: str,
openai_client: AsyncOpenAI,
):
"""Create a new OpenAI speech-to-text model.
Args:
model: The name of the model to use.
openai_client: The OpenAI client to use.
"""
self.model = model
self._client = openai_client
@property
def model_name(self) -> str:
return self.model
def _non_null_or_not_given(self, value: Any) -> Any:
return value if value is not None else None # NOT_GIVEN
async def transcribe(
self,
input: AudioInput,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
) -> str:
"""Transcribe an audio input.
Args:
input: The audio input to transcribe.
settings: The settings to use for the transcription.
Returns:
The transcribed text.
"""
with transcription_span(
model=self.model,
input=input.to_base64() if trace_include_sensitive_audio_data else "",
input_format="pcm",
model_config={
"temperature": self._non_null_or_not_given(settings.temperature),
"language": self._non_null_or_not_given(settings.language),
"prompt": self._non_null_or_not_given(settings.prompt),
},
) as span:
try:
response = await self._client.audio.transcriptions.create(
model=self.model,
file=input.to_audio_file(),
prompt=self._non_null_or_not_given(settings.prompt),
language=self._non_null_or_not_given(settings.language),
temperature=self._non_null_or_not_given(settings.temperature),
)
if trace_include_sensitive_data:
span.span_data.output = response.text
return response.text
except Exception as e:
span.span_data.output = ""
span.set_error(SpanError(message=str(e), data={}))
raise e
async def create_session(
self,
input: StreamedAudioInput,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
) -> StreamedTranscriptionSession:
"""Create a new transcription session.
Args:
input: The audio input to transcribe.
settings: The settings to use for the transcription.
trace_include_sensitive_data: Whether to include sensitive data in traces.
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
Returns:
A new transcription session.
"""
return OpenAISTTTranscriptionSession(
input,
self._client,
self.model,
settings,
trace_include_sensitive_data,
trace_include_sensitive_audio_data,
)

View file

@ -0,0 +1,54 @@
from collections.abc import AsyncIterator
from typing import Literal
from openai import AsyncOpenAI
from ..model import TTSModel, TTSModelSettings
DEFAULT_VOICE: Literal["ash"] = "ash"
class OpenAITTSModel(TTSModel):
"""A text-to-speech model for OpenAI."""
def __init__(
self,
model: str,
openai_client: AsyncOpenAI,
):
"""Create a new OpenAI text-to-speech model.
Args:
model: The name of the model to use.
openai_client: The OpenAI client to use.
"""
self.model = model
self._client = openai_client
@property
def model_name(self) -> str:
return self.model
async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
"""Run the text-to-speech model.
Args:
text: The text to convert to speech.
settings: The settings to use for the text-to-speech model.
Returns:
An iterator of audio chunks.
"""
response = self._client.audio.speech.with_streaming_response.create(
model=self.model,
voice=settings.voice or DEFAULT_VOICE,
input=text,
response_format="pcm",
extra_body={
"instructions": settings.instructions,
},
)
async with response as stream:
async for chunk in stream.iter_bytes(chunk_size=1024):
yield chunk

View file

@ -0,0 +1,151 @@
from __future__ import annotations
import asyncio
from .._run_impl import TraceCtxManager
from ..exceptions import UserError
from ..logger import logger
from .input import AudioInput, StreamedAudioInput
from .model import STTModel, TTSModel
from .pipeline_config import VoicePipelineConfig
from .result import StreamedAudioResult
from .workflow import VoiceWorkflowBase
class VoicePipeline:
"""An opinionated voice agent pipeline. It works in three steps:
1. Transcribe audio input into text.
2. Run the provided `workflow`, which produces a sequence of text responses.
3. Convert the text responses into streaming audio output.
"""
def __init__(
self,
*,
workflow: VoiceWorkflowBase,
stt_model: STTModel | str | None = None,
tts_model: TTSModel | str | None = None,
config: VoicePipelineConfig | None = None,
):
"""Create a new voice pipeline.
Args:
workflow: The workflow to run. See `VoiceWorkflowBase`.
stt_model: The speech-to-text model to use. If not provided, a default OpenAI
model will be used.
tts_model: The text-to-speech model to use. If not provided, a default OpenAI
model will be used.
config: The pipeline configuration. If not provided, a default configuration will be
used.
"""
self.workflow = workflow
self.stt_model = stt_model if isinstance(stt_model, STTModel) else None
self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None
self._stt_model_name = stt_model if isinstance(stt_model, str) else None
self._tts_model_name = tts_model if isinstance(tts_model, str) else None
self.config = config or VoicePipelineConfig()
async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult:
"""Run the voice pipeline.
Args:
audio_input: The audio input to process. This can either be an `AudioInput` instance,
which is a single static buffer, or a `StreamedAudioInput` instance, which is a
stream of audio data that you can append to.
Returns:
A `StreamedAudioResult` instance. You can use this object to stream audio events and
play them out.
"""
if isinstance(audio_input, AudioInput):
return await self._run_single_turn(audio_input)
elif isinstance(audio_input, StreamedAudioInput):
return await self._run_multi_turn(audio_input)
else:
raise UserError(f"Unsupported audio input type: {type(audio_input)}")
def _get_tts_model(self) -> TTSModel:
if not self.tts_model:
self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name)
return self.tts_model
def _get_stt_model(self) -> STTModel:
if not self.stt_model:
self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name)
return self.stt_model
async def _process_audio_input(self, audio_input: AudioInput) -> str:
model = self._get_stt_model()
return await model.transcribe(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)
async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
# trace
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None, # Automatically generated
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
disabled=self.config.tracing_disabled,
):
input_text = await self._process_audio_input(audio_input)
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)
async def stream_events():
try:
async for text_event in self.workflow.run(input_text):
await output._add_text(text_event)
await output._turn_done()
await output._done()
except Exception as e:
logger.error(f"Error processing single turn: {e}")
await output._add_error(e)
raise e
output._set_task(asyncio.create_task(stream_events()))
return output
async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None,
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
disabled=self.config.tracing_disabled,
):
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)
transcription_session = await self._get_stt_model().create_session(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)
async def process_turns():
try:
async for input_text in transcription_session.transcribe_turns():
result = self.workflow.run(input_text)
async for text_event in result:
await output._add_text(text_event)
await output._turn_done()
except Exception as e:
logger.error(f"Error processing turns: {e}")
await output._add_error(e)
raise e
finally:
await transcription_session.close()
await output._done()
output._set_task(asyncio.create_task(process_turns()))
return output

View file

@ -0,0 +1,46 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from ..tracing.util import gen_group_id
from .model import STTModelSettings, TTSModelSettings, VoiceModelProvider
from .models.openai_model_provider import OpenAIVoiceModelProvider
@dataclass
class VoicePipelineConfig:
"""Configuration for a `VoicePipeline`."""
model_provider: VoiceModelProvider = field(default_factory=OpenAIVoiceModelProvider)
"""The voice model provider to use for the pipeline. Defaults to OpenAI."""
tracing_disabled: bool = False
"""Whether to disable tracing of the pipeline. Defaults to `False`."""
trace_include_sensitive_data: bool = True
"""Whether to include sensitive data in traces. Defaults to `True`. This is specifically for the
voice pipeline, and not for anything that goes on inside your Workflow."""
trace_include_sensitive_audio_data: bool = True
"""Whether to include audio data in traces. Defaults to `True`."""
workflow_name: str = "Voice Agent"
"""The name of the workflow to use for tracing. Defaults to `Voice Agent`."""
group_id: str = field(default_factory=gen_group_id)
"""
A grouping identifier to use for tracing, to link multiple traces from the same conversation
or process. If not provided, we will create a random group ID.
"""
trace_metadata: dict[str, Any] | None = None
"""
An optional dictionary of additional metadata to include with the trace.
"""
stt_settings: STTModelSettings = field(default_factory=STTModelSettings)
"""The settings to use for the STT model."""
tts_settings: TTSModelSettings = field(default_factory=TTSModelSettings)
"""The settings to use for the TTS model."""

287
src/agents/voice/result.py Normal file
View file

@ -0,0 +1,287 @@
from __future__ import annotations
import asyncio
import base64
from collections.abc import AsyncIterator
from typing import Any
from ..exceptions import UserError
from ..logger import logger
from ..tracing import Span, SpeechGroupSpanData, speech_group_span, speech_span
from ..tracing.util import time_iso
from .events import (
VoiceStreamEvent,
VoiceStreamEventAudio,
VoiceStreamEventError,
VoiceStreamEventLifecycle,
)
from .imports import np, npt
from .model import TTSModel, TTSModelSettings
from .pipeline_config import VoicePipelineConfig
def _audio_to_base64(audio_data: list[bytes]) -> str:
joined_audio_data = b"".join(audio_data)
return base64.b64encode(joined_audio_data).decode("utf-8")
class StreamedAudioResult:
"""The output of a `VoicePipeline`. Streams events and audio data as they're generated."""
def __init__(
self,
tts_model: TTSModel,
tts_settings: TTSModelSettings,
voice_pipeline_config: VoicePipelineConfig,
):
"""Create a new `StreamedAudioResult` instance.
Args:
tts_model: The TTS model to use.
tts_settings: The TTS settings to use.
voice_pipeline_config: The voice pipeline config to use.
"""
self.tts_model = tts_model
self.tts_settings = tts_settings
self.total_output_text = ""
self.instructions = tts_settings.instructions
self.text_generation_task: asyncio.Task[Any] | None = None
self._voice_pipeline_config = voice_pipeline_config
self._text_buffer = ""
self._turn_text_buffer = ""
self._queue: asyncio.Queue[VoiceStreamEvent] = asyncio.Queue()
self._tasks: list[asyncio.Task[Any]] = []
self._ordered_tasks: list[
asyncio.Queue[VoiceStreamEvent | None]
] = [] # New: list to hold local queues for each text segment
self._dispatcher_task: asyncio.Task[Any] | None = (
None # Task to dispatch audio chunks in order
)
self._done_processing = False
self._buffer_size = tts_settings.buffer_size
self._started_processing_turn = False
self._first_byte_received = False
self._generation_start_time: str | None = None
self._completed_session = False
self._stored_exception: BaseException | None = None
self._tracing_span: Span[SpeechGroupSpanData] | None = None
async def _start_turn(self):
if self._started_processing_turn:
return
self._tracing_span = speech_group_span()
self._tracing_span.start()
self._started_processing_turn = True
self._first_byte_received = False
self._generation_start_time = time_iso()
await self._queue.put(VoiceStreamEventLifecycle(event="turn_started"))
def _set_task(self, task: asyncio.Task[Any]):
self.text_generation_task = task
async def _add_error(self, error: Exception):
await self._queue.put(VoiceStreamEventError(error))
def _transform_audio_buffer(
self, buffer: list[bytes], output_dtype: npt.DTypeLike
) -> npt.NDArray[np.int16 | np.float32]:
np_array = np.frombuffer(b"".join(buffer), dtype=np.int16)
if output_dtype == np.int16:
return np_array
elif output_dtype == np.float32:
return (np_array.astype(np.float32) / 32767.0).reshape(-1, 1)
else:
raise UserError("Invalid output dtype")
async def _stream_audio(
self,
text: str,
local_queue: asyncio.Queue[VoiceStreamEvent | None],
finish_turn: bool = False,
):
with speech_span(
model=self.tts_model.model_name,
input=text if self._voice_pipeline_config.trace_include_sensitive_data else "",
model_config={
"voice": self.tts_settings.voice,
"instructions": self.instructions,
"speed": self.tts_settings.speed,
},
output_format="pcm",
parent=self._tracing_span,
) as tts_span:
try:
first_byte_received = False
buffer: list[bytes] = []
full_audio_data: list[bytes] = []
async for chunk in self.tts_model.run(text, self.tts_settings):
if not first_byte_received:
first_byte_received = True
tts_span.span_data.first_content_at = time_iso()
if chunk:
buffer.append(chunk)
full_audio_data.append(chunk)
if len(buffer) >= self._buffer_size:
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
if self.tts_settings.transform_data:
audio_np = self.tts_settings.transform_data(audio_np)
await local_queue.put(
VoiceStreamEventAudio(data=audio_np)
) # Use local queue
buffer = []
if buffer:
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
if self.tts_settings.transform_data:
audio_np = self.tts_settings.transform_data(audio_np)
await local_queue.put(VoiceStreamEventAudio(data=audio_np)) # Use local queue
if self._voice_pipeline_config.trace_include_sensitive_audio_data:
tts_span.span_data.output = _audio_to_base64(full_audio_data)
else:
tts_span.span_data.output = ""
if finish_turn:
await local_queue.put(VoiceStreamEventLifecycle(event="turn_ended"))
else:
await local_queue.put(None) # Signal completion for this segment
except Exception as e:
tts_span.set_error(
{
"message": str(e),
"data": {
"text": text
if self._voice_pipeline_config.trace_include_sensitive_data
else "",
},
}
)
logger.error(f"Error streaming audio: {e}")
# Signal completion for whole session because of error
await local_queue.put(VoiceStreamEventLifecycle(event="session_ended"))
raise e
async def _add_text(self, text: str):
await self._start_turn()
self._text_buffer += text
self.total_output_text += text
self._turn_text_buffer += text
combined_sentences, self._text_buffer = self.tts_settings.text_splitter(self._text_buffer)
if len(combined_sentences) >= 20:
local_queue: asyncio.Queue[VoiceStreamEvent | None] = asyncio.Queue()
self._ordered_tasks.append(local_queue)
self._tasks.append(
asyncio.create_task(self._stream_audio(combined_sentences, local_queue))
)
if self._dispatcher_task is None:
self._dispatcher_task = asyncio.create_task(self._dispatch_audio())
async def _turn_done(self):
if self._text_buffer:
local_queue: asyncio.Queue[VoiceStreamEvent | None] = asyncio.Queue()
self._ordered_tasks.append(local_queue) # Append the local queue for the final segment
self._tasks.append(
asyncio.create_task(
self._stream_audio(self._text_buffer, local_queue, finish_turn=True)
)
)
self._text_buffer = ""
self._done_processing = True
if self._dispatcher_task is None:
self._dispatcher_task = asyncio.create_task(self._dispatch_audio())
await asyncio.gather(*self._tasks)
def _finish_turn(self):
if self._tracing_span:
if self._voice_pipeline_config.trace_include_sensitive_data:
self._tracing_span.span_data.input = self._turn_text_buffer
else:
self._tracing_span.span_data.input = ""
self._tracing_span.finish()
self._tracing_span = None
self._turn_text_buffer = ""
self._started_processing_turn = False
async def _done(self):
self._completed_session = True
await self._wait_for_completion()
async def _dispatch_audio(self):
# Dispatch audio chunks from each segment in the order they were added
while True:
if len(self._ordered_tasks) == 0:
if self._completed_session:
break
await asyncio.sleep(0)
continue
local_queue = self._ordered_tasks.pop(0)
while True:
chunk = await local_queue.get()
if chunk is None:
break
await self._queue.put(chunk)
if isinstance(chunk, VoiceStreamEventLifecycle):
local_queue.task_done()
if chunk.event == "turn_ended":
self._finish_turn()
break
await self._queue.put(VoiceStreamEventLifecycle(event="session_ended"))
async def _wait_for_completion(self):
tasks: list[asyncio.Task[Any]] = self._tasks
if self._dispatcher_task is not None:
tasks.append(self._dispatcher_task)
await asyncio.gather(*tasks)
def _cleanup_tasks(self):
self._finish_turn()
for task in self._tasks:
if not task.done():
task.cancel()
if self._dispatcher_task and not self._dispatcher_task.done():
self._dispatcher_task.cancel()
if self.text_generation_task and not self.text_generation_task.done():
self.text_generation_task.cancel()
def _check_errors(self):
for task in self._tasks:
if task.done():
if task.exception():
self._stored_exception = task.exception()
break
async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
"""Stream the events and audio data as they're generated."""
while True:
try:
event = await self._queue.get()
except asyncio.CancelledError:
break
if isinstance(event, VoiceStreamEventError):
self._stored_exception = event.error
logger.error(f"Error processing output: {event.error}")
break
if event is None:
break
yield event
if event.type == "voice_stream_event_lifecycle" and event.event == "session_ended":
break
self._check_errors()
self._cleanup_tasks()
if self._stored_exception:
raise self._stored_exception

37
src/agents/voice/utils.py Normal file
View file

@ -0,0 +1,37 @@
import re
from typing import Callable
def get_sentence_based_splitter(
min_sentence_length: int = 20,
) -> Callable[[str], tuple[str, str]]:
"""Returns a function that splits text into chunks based on sentence boundaries.
Args:
min_sentence_length: The minimum length of a sentence to be included in a chunk.
Returns:
A function that splits text into chunks based on sentence boundaries.
"""
def sentence_based_text_splitter(text_buffer: str) -> tuple[str, str]:
"""
A function to split the text into chunks. This is useful if you want to split the text into
chunks before sending it to the TTS model rather than waiting for the whole text to be
processed.
Args:
text_buffer: The text to split.
Returns:
A tuple of the text to process and the remaining text buffer.
"""
sentences = re.split(r"(?<=[.!?])\s+", text_buffer.strip())
if len(sentences) >= 1:
combined_sentences = " ".join(sentences[:-1])
if len(combined_sentences) >= min_sentence_length:
remaining_text_buffer = sentences[-1]
return combined_sentences, remaining_text_buffer
return "", text_buffer
return sentence_based_text_splitter

View file

@ -0,0 +1,93 @@
from __future__ import annotations
import abc
from collections.abc import AsyncIterator
from typing import Any
from ..agent import Agent
from ..items import TResponseInputItem
from ..result import RunResultStreaming
from ..run import Runner
class VoiceWorkflowBase(abc.ABC):
"""
A base class for a voice workflow. You must implement the `run` method. A "workflow" is any
code you want, that receives a transcription and yields text that will be turned into speech
by a text-to-speech model.
In most cases, you'll create `Agent`s and use `Runner.run_streamed()` to run them, returning
some or all of the text events from the stream. You can use the `VoiceWorkflowHelper` class to
help with extracting text events from the stream.
If you have a simple workflow that has a single starting agent and no custom logic, you can
use `SingleAgentVoiceWorkflow` directly.
"""
@abc.abstractmethod
def run(self, transcription: str) -> AsyncIterator[str]:
"""
Run the voice workflow. You will receive an input transcription, and must yield text that
will be spoken to the user. You can run whatever logic you want here. In most cases, the
final logic will involve calling `Runner.run_streamed()` and yielding any text events from
the stream.
"""
pass
class VoiceWorkflowHelper:
@classmethod
async def stream_text_from(cls, result: RunResultStreaming) -> AsyncIterator[str]:
"""Wraps a `RunResultStreaming` object and yields text events from the stream."""
async for event in result.stream_events():
if (
event.type == "raw_response_event"
and event.data.type == "response.output_text.delta"
):
yield event.data.delta
class SingleAgentWorkflowCallbacks:
def on_run(self, workflow: SingleAgentVoiceWorkflow, transcription: str) -> None:
"""Called when the workflow is run."""
pass
class SingleAgentVoiceWorkflow(VoiceWorkflowBase):
"""A simple voice workflow that runs a single agent. Each transcription and result is added to
the input history.
For more complex workflows (e.g. multiple Runner calls, custom message history, custom logic,
custom configs), subclass `VoiceWorkflowBase` and implement your own logic.
"""
def __init__(self, agent: Agent[Any], callbacks: SingleAgentWorkflowCallbacks | None = None):
"""Create a new single agent voice workflow.
Args:
agent: The agent to run.
callbacks: Optional callbacks to call during the workflow.
"""
self._input_history: list[TResponseInputItem] = []
self._current_agent = agent
self._callbacks = callbacks
async def run(self, transcription: str) -> AsyncIterator[str]:
if self._callbacks:
self._callbacks.on_run(self, transcription)
# Add the transcription to the input history
self._input_history.append(
{
"role": "user",
"content": transcription,
}
)
# Run the agent
result = Runner.run_streamed(self._current_agent, self._input_history)
# Stream the text from the result
async for chunk in VoiceWorkflowHelper.stream_text_from(result):
yield chunk
# Update the input history and current agent
self._input_history = result.to_input_list()
self._current_agent = result.last_agent

View file

@ -21,6 +21,8 @@ from agents import (
UserError,
handoff,
)
from agents.agent import ToolsToFinalOutputResult
from agents.tool import FunctionToolResult, function_tool
from .fake_model import FakeModel
from .test_responses import (
@ -552,3 +554,83 @@ async def test_output_guardrail_tripwire_triggered_causes_exception():
with pytest.raises(OutputGuardrailTripwireTriggered):
await Runner.run(agent, input="user_message")
@function_tool
def test_tool_one():
return Foo(bar="tool_one_result")
@function_tool
def test_tool_two():
return "tool_two_result"
@pytest.mark.asyncio
async def test_tool_use_behavior_first_output():
model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
tool_use_behavior="stop_on_first_tool",
output_type=Foo,
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[
get_text_message("a_message"),
get_function_tool_call("test_tool_one", None),
get_function_tool_call("test_tool_two", None),
],
]
)
result = await Runner.run(agent, input="user_message")
assert result.final_output == Foo(bar="tool_one_result"), (
"should have used the first tool result"
)
def custom_tool_use_behavior(
context: RunContextWrapper[Any], results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
if "test_tool_one" in [result.tool.name for result in results]:
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
else:
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
@pytest.mark.asyncio
async def test_tool_use_behavior_custom_function():
model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
tool_use_behavior=custom_tool_use_behavior,
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[
get_text_message("a_message"),
get_function_tool_call("test_tool_two", None),
],
# Second turn: a message and tool call
[
get_text_message("a_message"),
get_function_tool_call("test_tool_one", None),
get_function_tool_call("test_tool_two", None),
],
]
)
result = await Runner.run(agent, input="user_message")
assert len(result.raw_responses) == 2, "should have two model responses"
assert result.final_output == "the_final_output", "should have used the custom function"

View file

@ -674,7 +674,7 @@ async def test_streaming_events():
total_expected_item_count = sum(expected_item_type_map.values())
assert event_counts["run_item_stream_event"] == total_expected_item_count, (
f"Expectd {total_expected_item_count} events, got {event_counts['run_item_stream_event']}"
f"Expected {total_expected_item_count} events, got {event_counts['run_item_stream_event']}"
f"Expected events were: {expected_item_type_map}, got {event_counts}"
)

View file

@ -9,7 +9,7 @@ from agents import Agent, RunConfig, Runner, trace
from .fake_model import FakeModel
from .test_responses import get_text_message
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
from .testing_processor import assert_no_traces, fetch_normalized_spans
@pytest.mark.asyncio
@ -23,9 +23,6 @@ async def test_single_run_is_single_trace():
await Runner.run(agent, input="first_test")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -45,12 +42,6 @@ async def test_single_run_is_single_trace():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1, (
f"Got {len(spans)}, but expected 1: the agent span. data:"
f"{[span.span_data for span in spans]}"
)
@pytest.mark.asyncio
async def test_multiple_runs_are_multiple_traces():
@ -69,9 +60,6 @@ async def test_multiple_runs_are_multiple_traces():
await Runner.run(agent, input="first_test")
await Runner.run(agent, input="second_test")
traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -105,9 +93,6 @@ async def test_multiple_runs_are_multiple_traces():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run"
@pytest.mark.asyncio
async def test_wrapped_trace_is_single_trace():
@ -129,9 +114,6 @@ async def test_wrapped_trace_is_single_trace():
await Runner.run(agent, input="second_test")
await Runner.run(agent, input="third_test")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -169,9 +151,6 @@ async def test_wrapped_trace_is_single_trace():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run"
@pytest.mark.asyncio
async def test_parent_disabled_trace_disabled_agent_trace():
@ -185,14 +164,7 @@ async def test_parent_disabled_trace_disabled_agent_trace():
await Runner.run(agent, input="first_test")
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])
spans = fetch_ordered_spans()
assert len(spans) == 0, (
f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}"
)
assert_no_traces()
@pytest.mark.asyncio
@ -206,12 +178,7 @@ async def test_manual_disabling_works():
await Runner.run(agent, input="first_test", run_config=RunConfig(tracing_disabled=True))
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])
spans = fetch_ordered_spans()
assert len(spans) == 0, f"Got {len(spans)}, but expected no spans"
assert_no_traces()
@pytest.mark.asyncio
@ -226,16 +193,29 @@ async def test_trace_config_works():
await Runner.run(
agent,
input="first_test",
run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="456"),
run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="trace_456"),
)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
export = traces[0].export()
assert export is not None, "Trace export should not be None"
assert export["workflow_name"] == "Foo bar"
assert export["group_id"] == "123"
assert export["id"] == "456"
assert fetch_normalized_spans(keep_trace_id=True) == snapshot(
[
{
"id": "trace_456",
"workflow_name": "Foo bar",
"group_id": "123",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
@pytest.mark.asyncio
@ -255,9 +235,6 @@ async def test_not_starting_streaming_creates_trace():
break
await asyncio.sleep(0.1)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -277,9 +254,6 @@ async def test_not_starting_streaming_creates_trace():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span"
# Await the stream to avoid warnings about it not being awaited
async for _ in result.stream_events():
pass
@ -298,8 +272,24 @@ async def test_streaming_single_run_is_single_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
@pytest.mark.asyncio
@ -324,8 +314,38 @@ async def test_multiple_streamed_runs_are_multiple_traces():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
]
)
@pytest.mark.asyncio
@ -356,8 +376,42 @@ async def test_wrapped_streaming_trace_is_single_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)
@pytest.mark.asyncio
@ -386,8 +440,42 @@ async def test_wrapped_mixed_trace_is_single_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)
@pytest.mark.asyncio
@ -409,8 +497,7 @@ async def test_parent_disabled_trace_disables_streaming_agent_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert_no_traces()
@pytest.mark.asyncio
@ -431,5 +518,4 @@ async def test_manual_streaming_disabling_works():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert_no_traces()

View file

@ -49,10 +49,10 @@ async def test_simple_function():
assert tool.name == "simple_function"
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
assert result == "6"
assert result == 6
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
assert result == "3"
assert result == 3
# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):

View file

@ -152,9 +152,13 @@ def optional_param_function(a: int, b: Optional[int] = None) -> str:
@pytest.mark.asyncio
async def test_optional_param_function():
async def test_non_strict_mode_function():
tool = optional_param_function
assert tool.strict_json_schema is False, "strict_json_schema should be False"
assert tool.params_json_schema.get("required") == ["a"], "required should only be a"
input_data = {"a": 5}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_no_b"
@ -165,7 +169,7 @@ async def test_optional_param_function():
@function_tool(strict_mode=False)
def multiple_optional_params_function(
def all_optional_params_function(
x: int = 42,
y: str = "hello",
z: Optional[int] = None,
@ -176,8 +180,12 @@ def multiple_optional_params_function(
@pytest.mark.asyncio
async def test_multiple_optional_params_function():
tool = multiple_optional_params_function
async def test_all_optional_params_function():
tool = all_optional_params_function
assert tool.strict_json_schema is False, "strict_json_schema should be False"
assert tool.params_json_schema.get("required") is None, "required should be empty"
input_data: dict[str, Any] = {}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))

View file

@ -223,7 +223,7 @@ class Foo(TypedDict):
@pytest.mark.asyncio
async def test_structed_output_non_streamed_agent_hooks():
async def test_structured_output_non_streamed_agent_hooks():
hooks = RunHooksForTests()
model = FakeModel()
agent_1 = Agent(name="test_1", model=model)
@ -296,7 +296,7 @@ async def test_structed_output_non_streamed_agent_hooks():
@pytest.mark.asyncio
async def test_structed_output_streamed_agent_hooks():
async def test_structured_output_streamed_agent_hooks():
hooks = RunHooksForTests()
model = FakeModel()
agent_1 = Agent(name="test_1", model=model)

View file

@ -7,7 +7,7 @@ from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
from agents.tracing.span_data import ResponseSpanData
from tests import fake_model
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans
from .testing_processor import assert_no_spans, fetch_normalized_spans, fetch_ordered_spans
class DummyTracing:
@ -64,13 +64,6 @@ async def test_get_response_creates_trace(monkeypatch):
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
assert spans[0].span_data.response is not None
assert spans[0].span_data.response.id == "dummy-id"
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
@ -96,9 +89,8 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert spans[0].span_data.response is None
[span] = fetch_ordered_spans()
assert span.span_data.response is None
@pytest.mark.allow_call_model_methods
@ -123,8 +115,7 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans()
assert len(spans) == 0
assert_no_spans()
@pytest.mark.allow_call_model_methods
@ -164,12 +155,6 @@ async def test_stream_response_creates_trace(monkeypatch):
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
assert spans[0].span_data.response is not None
assert spans[0].span_data.response.id == "dummy-id-123"
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
@ -203,10 +188,9 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
assert spans[0].span_data.response is None
[span] = fetch_ordered_spans()
assert isinstance(span.span_data, ResponseSpanData)
assert span.span_data.response is None
@pytest.mark.allow_call_model_methods
@ -239,5 +223,4 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans()
assert len(spans) == 0
assert_no_spans()

View file

@ -0,0 +1,161 @@
import pytest
from agents import Agent, ModelSettings, Runner, Tool
from agents._run_impl import RunImpl
from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_function_tool_call,
get_text_message,
)
class TestToolChoiceReset:
def test_should_reset_tool_choice_direct(self):
"""
Test the _should_reset_tool_choice method directly with various inputs
to ensure it correctly identifies cases where reset is needed.
"""
# Case 1: tool_choice = None should not reset
model_settings = ModelSettings(tool_choice=None)
tools1: list[Tool] = [get_function_tool("tool1")]
# Cast to list[Tool] to fix type checking issues
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
# Case 2: tool_choice = "auto" should not reset
model_settings = ModelSettings(tool_choice="auto")
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
# Case 3: tool_choice = "none" should not reset
model_settings = ModelSettings(tool_choice="none")
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
# Case 4: tool_choice = "required" with one tool should reset
model_settings = ModelSettings(tool_choice="required")
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
# Case 5: tool_choice = "required" with multiple tools should not reset
model_settings = ModelSettings(tool_choice="required")
tools2: list[Tool] = [get_function_tool("tool1"), get_function_tool("tool2")]
assert not RunImpl._should_reset_tool_choice(model_settings, tools2)
# Case 6: Specific tool choice should reset
model_settings = ModelSettings(tool_choice="specific_tool")
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
@pytest.mark.asyncio
async def test_required_tool_choice_with_multiple_runs(self):
"""
Test scenario 1: When multiple runs are executed with tool_choice="required"
Ensure each run works correctly and doesn't get stuck in infinite loop
Also verify that tool_choice remains "required" between runs
"""
# Set up our fake model with responses for two runs
fake_model = FakeModel()
fake_model.add_multiple_turn_outputs([
[get_text_message("First run response")],
[get_text_message("Second run response")]
])
# Create agent with a custom tool and tool_choice="required"
custom_tool = get_function_tool("custom_tool")
agent = Agent(
name="test_agent",
model=fake_model,
tools=[custom_tool],
model_settings=ModelSettings(tool_choice="required"),
)
# First run should work correctly and preserve tool_choice
result1 = await Runner.run(agent, "first run")
assert result1.final_output == "First run response"
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
# Second run should also work correctly with tool_choice still required
result2 = await Runner.run(agent, "second run")
assert result2.final_output == "Second run response"
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
@pytest.mark.asyncio
async def test_required_with_stop_at_tool_name(self):
"""
Test scenario 2: When using required tool_choice with stop_at_tool_names behavior
Ensure it correctly stops at the specified tool
"""
# Set up fake model to return a tool call for second_tool
fake_model = FakeModel()
fake_model.set_next_output([
get_function_tool_call("second_tool", "{}")
])
# Create agent with two tools and tool_choice="required" and stop_at_tool behavior
first_tool = get_function_tool("first_tool", return_value="first tool result")
second_tool = get_function_tool("second_tool", return_value="second tool result")
agent = Agent(
name="test_agent",
model=fake_model,
tools=[first_tool, second_tool],
model_settings=ModelSettings(tool_choice="required"),
tool_use_behavior={"stop_at_tool_names": ["second_tool"]},
)
# Run should stop after using second_tool
result = await Runner.run(agent, "run test")
assert result.final_output == "second tool result"
@pytest.mark.asyncio
async def test_specific_tool_choice(self):
"""
Test scenario 3: When using a specific tool choice name
Ensure it doesn't cause infinite loops
"""
# Set up fake model to return a text message
fake_model = FakeModel()
fake_model.set_next_output([get_text_message("Test message")])
# Create agent with specific tool_choice
tool1 = get_function_tool("tool1")
tool2 = get_function_tool("tool2")
tool3 = get_function_tool("tool3")
agent = Agent(
name="test_agent",
model=fake_model,
tools=[tool1, tool2, tool3],
model_settings=ModelSettings(tool_choice="tool1"), # Specific tool
)
# Run should complete without infinite loops
result = await Runner.run(agent, "first run")
assert result.final_output == "Test message"
@pytest.mark.asyncio
async def test_required_with_single_tool(self):
"""
Test scenario 4: When using required tool_choice with only one tool
Ensure it doesn't cause infinite loops
"""
# Set up fake model to return a tool call followed by a text message
fake_model = FakeModel()
fake_model.add_multiple_turn_outputs([
# First call returns a tool call
[get_function_tool_call("custom_tool", "{}")],
# Second call returns a text message
[get_text_message("Final response")]
])
# Create agent with a single tool and tool_choice="required"
custom_tool = get_function_tool("custom_tool", return_value="tool result")
agent = Agent(
name="test_agent",
model=fake_model,
tools=[custom_tool],
model_settings=ModelSettings(tool_choice="required"),
)
# Run should complete without infinite loops
result = await Runner.run(agent, "first run")
assert result.final_output == "Final response"

View file

@ -0,0 +1,194 @@
# Copyright
from __future__ import annotations
from typing import cast
import pytest
from openai.types.responses.response_input_item_param import FunctionCallOutput
from agents import (
Agent,
FunctionToolResult,
RunConfig,
RunContextWrapper,
ToolCallOutputItem,
ToolsToFinalOutputResult,
UserError,
)
from agents._run_impl import RunImpl
from .test_responses import get_function_tool
def _make_function_tool_result(
agent: Agent, output: str, tool_name: str | None = None
) -> FunctionToolResult:
# Construct a FunctionToolResult with the given output using a simple function tool.
tool = get_function_tool(tool_name or "dummy", return_value=output)
raw_item: FunctionCallOutput = cast(
FunctionCallOutput,
{
"call_id": "1",
"output": output,
"type": "function_call_output",
},
)
# For this test we don't care about the specific RunItem subclass, only the output field
run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output)
return FunctionToolResult(tool=tool, output=output, run_item=run_item)
@pytest.mark.asyncio
async def test_no_tool_results_returns_not_final_output() -> None:
# If there are no tool results at all, tool_use_behavior should not produce a final output.
agent = Agent(name="test")
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=[],
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is False
assert result.final_output is None
@pytest.mark.asyncio
async def test_run_llm_again_behavior() -> None:
# With the default run_llm_again behavior, even with tools we still expect to keep running.
agent = Agent(name="test", tool_use_behavior="run_llm_again")
tool_results = [_make_function_tool_result(agent, "ignored")]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is False
assert result.final_output is None
@pytest.mark.asyncio
async def test_stop_on_first_tool_behavior() -> None:
# When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final.
agent = Agent(name="test", tool_use_behavior="stop_on_first_tool")
tool_results = [
_make_function_tool_result(agent, "first_tool_output"),
_make_function_tool_result(agent, "ignored"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True
assert result.final_output == "first_tool_output"
@pytest.mark.asyncio
async def test_custom_tool_use_behavior_sync() -> None:
"""If tool_use_behavior is a sync function, we should call it and propagate its return."""
def behavior(
context: RunContextWrapper, results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
assert len(results) == 3
return ToolsToFinalOutputResult(is_final_output=True, final_output="custom")
agent = Agent(name="test", tool_use_behavior=behavior)
tool_results = [
_make_function_tool_result(agent, "ignored1"),
_make_function_tool_result(agent, "ignored2"),
_make_function_tool_result(agent, "ignored3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True
assert result.final_output == "custom"
@pytest.mark.asyncio
async def test_custom_tool_use_behavior_async() -> None:
"""If tool_use_behavior is an async function, we should await it and propagate its return."""
async def behavior(
context: RunContextWrapper, results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
assert len(results) == 3
return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom")
agent = Agent(name="test", tool_use_behavior=behavior)
tool_results = [
_make_function_tool_result(agent, "ignored1"),
_make_function_tool_result(agent, "ignored2"),
_make_function_tool_result(agent, "ignored3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True
assert result.final_output == "async_custom"
@pytest.mark.asyncio
async def test_invalid_tool_use_behavior_raises() -> None:
"""If tool_use_behavior is invalid, we should raise a UserError."""
agent = Agent(name="test")
# Force an invalid value; mypy will complain, so ignore the type here.
agent.tool_use_behavior = "bad_value" # type: ignore[assignment]
tool_results = [_make_function_tool_result(agent, "ignored")]
with pytest.raises(UserError):
await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
@pytest.mark.asyncio
async def test_tool_names_to_stop_at_behavior() -> None:
agent = Agent(
name="test",
tools=[
get_function_tool("tool1", return_value="tool1_output"),
get_function_tool("tool2", return_value="tool2_output"),
get_function_tool("tool3", return_value="tool3_output"),
],
tool_use_behavior={"stop_at_tool_names": ["tool1"]},
)
tool_results = [
_make_function_tool_result(agent, "ignored1", "tool2"),
_make_function_tool_result(agent, "ignored3", "tool3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is False, "We should not have stopped at tool1"
# Now test with a tool that matches the list
tool_results = [
_make_function_tool_result(agent, "output1", "tool1"),
_make_function_tool_result(agent, "ignored2", "tool2"),
_make_function_tool_result(agent, "ignored3", "tool3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True, "We should have stopped at tool1"
assert result.final_output == "output1"

View file

@ -4,6 +4,7 @@ import asyncio
from typing import Any
import pytest
from inline_snapshot import snapshot
from agents.tracing import (
Span,
@ -17,7 +18,12 @@ from agents.tracing import (
)
from agents.tracing.spans import SpanError
from .testing_processor import fetch_events, fetch_ordered_spans, fetch_traces
from .testing_processor import (
SPAN_PROCESSOR_TESTING,
assert_no_traces,
fetch_events,
fetch_normalized_spans,
)
### HELPERS
@ -47,7 +53,7 @@ def simple_tracing():
x = trace("test")
x.start()
span_1 = agent_span(name="agent_1", parent=x)
span_1 = agent_span(name="agent_1", span_id="span_1", parent=x)
span_1.start()
span_1.finish()
@ -66,33 +72,36 @@ def simple_tracing():
def test_simple_tracing() -> None:
simple_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 3
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
assert first_span.span_data.name == "agent_1"
second_span = spans[1]
standard_span_checks(second_span, trace_id=trace_id, parent_id=None, span_type="custom")
assert second_span.span_id == "span_2"
assert second_span.span_data.name == "custom_1"
third_span = spans[2]
standard_span_checks(
third_span, trace_id=trace_id, parent_id=second_span.span_id, span_type="custom"
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
[
{
"workflow_name": "test",
"children": [
{
"type": "agent",
"id": "span_1",
"data": {"name": "agent_1"},
},
{
"type": "custom",
"id": "span_2",
"data": {"name": "custom_1", "data": {}},
"children": [
{
"type": "custom",
"id": "span_3",
"data": {"name": "custom_2", "data": {}},
}
],
},
],
}
]
)
assert third_span.span_id == "span_3"
assert third_span.span_data.name == "custom_2"
def ctxmanager_spans():
with trace(workflow_name="test", trace_id="123", group_id="456"):
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
with custom_span(name="custom_1", span_id="span_1"):
with custom_span(name="custom_2", span_id="span_1_inner"):
pass
@ -104,36 +113,38 @@ def ctxmanager_spans():
def test_ctxmanager_spans() -> None:
ctxmanager_spans()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 3
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="custom")
assert first_span.span_id == "span_1"
first_inner_span = spans[1]
standard_span_checks(
first_inner_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="custom"
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
[
{
"workflow_name": "test",
"group_id": "456",
"children": [
{
"type": "custom",
"id": "span_1",
"data": {"name": "custom_1", "data": {}},
"children": [
{
"type": "custom",
"id": "span_1_inner",
"data": {"name": "custom_2", "data": {}},
}
],
},
{"type": "custom", "id": "span_2", "data": {"name": "custom_2", "data": {}}},
],
}
]
)
assert first_inner_span.span_id == "span_1_inner"
second_span = spans[2]
standard_span_checks(second_span, trace_id=trace_id, parent_id=None, span_type="custom")
assert second_span.span_id == "span_2"
async def run_subtask(span_id: str | None = None) -> None:
with generation_span(span_id=span_id):
await asyncio.sleep(0.01)
await asyncio.sleep(0.0001)
async def simple_async_tracing():
with trace(workflow_name="test", trace_id="123", group_id="456"):
with trace(workflow_name="test", trace_id="trace_123", group_id="group_456"):
await run_subtask(span_id="span_1")
await run_subtask(span_id="span_2")
@ -142,21 +153,18 @@ async def simple_async_tracing():
async def test_async_tracing() -> None:
await simple_async_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 2
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
# We don't care about ordering here, just that they're there
for s in spans:
standard_span_checks(s, trace_id=trace_id, parent_id=None, span_type="generation")
ids = [span.span_id for span in spans]
assert "span_1" in ids
assert "span_2" in ids
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
[
{
"workflow_name": "test",
"group_id": "group_456",
"children": [
{"type": "generation", "id": "span_1"},
{"type": "generation", "id": "span_2"},
],
}
]
)
async def run_tasks_parallel(span_ids: list[str]) -> None:
@ -171,13 +179,11 @@ async def run_tasks_as_children(first_span_id: str, second_span_id: str) -> None
async def complex_async_tracing():
with trace(workflow_name="test", trace_id="123", group_id="456"):
await asyncio.sleep(0.01)
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
await asyncio.gather(
run_tasks_parallel(["span_1", "span_2"]),
run_tasks_parallel(["span_3", "span_4"]),
)
await asyncio.sleep(0.01)
await asyncio.gather(
run_tasks_as_children("span_5", "span_6"),
run_tasks_as_children("span_7", "span_8"),
@ -186,39 +192,38 @@ async def complex_async_tracing():
@pytest.mark.asyncio
async def test_complex_async_tracing() -> None:
await complex_async_tracing()
for _ in range(300):
SPAN_PROCESSOR_TESTING.clear()
await complex_async_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 8
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
# First ensure 1,2,3,4 exist and are in parallel with the trace as parent
for span_id in ["span_1", "span_2", "span_3", "span_4"]:
span = next((s for s in spans if s.span_id == span_id), None)
assert span is not None
standard_span_checks(span, trace_id=trace_id, parent_id=None, span_type="generation")
# Ensure 5 and 7 exist and have the trace as parent
for span_id in ["span_5", "span_7"]:
span = next((s for s in spans if s.span_id == span_id), None)
assert span is not None
standard_span_checks(span, trace_id=trace_id, parent_id=None, span_type="generation")
# Ensure 6 and 8 exist and have 5 and 7 as parents
six = next((s for s in spans if s.span_id == "span_6"), None)
assert six is not None
standard_span_checks(six, trace_id=trace_id, parent_id="span_5", span_type="generation")
eight = next((s for s in spans if s.span_id == "span_8"), None)
assert eight is not None
standard_span_checks(eight, trace_id=trace_id, parent_id="span_7", span_type="generation")
assert fetch_normalized_spans(keep_span_id=True) == (
[
{
"workflow_name": "test",
"group_id": "456",
"children": [
{"type": "generation", "id": "span_1"},
{"type": "generation", "id": "span_2"},
{"type": "generation", "id": "span_3"},
{"type": "generation", "id": "span_4"},
{
"type": "generation",
"id": "span_5",
"children": [{"type": "generation", "id": "span_6"}],
},
{
"type": "generation",
"id": "span_7",
"children": [{"type": "generation", "id": "span_8"}],
},
],
}
]
)
def spans_with_setters():
with trace(workflow_name="test", trace_id="123", group_id="456"):
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
with agent_span(name="agent_1") as span_a:
span_a.span_data.name = "agent_2"
@ -236,34 +241,33 @@ def spans_with_setters():
def test_spans_with_setters() -> None:
spans_with_setters()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 4
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
# Check the spans
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
assert first_span.span_data.name == "agent_2"
second_span = spans[1]
standard_span_checks(
second_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="function"
)
assert second_span.span_data.input == "i"
assert second_span.span_data.output == "o"
third_span = spans[2]
standard_span_checks(
third_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="generation"
)
fourth_span = spans[3]
standard_span_checks(
fourth_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="handoff"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"group_id": "456",
"children": [
{
"type": "agent",
"data": {"name": "agent_2"},
"children": [
{
"type": "function",
"data": {"name": "function_1", "input": "i", "output": "o"},
},
{
"type": "generation",
"data": {"input": [{"foo": "bar"}]},
},
{
"type": "handoff",
"data": {"from_agent": "agent_1", "to_agent": "agent_2"},
},
],
}
],
}
]
)
@ -276,14 +280,11 @@ def disabled_tracing():
def test_disabled_tracing():
disabled_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 0
assert len(traces) == 0
assert_no_traces()
def enabled_trace_disabled_span():
with trace(workflow_name="test", trace_id="123"):
with trace(workflow_name="test", trace_id="trace_123"):
with agent_span(name="agent_1"):
with function_span(name="function_1", disabled=True):
with generation_span():
@ -293,17 +294,19 @@ def enabled_trace_disabled_span():
def test_enabled_trace_disabled_span():
enabled_trace_disabled_span()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 1 # Only the agent span is recorded
assert len(traces) == 1 # The trace is recorded
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
assert first_span.span_data.name == "agent_1"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [
{
"type": "agent",
"data": {"name": "agent_1"},
}
],
}
]
)
def test_start_and_end_called_manual():
@ -367,9 +370,7 @@ async def test_noop_span_doesnt_record():
with custom_span(name="span_1") as span:
span.set_error(SpanError(message="test", data={}))
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 0
assert len(traces) == 0
assert_no_traces()
assert t.export() is None
assert span.export() is None

View file

@ -18,7 +18,6 @@ from agents import (
Runner,
TResponseInputItem,
)
from agents.tracing import AgentSpanData, FunctionSpanData, GenerationSpanData
from .fake_model import FakeModel
from .test_responses import (
@ -28,7 +27,7 @@ from .test_responses import (
get_handoff_tool_call,
get_text_message,
)
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans
@pytest.mark.asyncio
@ -43,9 +42,6 @@ async def test_single_turn_model_error():
with pytest.raises(ValueError):
await Runner.run(agent, input="first_test")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -74,13 +70,6 @@ async def test_single_turn_model_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
generation_span = spans[1]
assert isinstance(generation_span.span_data, GenerationSpanData)
assert generation_span.error, "should have error"
@pytest.mark.asyncio
async def test_multi_turn_no_handoffs():
@ -106,9 +95,6 @@ async def test_multi_turn_no_handoffs():
with pytest.raises(ValueError):
await Runner.run(agent, input="first_test")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -146,15 +132,6 @@ async def test_multi_turn_no_handoffs():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
last_generation_span = [x for x in spans if isinstance(x.span_data, GenerationSpanData)][-1]
assert last_generation_span.error, "should have error"
@pytest.mark.asyncio
async def test_tool_call_error():
@ -173,9 +150,6 @@ async def test_tool_call_error():
with pytest.raises(ModelBehaviorError):
await Runner.run(agent, input="first_test")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -209,15 +183,6 @@ async def test_tool_call_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
function_span = [x for x in spans if isinstance(x.span_data, FunctionSpanData)][0]
assert function_span.error, "should have error"
@pytest.mark.asyncio
async def test_multiple_handoff_doesnt_error():
@ -255,9 +220,6 @@ async def test_multiple_handoff_doesnt_error():
result = await Runner.run(agent_3, input="user_message")
assert result.last_agent == agent_1, "should have picked first handoff"
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -295,12 +257,6 @@ async def test_multiple_handoff_doesnt_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
class Foo(TypedDict):
bar: str
@ -326,9 +282,6 @@ async def test_multiple_final_output_doesnt_error():
result = await Runner.run(agent_1, input="user_message")
assert result.final_output == Foo(bar="abc")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -344,12 +297,6 @@ async def test_multiple_final_output_doesnt_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
@pytest.mark.asyncio
async def test_handoffs_lead_to_correct_agent_spans():
@ -399,9 +346,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
f"should have ended on the third agent, got {result.last_agent.name}"
)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -472,12 +416,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
@pytest.mark.asyncio
async def test_max_turns_exceeded():
@ -503,9 +441,6 @@ async def test_max_turns_exceeded():
with pytest.raises(MaxTurnsExceeded):
await Runner.run(agent, input="user_message", max_turns=2)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -538,15 +473,6 @@ async def test_max_turns_exceeded():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 5, (
f"should have 1 agent span, 2 generations, 2 function calls, got "
f"{len(spans)} with data: {[x.span_data for x in spans]}"
)
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"
def guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
@ -568,9 +494,6 @@ async def test_guardrail_error():
with pytest.raises(InputGuardrailTripwireTriggered):
await Runner.run(agent, input="user_message")
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -594,12 +517,3 @@ async def test_guardrail_error():
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"

View file

@ -10,9 +10,6 @@ from typing_extensions import TypedDict
from agents import (
Agent,
AgentSpanData,
FunctionSpanData,
GenerationSpanData,
GuardrailFunctionOutput,
InputGuardrail,
InputGuardrailTripwireTriggered,
@ -33,7 +30,7 @@ from .test_responses import (
get_handoff_tool_call,
get_text_message,
)
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans
@pytest.mark.asyncio
@ -50,9 +47,6 @@ async def test_single_turn_model_error():
async for _ in result.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -82,13 +76,6 @@ async def test_single_turn_model_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
generation_span = spans[1]
assert isinstance(generation_span.span_data, GenerationSpanData)
assert generation_span.error, "should have error"
@pytest.mark.asyncio
async def test_multi_turn_no_handoffs():
@ -116,9 +103,6 @@ async def test_multi_turn_no_handoffs():
async for _ in result.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -157,15 +141,6 @@ async def test_multi_turn_no_handoffs():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
last_generation_span = [x for x in spans if isinstance(x.span_data, GenerationSpanData)][-1]
assert last_generation_span.error, "should have error"
@pytest.mark.asyncio
async def test_tool_call_error():
@ -186,9 +161,6 @@ async def test_tool_call_error():
async for _ in result.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -226,15 +198,6 @@ async def test_tool_call_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
function_span = [x for x in spans if isinstance(x.span_data, FunctionSpanData)][0]
assert function_span.error, "should have error"
@pytest.mark.asyncio
async def test_multiple_handoff_doesnt_error():
@ -275,9 +238,6 @@ async def test_multiple_handoff_doesnt_error():
assert result.last_agent == agent_1, "should have picked first handoff"
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -315,12 +275,6 @@ async def test_multiple_handoff_doesnt_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
class Foo(TypedDict):
bar: str
@ -350,9 +304,6 @@ async def test_multiple_final_output_no_error():
assert isinstance(result.final_output, dict)
assert result.final_output["bar"] == "abc"
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -368,12 +319,6 @@ async def test_multiple_final_output_no_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
@pytest.mark.asyncio
async def test_handoffs_lead_to_correct_agent_spans():
@ -425,85 +370,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
f"should have ended on the third agent, got {result.last_agent.name}"
)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
assert fetch_normalized_spans() == snapshot(
[
{
@ -601,9 +467,6 @@ async def test_max_turns_exceeded():
async for _ in result.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -636,15 +499,6 @@ async def test_max_turns_exceeded():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 5, (
f"should have 1 agent, 2 generations, 2 function calls, got "
f"{len(spans)} with data: {[x.span_data for x in spans]}"
)
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"
def input_guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
@ -673,9 +527,6 @@ async def test_input_guardrail_error():
await asyncio.sleep(1)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -703,15 +554,6 @@ async def test_input_guardrail_error():
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"
def output_guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any
@ -740,9 +582,6 @@ async def test_output_guardrail_error():
await asyncio.sleep(1)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
@ -766,12 +605,3 @@ async def test_output_guardrail_error():
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
assert agent_span.error, "last agent should have error"

View file

@ -80,26 +80,44 @@ def fetch_events() -> list[TestSpanProcessorEvent]:
return SPAN_PROCESSOR_TESTING._events
def fetch_normalized_spans():
def assert_no_spans():
spans = fetch_ordered_spans()
if spans:
raise AssertionError(f"Expected 0 spans, got {len(spans)}")
def assert_no_traces():
traces = fetch_traces()
if traces:
raise AssertionError(f"Expected 0 traces, got {len(traces)}")
assert_no_spans()
def fetch_normalized_spans(
keep_span_id: bool = False, keep_trace_id: bool = False
) -> list[dict[str, Any]]:
nodes: dict[tuple[str, str | None], dict[str, Any]] = {}
traces = []
for trace_obj in fetch_traces():
trace = trace_obj.export()
assert trace
assert trace.pop("object") == "trace"
assert trace.pop("id").startswith("trace_")
assert trace["id"].startswith("trace_")
if not keep_trace_id:
del trace["id"]
trace = {k: v for k, v in trace.items() if v is not None}
nodes[(trace_obj.trace_id, None)] = trace
traces.append(trace)
if not traces:
assert not fetch_ordered_spans()
assert traces, "Use assert_no_traces() to check for empty traces"
for span_obj in fetch_ordered_spans():
span = span_obj.export()
assert span
assert span.pop("object") == "trace.span"
assert span.pop("id").startswith("span_")
assert span["id"].startswith("span_")
if not keep_span_id:
del span["id"]
assert datetime.fromisoformat(span.pop("started_at"))
assert datetime.fromisoformat(span.pop("ended_at"))
parent_id = span.pop("parent_id")

View file

@ -0,0 +1,27 @@
import pytest
from agents.tracing.processors import BackendSpanExporter
@pytest.mark.asyncio
async def test_processor_api_key(monkeypatch):
# If the API key is not set, it should be None
monkeypatch.delenv("OPENAI_API_KEY", None)
processor = BackendSpanExporter()
assert processor.api_key is None
# If we set it afterwards, it should be the new value
processor.set_api_key("test_api_key")
assert processor.api_key == "test_api_key"
@pytest.mark.asyncio
async def test_processor_api_key_from_env(monkeypatch):
# If the API key is not set at creation time but set before access time, it should be the new
# value
monkeypatch.delenv("OPENAI_API_KEY", None)
processor = BackendSpanExporter()
# If we set it afterwards, it should be the new value
monkeypatch.setenv("OPENAI_API_KEY", "foo_bar_123")
assert processor.api_key == "foo_bar_123"

0
tests/voice/__init__.py Normal file
View file

14
tests/voice/conftest.py Normal file
View file

@ -0,0 +1,14 @@
import os
import sys
import pytest
def pytest_collection_modifyitems(config, items):
if sys.version_info[:2] == (3, 9):
this_dir = os.path.dirname(__file__)
skip_marker = pytest.mark.skip(reason="Skipped on Python 3.9")
for item in items:
if item.fspath.dirname.startswith(this_dir):
item.add_marker(skip_marker)

115
tests/voice/fake_models.py Normal file
View file

@ -0,0 +1,115 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Literal
import numpy as np
import numpy.typing as npt
try:
from agents.voice import (
AudioInput,
StreamedAudioInput,
StreamedTranscriptionSession,
STTModel,
STTModelSettings,
TTSModel,
TTSModelSettings,
VoiceWorkflowBase,
)
except ImportError:
pass
class FakeTTS(TTSModel):
"""Fakes TTS by just returning string bytes."""
def __init__(self, strategy: Literal["default", "split_words"] = "default"):
self.strategy = strategy
@property
def model_name(self) -> str:
return "fake_tts"
async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
if self.strategy == "default":
yield np.zeros(2, dtype=np.int16).tobytes()
elif self.strategy == "split_words":
for _ in text.split():
yield np.zeros(2, dtype=np.int16).tobytes()
async def verify_audio(self, text: str, audio: bytes, dtype: npt.DTypeLike = np.int16) -> None:
assert audio == np.zeros(2, dtype=dtype).tobytes()
async def verify_audio_chunks(
self, text: str, audio_chunks: list[bytes], dtype: npt.DTypeLike = np.int16
) -> None:
assert audio_chunks == [np.zeros(2, dtype=dtype).tobytes() for _word in text.split()]
class FakeSession(StreamedTranscriptionSession):
"""A fake streamed transcription session that yields preconfigured transcripts."""
def __init__(self):
self.outputs: list[str] = []
async def transcribe_turns(self) -> AsyncIterator[str]:
for t in self.outputs:
yield t
async def close(self) -> None:
return None
class FakeSTT(STTModel):
"""A fake STT model that either returns a single transcript or yields multiple."""
def __init__(self, outputs: list[str] | None = None):
self.outputs = outputs or []
@property
def model_name(self) -> str:
return "fake_stt"
async def transcribe(self, _: AudioInput, __: STTModelSettings, ___: bool, ____: bool) -> str:
return self.outputs.pop(0)
async def create_session(
self,
_: StreamedAudioInput,
__: STTModelSettings,
___: bool,
____: bool,
) -> StreamedTranscriptionSession:
session = FakeSession()
session.outputs = self.outputs
return session
class FakeWorkflow(VoiceWorkflowBase):
"""A fake workflow that yields preconfigured outputs."""
def __init__(self, outputs: list[list[str]] | None = None):
self.outputs = outputs or []
def add_output(self, output: list[str]) -> None:
self.outputs.append(output)
def add_multiple_outputs(self, outputs: list[list[str]]) -> None:
self.outputs.extend(outputs)
async def run(self, _: str) -> AsyncIterator[str]:
if not self.outputs:
raise ValueError("No output configured")
output = self.outputs.pop(0)
for t in output:
yield t
class FakeStreamedAudioInput:
@classmethod
async def get(cls, count: int) -> StreamedAudioInput:
input = StreamedAudioInput()
for _ in range(count):
await input.add_audio(np.zeros(2, dtype=np.int16))
return input

21
tests/voice/helpers.py Normal file
View file

@ -0,0 +1,21 @@
try:
from agents.voice import StreamedAudioResult
except ImportError:
pass
async def extract_events(result: StreamedAudioResult) -> tuple[list[str], list[bytes]]:
"""Collapse pipeline stream events to simple labels for ordering assertions."""
flattened: list[str] = []
audio_chunks: list[bytes] = []
async for ev in result.stream():
if ev.type == "voice_stream_event_audio":
if ev.data is not None:
audio_chunks.append(ev.data.tobytes())
flattened.append("audio")
elif ev.type == "voice_stream_event_lifecycle":
flattened.append(ev.event)
elif ev.type == "voice_stream_event_error":
flattened.append("error")
return flattened, audio_chunks

127
tests/voice/test_input.py Normal file
View file

@ -0,0 +1,127 @@
import io
import wave
import numpy as np
import pytest
try:
from agents import UserError
from agents.voice import AudioInput, StreamedAudioInput
from agents.voice.input import DEFAULT_SAMPLE_RATE, _buffer_to_audio_file
except ImportError:
pass
def test_buffer_to_audio_file_int16():
# Create a simple sine wave in int16 format
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
buffer = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16)
filename, audio_file, content_type = _buffer_to_audio_file(buffer)
assert filename == "audio.wav"
assert content_type == "audio/wav"
assert isinstance(audio_file, io.BytesIO)
# Verify the WAV file contents
with wave.open(audio_file, "rb") as wav_file:
assert wav_file.getnchannels() == 1
assert wav_file.getsampwidth() == 2
assert wav_file.getframerate() == DEFAULT_SAMPLE_RATE
assert wav_file.getnframes() == len(buffer)
def test_buffer_to_audio_file_float32():
# Create a simple sine wave in float32 format
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
filename, audio_file, content_type = _buffer_to_audio_file(buffer)
assert filename == "audio.wav"
assert content_type == "audio/wav"
assert isinstance(audio_file, io.BytesIO)
# Verify the WAV file contents
with wave.open(audio_file, "rb") as wav_file:
assert wav_file.getnchannels() == 1
assert wav_file.getsampwidth() == 2
assert wav_file.getframerate() == DEFAULT_SAMPLE_RATE
assert wav_file.getnframes() == len(buffer)
def test_buffer_to_audio_file_invalid_dtype():
# Create a buffer with invalid dtype (float64)
buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64)
with pytest.raises(UserError, match="Buffer must be a numpy array of int16 or float32"):
# Purposely ignore the type error
_buffer_to_audio_file(buffer) # type: ignore
class TestAudioInput:
def test_audio_input_default_params(self):
# Create a simple sine wave
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
audio_input = AudioInput(buffer=buffer)
assert audio_input.frame_rate == DEFAULT_SAMPLE_RATE
assert audio_input.sample_width == 2
assert audio_input.channels == 1
assert np.array_equal(audio_input.buffer, buffer)
def test_audio_input_custom_params(self):
# Create a simple sine wave
t = np.linspace(0, 1, 48000)
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
audio_input = AudioInput(buffer=buffer, frame_rate=48000, sample_width=4, channels=2)
assert audio_input.frame_rate == 48000
assert audio_input.sample_width == 4
assert audio_input.channels == 2
assert np.array_equal(audio_input.buffer, buffer)
def test_audio_input_to_audio_file(self):
# Create a simple sine wave
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
audio_input = AudioInput(buffer=buffer)
filename, audio_file, content_type = audio_input.to_audio_file()
assert filename == "audio.wav"
assert content_type == "audio/wav"
assert isinstance(audio_file, io.BytesIO)
# Verify the WAV file contents
with wave.open(audio_file, "rb") as wav_file:
assert wav_file.getnchannels() == 1
assert wav_file.getsampwidth() == 2
assert wav_file.getframerate() == DEFAULT_SAMPLE_RATE
assert wav_file.getnframes() == len(buffer)
class TestStreamedAudioInput:
@pytest.mark.asyncio
async def test_streamed_audio_input(self):
streamed_input = StreamedAudioInput()
# Create some test audio data
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
audio1 = np.sin(2 * np.pi * 440 * t).astype(np.float32)
audio2 = np.sin(2 * np.pi * 880 * t).astype(np.float32)
# Add audio to the queue
await streamed_input.add_audio(audio1)
await streamed_input.add_audio(audio2)
# Verify the queue contents
assert streamed_input.queue.qsize() == 2
# Test non-blocking get
assert np.array_equal(streamed_input.queue.get_nowait(), audio1)
# Test blocking get
assert np.array_equal(await streamed_input.queue.get(), audio2)
assert streamed_input.queue.empty()

Some files were not shown because too many files have changed in this diff Show more