update tests (#266)

This commit is contained in:
Rohan Mehta 2025-03-20 13:10:47 -04:00 committed by GitHub
commit 3af879ec3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 57 additions and 87 deletions

View file

@ -1,15 +1,14 @@
import asyncio
import random
from agents import (
Agent,
from agents import Agent, function_tool
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
from agents.voice import (
AudioInput,
SingleAgentVoiceWorkflow,
SingleAgentWorkflowCallbacks,
VoicePipeline,
function_tool,
)
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
from .util import AudioPlayer, record_audio

View file

@ -2,15 +2,9 @@ import random
from collections.abc import AsyncIterator
from typing import Callable
from agents import (
Agent,
Runner,
TResponseInputItem,
VoiceWorkflowBase,
VoiceWorkflowHelper,
function_tool,
)
from agents import Agent, Runner, TResponseInputItem, function_tool
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
from agents.voice import VoiceWorkflowBase, VoiceWorkflowHelper
@function_tool

View file

@ -11,8 +11,7 @@ from textual.reactive import reactive
from textual.widgets import Button, RichLog, Static
from typing_extensions import override
from agents import VoicePipeline
from agents.voice.input import StreamedAudioInput
from agents.voice import StreamedAudioInput, VoicePipeline
from .agents import MyWorkflow

View file

@ -54,6 +54,7 @@ dev = [
"sounddevice",
"pynput",
"textual",
"websockets",
]
[tool.uv.workspace]
members = ["agents"]

View file

@ -98,31 +98,6 @@ from .tracing import (
transcription_span,
)
from .usage import Usage
from .voice import (
AudioInput,
OpenAISTTModel,
OpenAISTTTranscriptionSession,
OpenAITTSModel,
OpenAIVoiceModelProvider,
SingleAgentVoiceWorkflow,
SingleAgentWorkflowCallbacks,
StreamedAudioInput,
StreamedAudioResult,
StreamedTranscriptionSession,
STTModel,
STTModelSettings,
TTSModel,
TTSModelSettings,
VoiceModelProvider,
VoicePipeline,
VoicePipelineConfig,
VoiceStreamEvent,
VoiceStreamEventAudio,
VoiceStreamEventLifecycle,
VoiceWorkflowBase,
VoiceWorkflowHelper,
get_sentence_based_splitter,
)
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
@ -268,27 +243,4 @@ __all__ = [
"gen_trace_id",
"gen_span_id",
"default_tool_error_function",
"AudioInput",
"StreamedAudioInput",
"STTModel",
"STTModelSettings",
"TTSModel",
"TTSModelSettings",
"VoiceModelProvider",
"StreamedAudioResult",
"SingleAgentVoiceWorkflow",
"OpenAIVoiceModelProvider",
"OpenAISTTModel",
"OpenAITTSModel",
"VoiceStreamEventAudio",
"VoiceStreamEventLifecycle",
"VoiceStreamEvent",
"VoicePipeline",
"VoicePipelineConfig",
"get_sentence_based_splitter",
"VoiceWorkflowHelper",
"VoiceWorkflowBase",
"StreamedTranscriptionSession",
"OpenAISTTTranscriptionSession",
"SingleAgentWorkflowCallbacks",
]

View file

@ -6,16 +6,19 @@ from typing import Literal
import numpy as np
import numpy.typing as npt
from agents.voice import (
AudioInput,
StreamedAudioInput,
StreamedTranscriptionSession,
STTModel,
STTModelSettings,
TTSModel,
TTSModelSettings,
VoiceWorkflowBase,
)
try:
from agents.voice import (
AudioInput,
StreamedAudioInput,
StreamedTranscriptionSession,
STTModel,
STTModelSettings,
TTSModel,
TTSModelSettings,
VoiceWorkflowBase,
)
except ImportError:
pass
class FakeTTS(TTSModel):

View file

@ -1,4 +1,7 @@
from agents.voice import StreamedAudioResult
try:
from agents.voice import StreamedAudioResult
except ImportError:
pass
async def extract_events(result: StreamedAudioResult) -> tuple[list[str], list[bytes]]:

View file

@ -4,9 +4,12 @@ import wave
import numpy as np
import pytest
from agents import UserError
from agents.voice import AudioInput, StreamedAudioInput
from agents.voice.input import DEFAULT_SAMPLE_RATE, _buffer_to_audio_file
try:
from agents import UserError
from agents.voice import AudioInput, StreamedAudioInput
from agents.voice.input import DEFAULT_SAMPLE_RATE, _buffer_to_audio_file
except ImportError:
pass
def test_buffer_to_audio_file_int16():

View file

@ -8,11 +8,15 @@ from unittest.mock import AsyncMock, patch
import numpy as np
import pytest
from agents.voice import OpenAISTTTranscriptionSession, StreamedAudioInput, STTModelSettings
from agents.voice.exceptions import STTWebsocketConnectionError
from agents.voice.models.openai_stt import EVENT_INACTIVITY_TIMEOUT
try:
from agents.voice import OpenAISTTTranscriptionSession, StreamedAudioInput, STTModelSettings
from agents.voice.exceptions import STTWebsocketConnectionError
from agents.voice.models.openai_stt import EVENT_INACTIVITY_TIMEOUT
from .fake_models import FakeStreamedAudioInput
except ImportError:
pass
from .fake_models import FakeStreamedAudioInput
# ===== Helpers =====

View file

@ -5,7 +5,10 @@ from typing import Any
import pytest
from agents.voice import OpenAITTSModel, TTSModelSettings
try:
from agents.voice import OpenAITTSModel, TTSModelSettings
except ImportError:
pass
class _FakeStreamResponse:

View file

@ -4,10 +4,13 @@ import numpy as np
import numpy.typing as npt
import pytest
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig
try:
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig
from .fake_models import FakeStreamedAudioInput, FakeSTT, FakeTTS, FakeWorkflow
from .helpers import extract_events
from .fake_models import FakeStreamedAudioInput, FakeSTT, FakeTTS, FakeWorkflow
from .helpers import extract_events
except ImportError:
pass
@pytest.mark.asyncio

View file

@ -17,10 +17,14 @@ from agents.items import (
TResponseOutputItem,
TResponseStreamEvent,
)
from agents.voice import SingleAgentVoiceWorkflow
from ..fake_model import get_response_obj
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
try:
from agents.voice import SingleAgentVoiceWorkflow
from ..fake_model import get_response_obj
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
except ImportError:
pass
class FakeStreamingModel(Model):

View file

@ -1085,6 +1085,7 @@ dev = [
{ name = "sounddevice" },
{ name = "textual" },
{ name = "types-pynput" },
{ name = "websockets" },
]
[package.metadata]
@ -1118,6 +1119,7 @@ dev = [
{ name = "sounddevice" },
{ name = "textual" },
{ name = "types-pynput" },
{ name = "websockets" },
]
[[package]]