openai-agents-python/src/agents/voice/models/openai_stt.py
Dominik Kundel c7ce154637 feat: add voice pipeline support
> Co-authored-by: rm@openai.com
2025-03-20 09:43:13 -07:00

457 lines
16 KiB
Python

from __future__ import annotations
import asyncio
import base64
import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, cast
from openai import AsyncOpenAI
from agents.exceptions import AgentsException
from ... import _debug
from ...logger import logger
from ...tracing import Span, SpanError, TranscriptionSpanData, transcription_span
from ..exceptions import STTWebsocketConnectionError
from ..imports import np, npt, websockets
from ..input import AudioInput, StreamedAudioInput
from ..model import StreamedTranscriptionSession, STTModel, STTModelSettings
EVENT_INACTIVITY_TIMEOUT = 1000 # Timeout for inactivity in event processing
SESSION_CREATION_TIMEOUT = 10 # Timeout waiting for session.created event
SESSION_UPDATE_TIMEOUT = 10 # Timeout waiting for session.updated event
DEFAULT_TURN_DETECTION = {"type": "semantic_vad"}
@dataclass
class ErrorSentinel:
error: Exception
class SessionCompleteSentinel:
pass
class WebsocketDoneSentinel:
pass
def _audio_to_base64(audio_data: list[npt.NDArray[np.int16 | np.float32]]) -> str:
concatenated_audio = np.concatenate(audio_data)
if concatenated_audio.dtype == np.float32:
# convert to int16
concatenated_audio = np.clip(concatenated_audio, -1.0, 1.0)
concatenated_audio = (concatenated_audio * 32767).astype(np.int16)
audio_bytes = concatenated_audio.tobytes()
return base64.b64encode(audio_bytes).decode("utf-8")
async def _wait_for_event(
event_queue: asyncio.Queue[dict[str, Any]], expected_types: list[str], timeout: float
):
"""
Wait for an event from event_queue whose type is in expected_types within the specified timeout.
"""
start_time = time.time()
while True:
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
raise TimeoutError(f"Timeout waiting for event(s): {expected_types}")
evt = await asyncio.wait_for(event_queue.get(), timeout=remaining)
evt_type = evt.get("type", "")
if evt_type in expected_types:
return evt
elif evt_type == "error":
raise Exception(f"Error event: {evt.get('error')}")
class OpenAISTTTranscriptionSession(StreamedTranscriptionSession):
"""A transcription session for OpenAI's STT model."""
def __init__(
self,
input: StreamedAudioInput,
client: AsyncOpenAI,
model: str,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
):
self.connected: bool = False
self._client = client
self._model = model
self._settings = settings
self._turn_detection = settings.turn_detection or DEFAULT_TURN_DETECTION
self._trace_include_sensitive_data = trace_include_sensitive_data
self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data
self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
asyncio.Queue()
)
self._websocket: websockets.ClientConnection | None = None
self._event_queue: asyncio.Queue[dict[str, Any] | WebsocketDoneSentinel] = asyncio.Queue()
self._state_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._turn_audio_buffer: list[npt.NDArray[np.int16 | np.float32]] = []
self._tracing_span: Span[TranscriptionSpanData] | None = None
# tasks
self._listener_task: asyncio.Task[Any] | None = None
self._process_events_task: asyncio.Task[Any] | None = None
self._stream_audio_task: asyncio.Task[Any] | None = None
self._connection_task: asyncio.Task[Any] | None = None
self._stored_exception: Exception | None = None
def _start_turn(self) -> None:
self._tracing_span = transcription_span(
model=self._model,
model_config={
"temperature": self._settings.temperature,
"language": self._settings.language,
"prompt": self._settings.prompt,
"turn_detection": self._turn_detection,
},
)
self._tracing_span.start()
def _end_turn(self, _transcript: str) -> None:
if len(_transcript) < 1:
return
if self._tracing_span:
if self._trace_include_sensitive_audio_data:
self._tracing_span.span_data.input = _audio_to_base64(self._turn_audio_buffer)
self._tracing_span.span_data.input_format = "pcm"
if self._trace_include_sensitive_data:
self._tracing_span.span_data.output = _transcript
self._tracing_span.finish()
self._turn_audio_buffer = []
self._tracing_span = None
async def _event_listener(self) -> None:
assert self._websocket is not None, "Websocket not initialized"
async for message in self._websocket:
try:
event = json.loads(message)
if event.get("type") == "error":
raise STTWebsocketConnectionError(f"Error event: {event.get('error')}")
if event.get("type") in [
"session.updated",
"transcription_session.updated",
"session.created",
"transcription_session.created",
]:
await self._state_queue.put(event)
await self._event_queue.put(event)
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise STTWebsocketConnectionError("Error parsing events") from e
await self._event_queue.put(WebsocketDoneSentinel())
async def _configure_session(self) -> None:
assert self._websocket is not None, "Websocket not initialized"
await self._websocket.send(
json.dumps(
{
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16",
"input_audio_transcription": {"model": self._model},
"turn_detection": self._turn_detection,
},
}
)
)
async def _setup_connection(self, ws: websockets.ClientConnection) -> None:
self._websocket = ws
self._listener_task = asyncio.create_task(self._event_listener())
try:
event = await _wait_for_event(
self._state_queue,
["session.created", "transcription_session.created"],
SESSION_CREATION_TIMEOUT,
)
except TimeoutError as e:
wrapped_err = STTWebsocketConnectionError(
"Timeout waiting for transcription_session.created event"
)
await self._output_queue.put(ErrorSentinel(wrapped_err))
raise wrapped_err from e
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await self._configure_session()
try:
event = await _wait_for_event(
self._state_queue,
["session.updated", "transcription_session.updated"],
SESSION_UPDATE_TIMEOUT,
)
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Session updated")
else:
logger.debug(f"Session updated: {event}")
except TimeoutError as e:
wrapped_err = STTWebsocketConnectionError(
"Timeout waiting for transcription_session.updated event"
)
await self._output_queue.put(ErrorSentinel(wrapped_err))
raise wrapped_err from e
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise
async def _handle_events(self) -> None:
while True:
try:
event = await asyncio.wait_for(
self._event_queue.get(), timeout=EVENT_INACTIVITY_TIMEOUT
)
if isinstance(event, WebsocketDoneSentinel):
# processed all events and websocket is done
break
event_type = event.get("type", "unknown")
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = cast(str, event.get("transcript", ""))
if len(transcript) > 0:
self._end_turn(transcript)
self._start_turn()
await self._output_queue.put(transcript)
await asyncio.sleep(0) # yield control
except asyncio.TimeoutError:
# No new events for a while. Assume the session is done.
break
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await self._output_queue.put(SessionCompleteSentinel())
async def _stream_audio(
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
) -> None:
assert self._websocket is not None, "Websocket not initialized"
self._start_turn()
while True:
buffer = await audio_queue.get()
if buffer is None:
break
self._turn_audio_buffer.append(buffer)
try:
await self._websocket.send(
json.dumps(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(buffer.tobytes()).decode("utf-8"),
}
)
)
except websockets.ConnectionClosed:
break
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
await asyncio.sleep(0) # yield control
async def _process_websocket_connection(self) -> None:
try:
async with websockets.connect(
"wss://api.openai.com/v1/realtime?intent=transcription",
additional_headers={
"Authorization": f"Bearer {self._client.api_key}",
"OpenAI-Beta": "realtime=v1",
"OpenAI-Log-Session": "1",
},
) as ws:
await self._setup_connection(ws)
self._process_events_task = asyncio.create_task(self._handle_events())
self._stream_audio_task = asyncio.create_task(self._stream_audio(self._input_queue))
self.connected = True
if self._listener_task:
await self._listener_task
else:
logger.error("Listener task not initialized")
raise AgentsException("Listener task not initialized")
except Exception as e:
await self._output_queue.put(ErrorSentinel(e))
raise e
def _check_errors(self) -> None:
if self._connection_task and self._connection_task.done():
exc = self._connection_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._process_events_task and self._process_events_task.done():
exc = self._process_events_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._stream_audio_task and self._stream_audio_task.done():
exc = self._stream_audio_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
if self._listener_task and self._listener_task.done():
exc = self._listener_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
def _cleanup_tasks(self) -> None:
if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
if self._process_events_task and not self._process_events_task.done():
self._process_events_task.cancel()
if self._stream_audio_task and not self._stream_audio_task.done():
self._stream_audio_task.cancel()
if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
async def transcribe_turns(self) -> AsyncIterator[str]:
self._connection_task = asyncio.create_task(self._process_websocket_connection())
while True:
try:
turn = await self._output_queue.get()
except asyncio.CancelledError:
break
if (
turn is None
or isinstance(turn, ErrorSentinel)
or isinstance(turn, SessionCompleteSentinel)
):
self._output_queue.task_done()
break
yield turn
self._output_queue.task_done()
if self._tracing_span:
self._end_turn("")
if self._websocket:
await self._websocket.close()
self._check_errors()
if self._stored_exception:
raise self._stored_exception
async def close(self) -> None:
if self._websocket:
await self._websocket.close()
self._cleanup_tasks()
class OpenAISTTModel(STTModel):
"""A speech-to-text model for OpenAI."""
def __init__(
self,
model: str,
openai_client: AsyncOpenAI,
):
"""Create a new OpenAI speech-to-text model.
Args:
model: The name of the model to use.
openai_client: The OpenAI client to use.
"""
self.model = model
self._client = openai_client
@property
def model_name(self) -> str:
return self.model
def _non_null_or_not_given(self, value: Any) -> Any:
return value if value is not None else None # NOT_GIVEN
async def transcribe(
self,
input: AudioInput,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
) -> str:
"""Transcribe an audio input.
Args:
input: The audio input to transcribe.
settings: The settings to use for the transcription.
Returns:
The transcribed text.
"""
with transcription_span(
model=self.model,
input=input.to_base64() if trace_include_sensitive_audio_data else "",
input_format="pcm",
model_config={
"temperature": self._non_null_or_not_given(settings.temperature),
"language": self._non_null_or_not_given(settings.language),
"prompt": self._non_null_or_not_given(settings.prompt),
},
) as span:
try:
response = await self._client.audio.transcriptions.create(
model=self.model,
file=input.to_audio_file(),
prompt=self._non_null_or_not_given(settings.prompt),
language=self._non_null_or_not_given(settings.language),
temperature=self._non_null_or_not_given(settings.temperature),
)
if trace_include_sensitive_data:
span.span_data.output = response.text
return response.text
except Exception as e:
span.span_data.output = ""
span.set_error(SpanError(message=str(e), data={}))
raise e
async def create_session(
self,
input: StreamedAudioInput,
settings: STTModelSettings,
trace_include_sensitive_data: bool,
trace_include_sensitive_audio_data: bool,
) -> StreamedTranscriptionSession:
"""Create a new transcription session.
Args:
input: The audio input to transcribe.
settings: The settings to use for the transcription.
trace_include_sensitive_data: Whether to include sensitive data in traces.
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
Returns:
A new transcription session.
"""
return OpenAISTTTranscriptionSession(
input,
self._client,
self.model,
settings,
trace_include_sensitive_data,
trace_include_sensitive_audio_data,
)