update tests
This commit is contained in:
parent
6f13d50a6e
commit
1771c1e856
13 changed files with 57 additions and 87 deletions
|
|
@ -1,15 +1,14 @@
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
from agents import Agent, function_tool
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
from agents.voice import (
|
||||
AudioInput,
|
||||
SingleAgentVoiceWorkflow,
|
||||
SingleAgentWorkflowCallbacks,
|
||||
VoicePipeline,
|
||||
function_tool,
|
||||
)
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
|
||||
from .util import AudioPlayer, record_audio
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,9 @@ import random
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Callable
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
Runner,
|
||||
TResponseInputItem,
|
||||
VoiceWorkflowBase,
|
||||
VoiceWorkflowHelper,
|
||||
function_tool,
|
||||
)
|
||||
from agents import Agent, Runner, TResponseInputItem, function_tool
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
from agents.voice import VoiceWorkflowBase, VoiceWorkflowHelper
|
||||
|
||||
|
||||
@function_tool
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from textual.reactive import reactive
|
|||
from textual.widgets import Button, RichLog, Static
|
||||
from typing_extensions import override
|
||||
|
||||
from agents import VoicePipeline
|
||||
from agents.voice.input import StreamedAudioInput
|
||||
from agents.voice import StreamedAudioInput, VoicePipeline
|
||||
|
||||
from .agents import MyWorkflow
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ dev = [
|
|||
"sounddevice",
|
||||
"pynput",
|
||||
"textual",
|
||||
"websockets",
|
||||
]
|
||||
[tool.uv.workspace]
|
||||
members = ["agents"]
|
||||
|
|
|
|||
|
|
@ -98,31 +98,6 @@ from .tracing import (
|
|||
transcription_span,
|
||||
)
|
||||
from .usage import Usage
|
||||
from .voice import (
|
||||
AudioInput,
|
||||
OpenAISTTModel,
|
||||
OpenAISTTTranscriptionSession,
|
||||
OpenAITTSModel,
|
||||
OpenAIVoiceModelProvider,
|
||||
SingleAgentVoiceWorkflow,
|
||||
SingleAgentWorkflowCallbacks,
|
||||
StreamedAudioInput,
|
||||
StreamedAudioResult,
|
||||
StreamedTranscriptionSession,
|
||||
STTModel,
|
||||
STTModelSettings,
|
||||
TTSModel,
|
||||
TTSModelSettings,
|
||||
VoiceModelProvider,
|
||||
VoicePipeline,
|
||||
VoicePipelineConfig,
|
||||
VoiceStreamEvent,
|
||||
VoiceStreamEventAudio,
|
||||
VoiceStreamEventLifecycle,
|
||||
VoiceWorkflowBase,
|
||||
VoiceWorkflowHelper,
|
||||
get_sentence_based_splitter,
|
||||
)
|
||||
|
||||
|
||||
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
|
||||
|
|
@ -268,27 +243,4 @@ __all__ = [
|
|||
"gen_trace_id",
|
||||
"gen_span_id",
|
||||
"default_tool_error_function",
|
||||
"AudioInput",
|
||||
"StreamedAudioInput",
|
||||
"STTModel",
|
||||
"STTModelSettings",
|
||||
"TTSModel",
|
||||
"TTSModelSettings",
|
||||
"VoiceModelProvider",
|
||||
"StreamedAudioResult",
|
||||
"SingleAgentVoiceWorkflow",
|
||||
"OpenAIVoiceModelProvider",
|
||||
"OpenAISTTModel",
|
||||
"OpenAITTSModel",
|
||||
"VoiceStreamEventAudio",
|
||||
"VoiceStreamEventLifecycle",
|
||||
"VoiceStreamEvent",
|
||||
"VoicePipeline",
|
||||
"VoicePipelineConfig",
|
||||
"get_sentence_based_splitter",
|
||||
"VoiceWorkflowHelper",
|
||||
"VoiceWorkflowBase",
|
||||
"StreamedTranscriptionSession",
|
||||
"OpenAISTTTranscriptionSession",
|
||||
"SingleAgentWorkflowCallbacks",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,16 +6,19 @@ from typing import Literal
|
|||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from agents.voice import (
|
||||
AudioInput,
|
||||
StreamedAudioInput,
|
||||
StreamedTranscriptionSession,
|
||||
STTModel,
|
||||
STTModelSettings,
|
||||
TTSModel,
|
||||
TTSModelSettings,
|
||||
VoiceWorkflowBase,
|
||||
)
|
||||
try:
|
||||
from agents.voice import (
|
||||
AudioInput,
|
||||
StreamedAudioInput,
|
||||
StreamedTranscriptionSession,
|
||||
STTModel,
|
||||
STTModelSettings,
|
||||
TTSModel,
|
||||
TTSModelSettings,
|
||||
VoiceWorkflowBase,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class FakeTTS(TTSModel):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from agents.voice import StreamedAudioResult
|
||||
try:
|
||||
from agents.voice import StreamedAudioResult
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
async def extract_events(result: StreamedAudioResult) -> tuple[list[str], list[bytes]]:
|
||||
|
|
|
|||
|
|
@ -4,9 +4,12 @@ import wave
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from agents import UserError
|
||||
from agents.voice import AudioInput, StreamedAudioInput
|
||||
from agents.voice.input import DEFAULT_SAMPLE_RATE, _buffer_to_audio_file
|
||||
try:
|
||||
from agents import UserError
|
||||
from agents.voice import AudioInput, StreamedAudioInput
|
||||
from agents.voice.input import DEFAULT_SAMPLE_RATE, _buffer_to_audio_file
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def test_buffer_to_audio_file_int16():
|
||||
|
|
|
|||
|
|
@ -8,11 +8,15 @@ from unittest.mock import AsyncMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from agents.voice import OpenAISTTTranscriptionSession, StreamedAudioInput, STTModelSettings
|
||||
from agents.voice.exceptions import STTWebsocketConnectionError
|
||||
from agents.voice.models.openai_stt import EVENT_INACTIVITY_TIMEOUT
|
||||
try:
|
||||
from agents.voice import OpenAISTTTranscriptionSession, StreamedAudioInput, STTModelSettings
|
||||
from agents.voice.exceptions import STTWebsocketConnectionError
|
||||
from agents.voice.models.openai_stt import EVENT_INACTIVITY_TIMEOUT
|
||||
|
||||
from .fake_models import FakeStreamedAudioInput
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from .fake_models import FakeStreamedAudioInput
|
||||
|
||||
# ===== Helpers =====
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
|
||||
from agents.voice import OpenAITTSModel, TTSModelSettings
|
||||
try:
|
||||
from agents.voice import OpenAITTSModel, TTSModelSettings
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeStreamResponse:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,13 @@ import numpy as np
|
|||
import numpy.typing as npt
|
||||
import pytest
|
||||
|
||||
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig
|
||||
try:
|
||||
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig
|
||||
|
||||
from .fake_models import FakeStreamedAudioInput, FakeSTT, FakeTTS, FakeWorkflow
|
||||
from .helpers import extract_events
|
||||
from .fake_models import FakeStreamedAudioInput, FakeSTT, FakeTTS, FakeWorkflow
|
||||
from .helpers import extract_events
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -17,10 +17,14 @@ from agents.items import (
|
|||
TResponseOutputItem,
|
||||
TResponseStreamEvent,
|
||||
)
|
||||
from agents.voice import SingleAgentVoiceWorkflow
|
||||
|
||||
from ..fake_model import get_response_obj
|
||||
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||
try:
|
||||
from agents.voice import SingleAgentVoiceWorkflow
|
||||
|
||||
from ..fake_model import get_response_obj
|
||||
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class FakeStreamingModel(Model):
|
||||
|
|
|
|||
2
uv.lock
2
uv.lock
|
|
@ -1085,6 +1085,7 @@ dev = [
|
|||
{ name = "sounddevice" },
|
||||
{ name = "textual" },
|
||||
{ name = "types-pynput" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
|
|
@ -1118,6 +1119,7 @@ dev = [
|
|||
{ name = "sounddevice" },
|
||||
{ name = "textual" },
|
||||
{ name = "types-pynput" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Reference in a new issue