Merge branch 'main' of https://github.com/openai/openai-agents-python into feat/draw_graph
This commit is contained in:
commit
900a97fa55
105 changed files with 6252 additions and 764 deletions
9
.github/workflows/issues.yml
vendored
9
.github/workflows/issues.yml
vendored
|
|
@ -17,7 +17,10 @@ jobs:
|
|||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 7 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 3 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
any-of-labels: 'question,needs-more-info'
|
||||
any-of-issue-labels: 'question,needs-more-info'
|
||||
days-before-pr-stale: 10
|
||||
days-before-pr-close: 7
|
||||
stale-pr-label: "stale"
|
||||
stale-pr-message: "This PR is stale because it has been open for 10 days with no activity."
|
||||
close-pr-message: "This PR was closed because it has been inactive for 7 days since being marked as stale."
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
|
|
@ -8,6 +8,9 @@ on:
|
|||
branches:
|
||||
- main
|
||||
|
||||
env:
|
||||
UV_FROZEN: "1"
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -50,8 +53,8 @@ jobs:
|
|||
enable-cache: true
|
||||
- name: Install dependencies
|
||||
run: make sync
|
||||
- name: Run tests
|
||||
run: make tests
|
||||
- name: Run tests with coverage
|
||||
run: make coverage
|
||||
|
||||
build-docs:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
|||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -135,7 +135,7 @@ dmypy.json
|
|||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
|
|
|||
11
Makefile
11
Makefile
|
|
@ -5,6 +5,7 @@ sync:
|
|||
.PHONY: format
|
||||
format:
|
||||
uv run ruff format
|
||||
uv run ruff check --fix
|
||||
|
||||
.PHONY: lint
|
||||
lint:
|
||||
|
|
@ -18,6 +19,13 @@ mypy:
|
|||
tests:
|
||||
uv run pytest
|
||||
|
||||
.PHONY: coverage
|
||||
coverage:
|
||||
|
||||
uv run coverage run -m pytest
|
||||
uv run coverage xml -o coverage.xml
|
||||
uv run coverage report -m --fail-under=95
|
||||
|
||||
.PHONY: snapshots-fix
|
||||
snapshots-fix:
|
||||
uv run pytest --inline-snapshot=fix
|
||||
|
|
@ -29,7 +37,6 @@ snapshots-create:
|
|||
.PHONY: old_version_tests
|
||||
old_version_tests:
|
||||
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest
|
||||
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m mypy .
|
||||
|
||||
.PHONY: build-docs
|
||||
build-docs:
|
||||
|
|
@ -43,3 +50,5 @@ serve-docs:
|
|||
deploy-docs:
|
||||
uv run mkdocs gh-deploy --force --verbose
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ source env/bin/activate
|
|||
pip install openai-agents
|
||||
```
|
||||
|
||||
For voice support, install with the optional `voice` group: `pip install 'openai-agents[voice]'`.
|
||||
|
||||
## Hello world example
|
||||
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -130,3 +130,23 @@ robot_agent = pirate_agent.clone(
|
|||
instructions="Write like a robot",
|
||||
)
|
||||
```
|
||||
|
||||
## Forcing tool use
|
||||
|
||||
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
|
||||
|
||||
1. `auto`, which allows the LLM to decide whether or not to use a tool.
|
||||
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
|
||||
3. `none`, which requires the LLM to _not_ use a tool.
|
||||
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
|
||||
|
||||
!!! note
|
||||
|
||||
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:
|
||||
|
||||
1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
|
||||
2. When `tool_choice` is set to "required" AND there is only one tool available
|
||||
|
||||
This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.
|
||||
|
||||
If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.
|
||||
|
|
|
|||
|
|
@ -41,14 +41,14 @@ async def fetch_user_age(wrapper: RunContextWrapper[UserInfo]) -> str: # (2)!
|
|||
return f"User {wrapper.context.name} is 47 years old"
|
||||
|
||||
async def main():
|
||||
user_info = UserInfo(name="John", uid=123) # (3)!
|
||||
user_info = UserInfo(name="John", uid=123)
|
||||
|
||||
agent = Agent[UserInfo]( # (4)!
|
||||
agent = Agent[UserInfo]( # (3)!
|
||||
name="Assistant",
|
||||
tools=[fetch_user_age],
|
||||
)
|
||||
|
||||
result = await Runner.run(
|
||||
result = await Runner.run( # (4)!
|
||||
starting_agent=agent,
|
||||
input="What is the age of the user?",
|
||||
context=user_info,
|
||||
|
|
|
|||
36
docs/examples.md
Normal file
36
docs/examples.md
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Examples
|
||||
|
||||
Check out a variety of sample implementations of the SDK in the examples section of the [repo](https://github.com/openai/openai-agents-python/tree/main/examples). The examples are organized into several categories that demonstrate different patterns and capabilities.
|
||||
|
||||
|
||||
## Categories
|
||||
|
||||
- **agent_patterns:**
|
||||
Examples in this category illustrate common agent design patterns, such as
|
||||
|
||||
- Deterministic workflows
|
||||
- Agents as tools
|
||||
- Parallel agent execution
|
||||
|
||||
- **basic:**
|
||||
These examples showcase foundational capabilities of the SDK, such as
|
||||
|
||||
- Dynamic system prompts
|
||||
- Streaming outputs
|
||||
- Lifecycle events
|
||||
|
||||
- **tool examples:**
|
||||
Learn how to implement OAI hosted tools such as web search and file search,
|
||||
and integrate them into your agents.
|
||||
|
||||
- **model providers:**
|
||||
Explore how to use non-OpenAI models with the SDK.
|
||||
|
||||
- **handoffs:**
|
||||
See practical examples of agent handoffs.
|
||||
|
||||
- **customer_service** and **research_bot:**
|
||||
Two more built-out examples that illustrate real-world applications
|
||||
|
||||
- **customer_service**: Example customer service system for an airline.
|
||||
- **research_bot**: Simple deep research clone.
|
||||
|
|
@ -29,7 +29,7 @@ Output guardrails run in 3 steps:
|
|||
|
||||
!!! Note
|
||||
|
||||
Output guardrails are intended to run on the final agent input, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability.
|
||||
Output guardrails are intended to run on the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability.
|
||||
|
||||
## Tripwires
|
||||
|
||||
|
|
@ -111,8 +111,8 @@ class MessageOutput(BaseModel): # (1)!
|
|||
response: str
|
||||
|
||||
class MathOutput(BaseModel): # (2)!
|
||||
is_math: bool
|
||||
reasoning: str
|
||||
is_math: bool
|
||||
|
||||
guardrail_agent = Agent(
|
||||
name="Guardrail check",
|
||||
|
|
|
|||
3
docs/ref/voice/events.md
Normal file
3
docs/ref/voice/events.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Events`
|
||||
|
||||
::: agents.voice.events
|
||||
3
docs/ref/voice/exceptions.md
Normal file
3
docs/ref/voice/exceptions.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Exceptions`
|
||||
|
||||
::: agents.voice.exceptions
|
||||
3
docs/ref/voice/input.md
Normal file
3
docs/ref/voice/input.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Input`
|
||||
|
||||
::: agents.voice.input
|
||||
3
docs/ref/voice/model.md
Normal file
3
docs/ref/voice/model.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Model`
|
||||
|
||||
::: agents.voice.model
|
||||
3
docs/ref/voice/models/openai_provider.md
Normal file
3
docs/ref/voice/models/openai_provider.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `OpenAIVoiceModelProvider`
|
||||
|
||||
::: agents.voice.models.openai_model_provider
|
||||
3
docs/ref/voice/models/openai_stt.md
Normal file
3
docs/ref/voice/models/openai_stt.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `OpenAI STT`
|
||||
|
||||
::: agents.voice.models.openai_stt
|
||||
3
docs/ref/voice/models/openai_tts.md
Normal file
3
docs/ref/voice/models/openai_tts.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `OpenAI TTS`
|
||||
|
||||
::: agents.voice.models.openai_tts
|
||||
3
docs/ref/voice/pipeline.md
Normal file
3
docs/ref/voice/pipeline.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Pipeline`
|
||||
|
||||
::: agents.voice.pipeline
|
||||
3
docs/ref/voice/pipeline_config.md
Normal file
3
docs/ref/voice/pipeline_config.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Pipeline Config`
|
||||
|
||||
::: agents.voice.pipeline_config
|
||||
3
docs/ref/voice/result.md
Normal file
3
docs/ref/voice/result.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Result`
|
||||
|
||||
::: agents.voice.result
|
||||
3
docs/ref/voice/utils.md
Normal file
3
docs/ref/voice/utils.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Utils`
|
||||
|
||||
::: agents.voice.utils
|
||||
3
docs/ref/voice/workflow.md
Normal file
3
docs/ref/voice/workflow.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# `Workflow`
|
||||
|
||||
::: agents.voice.workflow
|
||||
|
|
@ -35,6 +35,9 @@ By default, the SDK traces the following:
|
|||
- Function tool calls are each wrapped in `function_span()`
|
||||
- Guardrails are wrapped in `guardrail_span()`
|
||||
- Handoffs are wrapped in `handoff_span()`
|
||||
- Audio inputs (speech-to-text) are wrapped in a `transcription_span()`
|
||||
- Audio outputs (text-to-speech) are wrapped in a `speech_span()`
|
||||
- Related audio spans may be parented under a `speech_group_span()`
|
||||
|
||||
By default, the trace is named "Agent trace". You can set this name if you use `trace`, or you can can configure the name and other properties with the [`RunConfig`][agents.run.RunConfig].
|
||||
|
||||
|
|
@ -76,7 +79,11 @@ Spans are automatically part of the current trace, and are nested under the near
|
|||
|
||||
## Sensitive data
|
||||
|
||||
Some spans track potentially sensitive data. For example, the `generation_span()` stores the inputs/outputs of the LLM generation, and `function_span()` stores the inputs/outputs of function calls. These may contain sensitive data, so you can disable capturing that data via [`RunConfig.trace_include_sensitive_data`][agents.run.RunConfig.trace_include_sensitive_data].
|
||||
Certain spans may capture potentially sensitive data.
|
||||
|
||||
The `generation_span()` stores the inputs/outputs of the LLM generation, and `function_span()` stores the inputs/outputs of function calls. These may contain sensitive data, so you can disable capturing that data via [`RunConfig.trace_include_sensitive_data`][agents.run.RunConfig.trace_include_sensitive_data].
|
||||
|
||||
Similarly, Audio spans include base64-encoded PCM data for input and output audio by default. You can disable capturing this audio data by configuring [`VoicePipelineConfig.trace_include_sensitive_audio_data`][agents.voice.pipeline_config.VoicePipelineConfig.trace_include_sensitive_audio_data].
|
||||
|
||||
## Custom tracing processors
|
||||
|
||||
|
|
@ -92,10 +99,15 @@ To customize this default setup, to send traces to alternative or additional bac
|
|||
|
||||
## External tracing processors list
|
||||
|
||||
- [Weights & Biases](https://weave-docs.wandb.ai/guides/integrations/openai_agents)
|
||||
- [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk)
|
||||
- [MLflow](https://mlflow.org/docs/latest/tracing/integrations/openai-agent)
|
||||
- [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk)
|
||||
- [Pydantic Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents)
|
||||
- [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk)
|
||||
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration)
|
||||
- [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent)
|
||||
- [LangSmith](https://docs.smith.langchain.com/observability/how_to_guides/trace_with_openai_agents_sdk)
|
||||
- [Maxim AI](https://www.getmaxim.ai/docs/observe/integrations/openai-agents-sdk)
|
||||
- [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents)
|
||||
- [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents)
|
||||
|
|
|
|||
75
docs/voice/pipeline.md
Normal file
75
docs/voice/pipeline.md
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# Pipelines and workflows
|
||||
|
||||
[`VoicePipeline`][agents.voice.pipeline.VoicePipeline] is a class that makes it easy to turn your agentic workflows into a voice app. You pass in a workflow to run, and the pipeline takes care of transcribing input audio, detecting when the audio ends, calling your workflow at the right time, and turning the workflow output back into audio.
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
%% Input
|
||||
A["🎤 Audio Input"]
|
||||
|
||||
%% Voice Pipeline
|
||||
subgraph Voice_Pipeline [Voice Pipeline]
|
||||
direction TB
|
||||
B["Transcribe (speech-to-text)"]
|
||||
C["Your Code"]:::highlight
|
||||
D["Text-to-speech"]
|
||||
B --> C --> D
|
||||
end
|
||||
|
||||
%% Output
|
||||
E["🎧 Audio Output"]
|
||||
|
||||
%% Flow
|
||||
A --> Voice_Pipeline
|
||||
Voice_Pipeline --> E
|
||||
|
||||
%% Custom styling
|
||||
classDef highlight fill:#ffcc66,stroke:#333,stroke-width:1px,font-weight:700;
|
||||
|
||||
```
|
||||
|
||||
## Configuring a pipeline
|
||||
|
||||
When you create a pipeline, you can set a few things:
|
||||
|
||||
1. The [`workflow`][agents.voice.workflow.VoiceWorkflowBase], which is the code that runs each time new audio is transcribed.
|
||||
2. The [`speech-to-text`][agents.voice.model.STTModel] and [`text-to-speech`][agents.voice.model.TTSModel] models used
|
||||
3. The [`config`][agents.voice.pipeline_config.VoicePipelineConfig], which lets you configure things like:
|
||||
- A model provider, which can map model names to models
|
||||
- Tracing, including whether to disable tracing, whether audio files are uploaded, the workflow name, trace IDs etc.
|
||||
- Settings on the TTS and STT models, like the prompt, language and data types used.
|
||||
|
||||
## Running a pipeline
|
||||
|
||||
You can run a pipeline via the [`run()`][agents.voice.pipeline.VoicePipeline.run] method, which lets you pass in audio input in two forms:
|
||||
|
||||
1. [`AudioInput`][agents.voice.input.AudioInput] is used when you have a full audio transcript, and just want to produce a result for it. This is useful in cases where you don't need to detect when a speaker is done speaking; for example, when you have pre-recorded audio or in push-to-talk apps where it's clear when the user is done speaking.
|
||||
2. [`StreamedAudioInput`][agents.voice.input.StreamedAudioInput] is used when you might need to detect when a user is done speaking. It allows you to push audio chunks as they are detected, and the voice pipeline will automatically run the agent workflow at the right time, via a process called "activity detection".
|
||||
|
||||
## Results
|
||||
|
||||
The result of a voice pipeline run is a [`StreamedAudioResult`][agents.voice.result.StreamedAudioResult]. This is an object that lets you stream events as they occur. There are a few kinds of [`VoiceStreamEvent`][agents.voice.events.VoiceStreamEvent], including:
|
||||
|
||||
1. [`VoiceStreamEventAudio`][agents.voice.events.VoiceStreamEventAudio], which contains a chunk of audio.
|
||||
2. [`VoiceStreamEventLifecycle`][agents.voice.events.VoiceStreamEventLifecycle], which informs you of lifecycle events like a turn starting or ending.
|
||||
3. [`VoiceStreamEventError`][agents.voice.events.VoiceStreamEventError], is an error event.
|
||||
|
||||
```python
|
||||
|
||||
result = await pipeline.run(input)
|
||||
|
||||
async for event in result.stream():
|
||||
if event.type == "voice_stream_event_audio":
|
||||
# play audio
|
||||
elif event.type == "voice_stream_event_lifecycle":
|
||||
# lifecycle
|
||||
elif event.type == "voice_stream_event_error"
|
||||
# error
|
||||
...
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
### Interruptions
|
||||
|
||||
The Agents SDK currently does not support any built-in interruptions support for [`StreamedAudioInput`][agents.voice.input.StreamedAudioInput]. Instead for every detected turn it will trigger a separate run of your workflow. If you want to handle interruptions inside your application you can listen to the [`VoiceStreamEventLifecycle`][agents.voice.events.VoiceStreamEventLifecycle] events. `turn_started` will indicate that a new turn was transcribed and processing is beginning. `turn_ended` will trigger after all the audio was dispatched for a respective turn. You could use these events to mute the microphone of the speaker when the model starts a turn and unmute it after you flushed all the related audio for a turn.
|
||||
194
docs/voice/quickstart.md
Normal file
194
docs/voice/quickstart.md
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
# Quickstart
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Make sure you've followed the base [quickstart instructions](../quickstart.md) for the Agents SDK, and set up a virtual environment. Then, install the optional voice dependencies from the SDK:
|
||||
|
||||
```bash
|
||||
pip install 'openai-agents[voice]'
|
||||
```
|
||||
|
||||
## Concepts
|
||||
|
||||
The main concept to know about is a [`VoicePipeline`][agents.voice.pipeline.VoicePipeline], which is a 3 step process:
|
||||
|
||||
1. Run a speech-to-text model to turn audio into text.
|
||||
2. Run your code, which is usually an agentic workflow, to produce a result.
|
||||
3. Run a text-to-speech model to turn the result text back into audio.
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
%% Input
|
||||
A["🎤 Audio Input"]
|
||||
|
||||
%% Voice Pipeline
|
||||
subgraph Voice_Pipeline [Voice Pipeline]
|
||||
direction TB
|
||||
B["Transcribe (speech-to-text)"]
|
||||
C["Your Code"]:::highlight
|
||||
D["Text-to-speech"]
|
||||
B --> C --> D
|
||||
end
|
||||
|
||||
%% Output
|
||||
E["🎧 Audio Output"]
|
||||
|
||||
%% Flow
|
||||
A --> Voice_Pipeline
|
||||
Voice_Pipeline --> E
|
||||
|
||||
%% Custom styling
|
||||
classDef highlight fill:#ffcc66,stroke:#333,stroke-width:1px,font-weight:700;
|
||||
|
||||
```
|
||||
|
||||
## Agents
|
||||
|
||||
First, let's set up some Agents. This should feel familiar to you if you've built any agents with this SDK. We'll have a couple of Agents, a handoff, and a tool.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
function_tool,
|
||||
)
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather for a given city."""
|
||||
print(f"[debug] get_weather called with city: {city}")
|
||||
choices = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
return f"The weather in {city} is {random.choice(choices)}."
|
||||
|
||||
|
||||
spanish_agent = Agent(
|
||||
name="Spanish",
|
||||
handoff_description="A spanish speaking agent.",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
handoffs=[spanish_agent],
|
||||
tools=[get_weather],
|
||||
)
|
||||
```
|
||||
|
||||
## Voice pipeline
|
||||
|
||||
We'll set up a simple voice pipeline, using [`SingleAgentVoiceWorkflow`][agents.voice.workflow.SingleAgentVoiceWorkflow] as the workflow.
|
||||
|
||||
```python
|
||||
from agents.voice import SingleAgentVoiceWorkflow, VoicePipeline
|
||||
pipeline = VoicePipeline(workflow=SingleAgentVoiceWorkflow(agent))
|
||||
```
|
||||
|
||||
## Run the pipeline
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
from agents.voice import AudioInput
|
||||
|
||||
# For simplicity, we'll just create 3 seconds of silence
|
||||
# In reality, you'd get microphone data
|
||||
buffer = np.zeros(24000 * 3, dtype=np.int16)
|
||||
audio_input = AudioInput(buffer=buffer)
|
||||
|
||||
result = await pipeline.run(audio_input)
|
||||
|
||||
# Create an audio player using `sounddevice`
|
||||
player = sd.OutputStream(samplerate=24000, channels=1, dtype=np.int16)
|
||||
player.start()
|
||||
|
||||
# Play the audio stream as it comes in
|
||||
async for event in result.stream():
|
||||
if event.type == "voice_stream_event_audio":
|
||||
player.write(event.data)
|
||||
|
||||
```
|
||||
|
||||
## Put it all together
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
function_tool,
|
||||
set_tracing_disabled,
|
||||
)
|
||||
from agents.voice import (
|
||||
AudioInput,
|
||||
SingleAgentVoiceWorkflow,
|
||||
VoicePipeline,
|
||||
)
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather for a given city."""
|
||||
print(f"[debug] get_weather called with city: {city}")
|
||||
choices = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
return f"The weather in {city} is {random.choice(choices)}."
|
||||
|
||||
|
||||
spanish_agent = Agent(
|
||||
name="Spanish",
|
||||
handoff_description="A spanish speaking agent.",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
handoffs=[spanish_agent],
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
pipeline = VoicePipeline(workflow=SingleAgentVoiceWorkflow(agent))
|
||||
buffer = np.zeros(24000 * 3, dtype=np.int16)
|
||||
audio_input = AudioInput(buffer=buffer)
|
||||
|
||||
result = await pipeline.run(audio_input)
|
||||
|
||||
# Create an audio player using `sounddevice`
|
||||
player = sd.OutputStream(samplerate=24000, channels=1, dtype=np.int16)
|
||||
player.start()
|
||||
|
||||
# Play the audio stream as it comes in
|
||||
async for event in result.stream():
|
||||
if event.type == "voice_stream_event_audio":
|
||||
player.write(event.data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
If you run this example, the agent will speak to you! Check out the example in [examples/voice/static](https://github.com/openai/openai-agents-python/tree/main/examples/voice/static) to see a demo where you can speak to the agent yourself.
|
||||
14
docs/voice/tracing.md
Normal file
14
docs/voice/tracing.md
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Tracing
|
||||
|
||||
Just like the way [agents are traced](../tracing.md), voice pipelines are also automatically traced.
|
||||
|
||||
You can read the tracing doc above for basic tracing information, but you can additionally configure tracing of a pipeline via [`VoicePipelineConfig`][agents.voice.pipeline_config.VoicePipelineConfig].
|
||||
|
||||
Key tracing related fields are:
|
||||
|
||||
- [`tracing_disabled`][agents.voice.pipeline_config.VoicePipelineConfig.tracing_disabled]: controls whether tracing is disabled. By default, tracing is enabled.
|
||||
- [`trace_include_sensitive_data`][agents.voice.pipeline_config.VoicePipelineConfig.trace_include_sensitive_data]: controls whether traces include potentially sensitive data, like audio transcripts. This is specifically for the voice pipeline, and not for anything that goes on inside your Workflow.
|
||||
- [`trace_include_sensitive_audio_data`][agents.voice.pipeline_config.VoicePipelineConfig.trace_include_sensitive_audio_data]: controls whether traces include audio data.
|
||||
- [`workflow_name`][agents.voice.pipeline_config.VoicePipelineConfig.workflow_name]: The name of the trace workflow.
|
||||
- [`group_id`][agents.voice.pipeline_config.VoicePipelineConfig.group_id]: The `group_id` of the trace, which lets you link multiple traces.
|
||||
- [`trace_metadata`][agents.voice.pipeline_config.VoicePipelineConfig.tracing_disabled]: Additional metadata to include with the trace.
|
||||
99
examples/agent_patterns/forcing_tool_use.py
Normal file
99
examples/agent_patterns/forcing_tool_use.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
FunctionToolResult,
|
||||
ModelSettings,
|
||||
RunContextWrapper,
|
||||
Runner,
|
||||
ToolsToFinalOutputFunction,
|
||||
ToolsToFinalOutputResult,
|
||||
function_tool,
|
||||
)
|
||||
|
||||
"""
|
||||
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
|
||||
to force the agent to use any tool.
|
||||
|
||||
You can run it with 3 options:
|
||||
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
|
||||
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
|
||||
call the tool, the tool would run and send the results to the LLM, and that would repeat
|
||||
(because the model is forced to use a tool every time.)
|
||||
2. `first_tool_result`: The first tool result is used as the final output.
|
||||
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
|
||||
results, and chooses to use the first tool result to generate the final output.
|
||||
|
||||
Usage:
|
||||
python examples/agent_patterns/forcing_tool_use.py -t default
|
||||
python examples/agent_patterns/forcing_tool_use.py -t first_tool
|
||||
python examples/agent_patterns/forcing_tool_use.py -t custom
|
||||
"""
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
city: str
|
||||
temperature_range: str
|
||||
conditions: str
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> Weather:
|
||||
print("[debug] get_weather called")
|
||||
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
|
||||
|
||||
|
||||
async def custom_tool_use_behavior(
|
||||
context: RunContextWrapper[Any], results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
weather: Weather = results[0].output
|
||||
return ToolsToFinalOutputResult(
|
||||
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
|
||||
)
|
||||
|
||||
|
||||
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
|
||||
if tool_use_behavior == "default":
|
||||
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
|
||||
"run_llm_again"
|
||||
)
|
||||
elif tool_use_behavior == "first_tool":
|
||||
behavior = "stop_on_first_tool"
|
||||
elif tool_use_behavior == "custom":
|
||||
behavior = custom_tool_use_behavior
|
||||
|
||||
agent = Agent(
|
||||
name="Weather agent",
|
||||
instructions="You are a helpful agent.",
|
||||
tools=[get_weather],
|
||||
tool_use_behavior=behavior,
|
||||
model_settings=ModelSettings(
|
||||
tool_choice="required" if tool_use_behavior != "default" else None
|
||||
),
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="What's the weather in Tokyo?")
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tool-use-behavior",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["default", "first_tool", "custom"],
|
||||
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
|
||||
"first_tool_result will cause the first tool result to be used as the final output. "
|
||||
"custom will use a custom tool use behavior function.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.tool_use_behavior))
|
||||
|
|
@ -79,7 +79,7 @@ multiply_agent = Agent(
|
|||
|
||||
start_agent = Agent(
|
||||
name="Start Agent",
|
||||
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multipler agent.",
|
||||
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multiplier agent.",
|
||||
tools=[random_number],
|
||||
output_type=FinalResult,
|
||||
handoffs=[multiply_agent],
|
||||
|
|
|
|||
34
examples/basic/tools.py
Normal file
34
examples/basic/tools.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent, Runner, function_tool
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
city: str
|
||||
temperature_range: str
|
||||
conditions: str
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> Weather:
|
||||
print("[debug] get_weather called")
|
||||
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
name="Hello world",
|
||||
instructions="You are a helpful agent.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
result = await Runner.run(agent, input="What's the weather in Tokyo?")
|
||||
print(result.final_output)
|
||||
# The weather in Tokyo is sunny.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
38
examples/financial_research_agent/README.md
Normal file
38
examples/financial_research_agent/README.md
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Financial Research Agent Example
|
||||
|
||||
This example shows how you might compose a richer financial research agent using the Agents SDK. The pattern is similar to the `research_bot` example, but with more specialized sub‑agents and a verification step.
|
||||
|
||||
The flow is:
|
||||
|
||||
1. **Planning**: A planner agent turns the end user’s request into a list of search terms relevant to financial analysis – recent news, earnings calls, corporate filings, industry commentary, etc.
|
||||
2. **Search**: A search agent uses the built‑in `WebSearchTool` to retrieve terse summaries for each search term. (You could also add `FileSearchTool` if you have indexed PDFs or 10‑Ks.)
|
||||
3. **Sub‑analysts**: Additional agents (e.g. a fundamentals analyst and a risk analyst) are exposed as tools so the writer can call them inline and incorporate their outputs.
|
||||
4. **Writing**: A senior writer agent brings together the search snippets and any sub‑analyst summaries into a long‑form markdown report plus a short executive summary.
|
||||
5. **Verification**: A final verifier agent audits the report for obvious inconsistencies or missing sourcing.
|
||||
|
||||
You can run the example with:
|
||||
|
||||
```bash
|
||||
python -m examples.financial_research_agent.main
|
||||
```
|
||||
|
||||
and enter a query like:
|
||||
|
||||
```
|
||||
Write up an analysis of Apple Inc.'s most recent quarter.
|
||||
```
|
||||
|
||||
### Starter prompt
|
||||
|
||||
The writer agent is seeded with instructions similar to:
|
||||
|
||||
```
|
||||
You are a senior financial analyst. You will be provided with the original query
|
||||
and a set of raw search summaries. Your job is to synthesize these into a
|
||||
long‑form markdown report (at least several paragraphs) with a short executive
|
||||
summary. You also have access to tools like `fundamentals_analysis` and
|
||||
`risk_analysis` to get short specialist write‑ups if you want to incorporate them.
|
||||
Add a few follow‑up questions for further research.
|
||||
```
|
||||
|
||||
You can tweak these prompts and sub‑agents to suit your own data sources and preferred report structure.
|
||||
0
examples/financial_research_agent/__init__.py
Normal file
0
examples/financial_research_agent/__init__.py
Normal file
0
examples/financial_research_agent/agents/__init__.py
Normal file
0
examples/financial_research_agent/agents/__init__.py
Normal file
23
examples/financial_research_agent/agents/financials_agent.py
Normal file
23
examples/financial_research_agent/agents/financials_agent.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent
|
||||
|
||||
# A sub‑agent focused on analyzing a company's fundamentals.
|
||||
FINANCIALS_PROMPT = (
|
||||
"You are a financial analyst focused on company fundamentals such as revenue, "
|
||||
"profit, margins and growth trajectory. Given a collection of web (and optional file) "
|
||||
"search results about a company, write a concise analysis of its recent financial "
|
||||
"performance. Pull out key metrics or quotes. Keep it under 2 paragraphs."
|
||||
)
|
||||
|
||||
|
||||
class AnalysisSummary(BaseModel):
|
||||
summary: str
|
||||
"""Short text summary for this aspect of the analysis."""
|
||||
|
||||
|
||||
financials_agent = Agent(
|
||||
name="FundamentalsAnalystAgent",
|
||||
instructions=FINANCIALS_PROMPT,
|
||||
output_type=AnalysisSummary,
|
||||
)
|
||||
35
examples/financial_research_agent/agents/planner_agent.py
Normal file
35
examples/financial_research_agent/agents/planner_agent.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent
|
||||
|
||||
# Generate a plan of searches to ground the financial analysis.
|
||||
# For a given financial question or company, we want to search for
|
||||
# recent news, official filings, analyst commentary, and other
|
||||
# relevant background.
|
||||
PROMPT = (
|
||||
"You are a financial research planner. Given a request for financial analysis, "
|
||||
"produce a set of web searches to gather the context needed. Aim for recent "
|
||||
"headlines, earnings calls or 10‑K snippets, analyst commentary, and industry background. "
|
||||
"Output between 5 and 15 search terms to query for."
|
||||
)
|
||||
|
||||
|
||||
class FinancialSearchItem(BaseModel):
|
||||
reason: str
|
||||
"""Your reasoning for why this search is relevant."""
|
||||
|
||||
query: str
|
||||
"""The search term to feed into a web (or file) search."""
|
||||
|
||||
|
||||
class FinancialSearchPlan(BaseModel):
|
||||
searches: list[FinancialSearchItem]
|
||||
"""A list of searches to perform."""
|
||||
|
||||
|
||||
planner_agent = Agent(
|
||||
name="FinancialPlannerAgent",
|
||||
instructions=PROMPT,
|
||||
model="o3-mini",
|
||||
output_type=FinancialSearchPlan,
|
||||
)
|
||||
22
examples/financial_research_agent/agents/risk_agent.py
Normal file
22
examples/financial_research_agent/agents/risk_agent.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent
|
||||
|
||||
# A sub‑agent specializing in identifying risk factors or concerns.
|
||||
RISK_PROMPT = (
|
||||
"You are a risk analyst looking for potential red flags in a company's outlook. "
|
||||
"Given background research, produce a short analysis of risks such as competitive threats, "
|
||||
"regulatory issues, supply chain problems, or slowing growth. Keep it under 2 paragraphs."
|
||||
)
|
||||
|
||||
|
||||
class AnalysisSummary(BaseModel):
|
||||
summary: str
|
||||
"""Short text summary for this aspect of the analysis."""
|
||||
|
||||
|
||||
risk_agent = Agent(
|
||||
name="RiskAnalystAgent",
|
||||
instructions=RISK_PROMPT,
|
||||
output_type=AnalysisSummary,
|
||||
)
|
||||
18
examples/financial_research_agent/agents/search_agent.py
Normal file
18
examples/financial_research_agent/agents/search_agent.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from agents import Agent, WebSearchTool
|
||||
from agents.model_settings import ModelSettings
|
||||
|
||||
# Given a search term, use web search to pull back a brief summary.
|
||||
# Summaries should be concise but capture the main financial points.
|
||||
INSTRUCTIONS = (
|
||||
"You are a research assistant specializing in financial topics. "
|
||||
"Given a search term, use web search to retrieve up‑to‑date context and "
|
||||
"produce a short summary of at most 300 words. Focus on key numbers, events, "
|
||||
"or quotes that will be useful to a financial analyst."
|
||||
)
|
||||
|
||||
search_agent = Agent(
|
||||
name="FinancialSearchAgent",
|
||||
instructions=INSTRUCTIONS,
|
||||
tools=[WebSearchTool()],
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
)
|
||||
27
examples/financial_research_agent/agents/verifier_agent.py
Normal file
27
examples/financial_research_agent/agents/verifier_agent.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent
|
||||
|
||||
# Agent to sanity‑check a synthesized report for consistency and recall.
|
||||
# This can be used to flag potential gaps or obvious mistakes.
|
||||
VERIFIER_PROMPT = (
|
||||
"You are a meticulous auditor. You have been handed a financial analysis report. "
|
||||
"Your job is to verify the report is internally consistent, clearly sourced, and makes "
|
||||
"no unsupported claims. Point out any issues or uncertainties."
|
||||
)
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
verified: bool
|
||||
"""Whether the report seems coherent and plausible."""
|
||||
|
||||
issues: str
|
||||
"""If not verified, describe the main issues or concerns."""
|
||||
|
||||
|
||||
verifier_agent = Agent(
|
||||
name="VerificationAgent",
|
||||
instructions=VERIFIER_PROMPT,
|
||||
model="gpt-4o",
|
||||
output_type=VerificationResult,
|
||||
)
|
||||
34
examples/financial_research_agent/agents/writer_agent.py
Normal file
34
examples/financial_research_agent/agents/writer_agent.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent
|
||||
|
||||
# Writer agent brings together the raw search results and optionally calls out
|
||||
# to sub‑analyst tools for specialized commentary, then returns a cohesive markdown report.
|
||||
WRITER_PROMPT = (
|
||||
"You are a senior financial analyst. You will be provided with the original query and "
|
||||
"a set of raw search summaries. Your task is to synthesize these into a long‑form markdown "
|
||||
"report (at least several paragraphs) including a short executive summary and follow‑up "
|
||||
"questions. If needed, you can call the available analysis tools (e.g. fundamentals_analysis, "
|
||||
"risk_analysis) to get short specialist write‑ups to incorporate."
|
||||
)
|
||||
|
||||
|
||||
class FinancialReportData(BaseModel):
|
||||
short_summary: str
|
||||
"""A short 2‑3 sentence executive summary."""
|
||||
|
||||
markdown_report: str
|
||||
"""The full markdown report."""
|
||||
|
||||
follow_up_questions: list[str]
|
||||
"""Suggested follow‑up questions for further research."""
|
||||
|
||||
|
||||
# Note: We will attach handoffs to specialist analyst agents at runtime in the manager.
|
||||
# This shows how an agent can use handoffs to delegate to specialized subagents.
|
||||
writer_agent = Agent(
|
||||
name="FinancialWriterAgent",
|
||||
instructions=WRITER_PROMPT,
|
||||
model="gpt-4.5-preview-2025-02-27",
|
||||
output_type=FinancialReportData,
|
||||
)
|
||||
17
examples/financial_research_agent/main.py
Normal file
17
examples/financial_research_agent/main.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import asyncio
|
||||
|
||||
from .manager import FinancialResearchManager
|
||||
|
||||
|
||||
# Entrypoint for the financial bot example.
|
||||
# Run this as `python -m examples.financial_bot.main` and enter a
|
||||
# financial research query, for example:
|
||||
# "Write up an analysis of Apple Inc.'s most recent quarter."
|
||||
async def main() -> None:
|
||||
query = input("Enter a financial research query: ")
|
||||
mgr = FinancialResearchManager()
|
||||
await mgr.run(query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
135
examples/financial_research_agent/manager.py
Normal file
135
examples/financial_research_agent/manager.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from agents import Runner, RunResult, custom_span, gen_trace_id, trace
|
||||
|
||||
from .agents.financials_agent import financials_agent
|
||||
from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent
|
||||
from .agents.risk_agent import risk_agent
|
||||
from .agents.search_agent import search_agent
|
||||
from .agents.verifier_agent import VerificationResult, verifier_agent
|
||||
from .agents.writer_agent import FinancialReportData, writer_agent
|
||||
from .printer import Printer
|
||||
|
||||
|
||||
async def _summary_extractor(run_result: RunResult) -> str:
|
||||
"""Custom output extractor for sub‑agents that return an AnalysisSummary."""
|
||||
# The financial/risk analyst agents emit an AnalysisSummary with a `summary` field.
|
||||
# We want the tool call to return just that summary text so the writer can drop it inline.
|
||||
return str(run_result.final_output.summary)
|
||||
|
||||
|
||||
class FinancialResearchManager:
|
||||
"""
|
||||
Orchestrates the full flow: planning, searching, sub‑analysis, writing, and verification.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.console = Console()
|
||||
self.printer = Printer(self.console)
|
||||
|
||||
async def run(self, query: str) -> None:
|
||||
trace_id = gen_trace_id()
|
||||
with trace("Financial research trace", trace_id=trace_id):
|
||||
self.printer.update_item(
|
||||
"trace_id",
|
||||
f"View trace: https://platform.openai.com/traces/{trace_id}",
|
||||
is_done=True,
|
||||
hide_checkmark=True,
|
||||
)
|
||||
self.printer.update_item("start", "Starting financial research...", is_done=True)
|
||||
search_plan = await self._plan_searches(query)
|
||||
search_results = await self._perform_searches(search_plan)
|
||||
report = await self._write_report(query, search_results)
|
||||
verification = await self._verify_report(report)
|
||||
|
||||
final_report = f"Report summary\n\n{report.short_summary}"
|
||||
self.printer.update_item("final_report", final_report, is_done=True)
|
||||
|
||||
self.printer.end()
|
||||
|
||||
# Print to stdout
|
||||
print("\n\n=====REPORT=====\n\n")
|
||||
print(f"Report:\n{report.markdown_report}")
|
||||
print("\n\n=====FOLLOW UP QUESTIONS=====\n\n")
|
||||
print("\n".join(report.follow_up_questions))
|
||||
print("\n\n=====VERIFICATION=====\n\n")
|
||||
print(verification)
|
||||
|
||||
async def _plan_searches(self, query: str) -> FinancialSearchPlan:
|
||||
self.printer.update_item("planning", "Planning searches...")
|
||||
result = await Runner.run(planner_agent, f"Query: {query}")
|
||||
self.printer.update_item(
|
||||
"planning",
|
||||
f"Will perform {len(result.final_output.searches)} searches",
|
||||
is_done=True,
|
||||
)
|
||||
return result.final_output_as(FinancialSearchPlan)
|
||||
|
||||
async def _perform_searches(self, search_plan: FinancialSearchPlan) -> Sequence[str]:
|
||||
with custom_span("Search the web"):
|
||||
self.printer.update_item("searching", "Searching...")
|
||||
tasks = [asyncio.create_task(self._search(item)) for item in search_plan.searches]
|
||||
results: list[str] = []
|
||||
num_completed = 0
|
||||
for task in asyncio.as_completed(tasks):
|
||||
result = await task
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
num_completed += 1
|
||||
self.printer.update_item(
|
||||
"searching", f"Searching... {num_completed}/{len(tasks)} completed"
|
||||
)
|
||||
self.printer.mark_item_done("searching")
|
||||
return results
|
||||
|
||||
async def _search(self, item: FinancialSearchItem) -> str | None:
|
||||
input_data = f"Search term: {item.query}\nReason: {item.reason}"
|
||||
try:
|
||||
result = await Runner.run(search_agent, input_data)
|
||||
return str(result.final_output)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _write_report(self, query: str, search_results: Sequence[str]) -> FinancialReportData:
|
||||
# Expose the specialist analysts as tools so the writer can invoke them inline
|
||||
# and still produce the final FinancialReportData output.
|
||||
fundamentals_tool = financials_agent.as_tool(
|
||||
tool_name="fundamentals_analysis",
|
||||
tool_description="Use to get a short write‑up of key financial metrics",
|
||||
custom_output_extractor=_summary_extractor,
|
||||
)
|
||||
risk_tool = risk_agent.as_tool(
|
||||
tool_name="risk_analysis",
|
||||
tool_description="Use to get a short write‑up of potential red flags",
|
||||
custom_output_extractor=_summary_extractor,
|
||||
)
|
||||
writer_with_tools = writer_agent.clone(tools=[fundamentals_tool, risk_tool])
|
||||
self.printer.update_item("writing", "Thinking about report...")
|
||||
input_data = f"Original query: {query}\nSummarized search results: {search_results}"
|
||||
result = Runner.run_streamed(writer_with_tools, input_data)
|
||||
update_messages = [
|
||||
"Planning report structure...",
|
||||
"Writing sections...",
|
||||
"Finalizing report...",
|
||||
]
|
||||
last_update = time.time()
|
||||
next_message = 0
|
||||
async for _ in result.stream_events():
|
||||
if time.time() - last_update > 5 and next_message < len(update_messages):
|
||||
self.printer.update_item("writing", update_messages[next_message])
|
||||
next_message += 1
|
||||
last_update = time.time()
|
||||
self.printer.mark_item_done("writing")
|
||||
return result.final_output_as(FinancialReportData)
|
||||
|
||||
async def _verify_report(self, report: FinancialReportData) -> VerificationResult:
|
||||
self.printer.update_item("verifying", "Verifying report...")
|
||||
result = await Runner.run(verifier_agent, report.markdown_report)
|
||||
self.printer.mark_item_done("verifying")
|
||||
return result.final_output_as(VerificationResult)
|
||||
46
examples/financial_research_agent/printer.py
Normal file
46
examples/financial_research_agent/printer.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Any
|
||||
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.spinner import Spinner
|
||||
|
||||
|
||||
class Printer:
|
||||
"""
|
||||
Simple wrapper to stream status updates. Used by the financial bot
|
||||
manager as it orchestrates planning, search and writing.
|
||||
"""
|
||||
|
||||
def __init__(self, console: Console) -> None:
|
||||
self.live = Live(console=console)
|
||||
self.items: dict[str, tuple[str, bool]] = {}
|
||||
self.hide_done_ids: set[str] = set()
|
||||
self.live.start()
|
||||
|
||||
def end(self) -> None:
|
||||
self.live.stop()
|
||||
|
||||
def hide_done_checkmark(self, item_id: str) -> None:
|
||||
self.hide_done_ids.add(item_id)
|
||||
|
||||
def update_item(
|
||||
self, item_id: str, content: str, is_done: bool = False, hide_checkmark: bool = False
|
||||
) -> None:
|
||||
self.items[item_id] = (content, is_done)
|
||||
if hide_checkmark:
|
||||
self.hide_done_ids.add(item_id)
|
||||
self.flush()
|
||||
|
||||
def mark_item_done(self, item_id: str) -> None:
|
||||
self.items[item_id] = (self.items[item_id][0], True)
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
renderables: list[Any] = []
|
||||
for item_id, (content, is_done) in self.items.items():
|
||||
if is_done:
|
||||
prefix = "✅ " if item_id not in self.hide_done_ids else ""
|
||||
renderables.append(prefix + content)
|
||||
else:
|
||||
renderables.append(Spinner("dots", text=content))
|
||||
self.live.update(Group(*renderables))
|
||||
|
|
@ -4,7 +4,7 @@ from agents.model_settings import ModelSettings
|
|||
INSTRUCTIONS = (
|
||||
"You are a research assistant. Given a search term, you search the web for that term and"
|
||||
"produce a concise summary of the results. The summary must 2-3 paragraphs and less than 300"
|
||||
"words. Capture the main points. Write succintly, no need to have complete sentences or good"
|
||||
"words. Capture the main points. Write succinctly, no need to have complete sentences or good"
|
||||
"grammar. This will be consumed by someone synthesizing a report, so its vital you capture the"
|
||||
"essence and ignore any fluff. Do not include any additional commentary other than the summary"
|
||||
"itself."
|
||||
|
|
|
|||
0
examples/voice/__init__.py
Normal file
0
examples/voice/__init__.py
Normal file
26
examples/voice/static/README.md
Normal file
26
examples/voice/static/README.md
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# Static voice demo
|
||||
|
||||
This demo operates by capturing a recording, then running a voice pipeline on it.
|
||||
|
||||
Run via:
|
||||
|
||||
```
|
||||
python -m examples.voice.static.main
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
||||
1. We create a `VoicePipeline`, setup with a custom workflow. The workflow runs an Agent, but it also has some custom responses if you say the secret word.
|
||||
2. When you speak, audio is forwarded to the voice pipeline. When you stop speaking, the agent runs.
|
||||
3. The pipeline is run with the audio, which causes it to:
|
||||
1. Transcribe the audio
|
||||
2. Feed the transcription to the workflow, which runs the agent.
|
||||
3. Stream the output of the agent to a text-to-speech model.
|
||||
4. Play the audio.
|
||||
|
||||
Some suggested examples to try:
|
||||
|
||||
- Tell me a joke (_the assistant tells you a joke_)
|
||||
- What's the weather in Tokyo? (_will call the `get_weather` tool and then speak_)
|
||||
- Hola, como estas? (_will handoff to the spanish agent_)
|
||||
- Tell me about dogs. (_will respond with the hardcoded "you guessed the secret word" message_)
|
||||
0
examples/voice/static/__init__.py
Normal file
0
examples/voice/static/__init__.py
Normal file
83
examples/voice/static/main.py
Normal file
83
examples/voice/static/main.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
from agents import Agent, function_tool
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
from agents.voice import (
|
||||
AudioInput,
|
||||
SingleAgentVoiceWorkflow,
|
||||
SingleAgentWorkflowCallbacks,
|
||||
VoicePipeline,
|
||||
)
|
||||
|
||||
from .util import AudioPlayer, record_audio
|
||||
|
||||
"""
|
||||
This is a simple example that uses a recorded audio buffer. Run it via:
|
||||
`python -m examples.voice.static.main`
|
||||
|
||||
1. You can record an audio clip in the terminal.
|
||||
2. The pipeline automatically transcribes the audio.
|
||||
3. The agent workflow is a simple one that starts at the Assistant agent.
|
||||
4. The output of the agent is streamed to the audio player.
|
||||
|
||||
Try examples like:
|
||||
- Tell me a joke (will respond with a joke)
|
||||
- What's the weather in Tokyo? (will call the `get_weather` tool and then speak)
|
||||
- Hola, como estas? (will handoff to the spanish agent)
|
||||
"""
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather for a given city."""
|
||||
print(f"[debug] get_weather called with city: {city}")
|
||||
choices = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
return f"The weather in {city} is {random.choice(choices)}."
|
||||
|
||||
|
||||
spanish_agent = Agent(
|
||||
name="Spanish",
|
||||
handoff_description="A spanish speaking agent.",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
handoffs=[spanish_agent],
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
|
||||
class WorkflowCallbacks(SingleAgentWorkflowCallbacks):
|
||||
def on_run(self, workflow: SingleAgentVoiceWorkflow, transcription: str) -> None:
|
||||
print(f"[debug] on_run called with transcription: {transcription}")
|
||||
|
||||
|
||||
async def main():
|
||||
pipeline = VoicePipeline(
|
||||
workflow=SingleAgentVoiceWorkflow(agent, callbacks=WorkflowCallbacks())
|
||||
)
|
||||
|
||||
audio_input = AudioInput(buffer=record_audio())
|
||||
|
||||
result = await pipeline.run(audio_input)
|
||||
|
||||
with AudioPlayer() as player:
|
||||
async for event in result.stream():
|
||||
if event.type == "voice_stream_event_audio":
|
||||
player.add_audio(event.data)
|
||||
print("Received audio")
|
||||
elif event.type == "voice_stream_event_lifecycle":
|
||||
print(f"Received lifecycle event: {event.event}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
68
examples/voice/static/util.py
Normal file
68
examples/voice/static/util.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import curses
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import sounddevice as sd
|
||||
|
||||
|
||||
def _record_audio(screen: curses.window) -> npt.NDArray[np.float32]:
|
||||
screen.nodelay(True) # Non-blocking input
|
||||
screen.clear()
|
||||
screen.addstr(
|
||||
"Press <spacebar> to start recording. Press <spacebar> again to stop recording.\n"
|
||||
)
|
||||
screen.refresh()
|
||||
|
||||
recording = False
|
||||
audio_buffer: list[npt.NDArray[np.float32]] = []
|
||||
|
||||
def _audio_callback(indata, frames, time_info, status):
|
||||
if status:
|
||||
screen.addstr(f"Status: {status}\n")
|
||||
screen.refresh()
|
||||
if recording:
|
||||
audio_buffer.append(indata.copy())
|
||||
|
||||
# Open the audio stream with the callback.
|
||||
with sd.InputStream(samplerate=24000, channels=1, dtype=np.float32, callback=_audio_callback):
|
||||
while True:
|
||||
key = screen.getch()
|
||||
if key == ord(" "):
|
||||
recording = not recording
|
||||
if recording:
|
||||
screen.addstr("Recording started...\n")
|
||||
else:
|
||||
screen.addstr("Recording stopped.\n")
|
||||
break
|
||||
screen.refresh()
|
||||
time.sleep(0.01)
|
||||
|
||||
# Combine recorded audio chunks.
|
||||
if audio_buffer:
|
||||
audio_data = np.concatenate(audio_buffer, axis=0)
|
||||
else:
|
||||
audio_data = np.empty((0,), dtype=np.float32)
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
def record_audio():
|
||||
# Using curses to record audio in a way that:
|
||||
# - doesn't require accessibility permissions on macos
|
||||
# - doesn't block the terminal
|
||||
audio_data = curses.wrapper(_record_audio)
|
||||
return audio_data
|
||||
|
||||
|
||||
class AudioPlayer:
|
||||
def __enter__(self):
|
||||
self.stream = sd.OutputStream(samplerate=24000, channels=1, dtype=np.int16)
|
||||
self.stream.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.stream.close()
|
||||
|
||||
def add_audio(self, audio_data: npt.NDArray[np.int16]):
|
||||
self.stream.write(audio_data)
|
||||
25
examples/voice/streamed/README.md
Normal file
25
examples/voice/streamed/README.md
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Streamed voice demo
|
||||
|
||||
This is an interactive demo, where you can talk to an Agent conversationally. It uses the voice pipeline's built in turn detection feature, so if you stop speaking the Agent responds.
|
||||
|
||||
Run via:
|
||||
|
||||
```
|
||||
python -m examples.voice.streamed.main
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
||||
1. We create a `VoicePipeline`, setup with a `SingleAgentVoiceWorkflow`. This is a workflow that starts at an Assistant agent, has tools and handoffs.
|
||||
2. Audio input is captured from the terminal.
|
||||
3. The pipeline is run with the recorded audio, which causes it to:
|
||||
1. Transcribe the audio
|
||||
2. Feed the transcription to the workflow, which runs the agent.
|
||||
3. Stream the output of the agent to a text-to-speech model.
|
||||
4. Play the audio.
|
||||
|
||||
Some suggested examples to try:
|
||||
|
||||
- Tell me a joke (_the assistant tells you a joke_)
|
||||
- What's the weather in Tokyo? (_will call the `get_weather` tool and then speak_)
|
||||
- Hola, como estas? (_will handoff to the spanish agent_)
|
||||
0
examples/voice/streamed/__init__.py
Normal file
0
examples/voice/streamed/__init__.py
Normal file
233
examples/voice/streamed/main.py
Normal file
233
examples/voice/streamed/main.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
from textual import events
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.containers import Container
|
||||
from textual.reactive import reactive
|
||||
from textual.widgets import Button, RichLog, Static
|
||||
from typing_extensions import override
|
||||
|
||||
from agents.voice import StreamedAudioInput, VoicePipeline
|
||||
|
||||
# Import MyWorkflow class - handle both module and package use cases
|
||||
if TYPE_CHECKING:
|
||||
# For type checking, use the relative import
|
||||
from .my_workflow import MyWorkflow
|
||||
else:
|
||||
# At runtime, try both import styles
|
||||
try:
|
||||
# Try relative import first (when used as a package)
|
||||
from .my_workflow import MyWorkflow
|
||||
except ImportError:
|
||||
# Fall back to direct import (when run as a script)
|
||||
from my_workflow import MyWorkflow
|
||||
|
||||
CHUNK_LENGTH_S = 0.05 # 100ms
|
||||
SAMPLE_RATE = 24000
|
||||
FORMAT = np.int16
|
||||
CHANNELS = 1
|
||||
|
||||
|
||||
class Header(Static):
|
||||
"""A header widget."""
|
||||
|
||||
session_id = reactive("")
|
||||
|
||||
@override
|
||||
def render(self) -> str:
|
||||
return "Speak to the agent. When you stop speaking, it will respond."
|
||||
|
||||
|
||||
class AudioStatusIndicator(Static):
|
||||
"""A widget that shows the current audio recording status."""
|
||||
|
||||
is_recording = reactive(False)
|
||||
|
||||
@override
|
||||
def render(self) -> str:
|
||||
status = (
|
||||
"🔴 Recording... (Press K to stop)"
|
||||
if self.is_recording
|
||||
else "⚪ Press K to start recording (Q to quit)"
|
||||
)
|
||||
return status
|
||||
|
||||
|
||||
class RealtimeApp(App[None]):
|
||||
CSS = """
|
||||
Screen {
|
||||
background: #1a1b26; /* Dark blue-grey background */
|
||||
}
|
||||
|
||||
Container {
|
||||
border: double rgb(91, 164, 91);
|
||||
}
|
||||
|
||||
Horizontal {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#input-container {
|
||||
height: 5; /* Explicit height for input container */
|
||||
margin: 1 1;
|
||||
padding: 1 2;
|
||||
}
|
||||
|
||||
Input {
|
||||
width: 80%;
|
||||
height: 3; /* Explicit height for input */
|
||||
}
|
||||
|
||||
Button {
|
||||
width: 20%;
|
||||
height: 3; /* Explicit height for button */
|
||||
}
|
||||
|
||||
#bottom-pane {
|
||||
width: 100%;
|
||||
height: 82%; /* Reduced to make room for session display */
|
||||
border: round rgb(205, 133, 63);
|
||||
content-align: center middle;
|
||||
}
|
||||
|
||||
#status-indicator {
|
||||
height: 3;
|
||||
content-align: center middle;
|
||||
background: #2a2b36;
|
||||
border: solid rgb(91, 164, 91);
|
||||
margin: 1 1;
|
||||
}
|
||||
|
||||
#session-display {
|
||||
height: 3;
|
||||
content-align: center middle;
|
||||
background: #2a2b36;
|
||||
border: solid rgb(91, 164, 91);
|
||||
margin: 1 1;
|
||||
}
|
||||
|
||||
Static {
|
||||
color: white;
|
||||
}
|
||||
"""
|
||||
|
||||
should_send_audio: asyncio.Event
|
||||
audio_player: sd.OutputStream
|
||||
last_audio_item_id: str | None
|
||||
connected: asyncio.Event
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.last_audio_item_id = None
|
||||
self.should_send_audio = asyncio.Event()
|
||||
self.connected = asyncio.Event()
|
||||
self.pipeline = VoicePipeline(
|
||||
workflow=MyWorkflow(secret_word="dog", on_start=self._on_transcription)
|
||||
)
|
||||
self._audio_input = StreamedAudioInput()
|
||||
self.audio_player = sd.OutputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
channels=CHANNELS,
|
||||
dtype=FORMAT,
|
||||
)
|
||||
|
||||
def _on_transcription(self, transcription: str) -> None:
|
||||
try:
|
||||
self.query_one("#bottom-pane", RichLog).write(f"Transcription: {transcription}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@override
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create child widgets for the app."""
|
||||
with Container():
|
||||
yield Header(id="session-display")
|
||||
yield AudioStatusIndicator(id="status-indicator")
|
||||
yield RichLog(id="bottom-pane", wrap=True, highlight=True, markup=True)
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
self.run_worker(self.start_voice_pipeline())
|
||||
self.run_worker(self.send_mic_audio())
|
||||
|
||||
async def start_voice_pipeline(self) -> None:
|
||||
try:
|
||||
self.audio_player.start()
|
||||
self.result = await self.pipeline.run(self._audio_input)
|
||||
|
||||
async for event in self.result.stream():
|
||||
bottom_pane = self.query_one("#bottom-pane", RichLog)
|
||||
if event.type == "voice_stream_event_audio":
|
||||
self.audio_player.write(event.data)
|
||||
bottom_pane.write(
|
||||
f"Received audio: {len(event.data) if event.data is not None else '0'} bytes"
|
||||
)
|
||||
elif event.type == "voice_stream_event_lifecycle":
|
||||
bottom_pane.write(f"Lifecycle event: {event.event}")
|
||||
except Exception as e:
|
||||
bottom_pane = self.query_one("#bottom-pane", RichLog)
|
||||
bottom_pane.write(f"Error: {e}")
|
||||
finally:
|
||||
self.audio_player.close()
|
||||
|
||||
async def send_mic_audio(self) -> None:
|
||||
device_info = sd.query_devices()
|
||||
print(device_info)
|
||||
|
||||
read_size = int(SAMPLE_RATE * 0.02)
|
||||
|
||||
stream = sd.InputStream(
|
||||
channels=CHANNELS,
|
||||
samplerate=SAMPLE_RATE,
|
||||
dtype="int16",
|
||||
)
|
||||
stream.start()
|
||||
|
||||
status_indicator = self.query_one(AudioStatusIndicator)
|
||||
|
||||
try:
|
||||
while True:
|
||||
if stream.read_available < read_size:
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
|
||||
await self.should_send_audio.wait()
|
||||
status_indicator.is_recording = True
|
||||
|
||||
data, _ = stream.read(read_size)
|
||||
|
||||
await self._audio_input.add_audio(data)
|
||||
await asyncio.sleep(0)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
|
||||
async def on_key(self, event: events.Key) -> None:
|
||||
"""Handle key press events."""
|
||||
if event.key == "enter":
|
||||
self.query_one(Button).press()
|
||||
return
|
||||
|
||||
if event.key == "q":
|
||||
self.exit()
|
||||
return
|
||||
|
||||
if event.key == "k":
|
||||
status_indicator = self.query_one(AudioStatusIndicator)
|
||||
if status_indicator.is_recording:
|
||||
self.should_send_audio.clear()
|
||||
status_indicator.is_recording = False
|
||||
else:
|
||||
self.should_send_audio.set()
|
||||
status_indicator.is_recording = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = RealtimeApp()
|
||||
app.run()
|
||||
81
examples/voice/streamed/my_workflow.py
Normal file
81
examples/voice/streamed/my_workflow.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import random
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Callable
|
||||
|
||||
from agents import Agent, Runner, TResponseInputItem, function_tool
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
from agents.voice import VoiceWorkflowBase, VoiceWorkflowHelper
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather for a given city."""
|
||||
print(f"[debug] get_weather called with city: {city}")
|
||||
choices = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
return f"The weather in {city} is {random.choice(choices)}."
|
||||
|
||||
|
||||
spanish_agent = Agent(
|
||||
name="Spanish",
|
||||
handoff_description="A spanish speaking agent.",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. Speak in Spanish.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions=prompt_with_handoff_instructions(
|
||||
"You're speaking to a human, so be polite and concise. If the user speaks in Spanish, handoff to the spanish agent.",
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
handoffs=[spanish_agent],
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
|
||||
class MyWorkflow(VoiceWorkflowBase):
|
||||
def __init__(self, secret_word: str, on_start: Callable[[str], None]):
|
||||
"""
|
||||
Args:
|
||||
secret_word: The secret word to guess.
|
||||
on_start: A callback that is called when the workflow starts. The transcription
|
||||
is passed in as an argument.
|
||||
"""
|
||||
self._input_history: list[TResponseInputItem] = []
|
||||
self._current_agent = agent
|
||||
self._secret_word = secret_word.lower()
|
||||
self._on_start = on_start
|
||||
|
||||
async def run(self, transcription: str) -> AsyncIterator[str]:
|
||||
self._on_start(transcription)
|
||||
|
||||
# Add the transcription to the input history
|
||||
self._input_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": transcription,
|
||||
}
|
||||
)
|
||||
|
||||
# If the user guessed the secret word, do alternate logic
|
||||
if self._secret_word in transcription.lower():
|
||||
yield "You guessed the secret word!"
|
||||
self._input_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "You guessed the secret word!",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Otherwise, run the agent
|
||||
result = Runner.run_streamed(self._current_agent, self._input_history)
|
||||
|
||||
async for chunk in VoiceWorkflowHelper.stream_text_from(result):
|
||||
yield chunk
|
||||
|
||||
# Update the input history and current agent
|
||||
self._input_history = result.to_input_list()
|
||||
self._current_agent = result.last_agent
|
||||
234
mkdocs.yml
234
mkdocs.yml
|
|
@ -1,121 +1,143 @@
|
|||
site_name: OpenAI Agents SDK
|
||||
theme:
|
||||
name: material
|
||||
features:
|
||||
# Allows copying code blocks
|
||||
- content.code.copy
|
||||
# Allows selecting code blocks
|
||||
- content.code.select
|
||||
# Shows the current path in the sidebar
|
||||
- navigation.path
|
||||
# Shows sections in the sidebar
|
||||
- navigation.sections
|
||||
# Shows sections expanded by default
|
||||
- navigation.expand
|
||||
# Enables annotations in code blocks
|
||||
- content.code.annotate
|
||||
palette:
|
||||
primary: black
|
||||
logo: assets/logo.svg
|
||||
favicon: images/favicon-platform.svg
|
||||
name: material
|
||||
features:
|
||||
# Allows copying code blocks
|
||||
- content.code.copy
|
||||
# Allows selecting code blocks
|
||||
- content.code.select
|
||||
# Shows the current path in the sidebar
|
||||
- navigation.path
|
||||
# Shows sections in the sidebar
|
||||
- navigation.sections
|
||||
# Shows sections expanded by default
|
||||
- navigation.expand
|
||||
# Enables annotations in code blocks
|
||||
- content.code.annotate
|
||||
palette:
|
||||
primary: black
|
||||
logo: assets/logo.svg
|
||||
favicon: images/favicon-platform.svg
|
||||
nav:
|
||||
- Intro: index.md
|
||||
- Quickstart: quickstart.md
|
||||
- Documentation:
|
||||
- agents.md
|
||||
- running_agents.md
|
||||
- results.md
|
||||
- streaming.md
|
||||
- tools.md
|
||||
- handoffs.md
|
||||
- tracing.md
|
||||
- context.md
|
||||
- guardrails.md
|
||||
- multi_agent.md
|
||||
- models.md
|
||||
- config.md
|
||||
- API Reference:
|
||||
- Agents:
|
||||
- ref/index.md
|
||||
- ref/agent.md
|
||||
- ref/run.md
|
||||
- ref/tool.md
|
||||
- ref/result.md
|
||||
- ref/stream_events.md
|
||||
- ref/handoffs.md
|
||||
- ref/lifecycle.md
|
||||
- ref/items.md
|
||||
- ref/run_context.md
|
||||
- ref/usage.md
|
||||
- ref/exceptions.md
|
||||
- ref/guardrail.md
|
||||
- ref/model_settings.md
|
||||
- ref/agent_output.md
|
||||
- ref/function_schema.md
|
||||
- ref/models/interface.md
|
||||
- ref/models/openai_chatcompletions.md
|
||||
- ref/models/openai_responses.md
|
||||
- Tracing:
|
||||
- ref/tracing/index.md
|
||||
- ref/tracing/create.md
|
||||
- ref/tracing/traces.md
|
||||
- ref/tracing/spans.md
|
||||
- ref/tracing/processor_interface.md
|
||||
- ref/tracing/processors.md
|
||||
- ref/tracing/scope.md
|
||||
- ref/tracing/setup.md
|
||||
- ref/tracing/span_data.md
|
||||
- ref/tracing/util.md
|
||||
- Extensions:
|
||||
- ref/extensions/handoff_filters.md
|
||||
- ref/extensions/handoff_prompt.md
|
||||
- Intro: index.md
|
||||
- Quickstart: quickstart.md
|
||||
- Examples: examples.md
|
||||
- Documentation:
|
||||
- agents.md
|
||||
- running_agents.md
|
||||
- results.md
|
||||
- streaming.md
|
||||
- tools.md
|
||||
- handoffs.md
|
||||
- tracing.md
|
||||
- context.md
|
||||
- guardrails.md
|
||||
- multi_agent.md
|
||||
- models.md
|
||||
- config.md
|
||||
- Voice agents:
|
||||
- voice/quickstart.md
|
||||
- voice/pipeline.md
|
||||
- voice/tracing.md
|
||||
- API Reference:
|
||||
- Agents:
|
||||
- ref/index.md
|
||||
- ref/agent.md
|
||||
- ref/run.md
|
||||
- ref/tool.md
|
||||
- ref/result.md
|
||||
- ref/stream_events.md
|
||||
- ref/handoffs.md
|
||||
- ref/lifecycle.md
|
||||
- ref/items.md
|
||||
- ref/run_context.md
|
||||
- ref/usage.md
|
||||
- ref/exceptions.md
|
||||
- ref/guardrail.md
|
||||
- ref/model_settings.md
|
||||
- ref/agent_output.md
|
||||
- ref/function_schema.md
|
||||
- ref/models/interface.md
|
||||
- ref/models/openai_chatcompletions.md
|
||||
- ref/models/openai_responses.md
|
||||
- Tracing:
|
||||
- ref/tracing/index.md
|
||||
- ref/tracing/create.md
|
||||
- ref/tracing/traces.md
|
||||
- ref/tracing/spans.md
|
||||
- ref/tracing/processor_interface.md
|
||||
- ref/tracing/processors.md
|
||||
- ref/tracing/scope.md
|
||||
- ref/tracing/setup.md
|
||||
- ref/tracing/span_data.md
|
||||
- ref/tracing/util.md
|
||||
- Voice:
|
||||
- ref/voice/pipeline.md
|
||||
- ref/voice/workflow.md
|
||||
- ref/voice/input.md
|
||||
- ref/voice/result.md
|
||||
- ref/voice/pipeline_config.md
|
||||
- ref/voice/events.md
|
||||
- ref/voice/exceptions.md
|
||||
- ref/voice/model.md
|
||||
- ref/voice/utils.md
|
||||
- ref/voice/models/openai_provider.md
|
||||
- ref/voice/models/openai_stt.md
|
||||
- ref/voice/models/openai_tts.md
|
||||
- Extensions:
|
||||
- ref/extensions/handoff_filters.md
|
||||
- ref/extensions/handoff_prompt.md
|
||||
|
||||
plugins:
|
||||
- search
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
paths: ["src/agents"]
|
||||
selection:
|
||||
docstring_style: google
|
||||
options:
|
||||
# Shows links to other members in signatures
|
||||
signature_crossrefs: true
|
||||
# Orders members by source order, rather than alphabetical
|
||||
members_order: source
|
||||
# Puts the signature on a separate line from the member name
|
||||
separate_signature: true
|
||||
# Shows type annotations in signatures
|
||||
show_signature_annotations: true
|
||||
# Makes the font sizes nicer
|
||||
heading_level: 3
|
||||
- search
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
paths: ["src/agents"]
|
||||
selection:
|
||||
docstring_style: google
|
||||
options:
|
||||
# Shows links to other members in signatures
|
||||
signature_crossrefs: true
|
||||
# Orders members by source order, rather than alphabetical
|
||||
members_order: source
|
||||
# Puts the signature on a separate line from the member name
|
||||
separate_signature: true
|
||||
# Shows type annotations in signatures
|
||||
show_signature_annotations: true
|
||||
# Makes the font sizes nicer
|
||||
heading_level: 3
|
||||
|
||||
extra:
|
||||
# Remove material generation message in footer
|
||||
generator: false
|
||||
# Remove material generation message in footer
|
||||
generator: false
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
- pymdownx.details
|
||||
- pymdownx.superfences
|
||||
- attr_list
|
||||
- md_in_html
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
line_spans: __span
|
||||
pygments_lang_class: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.snippets
|
||||
- pymdownx.superfences
|
||||
- pymdownx.superfences:
|
||||
custom_fences:
|
||||
- name: mermaid
|
||||
class: mermaid
|
||||
format: !!python/name:pymdownx.superfences.fence_code_format
|
||||
- admonition
|
||||
- pymdownx.details
|
||||
- attr_list
|
||||
- md_in_html
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
line_spans: __span
|
||||
pygments_lang_class: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.snippets
|
||||
- pymdownx.superfences
|
||||
|
||||
validation:
|
||||
omitted_files: warn
|
||||
absolute_links: warn
|
||||
unrecognized_links: warn
|
||||
anchors: warn
|
||||
omitted_files: warn
|
||||
absolute_links: warn
|
||||
unrecognized_links: warn
|
||||
anchors: warn
|
||||
|
||||
extra_css:
|
||||
- stylesheets/extra.css
|
||||
- stylesheets/extra.css
|
||||
|
||||
watch:
|
||||
- "src/agents"
|
||||
- "src/agents"
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
[project]
|
||||
name = "openai-agents"
|
||||
version = "0.0.4"
|
||||
version = "0.0.6"
|
||||
description = "OpenAI Agents SDK"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = "MIT"
|
||||
authors = [
|
||||
{ name = "OpenAI", email = "support@openai.com" },
|
||||
]
|
||||
authors = [{ name = "OpenAI", email = "support@openai.com" }]
|
||||
dependencies = [
|
||||
"openai>=1.66.2",
|
||||
"openai>=1.66.5",
|
||||
"pydantic>=2.10, <3",
|
||||
"griffe>=1.5.6, <2",
|
||||
"typing-extensions>=4.12.2, <5",
|
||||
|
|
@ -27,13 +25,16 @@ classifiers = [
|
|||
"Intended Audience :: Developers",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
"License :: OSI Approved :: MIT License",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/openai/openai-agents-python"
|
||||
Repository = "https://github.com/openai/openai-agents-python"
|
||||
|
||||
[project.optional-dependencies]
|
||||
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"mypy",
|
||||
|
|
@ -48,6 +49,12 @@ dev = [
|
|||
"coverage>=7.6.12",
|
||||
"playwright==1.50.0",
|
||||
"inline-snapshot>=0.20.7",
|
||||
"pynput",
|
||||
"types-pynput",
|
||||
"sounddevice",
|
||||
"pynput",
|
||||
"textual",
|
||||
"websockets",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
@ -80,8 +87,8 @@ select = [
|
|||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
isort = { combine-as-imports = true, known-first-party = ["agents"] }
|
||||
|
||||
|
|
@ -97,11 +104,12 @@ disallow_incomplete_defs = false
|
|||
disallow_untyped_defs = false
|
||||
disallow_untyped_calls = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "sounddevice.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.coverage.run]
|
||||
source = [
|
||||
"tests",
|
||||
"src/agents",
|
||||
]
|
||||
source = ["tests", "src/agents"]
|
||||
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
|
|
@ -126,4 +134,4 @@ markers = [
|
|||
]
|
||||
|
||||
[tool.inline-snapshot]
|
||||
format-command="ruff format --stdin-filename {filename}"
|
||||
format-command = "ruff format --stdin-filename {filename}"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Literal
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from . import _config
|
||||
from .agent import Agent
|
||||
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .computer import AsyncComputer, Button, Computer, Environment
|
||||
from .exceptions import (
|
||||
|
|
@ -57,6 +57,7 @@ from .tool import (
|
|||
ComputerTool,
|
||||
FileSearchTool,
|
||||
FunctionTool,
|
||||
FunctionToolResult,
|
||||
Tool,
|
||||
WebSearchTool,
|
||||
default_tool_error_function,
|
||||
|
|
@ -72,8 +73,11 @@ from .tracing import (
|
|||
Span,
|
||||
SpanData,
|
||||
SpanError,
|
||||
SpeechGroupSpanData,
|
||||
SpeechSpanData,
|
||||
Trace,
|
||||
TracingProcessor,
|
||||
TranscriptionSpanData,
|
||||
add_trace_processor,
|
||||
agent_span,
|
||||
custom_span,
|
||||
|
|
@ -88,7 +92,10 @@ from .tracing import (
|
|||
set_trace_processors,
|
||||
set_tracing_disabled,
|
||||
set_tracing_export_api_key,
|
||||
speech_group_span,
|
||||
speech_span,
|
||||
trace,
|
||||
transcription_span,
|
||||
)
|
||||
from .usage import Usage
|
||||
|
||||
|
|
@ -137,6 +144,8 @@ def enable_verbose_stdout_logging():
|
|||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"ToolsToFinalOutputFunction",
|
||||
"ToolsToFinalOutputResult",
|
||||
"Runner",
|
||||
"Model",
|
||||
"ModelProvider",
|
||||
|
|
@ -190,6 +199,7 @@ __all__ = [
|
|||
"AgentUpdatedStreamEvent",
|
||||
"StreamEvent",
|
||||
"FunctionTool",
|
||||
"FunctionToolResult",
|
||||
"ComputerTool",
|
||||
"FileSearchTool",
|
||||
"Tool",
|
||||
|
|
@ -207,6 +217,9 @@ __all__ = [
|
|||
"handoff_span",
|
||||
"set_trace_processors",
|
||||
"set_tracing_disabled",
|
||||
"speech_group_span",
|
||||
"transcription_span",
|
||||
"speech_span",
|
||||
"trace",
|
||||
"Trace",
|
||||
"TracingProcessor",
|
||||
|
|
@ -219,6 +232,9 @@ __all__ = [
|
|||
"GenerationSpanData",
|
||||
"GuardrailSpanData",
|
||||
"HandoffSpanData",
|
||||
"SpeechGroupSpanData",
|
||||
"SpeechSpanData",
|
||||
"TranscriptionSpanData",
|
||||
"set_default_openai_key",
|
||||
"set_default_openai_client",
|
||||
"set_default_openai_api",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseComputerToolCall,
|
||||
|
|
@ -25,7 +28,7 @@ from openai.types.responses.response_computer_tool_call import (
|
|||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
|
||||
from .agent import Agent
|
||||
from .agent import Agent, ToolsToFinalOutputResult
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .computer import AsyncComputer, Computer
|
||||
from .exceptions import AgentsException, ModelBehaviorError, UserError
|
||||
|
|
@ -45,10 +48,11 @@ from .items import (
|
|||
)
|
||||
from .lifecycle import RunHooks
|
||||
from .logger import logger
|
||||
from .model_settings import ModelSettings
|
||||
from .models.interface import ModelTracing
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .stream_events import RunItemStreamEvent, StreamEvent
|
||||
from .tool import ComputerTool, FunctionTool
|
||||
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
|
||||
from .tracing import (
|
||||
SpanError,
|
||||
Trace,
|
||||
|
|
@ -70,6 +74,8 @@ class QueueCompleteSentinel:
|
|||
|
||||
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
|
||||
|
||||
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolRunHandoff:
|
||||
|
|
@ -199,9 +205,32 @@ class RunImpl:
|
|||
config=run_config,
|
||||
),
|
||||
)
|
||||
new_step_items.extend(function_results)
|
||||
new_step_items.extend([result.run_item for result in function_results])
|
||||
new_step_items.extend(computer_results)
|
||||
|
||||
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
|
||||
if processed_response.functions or processed_response.computer_actions:
|
||||
tools = agent.tools
|
||||
|
||||
if (
|
||||
run_config.model_settings and
|
||||
cls._should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
# update the run_config model settings with a copy
|
||||
new_run_config_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto"
|
||||
)
|
||||
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)
|
||||
|
||||
if cls._should_reset_tool_choice(agent.model_settings, tools):
|
||||
# Create a modified copy instead of modifying the original agent
|
||||
new_model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto"
|
||||
)
|
||||
agent = dataclasses.replace(agent, model_settings=new_model_settings)
|
||||
|
||||
# Second, check if there are any handoffs
|
||||
if run_handoffs := processed_response.handoffs:
|
||||
return await cls.execute_handoffs(
|
||||
|
|
@ -216,6 +245,36 @@ class RunImpl:
|
|||
run_config=run_config,
|
||||
)
|
||||
|
||||
# Third, we'll check if the tool use should result in a final output
|
||||
check_tool_use = await cls._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=function_results,
|
||||
context_wrapper=context_wrapper,
|
||||
config=run_config,
|
||||
)
|
||||
|
||||
if check_tool_use.is_final_output:
|
||||
# If the output type is str, then let's just stringify it
|
||||
if not agent.output_type or agent.output_type is str:
|
||||
check_tool_use.final_output = str(check_tool_use.final_output)
|
||||
|
||||
if check_tool_use.final_output is None:
|
||||
logger.error(
|
||||
"Model returned a final output of None. Not raising an error because we assume"
|
||||
"you know what you're doing."
|
||||
)
|
||||
|
||||
return await cls.execute_final_output(
|
||||
agent=agent,
|
||||
original_input=original_input,
|
||||
new_response=new_response,
|
||||
pre_step_items=pre_step_items,
|
||||
new_step_items=new_step_items,
|
||||
final_output=check_tool_use.final_output,
|
||||
hooks=hooks,
|
||||
context_wrapper=context_wrapper,
|
||||
)
|
||||
|
||||
# Now we can check if the model also produced a final output
|
||||
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
|
||||
|
||||
|
|
@ -262,6 +321,24 @@ class RunImpl:
|
|||
next_step=NextStepRunAgain(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
|
||||
if model_settings is None or model_settings.tool_choice is None:
|
||||
return False
|
||||
|
||||
# for specific tool choices
|
||||
if (
|
||||
isinstance(model_settings.tool_choice, str) and
|
||||
model_settings.tool_choice not in ["auto", "required", "none"]
|
||||
):
|
||||
return True
|
||||
|
||||
# for one tool and required tool choice
|
||||
if model_settings.tool_choice == "required":
|
||||
return len(tools) == 1
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def process_model_response(
|
||||
cls,
|
||||
|
|
@ -355,10 +432,10 @@ class RunImpl:
|
|||
hooks: RunHooks[TContext],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
config: RunConfig,
|
||||
) -> list[RunItem]:
|
||||
) -> list[FunctionToolResult]:
|
||||
async def run_single_tool(
|
||||
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
||||
) -> str:
|
||||
) -> Any:
|
||||
with function_span(func_tool.name) as span_fn:
|
||||
if config.trace_include_sensitive_data:
|
||||
span_fn.span_data.input = tool_call.arguments
|
||||
|
|
@ -404,10 +481,14 @@ class RunImpl:
|
|||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [
|
||||
ToolCallOutputItem(
|
||||
output=str(result),
|
||||
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
||||
agent=agent,
|
||||
FunctionToolResult(
|
||||
tool=tool_run.function_tool,
|
||||
output=result,
|
||||
run_item=ToolCallOutputItem(
|
||||
output=result,
|
||||
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
||||
agent=agent,
|
||||
),
|
||||
)
|
||||
for tool_run, result in zip(tool_runs, results)
|
||||
]
|
||||
|
|
@ -646,6 +727,47 @@ class RunImpl:
|
|||
if event:
|
||||
queue.put_nowait(event)
|
||||
|
||||
@classmethod
|
||||
async def _check_for_final_output_from_tools(
|
||||
cls,
|
||||
*,
|
||||
agent: Agent[TContext],
|
||||
tool_results: list[FunctionToolResult],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
config: RunConfig,
|
||||
) -> ToolsToFinalOutputResult:
|
||||
"""Returns (i, final_output)."""
|
||||
if not tool_results:
|
||||
return _NOT_FINAL_OUTPUT
|
||||
|
||||
if agent.tool_use_behavior == "run_llm_again":
|
||||
return _NOT_FINAL_OUTPUT
|
||||
elif agent.tool_use_behavior == "stop_on_first_tool":
|
||||
return ToolsToFinalOutputResult(
|
||||
is_final_output=True, final_output=tool_results[0].output
|
||||
)
|
||||
elif isinstance(agent.tool_use_behavior, dict):
|
||||
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
|
||||
for tool_result in tool_results:
|
||||
if tool_result.tool.name in names:
|
||||
return ToolsToFinalOutputResult(
|
||||
is_final_output=True, final_output=tool_result.output
|
||||
)
|
||||
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
elif callable(agent.tool_use_behavior):
|
||||
if inspect.iscoroutinefunction(agent.tool_use_behavior):
|
||||
return await cast(
|
||||
Awaitable[ToolsToFinalOutputResult],
|
||||
agent.tool_use_behavior(context_wrapper, tool_results),
|
||||
)
|
||||
else:
|
||||
return cast(
|
||||
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
|
||||
)
|
||||
|
||||
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
||||
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
||||
|
||||
|
||||
class TraceCtxManager:
|
||||
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import dataclasses
|
|||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
||||
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from .guardrail import InputGuardrail, OutputGuardrail
|
||||
from .handoffs import Handoff
|
||||
|
|
@ -13,7 +15,7 @@ from .logger import logger
|
|||
from .model_settings import ModelSettings
|
||||
from .models.interface import Model
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .tool import Tool, function_tool
|
||||
from .tool import FunctionToolResult, Tool, function_tool
|
||||
from .util import _transforms
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
|
|
@ -22,6 +24,33 @@ if TYPE_CHECKING:
|
|||
from .result import RunResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsToFinalOutputResult:
|
||||
is_final_output: bool
|
||||
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
|
||||
output.
|
||||
"""
|
||||
|
||||
final_output: Any | None = None
|
||||
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
|
||||
`output_type` of the agent.
|
||||
"""
|
||||
|
||||
|
||||
ToolsToFinalOutputFunction: TypeAlias = Callable[
|
||||
[RunContextWrapper[TContext], list[FunctionToolResult]],
|
||||
MaybeAwaitable[ToolsToFinalOutputResult],
|
||||
]
|
||||
"""A function that takes a run context and a list of tool results, and returns a
|
||||
`ToolToFinalOutputResult`.
|
||||
"""
|
||||
|
||||
|
||||
class StopAtTools(TypedDict):
|
||||
stop_at_tool_names: list[str]
|
||||
"""A list of tool names, any of which will stop the agent from running further."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent(Generic[TContext]):
|
||||
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
||||
|
|
@ -95,6 +124,25 @@ class Agent(Generic[TContext]):
|
|||
"""A class that receives callbacks on various lifecycle events for this agent.
|
||||
"""
|
||||
|
||||
tool_use_behavior: (
|
||||
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
|
||||
) = "run_llm_again"
|
||||
"""This lets you configure how tool use is handled.
|
||||
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
|
||||
and gets to respond.
|
||||
- "stop_on_first_tool": The output of the first tool call is used as the final output. This
|
||||
means that the LLM does not process the result of the tool call.
|
||||
- A list of tool names: The agent will stop running if any of the tools in the list are called.
|
||||
The final output will be the output of the first matching tool call. The LLM does not
|
||||
process the result of the tool call.
|
||||
- A function: If you pass a function, it will be called with the run context and the list of
|
||||
tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
|
||||
calls result in a final output.
|
||||
|
||||
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
||||
web search, etc are always processed by the LLM.
|
||||
"""
|
||||
|
||||
def clone(self, **kwargs: Any) -> Agent[TContext]:
|
||||
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
|
||||
```
|
||||
|
|
|
|||
|
|
@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
|
|||
raw_item: FunctionCallOutput | ComputerCallOutput
|
||||
"""The raw item from the model."""
|
||||
|
||||
output: str
|
||||
"""The output of the tool call."""
|
||||
output: Any
|
||||
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
|
||||
contains a string representation of the output.
|
||||
"""
|
||||
|
||||
type: Literal["tool_call_output_item"] = "tool_call_output_item"
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ from openai.types.responses import (
|
|||
ResponseUsage,
|
||||
)
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
|
||||
from openai.types.responses.response_usage import OutputTokensDetails
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from .. import _debug
|
||||
from ..agent_output import AgentOutputSchema
|
||||
|
|
@ -420,6 +420,11 @@ class OpenAIChatCompletionsModel(Model):
|
|||
and usage.completion_tokens_details.reasoning_tokens
|
||||
else 0
|
||||
),
|
||||
input_tokens_details=InputTokensDetails(
|
||||
cached_tokens=usage.prompt_tokens_details.cached_tokens
|
||||
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
|
||||
else 0
|
||||
),
|
||||
)
|
||||
if usage
|
||||
else None
|
||||
|
|
@ -752,7 +757,7 @@ class _Converter:
|
|||
elif isinstance(c, dict) and c.get("type") == "input_file":
|
||||
raise UserError(f"File uploads are not supported for chat completions {c}")
|
||||
else:
|
||||
raise UserError(f"Unknonw content: {c}")
|
||||
raise UserError(f"Unknown content: {c}")
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -34,6 +34,19 @@ class OpenAIProvider(ModelProvider):
|
|||
project: str | None = None,
|
||||
use_responses: bool | None = None,
|
||||
) -> None:
|
||||
"""Create a new OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: The API key to use for the OpenAI client. If not provided, we will use the
|
||||
default API key.
|
||||
base_url: The base URL to use for the OpenAI client. If not provided, we will use the
|
||||
default base URL.
|
||||
openai_client: An optional OpenAI client to use. If not provided, we will create a new
|
||||
OpenAI client using the api_key and base_url.
|
||||
organization: The organization to use for the OpenAI client.
|
||||
project: The project to use for the OpenAI client.
|
||||
use_responses: Whether to use the OpenAI responses API.
|
||||
"""
|
||||
if openai_client is not None:
|
||||
assert api_key is None and base_url is None, (
|
||||
"Don't provide api_key or base_url if you provide openai_client"
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class OpenAIResponsesModel(Model):
|
|||
)
|
||||
|
||||
if _debug.DONT_LOG_MODEL_DATA:
|
||||
logger.debug("LLM responsed")
|
||||
logger.debug("LLM responded")
|
||||
else:
|
||||
logger.debug(
|
||||
"LLM resp:\n"
|
||||
|
|
@ -208,7 +208,9 @@ class OpenAIResponsesModel(Model):
|
|||
list_input = ItemHelpers.input_to_new_input_list(input)
|
||||
|
||||
parallel_tool_calls = (
|
||||
True if model_settings.parallel_tool_calls and tools and len(tools) > 0 else NOT_GIVEN
|
||||
True if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
||||
else False if model_settings.parallel_tool_calls is False
|
||||
else NOT_GIVEN
|
||||
)
|
||||
|
||||
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
||||
|
|
|
|||
1
src/agents/py.typed
Normal file
1
src/agents/py.typed
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -15,6 +15,7 @@ from . import _debug
|
|||
from .computer import AsyncComputer, Computer
|
||||
from .exceptions import ModelBehaviorError
|
||||
from .function_schema import DocstringStyle, function_schema
|
||||
from .items import RunItem
|
||||
from .logger import logger
|
||||
from .run_context import RunContextWrapper
|
||||
from .tracing import SpanError
|
||||
|
|
@ -29,6 +30,18 @@ ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParam
|
|||
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionToolResult:
|
||||
tool: FunctionTool
|
||||
"""The tool that was run."""
|
||||
|
||||
output: Any
|
||||
"""The output of the tool."""
|
||||
|
||||
run_item: RunItem
|
||||
"""The run item that was produced as a result of the tool call."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
|
||||
|
|
@ -44,15 +57,15 @@ class FunctionTool:
|
|||
params_json_schema: dict[str, Any]
|
||||
"""The JSON schema for the tool's parameters."""
|
||||
|
||||
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
|
||||
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
|
||||
"""A function that invokes the tool with the given context and parameters. The params passed
|
||||
are:
|
||||
1. The tool run context.
|
||||
2. The arguments from the LLM, as a JSON string.
|
||||
|
||||
You must return a string representation of the tool output. In case of errors, you can either
|
||||
raise an Exception (which will cause the run to fail) or return a string error message (which
|
||||
will be sent back to the LLM).
|
||||
You must return a string representation of the tool output, or something we can call `str()` on.
|
||||
In case of errors, you can either raise an Exception (which will cause the run to fail) or
|
||||
return a string error message (which will be sent back to the LLM).
|
||||
"""
|
||||
|
||||
strict_json_schema: bool = True
|
||||
|
|
@ -190,8 +203,11 @@ def function_tool(
|
|||
failure_error_function: If provided, use this function to generate an error message when
|
||||
the tool call fails. The error message is sent to the LLM. If you pass None, then no
|
||||
error message will be sent and instead an Exception will be raised.
|
||||
strict_mode: If False, parameters with default values become optional in the
|
||||
function schema.
|
||||
strict_mode: Whether to enable strict mode for the tool's JSON schema. We *strongly*
|
||||
recommend setting this to True, as it increases the likelihood of correct JSON input.
|
||||
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
|
||||
value, it will be optional, additional properties are allowed, etc. See here for more:
|
||||
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
|
||||
"""
|
||||
|
||||
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
||||
|
|
@ -204,7 +220,7 @@ def function_tool(
|
|||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||
try:
|
||||
json_data: dict[str, Any] = json.loads(input) if input else {}
|
||||
except Exception as e:
|
||||
|
|
@ -251,9 +267,9 @@ def function_tool(
|
|||
else:
|
||||
logger.debug(f"Tool {schema.name} returned {result}")
|
||||
|
||||
return str(result)
|
||||
return result
|
||||
|
||||
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
|
||||
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||
try:
|
||||
return await _on_invoke_tool_impl(ctx, input)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ from .create import (
|
|||
guardrail_span,
|
||||
handoff_span,
|
||||
response_span,
|
||||
speech_group_span,
|
||||
speech_span,
|
||||
trace,
|
||||
transcription_span,
|
||||
)
|
||||
from .processor_interface import TracingProcessor
|
||||
from .processors import default_exporter, default_processor
|
||||
|
|
@ -24,6 +27,9 @@ from .span_data import (
|
|||
HandoffSpanData,
|
||||
ResponseSpanData,
|
||||
SpanData,
|
||||
SpeechGroupSpanData,
|
||||
SpeechSpanData,
|
||||
TranscriptionSpanData,
|
||||
)
|
||||
from .spans import Span, SpanError
|
||||
from .traces import Trace
|
||||
|
|
@ -54,9 +60,15 @@ __all__ = [
|
|||
"GuardrailSpanData",
|
||||
"HandoffSpanData",
|
||||
"ResponseSpanData",
|
||||
"SpeechGroupSpanData",
|
||||
"SpeechSpanData",
|
||||
"TranscriptionSpanData",
|
||||
"TracingProcessor",
|
||||
"gen_trace_id",
|
||||
"gen_span_id",
|
||||
"speech_group_span",
|
||||
"speech_span",
|
||||
"transcription_span",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ from .span_data import (
|
|||
GuardrailSpanData,
|
||||
HandoffSpanData,
|
||||
ResponseSpanData,
|
||||
SpeechGroupSpanData,
|
||||
SpeechSpanData,
|
||||
TranscriptionSpanData,
|
||||
)
|
||||
from .spans import Span
|
||||
from .traces import Trace
|
||||
|
|
@ -181,7 +184,11 @@ def generation_span(
|
|||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
span_data=GenerationSpanData(
|
||||
input=input, output=output, model=model, model_config=model_config, usage=usage
|
||||
input=input,
|
||||
output=output,
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
usage=usage,
|
||||
),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -304,3 +311,116 @@ def guardrail_span(
|
|||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
||||
|
||||
def transcription_span(
|
||||
model: str | None = None,
|
||||
input: str | None = None,
|
||||
input_format: str | None = "pcm",
|
||||
output: str | None = None,
|
||||
model_config: Mapping[str, Any] | None = None,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[TranscriptionSpanData]:
|
||||
"""Create a new transcription span. The span will not be started automatically, you should
|
||||
either do `with transcription_span() ...` or call `span.start()` + `span.finish()` manually.
|
||||
|
||||
Args:
|
||||
model: The name of the model used for the speech-to-text.
|
||||
input: The audio input of the speech-to-text transcription, as a base64 encoded string of
|
||||
audio bytes.
|
||||
input_format: The format of the audio input (defaults to "pcm").
|
||||
output: The output of the speech-to-text transcription.
|
||||
model_config: The model configuration (hyperparameters) used.
|
||||
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
parent: The parent span or trace. If not provided, we will automatically use the current
|
||||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
|
||||
Returns:
|
||||
The newly created speech-to-text span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
span_data=TranscriptionSpanData(
|
||||
input=input,
|
||||
input_format=input_format,
|
||||
output=output,
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
||||
|
||||
def speech_span(
|
||||
model: str | None = None,
|
||||
input: str | None = None,
|
||||
output: str | None = None,
|
||||
output_format: str | None = "pcm",
|
||||
model_config: Mapping[str, Any] | None = None,
|
||||
first_content_at: str | None = None,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[SpeechSpanData]:
|
||||
"""Create a new speech span. The span will not be started automatically, you should either do
|
||||
`with speech_span() ...` or call `span.start()` + `span.finish()` manually.
|
||||
|
||||
Args:
|
||||
model: The name of the model used for the text-to-speech.
|
||||
input: The text input of the text-to-speech.
|
||||
output: The audio output of the text-to-speech as base64 encoded string of PCM audio bytes.
|
||||
output_format: The format of the audio output (defaults to "pcm").
|
||||
model_config: The model configuration (hyperparameters) used.
|
||||
first_content_at: The time of the first byte of the audio output.
|
||||
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
parent: The parent span or trace. If not provided, we will automatically use the current
|
||||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
span_data=SpeechSpanData(
|
||||
model=model,
|
||||
input=input,
|
||||
output=output,
|
||||
output_format=output_format,
|
||||
model_config=model_config,
|
||||
first_content_at=first_content_at,
|
||||
),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
||||
|
||||
def speech_group_span(
|
||||
input: str | None = None,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[SpeechGroupSpanData]:
|
||||
"""Create a new speech group span. The span will not be started automatically, you should
|
||||
either do `with speech_group_span() ...` or call `span.start()` + `span.finish()` manually.
|
||||
|
||||
Args:
|
||||
input: The input text used for the speech request.
|
||||
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
parent: The parent span or trace. If not provided, we will automatically use the current
|
||||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
span_data=SpeechGroupSpanData(input=input),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import queue
|
|||
import random
|
||||
import threading
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
|
@ -50,9 +51,9 @@ class BackendSpanExporter(TracingExporter):
|
|||
base_delay: Base delay (in seconds) for the first backoff.
|
||||
max_delay: Maximum delay (in seconds) for backoff growth.
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
self.organization = organization or os.environ.get("OPENAI_ORG_ID")
|
||||
self.project = project or os.environ.get("OPENAI_PROJECT_ID")
|
||||
self._api_key = api_key
|
||||
self._organization = organization
|
||||
self._project = project
|
||||
self.endpoint = endpoint
|
||||
self.max_retries = max_retries
|
||||
self.base_delay = base_delay
|
||||
|
|
@ -68,8 +69,22 @@ class BackendSpanExporter(TracingExporter):
|
|||
api_key: The OpenAI API key to use. This is the same key used by the OpenAI Python
|
||||
client.
|
||||
"""
|
||||
# We're specifically setting the underlying cached property as well
|
||||
self._api_key = api_key
|
||||
self.api_key = api_key
|
||||
|
||||
@cached_property
|
||||
def api_key(self):
|
||||
return self._api_key or os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
@cached_property
|
||||
def organization(self):
|
||||
return self._organization or os.environ.get("OPENAI_ORG_ID")
|
||||
|
||||
@cached_property
|
||||
def project(self):
|
||||
return self._project or os.environ.get("OPENAI_PROJECT_ID")
|
||||
|
||||
def export(self, items: list[Trace | Span[Any]]) -> None:
|
||||
if not items:
|
||||
return
|
||||
|
|
@ -102,18 +117,22 @@ class BackendSpanExporter(TracingExporter):
|
|||
|
||||
# If the response is a client error (4xx), we wont retry
|
||||
if 400 <= response.status_code < 500:
|
||||
logger.error(f"Tracing client error {response.status_code}: {response.text}")
|
||||
logger.error(
|
||||
f"[non-fatal] Tracing client error {response.status_code}: {response.text}"
|
||||
)
|
||||
return
|
||||
|
||||
# For 5xx or other unexpected codes, treat it as transient and retry
|
||||
logger.warning(f"Server error {response.status_code}, retrying.")
|
||||
logger.warning(
|
||||
f"[non-fatal] Tracing: server error {response.status_code}, retrying."
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
# Network or other I/O error, we'll retry
|
||||
logger.warning(f"Request failed: {exc}")
|
||||
logger.warning(f"[non-fatal] Tracing: request failed: {exc}")
|
||||
|
||||
# If we reach here, we need to retry or give up
|
||||
if attempt >= self.max_retries:
|
||||
logger.error("Max retries reached, giving up on this batch.")
|
||||
logger.error("[non-fatal] Tracing: max retries reached, giving up on this batch.")
|
||||
return
|
||||
|
||||
# Exponential backoff + jitter
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class AgentSpanData(SpanData):
|
|||
class FunctionSpanData(SpanData):
|
||||
__slots__ = ("name", "input", "output")
|
||||
|
||||
def __init__(self, name: str, input: str | None, output: str | None):
|
||||
def __init__(self, name: str, input: str | None, output: Any | None):
|
||||
self.name = name
|
||||
self.input = input
|
||||
self.output = output
|
||||
|
|
@ -65,7 +65,7 @@ class FunctionSpanData(SpanData):
|
|||
"type": self.type,
|
||||
"name": self.name,
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"output": str(self.output) if self.output else None,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -186,3 +186,99 @@ class GuardrailSpanData(SpanData):
|
|||
"name": self.name,
|
||||
"triggered": self.triggered,
|
||||
}
|
||||
|
||||
|
||||
class TranscriptionSpanData(SpanData):
|
||||
__slots__ = (
|
||||
"input",
|
||||
"output",
|
||||
"model",
|
||||
"model_config",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: str | None = None,
|
||||
input_format: str | None = "pcm",
|
||||
output: str | None = None,
|
||||
model: str | None = None,
|
||||
model_config: Mapping[str, Any] | None = None,
|
||||
):
|
||||
self.input = input
|
||||
self.input_format = input_format
|
||||
self.output = output
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "transcription"
|
||||
|
||||
def export(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"input": {
|
||||
"data": self.input or "",
|
||||
"format": self.input_format,
|
||||
},
|
||||
"output": self.output,
|
||||
"model": self.model,
|
||||
"model_config": self.model_config,
|
||||
}
|
||||
|
||||
|
||||
class SpeechSpanData(SpanData):
|
||||
__slots__ = ("input", "output", "model", "model_config", "first_byte_at")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: str | None = None,
|
||||
output: str | None = None,
|
||||
output_format: str | None = "pcm",
|
||||
model: str | None = None,
|
||||
model_config: Mapping[str, Any] | None = None,
|
||||
first_content_at: str | None = None,
|
||||
):
|
||||
self.input = input
|
||||
self.output = output
|
||||
self.output_format = output_format
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
self.first_content_at = first_content_at
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "speech"
|
||||
|
||||
def export(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"input": self.input,
|
||||
"output": {
|
||||
"data": self.output or "",
|
||||
"format": self.output_format,
|
||||
},
|
||||
"model": self.model,
|
||||
"model_config": self.model_config,
|
||||
"first_content_at": self.first_content_at,
|
||||
}
|
||||
|
||||
|
||||
class SpeechGroupSpanData(SpanData):
|
||||
__slots__ = "input"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: str | None = None,
|
||||
):
|
||||
self.input = input
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "speech-group"
|
||||
|
||||
def export(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"input": self.input,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,3 +15,8 @@ def gen_trace_id() -> str:
|
|||
def gen_span_id() -> str:
|
||||
"""Generates a new span ID."""
|
||||
return f"span_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
|
||||
def gen_group_id() -> str:
|
||||
"""Generates a new group ID."""
|
||||
return f"group_{uuid.uuid4().hex[:24]}"
|
||||
|
|
|
|||
51
src/agents/voice/__init__.py
Normal file
51
src/agents/voice/__init__.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from .events import VoiceStreamEvent, VoiceStreamEventAudio, VoiceStreamEventLifecycle
|
||||
from .exceptions import STTWebsocketConnectionError
|
||||
from .input import AudioInput, StreamedAudioInput
|
||||
from .model import (
|
||||
StreamedTranscriptionSession,
|
||||
STTModel,
|
||||
STTModelSettings,
|
||||
TTSModel,
|
||||
TTSModelSettings,
|
||||
VoiceModelProvider,
|
||||
)
|
||||
from .models.openai_model_provider import OpenAIVoiceModelProvider
|
||||
from .models.openai_stt import OpenAISTTModel, OpenAISTTTranscriptionSession
|
||||
from .models.openai_tts import OpenAITTSModel
|
||||
from .pipeline import VoicePipeline
|
||||
from .pipeline_config import VoicePipelineConfig
|
||||
from .result import StreamedAudioResult
|
||||
from .utils import get_sentence_based_splitter
|
||||
from .workflow import (
|
||||
SingleAgentVoiceWorkflow,
|
||||
SingleAgentWorkflowCallbacks,
|
||||
VoiceWorkflowBase,
|
||||
VoiceWorkflowHelper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AudioInput",
|
||||
"StreamedAudioInput",
|
||||
"STTModel",
|
||||
"STTModelSettings",
|
||||
"TTSModel",
|
||||
"TTSModelSettings",
|
||||
"VoiceModelProvider",
|
||||
"StreamedAudioResult",
|
||||
"SingleAgentVoiceWorkflow",
|
||||
"OpenAIVoiceModelProvider",
|
||||
"OpenAISTTModel",
|
||||
"OpenAITTSModel",
|
||||
"VoiceStreamEventAudio",
|
||||
"VoiceStreamEventLifecycle",
|
||||
"VoiceStreamEvent",
|
||||
"VoicePipeline",
|
||||
"VoicePipelineConfig",
|
||||
"get_sentence_based_splitter",
|
||||
"VoiceWorkflowHelper",
|
||||
"VoiceWorkflowBase",
|
||||
"SingleAgentWorkflowCallbacks",
|
||||
"StreamedTranscriptionSession",
|
||||
"OpenAISTTTranscriptionSession",
|
||||
"STTWebsocketConnectionError",
|
||||
]
|
||||
47
src/agents/voice/events.py
Normal file
47
src/agents/voice/events.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Union
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .imports import np, npt
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceStreamEventAudio:
|
||||
"""Streaming event from the VoicePipeline"""
|
||||
|
||||
data: npt.NDArray[np.int16 | np.float32] | None
|
||||
"""The audio data."""
|
||||
|
||||
type: Literal["voice_stream_event_audio"] = "voice_stream_event_audio"
|
||||
"""The type of event."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceStreamEventLifecycle:
|
||||
"""Streaming event from the VoicePipeline"""
|
||||
|
||||
event: Literal["turn_started", "turn_ended", "session_ended"]
|
||||
"""The event that occurred."""
|
||||
|
||||
type: Literal["voice_stream_event_lifecycle"] = "voice_stream_event_lifecycle"
|
||||
"""The type of event."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceStreamEventError:
|
||||
"""Streaming event from the VoicePipeline"""
|
||||
|
||||
error: Exception
|
||||
"""The error that occurred."""
|
||||
|
||||
type: Literal["voice_stream_event_error"] = "voice_stream_event_error"
|
||||
"""The type of event."""
|
||||
|
||||
|
||||
VoiceStreamEvent: TypeAlias = Union[
|
||||
VoiceStreamEventAudio, VoiceStreamEventLifecycle, VoiceStreamEventError
|
||||
]
|
||||
"""An event from the `VoicePipeline`, streamed via `StreamedAudioResult.stream()`."""
|
||||
8
src/agents/voice/exceptions.py
Normal file
8
src/agents/voice/exceptions.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from ..exceptions import AgentsException
|
||||
|
||||
|
||||
class STTWebsocketConnectionError(AgentsException):
|
||||
"""Exception raised when the STT websocket connection fails."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
11
src/agents/voice/imports.py
Normal file
11
src/agents/voice/imports.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
try:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import websockets
|
||||
except ImportError as _e:
|
||||
raise ImportError(
|
||||
"`numpy` + `websockets` are required to use voice. You can install them via the optional "
|
||||
"dependency group: `pip install 'openai-agents[voice]'`."
|
||||
) from _e
|
||||
|
||||
__all__ = ["np", "npt", "websockets"]
|
||||
88
src/agents/voice/input.py
Normal file
88
src/agents/voice/input.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..exceptions import UserError
|
||||
from .imports import np, npt
|
||||
|
||||
DEFAULT_SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
def _buffer_to_audio_file(
|
||||
buffer: npt.NDArray[np.int16 | np.float32],
|
||||
frame_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
sample_width: int = 2,
|
||||
channels: int = 1,
|
||||
) -> tuple[str, io.BytesIO, str]:
|
||||
if buffer.dtype == np.float32:
|
||||
# convert to int16
|
||||
buffer = np.clip(buffer, -1.0, 1.0)
|
||||
buffer = (buffer * 32767).astype(np.int16)
|
||||
elif buffer.dtype != np.int16:
|
||||
raise UserError("Buffer must be a numpy array of int16 or float32")
|
||||
|
||||
audio_file = io.BytesIO()
|
||||
with wave.open(audio_file, "w") as wav_file:
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(sample_width)
|
||||
wav_file.setframerate(frame_rate)
|
||||
wav_file.writeframes(buffer.tobytes())
|
||||
audio_file.seek(0)
|
||||
|
||||
# (filename, bytes, content_type)
|
||||
return ("audio.wav", audio_file, "audio/wav")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioInput:
|
||||
"""Static audio to be used as input for the VoicePipeline."""
|
||||
|
||||
buffer: npt.NDArray[np.int16 | np.float32]
|
||||
"""
|
||||
A buffer containing the audio data for the agent. Must be a numpy array of int16 or float32.
|
||||
"""
|
||||
|
||||
frame_rate: int = DEFAULT_SAMPLE_RATE
|
||||
"""The sample rate of the audio data. Defaults to 24000."""
|
||||
|
||||
sample_width: int = 2
|
||||
"""The sample width of the audio data. Defaults to 2."""
|
||||
|
||||
channels: int = 1
|
||||
"""The number of channels in the audio data. Defaults to 1."""
|
||||
|
||||
def to_audio_file(self) -> tuple[str, io.BytesIO, str]:
|
||||
"""Returns a tuple of (filename, bytes, content_type)"""
|
||||
return _buffer_to_audio_file(self.buffer, self.frame_rate, self.sample_width, self.channels)
|
||||
|
||||
def to_base64(self) -> str:
|
||||
"""Returns the audio data as a base64 encoded string."""
|
||||
if self.buffer.dtype == np.float32:
|
||||
# convert to int16
|
||||
self.buffer = np.clip(self.buffer, -1.0, 1.0)
|
||||
self.buffer = (self.buffer * 32767).astype(np.int16)
|
||||
elif self.buffer.dtype != np.int16:
|
||||
raise UserError("Buffer must be a numpy array of int16 or float32")
|
||||
|
||||
return base64.b64encode(self.buffer.tobytes()).decode("utf-8")
|
||||
|
||||
|
||||
class StreamedAudioInput:
|
||||
"""Audio input represented as a stream of audio data. You can pass this to the `VoicePipeline`
|
||||
and then push audio data into the queue using the `add_audio` method.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = asyncio.Queue()
|
||||
|
||||
async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32]):
|
||||
"""Adds more audio data to the stream.
|
||||
|
||||
Args:
|
||||
audio: The audio data to add. Must be a numpy array of int16 or float32.
|
||||
"""
|
||||
await self.queue.put(audio)
|
||||
193
src/agents/voice/model.py
Normal file
193
src/agents/voice/model.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from .imports import np, npt
|
||||
from .input import AudioInput, StreamedAudioInput
|
||||
from .utils import get_sentence_based_splitter
|
||||
|
||||
DEFAULT_TTS_INSTRUCTIONS = (
|
||||
"You will receive partial sentences. Do not complete the sentence, just read out the text."
|
||||
)
|
||||
DEFAULT_TTS_BUFFER_SIZE = 120
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSModelSettings:
|
||||
"""Settings for a TTS model."""
|
||||
|
||||
voice: (
|
||||
Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] | None
|
||||
) = None
|
||||
"""
|
||||
The voice to use for the TTS model. If not provided, the default voice for the respective model
|
||||
will be used.
|
||||
"""
|
||||
|
||||
buffer_size: int = 120
|
||||
"""The minimal size of the chunks of audio data that are being streamed out."""
|
||||
|
||||
dtype: npt.DTypeLike = np.int16
|
||||
"""The data type for the audio data to be returned in."""
|
||||
|
||||
transform_data: (
|
||||
Callable[[npt.NDArray[np.int16 | np.float32]], npt.NDArray[np.int16 | np.float32]] | None
|
||||
) = None
|
||||
"""
|
||||
A function to transform the data from the TTS model. This is useful if you want the resulting
|
||||
audio stream to have the data in a specific shape already.
|
||||
"""
|
||||
|
||||
instructions: str = (
|
||||
"You will receive partial sentences. Do not complete the sentence just read out the text."
|
||||
)
|
||||
"""
|
||||
The instructions to use for the TTS model. This is useful if you want to control the tone of the
|
||||
audio output.
|
||||
"""
|
||||
|
||||
text_splitter: Callable[[str], tuple[str, str]] = get_sentence_based_splitter()
|
||||
"""
|
||||
A function to split the text into chunks. This is useful if you want to split the text into
|
||||
chunks before sending it to the TTS model rather than waiting for the whole text to be
|
||||
processed.
|
||||
"""
|
||||
|
||||
speed: float | None = None
|
||||
"""The speed with which the TTS model will read the text. Between 0.25 and 4.0."""
|
||||
|
||||
|
||||
class TTSModel(abc.ABC):
|
||||
"""A text-to-speech model that can convert text into audio output."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_name(self) -> str:
|
||||
"""The name of the TTS model."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
|
||||
"""Given a text string, produces a stream of audio bytes, in PCM format.
|
||||
|
||||
Args:
|
||||
text: The text to convert to audio.
|
||||
|
||||
Returns:
|
||||
An async iterator of audio bytes, in PCM format.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StreamedTranscriptionSession(abc.ABC):
|
||||
"""A streamed transcription of audio input."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def transcribe_turns(self) -> AsyncIterator[str]:
|
||||
"""Yields a stream of text transcriptions. Each transcription is a turn in the conversation.
|
||||
|
||||
This method is expected to return only after `close()` is called.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Closes the session."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTModelSettings:
|
||||
"""Settings for a speech-to-text model."""
|
||||
|
||||
prompt: str | None = None
|
||||
"""Instructions for the model to follow."""
|
||||
|
||||
language: str | None = None
|
||||
"""The language of the audio input."""
|
||||
|
||||
temperature: float | None = None
|
||||
"""The temperature of the model."""
|
||||
|
||||
turn_detection: dict[str, Any] | None = None
|
||||
"""The turn detection settings for the model when using streamed audio input."""
|
||||
|
||||
|
||||
class STTModel(abc.ABC):
|
||||
"""A speech-to-text model that can convert audio input into text."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_name(self) -> str:
|
||||
"""The name of the STT model."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def transcribe(
|
||||
self,
|
||||
input: AudioInput,
|
||||
settings: STTModelSettings,
|
||||
trace_include_sensitive_data: bool,
|
||||
trace_include_sensitive_audio_data: bool,
|
||||
) -> str:
|
||||
"""Given an audio input, produces a text transcription.
|
||||
|
||||
Args:
|
||||
input: The audio input to transcribe.
|
||||
settings: The settings to use for the transcription.
|
||||
trace_include_sensitive_data: Whether to include sensitive data in traces.
|
||||
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
|
||||
|
||||
Returns:
|
||||
The text transcription of the audio input.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_session(
|
||||
self,
|
||||
input: StreamedAudioInput,
|
||||
settings: STTModelSettings,
|
||||
trace_include_sensitive_data: bool,
|
||||
trace_include_sensitive_audio_data: bool,
|
||||
) -> StreamedTranscriptionSession:
|
||||
"""Creates a new transcription session, which you can push audio to, and receive a stream
|
||||
of text transcriptions.
|
||||
|
||||
Args:
|
||||
input: The audio input to transcribe.
|
||||
settings: The settings to use for the transcription.
|
||||
trace_include_sensitive_data: Whether to include sensitive data in traces.
|
||||
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
|
||||
|
||||
Returns:
|
||||
A new transcription session.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class VoiceModelProvider(abc.ABC):
|
||||
"""The base interface for a voice model provider.
|
||||
|
||||
A model provider is responsible for creating speech-to-text and text-to-speech models, given a
|
||||
name.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stt_model(self, model_name: str | None) -> STTModel:
|
||||
"""Get a speech-to-text model by name.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to get.
|
||||
|
||||
Returns:
|
||||
The speech-to-text model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tts_model(self, model_name: str | None) -> TTSModel:
|
||||
"""Get a text-to-speech model by name."""
|
||||
0
src/agents/voice/models/__init__.py
Normal file
0
src/agents/voice/models/__init__.py
Normal file
97
src/agents/voice/models/openai_model_provider.py
Normal file
97
src/agents/voice/models/openai_model_provider.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
|
||||
|
||||
from ...models import _openai_shared
|
||||
from ..model import STTModel, TTSModel, VoiceModelProvider
|
||||
from .openai_stt import OpenAISTTModel
|
||||
from .openai_tts import OpenAITTSModel
|
||||
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
# If we create a new httpx client for each request, that would mean no sharing of connection pools,
|
||||
# which would mean worse latency and resource usage. So, we share the client across requests.
|
||||
def shared_http_client() -> httpx.AsyncClient:
|
||||
global _http_client
|
||||
if _http_client is None:
|
||||
_http_client = DefaultAsyncHttpxClient()
|
||||
return _http_client
|
||||
|
||||
|
||||
DEFAULT_STT_MODEL = "gpt-4o-transcribe"
|
||||
DEFAULT_TTS_MODEL = "gpt-4o-mini-tts"
|
||||
|
||||
|
||||
class OpenAIVoiceModelProvider(VoiceModelProvider):
|
||||
"""A voice model provider that uses OpenAI models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
openai_client: AsyncOpenAI | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
) -> None:
|
||||
"""Create a new OpenAI voice model provider.
|
||||
|
||||
Args:
|
||||
api_key: The API key to use for the OpenAI client. If not provided, we will use the
|
||||
default API key.
|
||||
base_url: The base URL to use for the OpenAI client. If not provided, we will use the
|
||||
default base URL.
|
||||
openai_client: An optional OpenAI client to use. If not provided, we will create a new
|
||||
OpenAI client using the api_key and base_url.
|
||||
organization: The organization to use for the OpenAI client.
|
||||
project: The project to use for the OpenAI client.
|
||||
"""
|
||||
if openai_client is not None:
|
||||
assert api_key is None and base_url is None, (
|
||||
"Don't provide api_key or base_url if you provide openai_client"
|
||||
)
|
||||
self._client: AsyncOpenAI | None = openai_client
|
||||
else:
|
||||
self._client = None
|
||||
self._stored_api_key = api_key
|
||||
self._stored_base_url = base_url
|
||||
self._stored_organization = organization
|
||||
self._stored_project = project
|
||||
|
||||
# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
|
||||
# AsyncOpenAI() raises an error if you don't have an API key set.
|
||||
def _get_client(self) -> AsyncOpenAI:
|
||||
if self._client is None:
|
||||
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
|
||||
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
|
||||
base_url=self._stored_base_url,
|
||||
organization=self._stored_organization,
|
||||
project=self._stored_project,
|
||||
http_client=shared_http_client(),
|
||||
)
|
||||
|
||||
return self._client
|
||||
|
||||
def get_stt_model(self, model_name: str | None) -> STTModel:
|
||||
"""Get a speech-to-text model by name.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to get.
|
||||
|
||||
Returns:
|
||||
The speech-to-text model.
|
||||
"""
|
||||
return OpenAISTTModel(model_name or DEFAULT_STT_MODEL, self._get_client())
|
||||
|
||||
def get_tts_model(self, model_name: str | None) -> TTSModel:
|
||||
"""Get a text-to-speech model by name.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to get.
|
||||
|
||||
Returns:
|
||||
The text-to-speech model.
|
||||
"""
|
||||
return OpenAITTSModel(model_name or DEFAULT_TTS_MODEL, self._get_client())
|
||||
457
src/agents/voice/models/openai_stt.py
Normal file
457
src/agents/voice/models/openai_stt.py
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from agents.exceptions import AgentsException
|
||||
|
||||
from ... import _debug
|
||||
from ...logger import logger
|
||||
from ...tracing import Span, SpanError, TranscriptionSpanData, transcription_span
|
||||
from ..exceptions import STTWebsocketConnectionError
|
||||
from ..imports import np, npt, websockets
|
||||
from ..input import AudioInput, StreamedAudioInput
|
||||
from ..model import StreamedTranscriptionSession, STTModel, STTModelSettings
|
||||
|
||||
EVENT_INACTIVITY_TIMEOUT = 1000 # Timeout for inactivity in event processing
|
||||
SESSION_CREATION_TIMEOUT = 10 # Timeout waiting for session.created event
|
||||
SESSION_UPDATE_TIMEOUT = 10 # Timeout waiting for session.updated event
|
||||
|
||||
DEFAULT_TURN_DETECTION = {"type": "semantic_vad"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorSentinel:
|
||||
error: Exception
|
||||
|
||||
|
||||
class SessionCompleteSentinel:
|
||||
pass
|
||||
|
||||
|
||||
class WebsocketDoneSentinel:
|
||||
pass
|
||||
|
||||
|
||||
def _audio_to_base64(audio_data: list[npt.NDArray[np.int16 | np.float32]]) -> str:
|
||||
concatenated_audio = np.concatenate(audio_data)
|
||||
if concatenated_audio.dtype == np.float32:
|
||||
# convert to int16
|
||||
concatenated_audio = np.clip(concatenated_audio, -1.0, 1.0)
|
||||
concatenated_audio = (concatenated_audio * 32767).astype(np.int16)
|
||||
audio_bytes = concatenated_audio.tobytes()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
async def _wait_for_event(
|
||||
event_queue: asyncio.Queue[dict[str, Any]], expected_types: list[str], timeout: float
|
||||
):
|
||||
"""
|
||||
Wait for an event from event_queue whose type is in expected_types within the specified timeout.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
raise TimeoutError(f"Timeout waiting for event(s): {expected_types}")
|
||||
evt = await asyncio.wait_for(event_queue.get(), timeout=remaining)
|
||||
evt_type = evt.get("type", "")
|
||||
if evt_type in expected_types:
|
||||
return evt
|
||||
elif evt_type == "error":
|
||||
raise Exception(f"Error event: {evt.get('error')}")
|
||||
|
||||
|
||||
class OpenAISTTTranscriptionSession(StreamedTranscriptionSession):
|
||||
"""A transcription session for OpenAI's STT model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: StreamedAudioInput,
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
settings: STTModelSettings,
|
||||
trace_include_sensitive_data: bool,
|
||||
trace_include_sensitive_audio_data: bool,
|
||||
):
|
||||
self.connected: bool = False
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._settings = settings
|
||||
self._turn_detection = settings.turn_detection or DEFAULT_TURN_DETECTION
|
||||
self._trace_include_sensitive_data = trace_include_sensitive_data
|
||||
self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data
|
||||
|
||||
self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
|
||||
self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
self._websocket: websockets.ClientConnection | None = None
|
||||
self._event_queue: asyncio.Queue[dict[str, Any] | WebsocketDoneSentinel] = asyncio.Queue()
|
||||
self._state_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
self._turn_audio_buffer: list[npt.NDArray[np.int16 | np.float32]] = []
|
||||
self._tracing_span: Span[TranscriptionSpanData] | None = None
|
||||
|
||||
# tasks
|
||||
self._listener_task: asyncio.Task[Any] | None = None
|
||||
self._process_events_task: asyncio.Task[Any] | None = None
|
||||
self._stream_audio_task: asyncio.Task[Any] | None = None
|
||||
self._connection_task: asyncio.Task[Any] | None = None
|
||||
self._stored_exception: Exception | None = None
|
||||
|
||||
def _start_turn(self) -> None:
|
||||
self._tracing_span = transcription_span(
|
||||
model=self._model,
|
||||
model_config={
|
||||
"temperature": self._settings.temperature,
|
||||
"language": self._settings.language,
|
||||
"prompt": self._settings.prompt,
|
||||
"turn_detection": self._turn_detection,
|
||||
},
|
||||
)
|
||||
self._tracing_span.start()
|
||||
|
||||
def _end_turn(self, _transcript: str) -> None:
|
||||
if len(_transcript) < 1:
|
||||
return
|
||||
|
||||
if self._tracing_span:
|
||||
if self._trace_include_sensitive_audio_data:
|
||||
self._tracing_span.span_data.input = _audio_to_base64(self._turn_audio_buffer)
|
||||
|
||||
self._tracing_span.span_data.input_format = "pcm"
|
||||
|
||||
if self._trace_include_sensitive_data:
|
||||
self._tracing_span.span_data.output = _transcript
|
||||
|
||||
self._tracing_span.finish()
|
||||
self._turn_audio_buffer = []
|
||||
self._tracing_span = None
|
||||
|
||||
async def _event_listener(self) -> None:
|
||||
assert self._websocket is not None, "Websocket not initialized"
|
||||
|
||||
async for message in self._websocket:
|
||||
try:
|
||||
event = json.loads(message)
|
||||
|
||||
if event.get("type") == "error":
|
||||
raise STTWebsocketConnectionError(f"Error event: {event.get('error')}")
|
||||
|
||||
if event.get("type") in [
|
||||
"session.updated",
|
||||
"transcription_session.updated",
|
||||
"session.created",
|
||||
"transcription_session.created",
|
||||
]:
|
||||
await self._state_queue.put(event)
|
||||
|
||||
await self._event_queue.put(event)
|
||||
except Exception as e:
|
||||
await self._output_queue.put(ErrorSentinel(e))
|
||||
raise STTWebsocketConnectionError("Error parsing events") from e
|
||||
await self._event_queue.put(WebsocketDoneSentinel())
|
||||
|
||||
async def _configure_session(self) -> None:
|
||||
assert self._websocket is not None, "Websocket not initialized"
|
||||
await self._websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16",
|
||||
"input_audio_transcription": {"model": self._model},
|
||||
"turn_detection": self._turn_detection,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def _setup_connection(self, ws: websockets.ClientConnection) -> None:
|
||||
self._websocket = ws
|
||||
self._listener_task = asyncio.create_task(self._event_listener())
|
||||
|
||||
try:
|
||||
event = await _wait_for_event(
|
||||
self._state_queue,
|
||||
["session.created", "transcription_session.created"],
|
||||
SESSION_CREATION_TIMEOUT,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
wrapped_err = STTWebsocketConnectionError(
|
||||
"Timeout waiting for transcription_session.created event"
|
||||
)
|
||||
await self._output_queue.put(ErrorSentinel(wrapped_err))
|
||||
raise wrapped_err from e
|
||||
except Exception as e:
|
||||
await self._output_queue.put(ErrorSentinel(e))
|
||||
raise e
|
||||
|
||||
await self._configure_session()
|
||||
|
||||
try:
|
||||
event = await _wait_for_event(
|
||||
self._state_queue,
|
||||
["session.updated", "transcription_session.updated"],
|
||||
SESSION_UPDATE_TIMEOUT,
|
||||
)
|
||||
if _debug.DONT_LOG_MODEL_DATA:
|
||||
logger.debug("Session updated")
|
||||
else:
|
||||
logger.debug(f"Session updated: {event}")
|
||||
except TimeoutError as e:
|
||||
wrapped_err = STTWebsocketConnectionError(
|
||||
"Timeout waiting for transcription_session.updated event"
|
||||
)
|
||||
await self._output_queue.put(ErrorSentinel(wrapped_err))
|
||||
raise wrapped_err from e
|
||||
except Exception as e:
|
||||
await self._output_queue.put(ErrorSentinel(e))
|
||||
raise
|
||||
|
||||
async def _handle_events(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(
|
||||
self._event_queue.get(), timeout=EVENT_INACTIVITY_TIMEOUT
|
||||
)
|
||||
if isinstance(event, WebsocketDoneSentinel):
|
||||
# processed all events and websocket is done
|
||||
break
|
||||
|
||||
event_type = event.get("type", "unknown")
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
transcript = cast(str, event.get("transcript", ""))
|
||||
if len(transcript) > 0:
|
||||
self._end_turn(transcript)
|
||||
self._start_turn()
|
||||
await self._output_queue.put(transcript)
|
||||
await asyncio.sleep(0) # yield control
|
||||
except asyncio.TimeoutError:
|
||||
# No new events for a while. Assume the session is done.
|
||||
break
|
||||
except Exception as e:
|
||||
await self._output_queue.put(ErrorSentinel(e))
|
||||
raise e
|
||||
await self._output_queue.put(SessionCompleteSentinel())
|
||||
|
||||
async def _stream_audio(
|
||||
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
|
||||
) -> None:
|
||||
assert self._websocket is not None, "Websocket not initialized"
|
||||
self._start_turn()
|
||||
while True:
|
||||
buffer = await audio_queue.get()
|
||||
if buffer is None:
|
||||
break
|
||||
|
||||
self._turn_audio_buffer.append(buffer)
|
||||
try:
|
||||
await self._websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(buffer.tobytes()).decode("utf-8"),
|
||||
}
|
||||
)
|
||||
)
|
||||
except websockets.ConnectionClosed:
|
||||
break
|
||||
except Exception as e:
|
||||
await self._output_queue.put(ErrorSentinel(e))
|
||||
raise e
|
||||
|
||||
await asyncio.sleep(0) # yield control
|
||||
|
||||
async def _process_websocket_connection(self) -> None:
|
||||
try:
|
||||
async with websockets.connect(
|
||||
"wss://api.openai.com/v1/realtime?intent=transcription",
|
||||
additional_headers={
|
||||
"Authorization": f"Bearer {self._client.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
"OpenAI-Log-Session": "1",
|
||||
},
|
||||
) as ws:
|
||||
await self._setup_connection(ws)
|
||||
self._process_events_task = asyncio.create_task(self._handle_events())
|
||||
self._stream_audio_task = asyncio.create_task(self._stream_audio(self._input_queue))
|
||||
self.connected = True
|
||||
if self._listener_task:
|
||||
await self._listener_task
|
||||
else:
|
||||
logger.error("Listener task not initialized")
|
||||
raise AgentsException("Listener task not initialized")
|
||||
except Exception as e:
|
||||
await self._output_queue.put(ErrorSentinel(e))
|
||||
raise e
|
||||
|
||||
def _check_errors(self) -> None:
|
||||
if self._connection_task and self._connection_task.done():
|
||||
exc = self._connection_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
|
||||
if self._process_events_task and self._process_events_task.done():
|
||||
exc = self._process_events_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
|
||||
if self._stream_audio_task and self._stream_audio_task.done():
|
||||
exc = self._stream_audio_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
|
||||
if self._listener_task and self._listener_task.done():
|
||||
exc = self._listener_task.exception()
|
||||
if exc and isinstance(exc, Exception):
|
||||
self._stored_exception = exc
|
||||
|
||||
def _cleanup_tasks(self) -> None:
|
||||
if self._listener_task and not self._listener_task.done():
|
||||
self._listener_task.cancel()
|
||||
|
||||
if self._process_events_task and not self._process_events_task.done():
|
||||
self._process_events_task.cancel()
|
||||
|
||||
if self._stream_audio_task and not self._stream_audio_task.done():
|
||||
self._stream_audio_task.cancel()
|
||||
|
||||
if self._connection_task and not self._connection_task.done():
|
||||
self._connection_task.cancel()
|
||||
|
||||
async def transcribe_turns(self) -> AsyncIterator[str]:
|
||||
self._connection_task = asyncio.create_task(self._process_websocket_connection())
|
||||
|
||||
while True:
|
||||
try:
|
||||
turn = await self._output_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
if (
|
||||
turn is None
|
||||
or isinstance(turn, ErrorSentinel)
|
||||
or isinstance(turn, SessionCompleteSentinel)
|
||||
):
|
||||
self._output_queue.task_done()
|
||||
break
|
||||
yield turn
|
||||
self._output_queue.task_done()
|
||||
|
||||
if self._tracing_span:
|
||||
self._end_turn("")
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
|
||||
self._check_errors()
|
||||
if self._stored_exception:
|
||||
raise self._stored_exception
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
|
||||
self._cleanup_tasks()
|
||||
|
||||
|
||||
class OpenAISTTModel(STTModel):
|
||||
"""A speech-to-text model for OpenAI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
openai_client: AsyncOpenAI,
|
||||
):
|
||||
"""Create a new OpenAI speech-to-text model.
|
||||
|
||||
Args:
|
||||
model: The name of the model to use.
|
||||
openai_client: The OpenAI client to use.
|
||||
"""
|
||||
self.model = model
|
||||
self._client = openai_client
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
def _non_null_or_not_given(self, value: Any) -> Any:
|
||||
return value if value is not None else None # NOT_GIVEN
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
input: AudioInput,
|
||||
settings: STTModelSettings,
|
||||
trace_include_sensitive_data: bool,
|
||||
trace_include_sensitive_audio_data: bool,
|
||||
) -> str:
|
||||
"""Transcribe an audio input.
|
||||
|
||||
Args:
|
||||
input: The audio input to transcribe.
|
||||
settings: The settings to use for the transcription.
|
||||
|
||||
Returns:
|
||||
The transcribed text.
|
||||
"""
|
||||
with transcription_span(
|
||||
model=self.model,
|
||||
input=input.to_base64() if trace_include_sensitive_audio_data else "",
|
||||
input_format="pcm",
|
||||
model_config={
|
||||
"temperature": self._non_null_or_not_given(settings.temperature),
|
||||
"language": self._non_null_or_not_given(settings.language),
|
||||
"prompt": self._non_null_or_not_given(settings.prompt),
|
||||
},
|
||||
) as span:
|
||||
try:
|
||||
response = await self._client.audio.transcriptions.create(
|
||||
model=self.model,
|
||||
file=input.to_audio_file(),
|
||||
prompt=self._non_null_or_not_given(settings.prompt),
|
||||
language=self._non_null_or_not_given(settings.language),
|
||||
temperature=self._non_null_or_not_given(settings.temperature),
|
||||
)
|
||||
if trace_include_sensitive_data:
|
||||
span.span_data.output = response.text
|
||||
return response.text
|
||||
except Exception as e:
|
||||
span.span_data.output = ""
|
||||
span.set_error(SpanError(message=str(e), data={}))
|
||||
raise e
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
input: StreamedAudioInput,
|
||||
settings: STTModelSettings,
|
||||
trace_include_sensitive_data: bool,
|
||||
trace_include_sensitive_audio_data: bool,
|
||||
) -> StreamedTranscriptionSession:
|
||||
"""Create a new transcription session.
|
||||
|
||||
Args:
|
||||
input: The audio input to transcribe.
|
||||
settings: The settings to use for the transcription.
|
||||
trace_include_sensitive_data: Whether to include sensitive data in traces.
|
||||
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
|
||||
|
||||
Returns:
|
||||
A new transcription session.
|
||||
"""
|
||||
return OpenAISTTTranscriptionSession(
|
||||
input,
|
||||
self._client,
|
||||
self.model,
|
||||
settings,
|
||||
trace_include_sensitive_data,
|
||||
trace_include_sensitive_audio_data,
|
||||
)
|
||||
54
src/agents/voice/models/openai_tts.py
Normal file
54
src/agents/voice/models/openai_tts.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Literal
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..model import TTSModel, TTSModelSettings
|
||||
|
||||
DEFAULT_VOICE: Literal["ash"] = "ash"
|
||||
|
||||
|
||||
class OpenAITTSModel(TTSModel):
|
||||
"""A text-to-speech model for OpenAI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
openai_client: AsyncOpenAI,
|
||||
):
|
||||
"""Create a new OpenAI text-to-speech model.
|
||||
|
||||
Args:
|
||||
model: The name of the model to use.
|
||||
openai_client: The OpenAI client to use.
|
||||
"""
|
||||
self.model = model
|
||||
self._client = openai_client
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
|
||||
"""Run the text-to-speech model.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
settings: The settings to use for the text-to-speech model.
|
||||
|
||||
Returns:
|
||||
An iterator of audio chunks.
|
||||
"""
|
||||
response = self._client.audio.speech.with_streaming_response.create(
|
||||
model=self.model,
|
||||
voice=settings.voice or DEFAULT_VOICE,
|
||||
input=text,
|
||||
response_format="pcm",
|
||||
extra_body={
|
||||
"instructions": settings.instructions,
|
||||
},
|
||||
)
|
||||
|
||||
async with response as stream:
|
||||
async for chunk in stream.iter_bytes(chunk_size=1024):
|
||||
yield chunk
|
||||
151
src/agents/voice/pipeline.py
Normal file
151
src/agents/voice/pipeline.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from .._run_impl import TraceCtxManager
|
||||
from ..exceptions import UserError
|
||||
from ..logger import logger
|
||||
from .input import AudioInput, StreamedAudioInput
|
||||
from .model import STTModel, TTSModel
|
||||
from .pipeline_config import VoicePipelineConfig
|
||||
from .result import StreamedAudioResult
|
||||
from .workflow import VoiceWorkflowBase
|
||||
|
||||
|
||||
class VoicePipeline:
|
||||
"""An opinionated voice agent pipeline. It works in three steps:
|
||||
1. Transcribe audio input into text.
|
||||
2. Run the provided `workflow`, which produces a sequence of text responses.
|
||||
3. Convert the text responses into streaming audio output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workflow: VoiceWorkflowBase,
|
||||
stt_model: STTModel | str | None = None,
|
||||
tts_model: TTSModel | str | None = None,
|
||||
config: VoicePipelineConfig | None = None,
|
||||
):
|
||||
"""Create a new voice pipeline.
|
||||
|
||||
Args:
|
||||
workflow: The workflow to run. See `VoiceWorkflowBase`.
|
||||
stt_model: The speech-to-text model to use. If not provided, a default OpenAI
|
||||
model will be used.
|
||||
tts_model: The text-to-speech model to use. If not provided, a default OpenAI
|
||||
model will be used.
|
||||
config: The pipeline configuration. If not provided, a default configuration will be
|
||||
used.
|
||||
"""
|
||||
self.workflow = workflow
|
||||
self.stt_model = stt_model if isinstance(stt_model, STTModel) else None
|
||||
self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None
|
||||
self._stt_model_name = stt_model if isinstance(stt_model, str) else None
|
||||
self._tts_model_name = tts_model if isinstance(tts_model, str) else None
|
||||
self.config = config or VoicePipelineConfig()
|
||||
|
||||
async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult:
|
||||
"""Run the voice pipeline.
|
||||
|
||||
Args:
|
||||
audio_input: The audio input to process. This can either be an `AudioInput` instance,
|
||||
which is a single static buffer, or a `StreamedAudioInput` instance, which is a
|
||||
stream of audio data that you can append to.
|
||||
|
||||
Returns:
|
||||
A `StreamedAudioResult` instance. You can use this object to stream audio events and
|
||||
play them out.
|
||||
"""
|
||||
if isinstance(audio_input, AudioInput):
|
||||
return await self._run_single_turn(audio_input)
|
||||
elif isinstance(audio_input, StreamedAudioInput):
|
||||
return await self._run_multi_turn(audio_input)
|
||||
else:
|
||||
raise UserError(f"Unsupported audio input type: {type(audio_input)}")
|
||||
|
||||
def _get_tts_model(self) -> TTSModel:
|
||||
if not self.tts_model:
|
||||
self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name)
|
||||
return self.tts_model
|
||||
|
||||
def _get_stt_model(self) -> STTModel:
|
||||
if not self.stt_model:
|
||||
self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name)
|
||||
return self.stt_model
|
||||
|
||||
async def _process_audio_input(self, audio_input: AudioInput) -> str:
|
||||
model = self._get_stt_model()
|
||||
return await model.transcribe(
|
||||
audio_input,
|
||||
self.config.stt_settings,
|
||||
self.config.trace_include_sensitive_data,
|
||||
self.config.trace_include_sensitive_audio_data,
|
||||
)
|
||||
|
||||
async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
|
||||
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
|
||||
# trace
|
||||
with TraceCtxManager(
|
||||
workflow_name=self.config.workflow_name or "Voice Agent",
|
||||
trace_id=None, # Automatically generated
|
||||
group_id=self.config.group_id,
|
||||
metadata=self.config.trace_metadata,
|
||||
disabled=self.config.tracing_disabled,
|
||||
):
|
||||
input_text = await self._process_audio_input(audio_input)
|
||||
|
||||
output = StreamedAudioResult(
|
||||
self._get_tts_model(), self.config.tts_settings, self.config
|
||||
)
|
||||
|
||||
async def stream_events():
|
||||
try:
|
||||
async for text_event in self.workflow.run(input_text):
|
||||
await output._add_text(text_event)
|
||||
await output._turn_done()
|
||||
await output._done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing single turn: {e}")
|
||||
await output._add_error(e)
|
||||
raise e
|
||||
|
||||
output._set_task(asyncio.create_task(stream_events()))
|
||||
return output
|
||||
|
||||
async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
|
||||
with TraceCtxManager(
|
||||
workflow_name=self.config.workflow_name or "Voice Agent",
|
||||
trace_id=None,
|
||||
group_id=self.config.group_id,
|
||||
metadata=self.config.trace_metadata,
|
||||
disabled=self.config.tracing_disabled,
|
||||
):
|
||||
output = StreamedAudioResult(
|
||||
self._get_tts_model(), self.config.tts_settings, self.config
|
||||
)
|
||||
|
||||
transcription_session = await self._get_stt_model().create_session(
|
||||
audio_input,
|
||||
self.config.stt_settings,
|
||||
self.config.trace_include_sensitive_data,
|
||||
self.config.trace_include_sensitive_audio_data,
|
||||
)
|
||||
|
||||
async def process_turns():
|
||||
try:
|
||||
async for input_text in transcription_session.transcribe_turns():
|
||||
result = self.workflow.run(input_text)
|
||||
async for text_event in result:
|
||||
await output._add_text(text_event)
|
||||
await output._turn_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing turns: {e}")
|
||||
await output._add_error(e)
|
||||
raise e
|
||||
finally:
|
||||
await transcription_session.close()
|
||||
await output._done()
|
||||
|
||||
output._set_task(asyncio.create_task(process_turns()))
|
||||
return output
|
||||
46
src/agents/voice/pipeline_config.py
Normal file
46
src/agents/voice/pipeline_config.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..tracing.util import gen_group_id
|
||||
from .model import STTModelSettings, TTSModelSettings, VoiceModelProvider
|
||||
from .models.openai_model_provider import OpenAIVoiceModelProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoicePipelineConfig:
|
||||
"""Configuration for a `VoicePipeline`."""
|
||||
|
||||
model_provider: VoiceModelProvider = field(default_factory=OpenAIVoiceModelProvider)
|
||||
"""The voice model provider to use for the pipeline. Defaults to OpenAI."""
|
||||
|
||||
tracing_disabled: bool = False
|
||||
"""Whether to disable tracing of the pipeline. Defaults to `False`."""
|
||||
|
||||
trace_include_sensitive_data: bool = True
|
||||
"""Whether to include sensitive data in traces. Defaults to `True`. This is specifically for the
|
||||
voice pipeline, and not for anything that goes on inside your Workflow."""
|
||||
|
||||
trace_include_sensitive_audio_data: bool = True
|
||||
"""Whether to include audio data in traces. Defaults to `True`."""
|
||||
|
||||
workflow_name: str = "Voice Agent"
|
||||
"""The name of the workflow to use for tracing. Defaults to `Voice Agent`."""
|
||||
|
||||
group_id: str = field(default_factory=gen_group_id)
|
||||
"""
|
||||
A grouping identifier to use for tracing, to link multiple traces from the same conversation
|
||||
or process. If not provided, we will create a random group ID.
|
||||
"""
|
||||
|
||||
trace_metadata: dict[str, Any] | None = None
|
||||
"""
|
||||
An optional dictionary of additional metadata to include with the trace.
|
||||
"""
|
||||
|
||||
stt_settings: STTModelSettings = field(default_factory=STTModelSettings)
|
||||
"""The settings to use for the STT model."""
|
||||
|
||||
tts_settings: TTSModelSettings = field(default_factory=TTSModelSettings)
|
||||
"""The settings to use for the TTS model."""
|
||||
287
src/agents/voice/result.py
Normal file
287
src/agents/voice/result.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import UserError
|
||||
from ..logger import logger
|
||||
from ..tracing import Span, SpeechGroupSpanData, speech_group_span, speech_span
|
||||
from ..tracing.util import time_iso
|
||||
from .events import (
|
||||
VoiceStreamEvent,
|
||||
VoiceStreamEventAudio,
|
||||
VoiceStreamEventError,
|
||||
VoiceStreamEventLifecycle,
|
||||
)
|
||||
from .imports import np, npt
|
||||
from .model import TTSModel, TTSModelSettings
|
||||
from .pipeline_config import VoicePipelineConfig
|
||||
|
||||
|
||||
def _audio_to_base64(audio_data: list[bytes]) -> str:
|
||||
joined_audio_data = b"".join(audio_data)
|
||||
return base64.b64encode(joined_audio_data).decode("utf-8")
|
||||
|
||||
|
||||
class StreamedAudioResult:
|
||||
"""The output of a `VoicePipeline`. Streams events and audio data as they're generated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tts_model: TTSModel,
|
||||
tts_settings: TTSModelSettings,
|
||||
voice_pipeline_config: VoicePipelineConfig,
|
||||
):
|
||||
"""Create a new `StreamedAudioResult` instance.
|
||||
|
||||
Args:
|
||||
tts_model: The TTS model to use.
|
||||
tts_settings: The TTS settings to use.
|
||||
voice_pipeline_config: The voice pipeline config to use.
|
||||
"""
|
||||
self.tts_model = tts_model
|
||||
self.tts_settings = tts_settings
|
||||
self.total_output_text = ""
|
||||
self.instructions = tts_settings.instructions
|
||||
self.text_generation_task: asyncio.Task[Any] | None = None
|
||||
|
||||
self._voice_pipeline_config = voice_pipeline_config
|
||||
self._text_buffer = ""
|
||||
self._turn_text_buffer = ""
|
||||
self._queue: asyncio.Queue[VoiceStreamEvent] = asyncio.Queue()
|
||||
self._tasks: list[asyncio.Task[Any]] = []
|
||||
self._ordered_tasks: list[
|
||||
asyncio.Queue[VoiceStreamEvent | None]
|
||||
] = [] # New: list to hold local queues for each text segment
|
||||
self._dispatcher_task: asyncio.Task[Any] | None = (
|
||||
None # Task to dispatch audio chunks in order
|
||||
)
|
||||
|
||||
self._done_processing = False
|
||||
self._buffer_size = tts_settings.buffer_size
|
||||
self._started_processing_turn = False
|
||||
self._first_byte_received = False
|
||||
self._generation_start_time: str | None = None
|
||||
self._completed_session = False
|
||||
self._stored_exception: BaseException | None = None
|
||||
self._tracing_span: Span[SpeechGroupSpanData] | None = None
|
||||
|
||||
async def _start_turn(self):
|
||||
if self._started_processing_turn:
|
||||
return
|
||||
|
||||
self._tracing_span = speech_group_span()
|
||||
self._tracing_span.start()
|
||||
self._started_processing_turn = True
|
||||
self._first_byte_received = False
|
||||
self._generation_start_time = time_iso()
|
||||
await self._queue.put(VoiceStreamEventLifecycle(event="turn_started"))
|
||||
|
||||
def _set_task(self, task: asyncio.Task[Any]):
|
||||
self.text_generation_task = task
|
||||
|
||||
async def _add_error(self, error: Exception):
|
||||
await self._queue.put(VoiceStreamEventError(error))
|
||||
|
||||
def _transform_audio_buffer(
|
||||
self, buffer: list[bytes], output_dtype: npt.DTypeLike
|
||||
) -> npt.NDArray[np.int16 | np.float32]:
|
||||
np_array = np.frombuffer(b"".join(buffer), dtype=np.int16)
|
||||
|
||||
if output_dtype == np.int16:
|
||||
return np_array
|
||||
elif output_dtype == np.float32:
|
||||
return (np_array.astype(np.float32) / 32767.0).reshape(-1, 1)
|
||||
else:
|
||||
raise UserError("Invalid output dtype")
|
||||
|
||||
async def _stream_audio(
|
||||
self,
|
||||
text: str,
|
||||
local_queue: asyncio.Queue[VoiceStreamEvent | None],
|
||||
finish_turn: bool = False,
|
||||
):
|
||||
with speech_span(
|
||||
model=self.tts_model.model_name,
|
||||
input=text if self._voice_pipeline_config.trace_include_sensitive_data else "",
|
||||
model_config={
|
||||
"voice": self.tts_settings.voice,
|
||||
"instructions": self.instructions,
|
||||
"speed": self.tts_settings.speed,
|
||||
},
|
||||
output_format="pcm",
|
||||
parent=self._tracing_span,
|
||||
) as tts_span:
|
||||
try:
|
||||
first_byte_received = False
|
||||
buffer: list[bytes] = []
|
||||
full_audio_data: list[bytes] = []
|
||||
|
||||
async for chunk in self.tts_model.run(text, self.tts_settings):
|
||||
if not first_byte_received:
|
||||
first_byte_received = True
|
||||
tts_span.span_data.first_content_at = time_iso()
|
||||
|
||||
if chunk:
|
||||
buffer.append(chunk)
|
||||
full_audio_data.append(chunk)
|
||||
if len(buffer) >= self._buffer_size:
|
||||
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
|
||||
if self.tts_settings.transform_data:
|
||||
audio_np = self.tts_settings.transform_data(audio_np)
|
||||
await local_queue.put(
|
||||
VoiceStreamEventAudio(data=audio_np)
|
||||
) # Use local queue
|
||||
buffer = []
|
||||
if buffer:
|
||||
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
|
||||
if self.tts_settings.transform_data:
|
||||
audio_np = self.tts_settings.transform_data(audio_np)
|
||||
await local_queue.put(VoiceStreamEventAudio(data=audio_np)) # Use local queue
|
||||
|
||||
if self._voice_pipeline_config.trace_include_sensitive_audio_data:
|
||||
tts_span.span_data.output = _audio_to_base64(full_audio_data)
|
||||
else:
|
||||
tts_span.span_data.output = ""
|
||||
|
||||
if finish_turn:
|
||||
await local_queue.put(VoiceStreamEventLifecycle(event="turn_ended"))
|
||||
else:
|
||||
await local_queue.put(None) # Signal completion for this segment
|
||||
except Exception as e:
|
||||
tts_span.set_error(
|
||||
{
|
||||
"message": str(e),
|
||||
"data": {
|
||||
"text": text
|
||||
if self._voice_pipeline_config.trace_include_sensitive_data
|
||||
else "",
|
||||
},
|
||||
}
|
||||
)
|
||||
logger.error(f"Error streaming audio: {e}")
|
||||
|
||||
# Signal completion for whole session because of error
|
||||
await local_queue.put(VoiceStreamEventLifecycle(event="session_ended"))
|
||||
raise e
|
||||
|
||||
async def _add_text(self, text: str):
|
||||
await self._start_turn()
|
||||
|
||||
self._text_buffer += text
|
||||
self.total_output_text += text
|
||||
self._turn_text_buffer += text
|
||||
|
||||
combined_sentences, self._text_buffer = self.tts_settings.text_splitter(self._text_buffer)
|
||||
|
||||
if len(combined_sentences) >= 20:
|
||||
local_queue: asyncio.Queue[VoiceStreamEvent | None] = asyncio.Queue()
|
||||
self._ordered_tasks.append(local_queue)
|
||||
self._tasks.append(
|
||||
asyncio.create_task(self._stream_audio(combined_sentences, local_queue))
|
||||
)
|
||||
if self._dispatcher_task is None:
|
||||
self._dispatcher_task = asyncio.create_task(self._dispatch_audio())
|
||||
|
||||
async def _turn_done(self):
|
||||
if self._text_buffer:
|
||||
local_queue: asyncio.Queue[VoiceStreamEvent | None] = asyncio.Queue()
|
||||
self._ordered_tasks.append(local_queue) # Append the local queue for the final segment
|
||||
self._tasks.append(
|
||||
asyncio.create_task(
|
||||
self._stream_audio(self._text_buffer, local_queue, finish_turn=True)
|
||||
)
|
||||
)
|
||||
self._text_buffer = ""
|
||||
self._done_processing = True
|
||||
if self._dispatcher_task is None:
|
||||
self._dispatcher_task = asyncio.create_task(self._dispatch_audio())
|
||||
await asyncio.gather(*self._tasks)
|
||||
|
||||
def _finish_turn(self):
|
||||
if self._tracing_span:
|
||||
if self._voice_pipeline_config.trace_include_sensitive_data:
|
||||
self._tracing_span.span_data.input = self._turn_text_buffer
|
||||
else:
|
||||
self._tracing_span.span_data.input = ""
|
||||
|
||||
self._tracing_span.finish()
|
||||
self._tracing_span = None
|
||||
self._turn_text_buffer = ""
|
||||
self._started_processing_turn = False
|
||||
|
||||
async def _done(self):
|
||||
self._completed_session = True
|
||||
await self._wait_for_completion()
|
||||
|
||||
async def _dispatch_audio(self):
|
||||
# Dispatch audio chunks from each segment in the order they were added
|
||||
while True:
|
||||
if len(self._ordered_tasks) == 0:
|
||||
if self._completed_session:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
local_queue = self._ordered_tasks.pop(0)
|
||||
while True:
|
||||
chunk = await local_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
await self._queue.put(chunk)
|
||||
if isinstance(chunk, VoiceStreamEventLifecycle):
|
||||
local_queue.task_done()
|
||||
if chunk.event == "turn_ended":
|
||||
self._finish_turn()
|
||||
break
|
||||
await self._queue.put(VoiceStreamEventLifecycle(event="session_ended"))
|
||||
|
||||
async def _wait_for_completion(self):
|
||||
tasks: list[asyncio.Task[Any]] = self._tasks
|
||||
if self._dispatcher_task is not None:
|
||||
tasks.append(self._dispatcher_task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _cleanup_tasks(self):
|
||||
self._finish_turn()
|
||||
|
||||
for task in self._tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
if self._dispatcher_task and not self._dispatcher_task.done():
|
||||
self._dispatcher_task.cancel()
|
||||
|
||||
if self.text_generation_task and not self.text_generation_task.done():
|
||||
self.text_generation_task.cancel()
|
||||
|
||||
def _check_errors(self):
|
||||
for task in self._tasks:
|
||||
if task.done():
|
||||
if task.exception():
|
||||
self._stored_exception = task.exception()
|
||||
break
|
||||
|
||||
async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
|
||||
"""Stream the events and audio data as they're generated."""
|
||||
while True:
|
||||
try:
|
||||
event = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
if isinstance(event, VoiceStreamEventError):
|
||||
self._stored_exception = event.error
|
||||
logger.error(f"Error processing output: {event.error}")
|
||||
break
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
if event.type == "voice_stream_event_lifecycle" and event.event == "session_ended":
|
||||
break
|
||||
|
||||
self._check_errors()
|
||||
self._cleanup_tasks()
|
||||
|
||||
if self._stored_exception:
|
||||
raise self._stored_exception
|
||||
37
src/agents/voice/utils.py
Normal file
37
src/agents/voice/utils.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import re
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def get_sentence_based_splitter(
|
||||
min_sentence_length: int = 20,
|
||||
) -> Callable[[str], tuple[str, str]]:
|
||||
"""Returns a function that splits text into chunks based on sentence boundaries.
|
||||
|
||||
Args:
|
||||
min_sentence_length: The minimum length of a sentence to be included in a chunk.
|
||||
|
||||
Returns:
|
||||
A function that splits text into chunks based on sentence boundaries.
|
||||
"""
|
||||
|
||||
def sentence_based_text_splitter(text_buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
A function to split the text into chunks. This is useful if you want to split the text into
|
||||
chunks before sending it to the TTS model rather than waiting for the whole text to be
|
||||
processed.
|
||||
|
||||
Args:
|
||||
text_buffer: The text to split.
|
||||
|
||||
Returns:
|
||||
A tuple of the text to process and the remaining text buffer.
|
||||
"""
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text_buffer.strip())
|
||||
if len(sentences) >= 1:
|
||||
combined_sentences = " ".join(sentences[:-1])
|
||||
if len(combined_sentences) >= min_sentence_length:
|
||||
remaining_text_buffer = sentences[-1]
|
||||
return combined_sentences, remaining_text_buffer
|
||||
return "", text_buffer
|
||||
|
||||
return sentence_based_text_splitter
|
||||
93
src/agents/voice/workflow.py
Normal file
93
src/agents/voice/workflow.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ..agent import Agent
|
||||
from ..items import TResponseInputItem
|
||||
from ..result import RunResultStreaming
|
||||
from ..run import Runner
|
||||
|
||||
|
||||
class VoiceWorkflowBase(abc.ABC):
|
||||
"""
|
||||
A base class for a voice workflow. You must implement the `run` method. A "workflow" is any
|
||||
code you want, that receives a transcription and yields text that will be turned into speech
|
||||
by a text-to-speech model.
|
||||
In most cases, you'll create `Agent`s and use `Runner.run_streamed()` to run them, returning
|
||||
some or all of the text events from the stream. You can use the `VoiceWorkflowHelper` class to
|
||||
help with extracting text events from the stream.
|
||||
If you have a simple workflow that has a single starting agent and no custom logic, you can
|
||||
use `SingleAgentVoiceWorkflow` directly.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, transcription: str) -> AsyncIterator[str]:
|
||||
"""
|
||||
Run the voice workflow. You will receive an input transcription, and must yield text that
|
||||
will be spoken to the user. You can run whatever logic you want here. In most cases, the
|
||||
final logic will involve calling `Runner.run_streamed()` and yielding any text events from
|
||||
the stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class VoiceWorkflowHelper:
|
||||
@classmethod
|
||||
async def stream_text_from(cls, result: RunResultStreaming) -> AsyncIterator[str]:
|
||||
"""Wraps a `RunResultStreaming` object and yields text events from the stream."""
|
||||
async for event in result.stream_events():
|
||||
if (
|
||||
event.type == "raw_response_event"
|
||||
and event.data.type == "response.output_text.delta"
|
||||
):
|
||||
yield event.data.delta
|
||||
|
||||
|
||||
class SingleAgentWorkflowCallbacks:
|
||||
def on_run(self, workflow: SingleAgentVoiceWorkflow, transcription: str) -> None:
|
||||
"""Called when the workflow is run."""
|
||||
pass
|
||||
|
||||
|
||||
class SingleAgentVoiceWorkflow(VoiceWorkflowBase):
|
||||
"""A simple voice workflow that runs a single agent. Each transcription and result is added to
|
||||
the input history.
|
||||
For more complex workflows (e.g. multiple Runner calls, custom message history, custom logic,
|
||||
custom configs), subclass `VoiceWorkflowBase` and implement your own logic.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Agent[Any], callbacks: SingleAgentWorkflowCallbacks | None = None):
|
||||
"""Create a new single agent voice workflow.
|
||||
|
||||
Args:
|
||||
agent: The agent to run.
|
||||
callbacks: Optional callbacks to call during the workflow.
|
||||
"""
|
||||
self._input_history: list[TResponseInputItem] = []
|
||||
self._current_agent = agent
|
||||
self._callbacks = callbacks
|
||||
|
||||
async def run(self, transcription: str) -> AsyncIterator[str]:
|
||||
if self._callbacks:
|
||||
self._callbacks.on_run(self, transcription)
|
||||
|
||||
# Add the transcription to the input history
|
||||
self._input_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": transcription,
|
||||
}
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
result = Runner.run_streamed(self._current_agent, self._input_history)
|
||||
|
||||
# Stream the text from the result
|
||||
async for chunk in VoiceWorkflowHelper.stream_text_from(result):
|
||||
yield chunk
|
||||
|
||||
# Update the input history and current agent
|
||||
self._input_history = result.to_input_list()
|
||||
self._current_agent = result.last_agent
|
||||
|
|
@ -21,6 +21,8 @@ from agents import (
|
|||
UserError,
|
||||
handoff,
|
||||
)
|
||||
from agents.agent import ToolsToFinalOutputResult
|
||||
from agents.tool import FunctionToolResult, function_tool
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import (
|
||||
|
|
@ -552,3 +554,83 @@ async def test_output_guardrail_tripwire_triggered_causes_exception():
|
|||
|
||||
with pytest.raises(OutputGuardrailTripwireTriggered):
|
||||
await Runner.run(agent, input="user_message")
|
||||
|
||||
|
||||
@function_tool
|
||||
def test_tool_one():
|
||||
return Foo(bar="tool_one_result")
|
||||
|
||||
|
||||
@function_tool
|
||||
def test_tool_two():
|
||||
return "tool_two_result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_use_behavior_first_output():
|
||||
model = FakeModel()
|
||||
agent = Agent(
|
||||
name="test",
|
||||
model=model,
|
||||
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
|
||||
tool_use_behavior="stop_on_first_tool",
|
||||
output_type=Foo,
|
||||
)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
# First turn: a message and tool call
|
||||
[
|
||||
get_text_message("a_message"),
|
||||
get_function_tool_call("test_tool_one", None),
|
||||
get_function_tool_call("test_tool_two", None),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="user_message")
|
||||
|
||||
assert result.final_output == Foo(bar="tool_one_result"), (
|
||||
"should have used the first tool result"
|
||||
)
|
||||
|
||||
|
||||
def custom_tool_use_behavior(
|
||||
context: RunContextWrapper[Any], results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
if "test_tool_one" in [result.tool.name for result in results]:
|
||||
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
|
||||
else:
|
||||
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_use_behavior_custom_function():
|
||||
model = FakeModel()
|
||||
agent = Agent(
|
||||
name="test",
|
||||
model=model,
|
||||
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
|
||||
tool_use_behavior=custom_tool_use_behavior,
|
||||
)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
# First turn: a message and tool call
|
||||
[
|
||||
get_text_message("a_message"),
|
||||
get_function_tool_call("test_tool_two", None),
|
||||
],
|
||||
# Second turn: a message and tool call
|
||||
[
|
||||
get_text_message("a_message"),
|
||||
get_function_tool_call("test_tool_one", None),
|
||||
get_function_tool_call("test_tool_two", None),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="user_message")
|
||||
|
||||
assert len(result.raw_responses) == 2, "should have two model responses"
|
||||
assert result.final_output == "the_final_output", "should have used the custom function"
|
||||
|
|
|
|||
|
|
@ -674,7 +674,7 @@ async def test_streaming_events():
|
|||
total_expected_item_count = sum(expected_item_type_map.values())
|
||||
|
||||
assert event_counts["run_item_stream_event"] == total_expected_item_count, (
|
||||
f"Expectd {total_expected_item_count} events, got {event_counts['run_item_stream_event']}"
|
||||
f"Expected {total_expected_item_count} events, got {event_counts['run_item_stream_event']}"
|
||||
f"Expected events were: {expected_item_type_map}, got {event_counts}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from agents import Agent, RunConfig, Runner, trace
|
|||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import get_text_message
|
||||
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
|
||||
from .testing_processor import assert_no_traces, fetch_normalized_spans
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -23,9 +23,6 @@ async def test_single_run_is_single_trace():
|
|||
|
||||
await Runner.run(agent, input="first_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -45,12 +42,6 @@ async def test_single_run_is_single_trace():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 1, (
|
||||
f"Got {len(spans)}, but expected 1: the agent span. data:"
|
||||
f"{[span.span_data for span in spans]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_runs_are_multiple_traces():
|
||||
|
|
@ -69,9 +60,6 @@ async def test_multiple_runs_are_multiple_traces():
|
|||
await Runner.run(agent, input="first_test")
|
||||
await Runner.run(agent, input="second_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -105,9 +93,6 @@ async def test_multiple_runs_are_multiple_traces():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrapped_trace_is_single_trace():
|
||||
|
|
@ -129,9 +114,6 @@ async def test_wrapped_trace_is_single_trace():
|
|||
await Runner.run(agent, input="second_test")
|
||||
await Runner.run(agent, input="third_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -169,9 +151,6 @@ async def test_wrapped_trace_is_single_trace():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_disabled_trace_disabled_agent_trace():
|
||||
|
|
@ -185,14 +164,7 @@ async def test_parent_disabled_trace_disabled_agent_trace():
|
|||
|
||||
await Runner.run(agent, input="first_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
|
||||
assert fetch_normalized_spans() == snapshot([])
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 0, (
|
||||
f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}"
|
||||
)
|
||||
assert_no_traces()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -206,12 +178,7 @@ async def test_manual_disabling_works():
|
|||
|
||||
await Runner.run(agent, input="first_test", run_config=RunConfig(tracing_disabled=True))
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
|
||||
assert fetch_normalized_spans() == snapshot([])
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 0, f"Got {len(spans)}, but expected no spans"
|
||||
assert_no_traces()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -226,16 +193,29 @@ async def test_trace_config_works():
|
|||
await Runner.run(
|
||||
agent,
|
||||
input="first_test",
|
||||
run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="456"),
|
||||
run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="trace_456"),
|
||||
)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
export = traces[0].export()
|
||||
assert export is not None, "Trace export should not be None"
|
||||
assert export["workflow_name"] == "Foo bar"
|
||||
assert export["group_id"] == "123"
|
||||
assert export["id"] == "456"
|
||||
assert fetch_normalized_spans(keep_trace_id=True) == snapshot(
|
||||
[
|
||||
{
|
||||
"id": "trace_456",
|
||||
"workflow_name": "Foo bar",
|
||||
"group_id": "123",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -255,9 +235,6 @@ async def test_not_starting_streaming_creates_trace():
|
|||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -277,9 +254,6 @@ async def test_not_starting_streaming_creates_trace():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span"
|
||||
|
||||
# Await the stream to avoid warnings about it not being awaited
|
||||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
|
@ -298,8 +272,24 @@ async def test_streaming_single_run_is_single_trace():
|
|||
async for _ in x.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -324,8 +314,38 @@ async def test_multiple_streamed_runs_are_multiple_traces():
|
|||
async for _ in x.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -356,8 +376,42 @@ async def test_wrapped_streaming_trace_is_single_trace():
|
|||
async for _ in x.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test_workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -386,8 +440,42 @@ async def test_wrapped_mixed_trace_is_single_trace():
|
|||
async for _ in x.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test_workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": [],
|
||||
"tools": [],
|
||||
"output_type": "str",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -409,8 +497,7 @@ async def test_parent_disabled_trace_disables_streaming_agent_trace():
|
|||
async for _ in x.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
|
||||
assert_no_traces()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -431,5 +518,4 @@ async def test_manual_streaming_disabling_works():
|
|||
async for _ in x.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
|
||||
assert_no_traces()
|
||||
|
|
|
|||
|
|
@ -49,10 +49,10 @@ async def test_simple_function():
|
|||
assert tool.name == "simple_function"
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
|
||||
assert result == "6"
|
||||
assert result == 6
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
|
||||
assert result == "3"
|
||||
assert result == 3
|
||||
|
||||
# Missing required argument should raise an error
|
||||
with pytest.raises(ModelBehaviorError):
|
||||
|
|
|
|||
|
|
@ -152,9 +152,13 @@ def optional_param_function(a: int, b: Optional[int] = None) -> str:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_param_function():
|
||||
async def test_non_strict_mode_function():
|
||||
tool = optional_param_function
|
||||
|
||||
assert tool.strict_json_schema is False, "strict_json_schema should be False"
|
||||
|
||||
assert tool.params_json_schema.get("required") == ["a"], "required should only be a"
|
||||
|
||||
input_data = {"a": 5}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "5_no_b"
|
||||
|
|
@ -165,7 +169,7 @@ async def test_optional_param_function():
|
|||
|
||||
|
||||
@function_tool(strict_mode=False)
|
||||
def multiple_optional_params_function(
|
||||
def all_optional_params_function(
|
||||
x: int = 42,
|
||||
y: str = "hello",
|
||||
z: Optional[int] = None,
|
||||
|
|
@ -176,8 +180,12 @@ def multiple_optional_params_function(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_optional_params_function():
|
||||
tool = multiple_optional_params_function
|
||||
async def test_all_optional_params_function():
|
||||
tool = all_optional_params_function
|
||||
|
||||
assert tool.strict_json_schema is False, "strict_json_schema should be False"
|
||||
|
||||
assert tool.params_json_schema.get("required") is None, "required should be empty"
|
||||
|
||||
input_data: dict[str, Any] = {}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ class Foo(TypedDict):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structed_output_non_streamed_agent_hooks():
|
||||
async def test_structured_output_non_streamed_agent_hooks():
|
||||
hooks = RunHooksForTests()
|
||||
model = FakeModel()
|
||||
agent_1 = Agent(name="test_1", model=model)
|
||||
|
|
@ -296,7 +296,7 @@ async def test_structed_output_non_streamed_agent_hooks():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structed_output_streamed_agent_hooks():
|
||||
async def test_structured_output_streamed_agent_hooks():
|
||||
hooks = RunHooksForTests()
|
||||
model = FakeModel()
|
||||
agent_1 = Agent(name="test_1", model=model)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
|
|||
from agents.tracing.span_data import ResponseSpanData
|
||||
from tests import fake_model
|
||||
|
||||
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans
|
||||
from .testing_processor import assert_no_spans, fetch_normalized_spans, fetch_ordered_spans
|
||||
|
||||
|
||||
class DummyTracing:
|
||||
|
|
@ -64,13 +64,6 @@ async def test_get_response_creates_trace(monkeypatch):
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 1
|
||||
|
||||
assert isinstance(spans[0].span_data, ResponseSpanData)
|
||||
assert spans[0].span_data.response is not None
|
||||
assert spans[0].span_data.response.id == "dummy-id"
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -96,9 +89,8 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
|
|||
[{"workflow_name": "test", "children": [{"type": "response"}]}]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 1
|
||||
assert spans[0].span_data.response is None
|
||||
[span] = fetch_ordered_spans()
|
||||
assert span.span_data.response is None
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
@ -123,8 +115,7 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
|
|||
|
||||
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 0
|
||||
assert_no_spans()
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
@ -164,12 +155,6 @@ async def test_stream_response_creates_trace(monkeypatch):
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 1
|
||||
assert isinstance(spans[0].span_data, ResponseSpanData)
|
||||
assert spans[0].span_data.response is not None
|
||||
assert spans[0].span_data.response.id == "dummy-id-123"
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -203,10 +188,9 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
|
|||
[{"workflow_name": "test", "children": [{"type": "response"}]}]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 1
|
||||
assert isinstance(spans[0].span_data, ResponseSpanData)
|
||||
assert spans[0].span_data.response is None
|
||||
[span] = fetch_ordered_spans()
|
||||
assert isinstance(span.span_data, ResponseSpanData)
|
||||
assert span.span_data.response is None
|
||||
|
||||
|
||||
@pytest.mark.allow_call_model_methods
|
||||
|
|
@ -239,5 +223,4 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
|
|||
|
||||
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 0
|
||||
assert_no_spans()
|
||||
|
|
|
|||
161
tests/test_tool_choice_reset.py
Normal file
161
tests/test_tool_choice_reset.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
import pytest
|
||||
|
||||
from agents import Agent, ModelSettings, Runner, Tool
|
||||
from agents._run_impl import RunImpl
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import (
|
||||
get_function_tool,
|
||||
get_function_tool_call,
|
||||
get_text_message,
|
||||
)
|
||||
|
||||
|
||||
class TestToolChoiceReset:
|
||||
|
||||
def test_should_reset_tool_choice_direct(self):
|
||||
"""
|
||||
Test the _should_reset_tool_choice method directly with various inputs
|
||||
to ensure it correctly identifies cases where reset is needed.
|
||||
"""
|
||||
# Case 1: tool_choice = None should not reset
|
||||
model_settings = ModelSettings(tool_choice=None)
|
||||
tools1: list[Tool] = [get_function_tool("tool1")]
|
||||
# Cast to list[Tool] to fix type checking issues
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
|
||||
# Case 2: tool_choice = "auto" should not reset
|
||||
model_settings = ModelSettings(tool_choice="auto")
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
|
||||
# Case 3: tool_choice = "none" should not reset
|
||||
model_settings = ModelSettings(tool_choice="none")
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
|
||||
# Case 4: tool_choice = "required" with one tool should reset
|
||||
model_settings = ModelSettings(tool_choice="required")
|
||||
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
|
||||
# Case 5: tool_choice = "required" with multiple tools should not reset
|
||||
model_settings = ModelSettings(tool_choice="required")
|
||||
tools2: list[Tool] = [get_function_tool("tool1"), get_function_tool("tool2")]
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools2)
|
||||
|
||||
# Case 6: Specific tool choice should reset
|
||||
model_settings = ModelSettings(tool_choice="specific_tool")
|
||||
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_tool_choice_with_multiple_runs(self):
|
||||
"""
|
||||
Test scenario 1: When multiple runs are executed with tool_choice="required"
|
||||
Ensure each run works correctly and doesn't get stuck in infinite loop
|
||||
Also verify that tool_choice remains "required" between runs
|
||||
"""
|
||||
# Set up our fake model with responses for two runs
|
||||
fake_model = FakeModel()
|
||||
fake_model.add_multiple_turn_outputs([
|
||||
[get_text_message("First run response")],
|
||||
[get_text_message("Second run response")]
|
||||
])
|
||||
|
||||
# Create agent with a custom tool and tool_choice="required"
|
||||
custom_tool = get_function_tool("custom_tool")
|
||||
agent = Agent(
|
||||
name="test_agent",
|
||||
model=fake_model,
|
||||
tools=[custom_tool],
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
)
|
||||
|
||||
# First run should work correctly and preserve tool_choice
|
||||
result1 = await Runner.run(agent, "first run")
|
||||
assert result1.final_output == "First run response"
|
||||
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
|
||||
|
||||
# Second run should also work correctly with tool_choice still required
|
||||
result2 = await Runner.run(agent, "second run")
|
||||
assert result2.final_output == "Second run response"
|
||||
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_with_stop_at_tool_name(self):
|
||||
"""
|
||||
Test scenario 2: When using required tool_choice with stop_at_tool_names behavior
|
||||
Ensure it correctly stops at the specified tool
|
||||
"""
|
||||
# Set up fake model to return a tool call for second_tool
|
||||
fake_model = FakeModel()
|
||||
fake_model.set_next_output([
|
||||
get_function_tool_call("second_tool", "{}")
|
||||
])
|
||||
|
||||
# Create agent with two tools and tool_choice="required" and stop_at_tool behavior
|
||||
first_tool = get_function_tool("first_tool", return_value="first tool result")
|
||||
second_tool = get_function_tool("second_tool", return_value="second tool result")
|
||||
|
||||
agent = Agent(
|
||||
name="test_agent",
|
||||
model=fake_model,
|
||||
tools=[first_tool, second_tool],
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
tool_use_behavior={"stop_at_tool_names": ["second_tool"]},
|
||||
)
|
||||
|
||||
# Run should stop after using second_tool
|
||||
result = await Runner.run(agent, "run test")
|
||||
assert result.final_output == "second tool result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_specific_tool_choice(self):
|
||||
"""
|
||||
Test scenario 3: When using a specific tool choice name
|
||||
Ensure it doesn't cause infinite loops
|
||||
"""
|
||||
# Set up fake model to return a text message
|
||||
fake_model = FakeModel()
|
||||
fake_model.set_next_output([get_text_message("Test message")])
|
||||
|
||||
# Create agent with specific tool_choice
|
||||
tool1 = get_function_tool("tool1")
|
||||
tool2 = get_function_tool("tool2")
|
||||
tool3 = get_function_tool("tool3")
|
||||
|
||||
agent = Agent(
|
||||
name="test_agent",
|
||||
model=fake_model,
|
||||
tools=[tool1, tool2, tool3],
|
||||
model_settings=ModelSettings(tool_choice="tool1"), # Specific tool
|
||||
)
|
||||
|
||||
# Run should complete without infinite loops
|
||||
result = await Runner.run(agent, "first run")
|
||||
assert result.final_output == "Test message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_with_single_tool(self):
|
||||
"""
|
||||
Test scenario 4: When using required tool_choice with only one tool
|
||||
Ensure it doesn't cause infinite loops
|
||||
"""
|
||||
# Set up fake model to return a tool call followed by a text message
|
||||
fake_model = FakeModel()
|
||||
fake_model.add_multiple_turn_outputs([
|
||||
# First call returns a tool call
|
||||
[get_function_tool_call("custom_tool", "{}")],
|
||||
# Second call returns a text message
|
||||
[get_text_message("Final response")]
|
||||
])
|
||||
|
||||
# Create agent with a single tool and tool_choice="required"
|
||||
custom_tool = get_function_tool("custom_tool", return_value="tool result")
|
||||
agent = Agent(
|
||||
name="test_agent",
|
||||
model=fake_model,
|
||||
tools=[custom_tool],
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
)
|
||||
|
||||
# Run should complete without infinite loops
|
||||
result = await Runner.run(agent, "first run")
|
||||
assert result.final_output == "Final response"
|
||||
194
tests/test_tool_use_behavior.py
Normal file
194
tests/test_tool_use_behavior.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from openai.types.responses.response_input_item_param import FunctionCallOutput
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
FunctionToolResult,
|
||||
RunConfig,
|
||||
RunContextWrapper,
|
||||
ToolCallOutputItem,
|
||||
ToolsToFinalOutputResult,
|
||||
UserError,
|
||||
)
|
||||
from agents._run_impl import RunImpl
|
||||
|
||||
from .test_responses import get_function_tool
|
||||
|
||||
|
||||
def _make_function_tool_result(
|
||||
agent: Agent, output: str, tool_name: str | None = None
|
||||
) -> FunctionToolResult:
|
||||
# Construct a FunctionToolResult with the given output using a simple function tool.
|
||||
tool = get_function_tool(tool_name or "dummy", return_value=output)
|
||||
raw_item: FunctionCallOutput = cast(
|
||||
FunctionCallOutput,
|
||||
{
|
||||
"call_id": "1",
|
||||
"output": output,
|
||||
"type": "function_call_output",
|
||||
},
|
||||
)
|
||||
# For this test we don't care about the specific RunItem subclass, only the output field
|
||||
run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output)
|
||||
return FunctionToolResult(tool=tool, output=output, run_item=run_item)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_results_returns_not_final_output() -> None:
|
||||
# If there are no tool results at all, tool_use_behavior should not produce a final output.
|
||||
agent = Agent(name="test")
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=[],
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is False
|
||||
assert result.final_output is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_llm_again_behavior() -> None:
|
||||
# With the default run_llm_again behavior, even with tools we still expect to keep running.
|
||||
agent = Agent(name="test", tool_use_behavior="run_llm_again")
|
||||
tool_results = [_make_function_tool_result(agent, "ignored")]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is False
|
||||
assert result.final_output is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_on_first_tool_behavior() -> None:
|
||||
# When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final.
|
||||
agent = Agent(name="test", tool_use_behavior="stop_on_first_tool")
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "first_tool_output"),
|
||||
_make_function_tool_result(agent, "ignored"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True
|
||||
assert result.final_output == "first_tool_output"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_tool_use_behavior_sync() -> None:
|
||||
"""If tool_use_behavior is a sync function, we should call it and propagate its return."""
|
||||
|
||||
def behavior(
|
||||
context: RunContextWrapper, results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
assert len(results) == 3
|
||||
return ToolsToFinalOutputResult(is_final_output=True, final_output="custom")
|
||||
|
||||
agent = Agent(name="test", tool_use_behavior=behavior)
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "ignored1"),
|
||||
_make_function_tool_result(agent, "ignored2"),
|
||||
_make_function_tool_result(agent, "ignored3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True
|
||||
assert result.final_output == "custom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_tool_use_behavior_async() -> None:
|
||||
"""If tool_use_behavior is an async function, we should await it and propagate its return."""
|
||||
|
||||
async def behavior(
|
||||
context: RunContextWrapper, results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
assert len(results) == 3
|
||||
return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom")
|
||||
|
||||
agent = Agent(name="test", tool_use_behavior=behavior)
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "ignored1"),
|
||||
_make_function_tool_result(agent, "ignored2"),
|
||||
_make_function_tool_result(agent, "ignored3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True
|
||||
assert result.final_output == "async_custom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_tool_use_behavior_raises() -> None:
|
||||
"""If tool_use_behavior is invalid, we should raise a UserError."""
|
||||
agent = Agent(name="test")
|
||||
# Force an invalid value; mypy will complain, so ignore the type here.
|
||||
agent.tool_use_behavior = "bad_value" # type: ignore[assignment]
|
||||
tool_results = [_make_function_tool_result(agent, "ignored")]
|
||||
with pytest.raises(UserError):
|
||||
await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_names_to_stop_at_behavior() -> None:
|
||||
agent = Agent(
|
||||
name="test",
|
||||
tools=[
|
||||
get_function_tool("tool1", return_value="tool1_output"),
|
||||
get_function_tool("tool2", return_value="tool2_output"),
|
||||
get_function_tool("tool3", return_value="tool3_output"),
|
||||
],
|
||||
tool_use_behavior={"stop_at_tool_names": ["tool1"]},
|
||||
)
|
||||
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "ignored1", "tool2"),
|
||||
_make_function_tool_result(agent, "ignored3", "tool3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is False, "We should not have stopped at tool1"
|
||||
|
||||
# Now test with a tool that matches the list
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "output1", "tool1"),
|
||||
_make_function_tool_result(agent, "ignored2", "tool2"),
|
||||
_make_function_tool_result(agent, "ignored3", "tool3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True, "We should have stopped at tool1"
|
||||
assert result.final_output == "output1"
|
||||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agents.tracing import (
|
||||
Span,
|
||||
|
|
@ -17,7 +18,12 @@ from agents.tracing import (
|
|||
)
|
||||
from agents.tracing.spans import SpanError
|
||||
|
||||
from .testing_processor import fetch_events, fetch_ordered_spans, fetch_traces
|
||||
from .testing_processor import (
|
||||
SPAN_PROCESSOR_TESTING,
|
||||
assert_no_traces,
|
||||
fetch_events,
|
||||
fetch_normalized_spans,
|
||||
)
|
||||
|
||||
### HELPERS
|
||||
|
||||
|
|
@ -47,7 +53,7 @@ def simple_tracing():
|
|||
x = trace("test")
|
||||
x.start()
|
||||
|
||||
span_1 = agent_span(name="agent_1", parent=x)
|
||||
span_1 = agent_span(name="agent_1", span_id="span_1", parent=x)
|
||||
span_1.start()
|
||||
span_1.finish()
|
||||
|
||||
|
|
@ -66,33 +72,36 @@ def simple_tracing():
|
|||
def test_simple_tracing() -> None:
|
||||
simple_tracing()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 3
|
||||
assert len(traces) == 1
|
||||
|
||||
trace = traces[0]
|
||||
standard_trace_checks(trace, name_check="test")
|
||||
trace_id = trace.trace_id
|
||||
|
||||
first_span = spans[0]
|
||||
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
|
||||
assert first_span.span_data.name == "agent_1"
|
||||
|
||||
second_span = spans[1]
|
||||
standard_span_checks(second_span, trace_id=trace_id, parent_id=None, span_type="custom")
|
||||
assert second_span.span_id == "span_2"
|
||||
assert second_span.span_data.name == "custom_1"
|
||||
|
||||
third_span = spans[2]
|
||||
standard_span_checks(
|
||||
third_span, trace_id=trace_id, parent_id=second_span.span_id, span_type="custom"
|
||||
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"id": "span_1",
|
||||
"data": {"name": "agent_1"},
|
||||
},
|
||||
{
|
||||
"type": "custom",
|
||||
"id": "span_2",
|
||||
"data": {"name": "custom_1", "data": {}},
|
||||
"children": [
|
||||
{
|
||||
"type": "custom",
|
||||
"id": "span_3",
|
||||
"data": {"name": "custom_2", "data": {}},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
assert third_span.span_id == "span_3"
|
||||
assert third_span.span_data.name == "custom_2"
|
||||
|
||||
|
||||
def ctxmanager_spans():
|
||||
with trace(workflow_name="test", trace_id="123", group_id="456"):
|
||||
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
|
||||
with custom_span(name="custom_1", span_id="span_1"):
|
||||
with custom_span(name="custom_2", span_id="span_1_inner"):
|
||||
pass
|
||||
|
|
@ -104,36 +113,38 @@ def ctxmanager_spans():
|
|||
def test_ctxmanager_spans() -> None:
|
||||
ctxmanager_spans()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 3
|
||||
assert len(traces) == 1
|
||||
|
||||
trace = traces[0]
|
||||
standard_trace_checks(trace, name_check="test")
|
||||
trace_id = trace.trace_id
|
||||
|
||||
first_span = spans[0]
|
||||
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="custom")
|
||||
assert first_span.span_id == "span_1"
|
||||
|
||||
first_inner_span = spans[1]
|
||||
standard_span_checks(
|
||||
first_inner_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="custom"
|
||||
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test",
|
||||
"group_id": "456",
|
||||
"children": [
|
||||
{
|
||||
"type": "custom",
|
||||
"id": "span_1",
|
||||
"data": {"name": "custom_1", "data": {}},
|
||||
"children": [
|
||||
{
|
||||
"type": "custom",
|
||||
"id": "span_1_inner",
|
||||
"data": {"name": "custom_2", "data": {}},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"type": "custom", "id": "span_2", "data": {"name": "custom_2", "data": {}}},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
assert first_inner_span.span_id == "span_1_inner"
|
||||
|
||||
second_span = spans[2]
|
||||
standard_span_checks(second_span, trace_id=trace_id, parent_id=None, span_type="custom")
|
||||
assert second_span.span_id == "span_2"
|
||||
|
||||
|
||||
async def run_subtask(span_id: str | None = None) -> None:
|
||||
with generation_span(span_id=span_id):
|
||||
await asyncio.sleep(0.01)
|
||||
await asyncio.sleep(0.0001)
|
||||
|
||||
|
||||
async def simple_async_tracing():
|
||||
with trace(workflow_name="test", trace_id="123", group_id="456"):
|
||||
with trace(workflow_name="test", trace_id="trace_123", group_id="group_456"):
|
||||
await run_subtask(span_id="span_1")
|
||||
await run_subtask(span_id="span_2")
|
||||
|
||||
|
|
@ -142,21 +153,18 @@ async def simple_async_tracing():
|
|||
async def test_async_tracing() -> None:
|
||||
await simple_async_tracing()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 2
|
||||
assert len(traces) == 1
|
||||
|
||||
trace = traces[0]
|
||||
standard_trace_checks(trace, name_check="test")
|
||||
trace_id = trace.trace_id
|
||||
|
||||
# We don't care about ordering here, just that they're there
|
||||
for s in spans:
|
||||
standard_span_checks(s, trace_id=trace_id, parent_id=None, span_type="generation")
|
||||
|
||||
ids = [span.span_id for span in spans]
|
||||
assert "span_1" in ids
|
||||
assert "span_2" in ids
|
||||
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test",
|
||||
"group_id": "group_456",
|
||||
"children": [
|
||||
{"type": "generation", "id": "span_1"},
|
||||
{"type": "generation", "id": "span_2"},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def run_tasks_parallel(span_ids: list[str]) -> None:
|
||||
|
|
@ -171,13 +179,11 @@ async def run_tasks_as_children(first_span_id: str, second_span_id: str) -> None
|
|||
|
||||
|
||||
async def complex_async_tracing():
|
||||
with trace(workflow_name="test", trace_id="123", group_id="456"):
|
||||
await asyncio.sleep(0.01)
|
||||
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
|
||||
await asyncio.gather(
|
||||
run_tasks_parallel(["span_1", "span_2"]),
|
||||
run_tasks_parallel(["span_3", "span_4"]),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
await asyncio.gather(
|
||||
run_tasks_as_children("span_5", "span_6"),
|
||||
run_tasks_as_children("span_7", "span_8"),
|
||||
|
|
@ -186,39 +192,38 @@ async def complex_async_tracing():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_async_tracing() -> None:
|
||||
await complex_async_tracing()
|
||||
for _ in range(300):
|
||||
SPAN_PROCESSOR_TESTING.clear()
|
||||
await complex_async_tracing()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 8
|
||||
assert len(traces) == 1
|
||||
|
||||
trace = traces[0]
|
||||
standard_trace_checks(trace, name_check="test")
|
||||
trace_id = trace.trace_id
|
||||
|
||||
# First ensure 1,2,3,4 exist and are in parallel with the trace as parent
|
||||
for span_id in ["span_1", "span_2", "span_3", "span_4"]:
|
||||
span = next((s for s in spans if s.span_id == span_id), None)
|
||||
assert span is not None
|
||||
standard_span_checks(span, trace_id=trace_id, parent_id=None, span_type="generation")
|
||||
|
||||
# Ensure 5 and 7 exist and have the trace as parent
|
||||
for span_id in ["span_5", "span_7"]:
|
||||
span = next((s for s in spans if s.span_id == span_id), None)
|
||||
assert span is not None
|
||||
standard_span_checks(span, trace_id=trace_id, parent_id=None, span_type="generation")
|
||||
|
||||
# Ensure 6 and 8 exist and have 5 and 7 as parents
|
||||
six = next((s for s in spans if s.span_id == "span_6"), None)
|
||||
assert six is not None
|
||||
standard_span_checks(six, trace_id=trace_id, parent_id="span_5", span_type="generation")
|
||||
eight = next((s for s in spans if s.span_id == "span_8"), None)
|
||||
assert eight is not None
|
||||
standard_span_checks(eight, trace_id=trace_id, parent_id="span_7", span_type="generation")
|
||||
assert fetch_normalized_spans(keep_span_id=True) == (
|
||||
[
|
||||
{
|
||||
"workflow_name": "test",
|
||||
"group_id": "456",
|
||||
"children": [
|
||||
{"type": "generation", "id": "span_1"},
|
||||
{"type": "generation", "id": "span_2"},
|
||||
{"type": "generation", "id": "span_3"},
|
||||
{"type": "generation", "id": "span_4"},
|
||||
{
|
||||
"type": "generation",
|
||||
"id": "span_5",
|
||||
"children": [{"type": "generation", "id": "span_6"}],
|
||||
},
|
||||
{
|
||||
"type": "generation",
|
||||
"id": "span_7",
|
||||
"children": [{"type": "generation", "id": "span_8"}],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def spans_with_setters():
|
||||
with trace(workflow_name="test", trace_id="123", group_id="456"):
|
||||
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
|
||||
with agent_span(name="agent_1") as span_a:
|
||||
span_a.span_data.name = "agent_2"
|
||||
|
||||
|
|
@ -236,34 +241,33 @@ def spans_with_setters():
|
|||
def test_spans_with_setters() -> None:
|
||||
spans_with_setters()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 4
|
||||
assert len(traces) == 1
|
||||
|
||||
trace = traces[0]
|
||||
standard_trace_checks(trace, name_check="test")
|
||||
trace_id = trace.trace_id
|
||||
|
||||
# Check the spans
|
||||
first_span = spans[0]
|
||||
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
|
||||
assert first_span.span_data.name == "agent_2"
|
||||
|
||||
second_span = spans[1]
|
||||
standard_span_checks(
|
||||
second_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="function"
|
||||
)
|
||||
assert second_span.span_data.input == "i"
|
||||
assert second_span.span_data.output == "o"
|
||||
|
||||
third_span = spans[2]
|
||||
standard_span_checks(
|
||||
third_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="generation"
|
||||
)
|
||||
|
||||
fourth_span = spans[3]
|
||||
standard_span_checks(
|
||||
fourth_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="handoff"
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test",
|
||||
"group_id": "456",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {"name": "agent_2"},
|
||||
"children": [
|
||||
{
|
||||
"type": "function",
|
||||
"data": {"name": "function_1", "input": "i", "output": "o"},
|
||||
},
|
||||
{
|
||||
"type": "generation",
|
||||
"data": {"input": [{"foo": "bar"}]},
|
||||
},
|
||||
{
|
||||
"type": "handoff",
|
||||
"data": {"from_agent": "agent_1", "to_agent": "agent_2"},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -276,14 +280,11 @@ def disabled_tracing():
|
|||
|
||||
def test_disabled_tracing():
|
||||
disabled_tracing()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 0
|
||||
assert len(traces) == 0
|
||||
assert_no_traces()
|
||||
|
||||
|
||||
def enabled_trace_disabled_span():
|
||||
with trace(workflow_name="test", trace_id="123"):
|
||||
with trace(workflow_name="test", trace_id="trace_123"):
|
||||
with agent_span(name="agent_1"):
|
||||
with function_span(name="function_1", disabled=True):
|
||||
with generation_span():
|
||||
|
|
@ -293,17 +294,19 @@ def enabled_trace_disabled_span():
|
|||
def test_enabled_trace_disabled_span():
|
||||
enabled_trace_disabled_span()
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 1 # Only the agent span is recorded
|
||||
assert len(traces) == 1 # The trace is recorded
|
||||
|
||||
trace = traces[0]
|
||||
standard_trace_checks(trace, name_check="test")
|
||||
trace_id = trace.trace_id
|
||||
|
||||
first_span = spans[0]
|
||||
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
|
||||
assert first_span.span_data.name == "agent_1"
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "test",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {"name": "agent_1"},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_start_and_end_called_manual():
|
||||
|
|
@ -367,9 +370,7 @@ async def test_noop_span_doesnt_record():
|
|||
with custom_span(name="span_1") as span:
|
||||
span.set_error(SpanError(message="test", data={}))
|
||||
|
||||
spans, traces = fetch_ordered_spans(), fetch_traces()
|
||||
assert len(spans) == 0
|
||||
assert len(traces) == 0
|
||||
assert_no_traces()
|
||||
|
||||
assert t.export() is None
|
||||
assert span.export() is None
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from agents import (
|
|||
Runner,
|
||||
TResponseInputItem,
|
||||
)
|
||||
from agents.tracing import AgentSpanData, FunctionSpanData, GenerationSpanData
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import (
|
||||
|
|
@ -28,7 +27,7 @@ from .test_responses import (
|
|||
get_handoff_tool_call,
|
||||
get_text_message,
|
||||
)
|
||||
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
|
||||
from .testing_processor import fetch_normalized_spans
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -43,9 +42,6 @@ async def test_single_turn_model_error():
|
|||
with pytest.raises(ValueError):
|
||||
await Runner.run(agent, input="first_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -74,13 +70,6 @@ async def test_single_turn_model_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
|
||||
|
||||
generation_span = spans[1]
|
||||
assert isinstance(generation_span.span_data, GenerationSpanData)
|
||||
assert generation_span.error, "should have error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_no_handoffs():
|
||||
|
|
@ -106,9 +95,6 @@ async def test_multi_turn_no_handoffs():
|
|||
with pytest.raises(ValueError):
|
||||
await Runner.run(agent, input="first_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -146,15 +132,6 @@ async def test_multi_turn_no_handoffs():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 4, (
|
||||
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
last_generation_span = [x for x in spans if isinstance(x.span_data, GenerationSpanData)][-1]
|
||||
assert last_generation_span.error, "should have error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_error():
|
||||
|
|
@ -173,9 +150,6 @@ async def test_tool_call_error():
|
|||
with pytest.raises(ModelBehaviorError):
|
||||
await Runner.run(agent, input="first_test")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -209,15 +183,6 @@ async def test_tool_call_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 3, (
|
||||
f"should have agent, generation, tool spans, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
function_span = [x for x in spans if isinstance(x.span_data, FunctionSpanData)][0]
|
||||
assert function_span.error, "should have error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_handoff_doesnt_error():
|
||||
|
|
@ -255,9 +220,6 @@ async def test_multiple_handoff_doesnt_error():
|
|||
result = await Runner.run(agent_3, input="user_message")
|
||||
assert result.last_agent == agent_1, "should have picked first handoff"
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -295,12 +257,6 @@ async def test_multiple_handoff_doesnt_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 7, (
|
||||
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
|
||||
class Foo(TypedDict):
|
||||
bar: str
|
||||
|
|
@ -326,9 +282,6 @@ async def test_multiple_final_output_doesnt_error():
|
|||
result = await Runner.run(agent_1, input="user_message")
|
||||
assert result.final_output == Foo(bar="abc")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -344,12 +297,6 @@ async def test_multiple_final_output_doesnt_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, (
|
||||
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs_lead_to_correct_agent_spans():
|
||||
|
|
@ -399,9 +346,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
|
|||
f"should have ended on the third agent, got {result.last_agent.name}"
|
||||
)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -472,12 +416,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 12, (
|
||||
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_turns_exceeded():
|
||||
|
|
@ -503,9 +441,6 @@ async def test_max_turns_exceeded():
|
|||
with pytest.raises(MaxTurnsExceeded):
|
||||
await Runner.run(agent, input="user_message", max_turns=2)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -538,15 +473,6 @@ async def test_max_turns_exceeded():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 5, (
|
||||
f"should have 1 agent span, 2 generations, 2 function calls, got "
|
||||
f"{len(spans)} with data: {[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
|
||||
assert agent_span.error, "last agent should have error"
|
||||
|
||||
|
||||
def guardrail_function(
|
||||
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
|
||||
|
|
@ -568,9 +494,6 @@ async def test_guardrail_error():
|
|||
with pytest.raises(InputGuardrailTripwireTriggered):
|
||||
await Runner.run(agent, input="user_message")
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -594,12 +517,3 @@ async def test_guardrail_error():
|
|||
}
|
||||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, (
|
||||
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
|
||||
assert agent_span.error, "last agent should have error"
|
||||
|
|
|
|||
|
|
@ -10,9 +10,6 @@ from typing_extensions import TypedDict
|
|||
|
||||
from agents import (
|
||||
Agent,
|
||||
AgentSpanData,
|
||||
FunctionSpanData,
|
||||
GenerationSpanData,
|
||||
GuardrailFunctionOutput,
|
||||
InputGuardrail,
|
||||
InputGuardrailTripwireTriggered,
|
||||
|
|
@ -33,7 +30,7 @@ from .test_responses import (
|
|||
get_handoff_tool_call,
|
||||
get_text_message,
|
||||
)
|
||||
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
|
||||
from .testing_processor import fetch_normalized_spans
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -50,9 +47,6 @@ async def test_single_turn_model_error():
|
|||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -82,13 +76,6 @@ async def test_single_turn_model_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
|
||||
|
||||
generation_span = spans[1]
|
||||
assert isinstance(generation_span.span_data, GenerationSpanData)
|
||||
assert generation_span.error, "should have error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_no_handoffs():
|
||||
|
|
@ -116,9 +103,6 @@ async def test_multi_turn_no_handoffs():
|
|||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -157,15 +141,6 @@ async def test_multi_turn_no_handoffs():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 4, (
|
||||
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
last_generation_span = [x for x in spans if isinstance(x.span_data, GenerationSpanData)][-1]
|
||||
assert last_generation_span.error, "should have error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_error():
|
||||
|
|
@ -186,9 +161,6 @@ async def test_tool_call_error():
|
|||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -226,15 +198,6 @@ async def test_tool_call_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 3, (
|
||||
f"should have agent, generation, tool spans, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
function_span = [x for x in spans if isinstance(x.span_data, FunctionSpanData)][0]
|
||||
assert function_span.error, "should have error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_handoff_doesnt_error():
|
||||
|
|
@ -275,9 +238,6 @@ async def test_multiple_handoff_doesnt_error():
|
|||
|
||||
assert result.last_agent == agent_1, "should have picked first handoff"
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -315,12 +275,6 @@ async def test_multiple_handoff_doesnt_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 7, (
|
||||
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
|
||||
class Foo(TypedDict):
|
||||
bar: str
|
||||
|
|
@ -350,9 +304,6 @@ async def test_multiple_final_output_no_error():
|
|||
assert isinstance(result.final_output, dict)
|
||||
assert result.final_output["bar"] == "abc"
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -368,12 +319,6 @@ async def test_multiple_final_output_no_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, (
|
||||
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs_lead_to_correct_agent_spans():
|
||||
|
|
@ -425,85 +370,6 @@ async def test_handoffs_lead_to_correct_agent_spans():
|
|||
f"should have ended on the third agent, got {result.last_agent.name}"
|
||||
)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
"workflow_name": "Agent workflow",
|
||||
"children": [
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_3",
|
||||
"handoffs": ["test_agent_1", "test_agent_2"],
|
||||
"tools": ["some_function"],
|
||||
"output_type": "str",
|
||||
},
|
||||
"children": [
|
||||
{"type": "generation"},
|
||||
{
|
||||
"type": "function",
|
||||
"data": {
|
||||
"name": "some_function",
|
||||
"input": '{"a": "b"}',
|
||||
"output": "result",
|
||||
},
|
||||
},
|
||||
{"type": "generation"},
|
||||
{
|
||||
"type": "handoff",
|
||||
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_1",
|
||||
"handoffs": ["test_agent_3"],
|
||||
"tools": ["some_function"],
|
||||
"output_type": "str",
|
||||
},
|
||||
"children": [
|
||||
{"type": "generation"},
|
||||
{
|
||||
"type": "function",
|
||||
"data": {
|
||||
"name": "some_function",
|
||||
"input": '{"a": "b"}',
|
||||
"output": "result",
|
||||
},
|
||||
},
|
||||
{"type": "generation"},
|
||||
{
|
||||
"type": "handoff",
|
||||
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "agent",
|
||||
"data": {
|
||||
"name": "test_agent_3",
|
||||
"handoffs": ["test_agent_1", "test_agent_2"],
|
||||
"tools": ["some_function"],
|
||||
"output_type": "str",
|
||||
},
|
||||
"children": [{"type": "generation"}],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 12, (
|
||||
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -601,9 +467,6 @@ async def test_max_turns_exceeded():
|
|||
async for _ in result.stream_events():
|
||||
pass
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -636,15 +499,6 @@ async def test_max_turns_exceeded():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 5, (
|
||||
f"should have 1 agent, 2 generations, 2 function calls, got "
|
||||
f"{len(spans)} with data: {[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
|
||||
assert agent_span.error, "last agent should have error"
|
||||
|
||||
|
||||
def input_guardrail_function(
|
||||
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
|
||||
|
|
@ -673,9 +527,6 @@ async def test_input_guardrail_error():
|
|||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -703,15 +554,6 @@ async def test_input_guardrail_error():
|
|||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, (
|
||||
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
|
||||
assert agent_span.error, "last agent should have error"
|
||||
|
||||
|
||||
def output_guardrail_function(
|
||||
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any
|
||||
|
|
@ -740,9 +582,6 @@ async def test_output_guardrail_error():
|
|||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
traces = fetch_traces()
|
||||
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
|
||||
|
||||
assert fetch_normalized_spans() == snapshot(
|
||||
[
|
||||
{
|
||||
|
|
@ -766,12 +605,3 @@ async def test_output_guardrail_error():
|
|||
}
|
||||
]
|
||||
)
|
||||
|
||||
spans = fetch_ordered_spans()
|
||||
assert len(spans) == 2, (
|
||||
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
|
||||
f"{[x.span_data for x in spans]}"
|
||||
)
|
||||
|
||||
agent_span = [x for x in spans if isinstance(x.span_data, AgentSpanData)][-1]
|
||||
assert agent_span.error, "last agent should have error"
|
||||
|
|
|
|||
|
|
@ -80,26 +80,44 @@ def fetch_events() -> list[TestSpanProcessorEvent]:
|
|||
return SPAN_PROCESSOR_TESTING._events
|
||||
|
||||
|
||||
def fetch_normalized_spans():
|
||||
def assert_no_spans():
|
||||
spans = fetch_ordered_spans()
|
||||
if spans:
|
||||
raise AssertionError(f"Expected 0 spans, got {len(spans)}")
|
||||
|
||||
|
||||
def assert_no_traces():
|
||||
traces = fetch_traces()
|
||||
if traces:
|
||||
raise AssertionError(f"Expected 0 traces, got {len(traces)}")
|
||||
assert_no_spans()
|
||||
|
||||
|
||||
def fetch_normalized_spans(
|
||||
keep_span_id: bool = False, keep_trace_id: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
nodes: dict[tuple[str, str | None], dict[str, Any]] = {}
|
||||
traces = []
|
||||
for trace_obj in fetch_traces():
|
||||
trace = trace_obj.export()
|
||||
assert trace
|
||||
assert trace.pop("object") == "trace"
|
||||
assert trace.pop("id").startswith("trace_")
|
||||
assert trace["id"].startswith("trace_")
|
||||
if not keep_trace_id:
|
||||
del trace["id"]
|
||||
trace = {k: v for k, v in trace.items() if v is not None}
|
||||
nodes[(trace_obj.trace_id, None)] = trace
|
||||
traces.append(trace)
|
||||
|
||||
if not traces:
|
||||
assert not fetch_ordered_spans()
|
||||
assert traces, "Use assert_no_traces() to check for empty traces"
|
||||
|
||||
for span_obj in fetch_ordered_spans():
|
||||
span = span_obj.export()
|
||||
assert span
|
||||
assert span.pop("object") == "trace.span"
|
||||
assert span.pop("id").startswith("span_")
|
||||
assert span["id"].startswith("span_")
|
||||
if not keep_span_id:
|
||||
del span["id"]
|
||||
assert datetime.fromisoformat(span.pop("started_at"))
|
||||
assert datetime.fromisoformat(span.pop("ended_at"))
|
||||
parent_id = span.pop("parent_id")
|
||||
|
|
|
|||
27
tests/tracing/test_processor_api_key.py
Normal file
27
tests/tracing/test_processor_api_key.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import pytest
|
||||
|
||||
from agents.tracing.processors import BackendSpanExporter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_api_key(monkeypatch):
|
||||
# If the API key is not set, it should be None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", None)
|
||||
processor = BackendSpanExporter()
|
||||
assert processor.api_key is None
|
||||
|
||||
# If we set it afterwards, it should be the new value
|
||||
processor.set_api_key("test_api_key")
|
||||
assert processor.api_key == "test_api_key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_api_key_from_env(monkeypatch):
|
||||
# If the API key is not set at creation time but set before access time, it should be the new
|
||||
# value
|
||||
monkeypatch.delenv("OPENAI_API_KEY", None)
|
||||
processor = BackendSpanExporter()
|
||||
|
||||
# If we set it afterwards, it should be the new value
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "foo_bar_123")
|
||||
assert processor.api_key == "foo_bar_123"
|
||||
0
tests/voice/__init__.py
Normal file
0
tests/voice/__init__.py
Normal file
14
tests/voice/conftest.py
Normal file
14
tests/voice/conftest.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if sys.version_info[:2] == (3, 9):
|
||||
this_dir = os.path.dirname(__file__)
|
||||
skip_marker = pytest.mark.skip(reason="Skipped on Python 3.9")
|
||||
|
||||
for item in items:
|
||||
if item.fspath.dirname.startswith(this_dir):
|
||||
item.add_marker(skip_marker)
|
||||
115
tests/voice/fake_models.py
Normal file
115
tests/voice/fake_models.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
try:
|
||||
from agents.voice import (
|
||||
AudioInput,
|
||||
StreamedAudioInput,
|
||||
StreamedTranscriptionSession,
|
||||
STTModel,
|
||||
STTModelSettings,
|
||||
TTSModel,
|
||||
TTSModelSettings,
|
||||
VoiceWorkflowBase,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class FakeTTS(TTSModel):
|
||||
"""Fakes TTS by just returning string bytes."""
|
||||
|
||||
def __init__(self, strategy: Literal["default", "split_words"] = "default"):
|
||||
self.strategy = strategy
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "fake_tts"
|
||||
|
||||
async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
|
||||
if self.strategy == "default":
|
||||
yield np.zeros(2, dtype=np.int16).tobytes()
|
||||
elif self.strategy == "split_words":
|
||||
for _ in text.split():
|
||||
yield np.zeros(2, dtype=np.int16).tobytes()
|
||||
|
||||
async def verify_audio(self, text: str, audio: bytes, dtype: npt.DTypeLike = np.int16) -> None:
|
||||
assert audio == np.zeros(2, dtype=dtype).tobytes()
|
||||
|
||||
async def verify_audio_chunks(
|
||||
self, text: str, audio_chunks: list[bytes], dtype: npt.DTypeLike = np.int16
|
||||
) -> None:
|
||||
assert audio_chunks == [np.zeros(2, dtype=dtype).tobytes() for _word in text.split()]
|
||||
|
||||
|
||||
class FakeSession(StreamedTranscriptionSession):
|
||||
"""A fake streamed transcription session that yields preconfigured transcripts."""
|
||||
|
||||
def __init__(self):
|
||||
self.outputs: list[str] = []
|
||||
|
||||
async def transcribe_turns(self) -> AsyncIterator[str]:
|
||||
for t in self.outputs:
|
||||
yield t
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class FakeSTT(STTModel):
|
||||
"""A fake STT model that either returns a single transcript or yields multiple."""
|
||||
|
||||
def __init__(self, outputs: list[str] | None = None):
|
||||
self.outputs = outputs or []
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "fake_stt"
|
||||
|
||||
async def transcribe(self, _: AudioInput, __: STTModelSettings, ___: bool, ____: bool) -> str:
|
||||
return self.outputs.pop(0)
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
_: StreamedAudioInput,
|
||||
__: STTModelSettings,
|
||||
___: bool,
|
||||
____: bool,
|
||||
) -> StreamedTranscriptionSession:
|
||||
session = FakeSession()
|
||||
session.outputs = self.outputs
|
||||
return session
|
||||
|
||||
|
||||
class FakeWorkflow(VoiceWorkflowBase):
|
||||
"""A fake workflow that yields preconfigured outputs."""
|
||||
|
||||
def __init__(self, outputs: list[list[str]] | None = None):
|
||||
self.outputs = outputs or []
|
||||
|
||||
def add_output(self, output: list[str]) -> None:
|
||||
self.outputs.append(output)
|
||||
|
||||
def add_multiple_outputs(self, outputs: list[list[str]]) -> None:
|
||||
self.outputs.extend(outputs)
|
||||
|
||||
async def run(self, _: str) -> AsyncIterator[str]:
|
||||
if not self.outputs:
|
||||
raise ValueError("No output configured")
|
||||
output = self.outputs.pop(0)
|
||||
for t in output:
|
||||
yield t
|
||||
|
||||
|
||||
class FakeStreamedAudioInput:
|
||||
@classmethod
|
||||
async def get(cls, count: int) -> StreamedAudioInput:
|
||||
input = StreamedAudioInput()
|
||||
for _ in range(count):
|
||||
await input.add_audio(np.zeros(2, dtype=np.int16))
|
||||
return input
|
||||
21
tests/voice/helpers.py
Normal file
21
tests/voice/helpers.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
try:
|
||||
from agents.voice import StreamedAudioResult
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
async def extract_events(result: StreamedAudioResult) -> tuple[list[str], list[bytes]]:
|
||||
"""Collapse pipeline stream events to simple labels for ordering assertions."""
|
||||
flattened: list[str] = []
|
||||
audio_chunks: list[bytes] = []
|
||||
|
||||
async for ev in result.stream():
|
||||
if ev.type == "voice_stream_event_audio":
|
||||
if ev.data is not None:
|
||||
audio_chunks.append(ev.data.tobytes())
|
||||
flattened.append("audio")
|
||||
elif ev.type == "voice_stream_event_lifecycle":
|
||||
flattened.append(ev.event)
|
||||
elif ev.type == "voice_stream_event_error":
|
||||
flattened.append("error")
|
||||
return flattened, audio_chunks
|
||||
127
tests/voice/test_input.py
Normal file
127
tests/voice/test_input.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import io
|
||||
import wave
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from agents import UserError
|
||||
from agents.voice import AudioInput, StreamedAudioInput
|
||||
from agents.voice.input import DEFAULT_SAMPLE_RATE, _buffer_to_audio_file
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def test_buffer_to_audio_file_int16():
|
||||
# Create a simple sine wave in int16 format
|
||||
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
|
||||
buffer = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16)
|
||||
|
||||
filename, audio_file, content_type = _buffer_to_audio_file(buffer)
|
||||
|
||||
assert filename == "audio.wav"
|
||||
assert content_type == "audio/wav"
|
||||
assert isinstance(audio_file, io.BytesIO)
|
||||
|
||||
# Verify the WAV file contents
|
||||
with wave.open(audio_file, "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == DEFAULT_SAMPLE_RATE
|
||||
assert wav_file.getnframes() == len(buffer)
|
||||
|
||||
|
||||
def test_buffer_to_audio_file_float32():
|
||||
# Create a simple sine wave in float32 format
|
||||
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
|
||||
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
||||
|
||||
filename, audio_file, content_type = _buffer_to_audio_file(buffer)
|
||||
|
||||
assert filename == "audio.wav"
|
||||
assert content_type == "audio/wav"
|
||||
assert isinstance(audio_file, io.BytesIO)
|
||||
|
||||
# Verify the WAV file contents
|
||||
with wave.open(audio_file, "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == DEFAULT_SAMPLE_RATE
|
||||
assert wav_file.getnframes() == len(buffer)
|
||||
|
||||
|
||||
def test_buffer_to_audio_file_invalid_dtype():
|
||||
# Create a buffer with invalid dtype (float64)
|
||||
buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
||||
|
||||
with pytest.raises(UserError, match="Buffer must be a numpy array of int16 or float32"):
|
||||
# Purposely ignore the type error
|
||||
_buffer_to_audio_file(buffer) # type: ignore
|
||||
|
||||
|
||||
class TestAudioInput:
|
||||
def test_audio_input_default_params(self):
|
||||
# Create a simple sine wave
|
||||
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
|
||||
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
||||
|
||||
audio_input = AudioInput(buffer=buffer)
|
||||
|
||||
assert audio_input.frame_rate == DEFAULT_SAMPLE_RATE
|
||||
assert audio_input.sample_width == 2
|
||||
assert audio_input.channels == 1
|
||||
assert np.array_equal(audio_input.buffer, buffer)
|
||||
|
||||
def test_audio_input_custom_params(self):
|
||||
# Create a simple sine wave
|
||||
t = np.linspace(0, 1, 48000)
|
||||
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
||||
|
||||
audio_input = AudioInput(buffer=buffer, frame_rate=48000, sample_width=4, channels=2)
|
||||
|
||||
assert audio_input.frame_rate == 48000
|
||||
assert audio_input.sample_width == 4
|
||||
assert audio_input.channels == 2
|
||||
assert np.array_equal(audio_input.buffer, buffer)
|
||||
|
||||
def test_audio_input_to_audio_file(self):
|
||||
# Create a simple sine wave
|
||||
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
|
||||
buffer = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
||||
|
||||
audio_input = AudioInput(buffer=buffer)
|
||||
filename, audio_file, content_type = audio_input.to_audio_file()
|
||||
|
||||
assert filename == "audio.wav"
|
||||
assert content_type == "audio/wav"
|
||||
assert isinstance(audio_file, io.BytesIO)
|
||||
|
||||
# Verify the WAV file contents
|
||||
with wave.open(audio_file, "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == DEFAULT_SAMPLE_RATE
|
||||
assert wav_file.getnframes() == len(buffer)
|
||||
|
||||
|
||||
class TestStreamedAudioInput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_streamed_audio_input(self):
|
||||
streamed_input = StreamedAudioInput()
|
||||
|
||||
# Create some test audio data
|
||||
t = np.linspace(0, 1, DEFAULT_SAMPLE_RATE)
|
||||
audio1 = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
||||
audio2 = np.sin(2 * np.pi * 880 * t).astype(np.float32)
|
||||
|
||||
# Add audio to the queue
|
||||
await streamed_input.add_audio(audio1)
|
||||
await streamed_input.add_audio(audio2)
|
||||
|
||||
# Verify the queue contents
|
||||
assert streamed_input.queue.qsize() == 2
|
||||
# Test non-blocking get
|
||||
assert np.array_equal(streamed_input.queue.get_nowait(), audio1)
|
||||
# Test blocking get
|
||||
assert np.array_equal(await streamed_input.queue.get(), audio2)
|
||||
assert streamed_input.queue.empty()
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue