AIvoices/server/fastapi/bot.py
2026-05-09 18:45:55 +05:30

209 lines
7.7 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Shared Pipecat bot logic for the local multi-transport server."""
import os
from typing import Literal
from voice_pipeline import build_voice_pipeline
from dotenv import load_dotenv
from gem_live_route import build_gem_live_route
from grok_route import build_grok_route
from loguru import logger
logger.info("Loading Silero VAD model...")
logger.info("Silero VAD model loaded")
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
ErrorFrame,
Frame,
InputTransportMessageFrame,
InterruptionFrame,
LLMContextFrame,
LLMRunFrame,
OutputAudioRawFrame,
OutputTransportMessageFrame,
STTMuteFrame,
TTSStoppedFrame,
UserStoppedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.transports.base_transport import BaseTransport
logger.info("All components loaded successfully")
load_dotenv(override=True)
CURRENT_VOICE_ROUTE = os.getenv("CURRENT_VOICE_ROUTE", "classic").strip().lower()
AUDIO_IN_SAMPLE_RATE = int(os.getenv("PIPELINE_AUDIO_IN_SAMPLE_RATE", "16000"))
AUDIO_OUT_SAMPLE_RATE = int(os.getenv("PIPELINE_AUDIO_OUT_SAMPLE_RATE", "24000"))
class RealtimeInputControlProcessor(FrameProcessor):
"""Bridge incoming websocket control messages into Pipecat frames."""
def __init__(self, voice_route: str):
super().__init__()
self._voice_route = voice_route
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, InputTransportMessageFrame):
message = frame.message if isinstance(frame.message, dict) else {}
msg_type = message.get("type")
msg = message.get("msg")
if msg_type == "instruction" and msg == "end_of_speech":
if self._voice_route == "gem_live":
await self.push_frame(VADUserStoppedSpeakingFrame(), FrameDirection.DOWNSTREAM)
else:
await self.push_frame(
EmulateUserStoppedSpeakingFrame(), FrameDirection.DOWNSTREAM
)
await self.push_frame(STTMuteFrame(mute=True), FrameDirection.DOWNSTREAM)
return
if msg_type == "instruction" and msg == "INTERRUPT":
await self.push_frame(InterruptionFrame(), FrameDirection.DOWNSTREAM)
if self._voice_route != "gem_live":
await self.push_frame(STTMuteFrame(mute=False), FrameDirection.DOWNSTREAM)
return
await self.push_frame(frame, direction)
class RealtimeOutputControlProcessor(FrameProcessor):
"""Translate pipeline state changes into the old websocket control protocol."""
def __init__(self):
super().__init__()
self._response_started = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if direction is FrameDirection.DOWNSTREAM:
if isinstance(frame, (UserStoppedSpeakingFrame, VADUserStoppedSpeakingFrame)):
await self.push_frame(
OutputTransportMessageFrame(message={"type": "server", "msg": "AUDIO.COMMITTED"}),
direction,
)
elif isinstance(frame, OutputAudioRawFrame) and not self._response_started:
self._response_started = True
logger.debug("Sending RESPONSE.CREATED before first audio packet")
await self.push_frame(STTMuteFrame(mute=True), direction)
await self.push_frame(
OutputTransportMessageFrame(message={"type": "server", "msg": "RESPONSE.CREATED"}),
direction,
)
elif isinstance(frame, (TTSStoppedFrame, BotStoppedSpeakingFrame)):
self._response_started = False
logger.debug("Sending RESPONSE.COMPLETE after TTS stop")
await self.push_frame(STTMuteFrame(mute=False), direction)
await self.push_frame(frame, direction)
await self.push_frame(
OutputTransportMessageFrame(message={"type": "server", "msg": "RESPONSE.COMPLETE"}),
direction,
)
return
elif isinstance(frame, ErrorFrame):
self._response_started = False
await self.push_frame(STTMuteFrame(mute=False), direction)
await self.push_frame(
OutputTransportMessageFrame(message={"type": "server", "msg": "RESPONSE.ERROR"}),
direction,
)
await self.push_frame(frame, direction)
def create_esp32_auth_message() -> dict:
return {
"type": "auth",
"volume_control": int(os.getenv("ESP32_DEFAULT_VOLUME", "100")),
"pitch_factor": float(os.getenv("ESP32_DEFAULT_PITCH_FACTOR", "1.0")),
"is_ota": False,
"is_reset": False,
}
async def run_bot_session(
transport: BaseTransport,
transport_kind: Literal["browser", "esp32"],
handle_sigint: bool = False,
):
voice_route = CURRENT_VOICE_ROUTE
logger.info(f"Starting bot session for {transport_kind} via route={voice_route}")
context = LLMContext()
input_processor = RealtimeInputControlProcessor(voice_route)
if voice_route == "gem_live":
route_processors, assistant_aggregator = build_gem_live_route(input_processor, context)
elif voice_route == "grok":
route_processors, assistant_aggregator = build_grok_route(input_processor, context)
else:
route_processors, assistant_aggregator = build_voice_pipeline(input_processor, context)
processors = [transport.input(), *route_processors]
if transport_kind in {"esp32", "browser"}:
processors.append(RealtimeOutputControlProcessor())
processors.append(transport.output())
processors.append(assistant_aggregator)
pipeline = Pipeline(processors)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
audio_in_sample_rate=AUDIO_IN_SAMPLE_RATE,
audio_out_sample_rate=AUDIO_OUT_SAMPLE_RATE,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"{transport_kind} client connected")
if voice_route in {"gem_live", "grok"}:
context.add_message(
{
"role": "user",
"content": "Say hello and briefly introduce yourself.",
}
)
await task.queue_frames(
[
LLMContextFrame(context=context)
]
)
else:
context.add_message(
{
"role": "developer",
"content": "Say hello and briefly introduce yourself.",
}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"{transport_kind} client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=handle_sigint)
await runner.run(task)