voxtral mlx : improved chunking
This commit is contained in:
parent
9d8db7ab38
commit
dfd5bf417c
11 changed files with 1812 additions and 171 deletions
|
|
@ -13,6 +13,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logging.getLogger("whisperlivekit.qwen3_asr").setLevel(logging.DEBUG)
|
||||||
|
|
||||||
config = parse_args()
|
config = parse_args()
|
||||||
transcription_engine = None
|
transcription_engine = None
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,8 @@ BACKENDS = [
|
||||||
"install": "pip install faster-whisper",
|
"install": "pip install faster-whisper",
|
||||||
"description": "CTranslate2-based Whisper (fast, CPU/CUDA)",
|
"description": "CTranslate2-based Whisper (fast, CPU/CUDA)",
|
||||||
"policy": "localagreement",
|
"policy": "localagreement",
|
||||||
|
"streaming": "chunk", # batch inference with LocalAgreement/SimulStreaming
|
||||||
|
"devices": ["cpu", "cuda"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "whisper",
|
"id": "whisper",
|
||||||
|
|
@ -65,6 +67,8 @@ BACKENDS = [
|
||||||
"install": "pip install openai-whisper",
|
"install": "pip install openai-whisper",
|
||||||
"description": "Original OpenAI Whisper (PyTorch)",
|
"description": "Original OpenAI Whisper (PyTorch)",
|
||||||
"policy": "simulstreaming",
|
"policy": "simulstreaming",
|
||||||
|
"streaming": "chunk",
|
||||||
|
"devices": ["cpu", "cuda"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "mlx-whisper",
|
"id": "mlx-whisper",
|
||||||
|
|
@ -74,21 +78,27 @@ BACKENDS = [
|
||||||
"description": "Apple Silicon native Whisper (MLX)",
|
"description": "Apple Silicon native Whisper (MLX)",
|
||||||
"policy": "localagreement",
|
"policy": "localagreement",
|
||||||
"platform": "darwin-arm64",
|
"platform": "darwin-arm64",
|
||||||
|
"streaming": "chunk",
|
||||||
|
"devices": ["mlx"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "voxtral-mlx",
|
"id": "voxtral-mlx",
|
||||||
"name": "Voxtral MLX",
|
"name": "Voxtral MLX",
|
||||||
"module": "mlx",
|
"module": "mlx",
|
||||||
"install": "pip install whisperlivekit[voxtral-mlx]",
|
"install": "pip install whisperlivekit[voxtral-mlx]",
|
||||||
"description": "Mistral Voxtral Mini on Apple Silicon (MLX)",
|
"description": "Mistral Voxtral Mini on Apple Silicon (MLX, native streaming)",
|
||||||
"platform": "darwin-arm64",
|
"platform": "darwin-arm64",
|
||||||
|
"streaming": "native", # truly streaming (token-by-token)
|
||||||
|
"devices": ["mlx"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "voxtral",
|
"id": "voxtral",
|
||||||
"name": "Voxtral HF",
|
"name": "Voxtral HF",
|
||||||
"module": "transformers",
|
"module": "transformers",
|
||||||
"install": "pip install whisperlivekit[voxtral-hf]",
|
"install": "pip install whisperlivekit[voxtral-hf]",
|
||||||
"description": "Mistral Voxtral Mini (HF Transformers, CUDA/CPU/MPS)",
|
"description": "Mistral Voxtral Mini (HF Transformers, native streaming)",
|
||||||
|
"streaming": "native",
|
||||||
|
"devices": ["cuda", "mps", "cpu"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "qwen3",
|
"id": "qwen3",
|
||||||
|
|
@ -96,6 +106,8 @@ BACKENDS = [
|
||||||
"module": "qwen_asr",
|
"module": "qwen_asr",
|
||||||
"install": "pip install qwen-asr",
|
"install": "pip install qwen-asr",
|
||||||
"description": "Qwen3-ASR with ForcedAligner timestamps",
|
"description": "Qwen3-ASR with ForcedAligner timestamps",
|
||||||
|
"streaming": "chunk",
|
||||||
|
"devices": ["cuda", "mps", "cpu"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "openai-api",
|
"id": "openai-api",
|
||||||
|
|
@ -103,6 +115,8 @@ BACKENDS = [
|
||||||
"module": "openai",
|
"module": "openai",
|
||||||
"install": "pip install openai",
|
"install": "pip install openai",
|
||||||
"description": "Cloud-based transcription via OpenAI API",
|
"description": "Cloud-based transcription via OpenAI API",
|
||||||
|
"streaming": "cloud",
|
||||||
|
"devices": ["cloud"],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -159,6 +173,28 @@ QWEN3_REPOS = {
|
||||||
}
|
}
|
||||||
QWEN3_ALIGNER_REPO = "Qwen/Qwen3-ForcedAligner-0.6B"
|
QWEN3_ALIGNER_REPO = "Qwen/Qwen3-ForcedAligner-0.6B"
|
||||||
|
|
||||||
|
# Model catalog: metadata for display in `wlk models`
|
||||||
|
# params = approximate parameter count, disk = approximate download size
|
||||||
|
MODEL_CATALOG = [
|
||||||
|
# Whisper family (available across faster-whisper, mlx-whisper, whisper backends)
|
||||||
|
{"name": "tiny", "family": "whisper", "params": "39M", "disk": "75 MB", "languages": 99, "quality": "low", "speed": "fastest"},
|
||||||
|
{"name": "tiny.en", "family": "whisper", "params": "39M", "disk": "75 MB", "languages": 1, "quality": "low", "speed": "fastest"},
|
||||||
|
{"name": "base", "family": "whisper", "params": "74M", "disk": "142 MB", "languages": 99, "quality": "fair", "speed": "fast"},
|
||||||
|
{"name": "base.en", "family": "whisper", "params": "74M", "disk": "142 MB", "languages": 1, "quality": "fair", "speed": "fast"},
|
||||||
|
{"name": "small", "family": "whisper", "params": "244M", "disk": "466 MB", "languages": 99, "quality": "good", "speed": "medium"},
|
||||||
|
{"name": "small.en", "family": "whisper", "params": "244M", "disk": "466 MB", "languages": 1, "quality": "good", "speed": "medium"},
|
||||||
|
{"name": "medium", "family": "whisper", "params": "769M", "disk": "1.5 GB", "languages": 99, "quality": "great", "speed": "slow"},
|
||||||
|
{"name": "medium.en", "family": "whisper", "params": "769M", "disk": "1.5 GB", "languages": 1, "quality": "great", "speed": "slow"},
|
||||||
|
{"name": "large-v3", "family": "whisper", "params": "1.5B", "disk": "3.1 GB", "languages": 99, "quality": "best", "speed": "slowest"},
|
||||||
|
{"name": "large-v3-turbo", "family": "whisper", "params": "809M", "disk": "1.6 GB", "languages": 99, "quality": "great", "speed": "medium"},
|
||||||
|
# Voxtral (native streaming, single model)
|
||||||
|
{"name": "voxtral", "family": "voxtral", "params": "4B", "disk": "8.2 GB", "languages": 15, "quality": "great", "speed": "medium"},
|
||||||
|
{"name": "voxtral-mlx", "family": "voxtral", "params": "4B", "disk": "2.7 GB", "languages": 15, "quality": "great", "speed": "medium"},
|
||||||
|
# Qwen3 ASR
|
||||||
|
{"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"},
|
||||||
|
{"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _check_platform(backend: dict) -> bool:
|
def _check_platform(backend: dict) -> bool:
|
||||||
"""Check if backend is compatible with current platform."""
|
"""Check if backend is compatible with current platform."""
|
||||||
|
|
@ -254,93 +290,124 @@ def print_banner(config, host: str, port: int, ssl: bool = False):
|
||||||
# `wlk models` subcommand
|
# `wlk models` subcommand
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def cmd_models():
|
def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool:
|
||||||
"""List available backends and their installation status."""
|
"""Check if a model catalog entry has been downloaded."""
|
||||||
is_apple_silicon = platform.system() == "Darwin" and platform.machine() == "arm64"
|
name = model_entry["name"]
|
||||||
|
family = model_entry["family"]
|
||||||
|
|
||||||
print("\nAvailable backends:\n")
|
if family == "whisper":
|
||||||
|
# Check all whisper backends
|
||||||
|
repos = [
|
||||||
|
FASTER_WHISPER_REPOS.get(name),
|
||||||
|
MLX_WHISPER_REPOS.get(name),
|
||||||
|
f"openai/whisper-{name}",
|
||||||
|
]
|
||||||
|
return any(r in downloaded for r in repos if r)
|
||||||
|
elif name == "voxtral":
|
||||||
|
return VOXTRAL_HF_REPO in downloaded
|
||||||
|
elif name == "voxtral-mlx":
|
||||||
|
return VOXTRAL_MLX_REPO in downloaded
|
||||||
|
elif family == "qwen3":
|
||||||
|
size = name.split(":")[1] if ":" in name else "1.7b"
|
||||||
|
return QWEN3_REPOS.get(size, "") in downloaded
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _best_backend_for_model(model_entry: dict) -> str:
|
||||||
|
"""Suggest the best available backend for a model."""
|
||||||
|
family = model_entry["family"]
|
||||||
|
is_apple = platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||||
|
|
||||||
|
if family == "voxtral":
|
||||||
|
if "mlx" in model_entry["name"]:
|
||||||
|
return "voxtral-mlx"
|
||||||
|
return "voxtral"
|
||||||
|
elif family == "qwen3":
|
||||||
|
return "qwen3"
|
||||||
|
elif family == "whisper":
|
||||||
|
if is_apple and _module_available("mlx_whisper"):
|
||||||
|
return "mlx-whisper"
|
||||||
|
if _module_available("faster_whisper"):
|
||||||
|
return "faster-whisper"
|
||||||
|
if _module_available("whisper"):
|
||||||
|
return "whisper"
|
||||||
|
# Suggest best installable
|
||||||
|
return "mlx-whisper" if is_apple else "faster-whisper"
|
||||||
|
return "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_models():
|
||||||
|
"""List available models and backends (ollama-style)."""
|
||||||
|
is_apple_silicon = platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||||
|
downloaded = _scan_downloaded_models()
|
||||||
|
|
||||||
|
# --- Installed backends ---
|
||||||
|
print("\n Backends:\n")
|
||||||
|
|
||||||
max_name = max(len(b["name"]) for b in BACKENDS)
|
max_name = max(len(b["name"]) for b in BACKENDS)
|
||||||
|
|
||||||
for b in BACKENDS:
|
for b in BACKENDS:
|
||||||
compatible = _check_platform(b)
|
compatible = _check_platform(b)
|
||||||
installed = _is_installed(b)
|
installed = _is_installed(b)
|
||||||
|
streaming = b.get("streaming", "chunk")
|
||||||
|
stream_label = {"native": "streaming", "chunk": "chunked", "cloud": "cloud"}.get(streaming, streaming)
|
||||||
|
|
||||||
if installed:
|
if installed:
|
||||||
status = "\033[32m installed\033[0m"
|
status = "\033[32m+\033[0m"
|
||||||
elif not compatible:
|
elif not compatible:
|
||||||
status = "\033[90m n/a (wrong platform)\033[0m"
|
status = "\033[90m-\033[0m"
|
||||||
else:
|
else:
|
||||||
status = "\033[33m not installed\033[0m"
|
status = "\033[33m-\033[0m"
|
||||||
|
|
||||||
name_pad = b["name"].ljust(max_name)
|
name_pad = b["name"].ljust(max_name)
|
||||||
print(f" {name_pad} [{status}] {b['description']}")
|
desc_short = b["description"]
|
||||||
|
print(f" {status} {name_pad} {desc_short} [{stream_label}]")
|
||||||
|
|
||||||
if not installed and compatible:
|
if not installed and compatible:
|
||||||
print(f" {''.ljust(max_name)} └─ {b['install']}")
|
print(f" {''.ljust(max_name)} \033[90m{b['install']}\033[0m")
|
||||||
|
|
||||||
# System info
|
# --- System info ---
|
||||||
print(f"\n Platform: {platform.system()} {platform.machine()}")
|
print(f"\n Platform: {platform.system()} {platform.machine()}")
|
||||||
print(f" Python: {platform.python_version()}")
|
|
||||||
print(f" Accelerator: {_gpu_info()}")
|
print(f" Accelerator: {_gpu_info()}")
|
||||||
print(f" ffmpeg: {'found' if _check_ffmpeg() else 'NOT FOUND (required)'}")
|
print(f" ffmpeg: {'found' if _check_ffmpeg() else '\033[31mNOT FOUND\033[0m (required)'}")
|
||||||
|
|
||||||
|
# --- Model catalog ---
|
||||||
|
print("\n Models:\n")
|
||||||
|
|
||||||
|
# Table header
|
||||||
|
hdr = f" {'NAME':<20} {'PARAMS':>7} {'SIZE':>8} {'QUALITY':<8} {'SPEED':<8} {'LANGS':>5} {'STATUS':<10}"
|
||||||
|
print(hdr)
|
||||||
|
print(f" {'─' * 20} {'─' * 7} {'─' * 8} {'─' * 8} {'─' * 8} {'─' * 5} {'─' * 10}")
|
||||||
|
|
||||||
|
for m in MODEL_CATALOG:
|
||||||
|
name = m["name"]
|
||||||
|
# Skip platform-incompatible models
|
||||||
|
if name == "voxtral-mlx" and not is_apple_silicon:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_dl = _model_is_downloaded(m, downloaded)
|
||||||
|
|
||||||
|
if is_dl:
|
||||||
|
status = "\033[32mpulled\033[0m "
|
||||||
|
else:
|
||||||
|
status = "\033[90mavailable\033[0m "
|
||||||
|
|
||||||
|
langs = str(m["languages"]) if m["languages"] < 99 else "99+"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" {name:<20} {m['params']:>7} {m['disk']:>8} "
|
||||||
|
f"{m['quality']:<8} {m['speed']:<8} {langs:>5} {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Quick start ---
|
||||||
|
print(f"\n Quick start:\n")
|
||||||
if is_apple_silicon:
|
if is_apple_silicon:
|
||||||
print("\n Tip: On Apple Silicon, mlx-whisper and voxtral-mlx offer the best performance.")
|
print(" wlk run voxtral-mlx # Best streaming on Apple Silicon")
|
||||||
|
print(" wlk run large-v3-turbo # Best quality/speed balance")
|
||||||
# Scan for downloaded models
|
else:
|
||||||
downloaded = _scan_downloaded_models()
|
print(" wlk run large-v3-turbo # Best quality/speed balance")
|
||||||
|
print(" wlk run voxtral # Native streaming (CUDA/CPU)")
|
||||||
print("\n Downloaded models:\n")
|
print(" wlk pull base # Download smallest multilingual model")
|
||||||
found_any = False
|
print(" wlk transcribe audio.mp3 # Offline transcription")
|
||||||
|
|
||||||
# Check Whisper-family models
|
|
||||||
all_repos = {
|
|
||||||
"faster-whisper": FASTER_WHISPER_REPOS,
|
|
||||||
"mlx-whisper": MLX_WHISPER_REPOS,
|
|
||||||
}
|
|
||||||
for backend_name, repos in all_repos.items():
|
|
||||||
for size, repo_id in repos.items():
|
|
||||||
if repo_id in downloaded:
|
|
||||||
found_any = True
|
|
||||||
print(f" \033[32m*\033[0m {backend_name}:{size} ({repo_id})")
|
|
||||||
|
|
||||||
# Check native whisper
|
|
||||||
for size in WHISPER_SIZES:
|
|
||||||
key = f"openai/whisper-{size}"
|
|
||||||
if key in downloaded:
|
|
||||||
found_any = True
|
|
||||||
print(f" \033[32m*\033[0m whisper:{size}")
|
|
||||||
|
|
||||||
# Check voxtral / qwen3
|
|
||||||
if VOXTRAL_HF_REPO in downloaded:
|
|
||||||
found_any = True
|
|
||||||
print(f" \033[32m*\033[0m voxtral ({VOXTRAL_HF_REPO})")
|
|
||||||
if VOXTRAL_MLX_REPO in downloaded:
|
|
||||||
found_any = True
|
|
||||||
print(f" \033[32m*\033[0m voxtral-mlx ({VOXTRAL_MLX_REPO})")
|
|
||||||
for qsize, repo_id in QWEN3_REPOS.items():
|
|
||||||
if repo_id in downloaded:
|
|
||||||
found_any = True
|
|
||||||
print(f" \033[32m*\033[0m qwen3:{qsize} ({repo_id})")
|
|
||||||
if QWEN3_ALIGNER_REPO in downloaded:
|
|
||||||
found_any = True
|
|
||||||
print(f" \033[32m*\033[0m qwen3-aligner ({QWEN3_ALIGNER_REPO})")
|
|
||||||
|
|
||||||
if not found_any:
|
|
||||||
print(" (none — models download automatically on first use, or use 'wlk pull')")
|
|
||||||
|
|
||||||
# Show pullable models
|
|
||||||
print("\n Available models (use 'wlk pull <name>'):\n")
|
|
||||||
print(" Whisper sizes: " + ", ".join(WHISPER_SIZES))
|
|
||||||
print(" Voxtral: voxtral, voxtral-mlx")
|
|
||||||
print(" Qwen3: qwen3:1.7b, qwen3:0.6b")
|
|
||||||
print()
|
|
||||||
print(" Examples:")
|
|
||||||
print(" wlk pull base # Download for best available backend")
|
|
||||||
print(" wlk pull faster-whisper:large-v3 # Specific backend + model")
|
|
||||||
print(" wlk pull voxtral # Voxtral HF model")
|
|
||||||
print(" wlk pull qwen3:1.7b # Qwen3-ASR 1.7B")
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1010,6 +1077,23 @@ def cmd_run(args: list):
|
||||||
if parsed.model:
|
if parsed.model:
|
||||||
backend_flag, model_flag = _resolve_run_spec(parsed.model)
|
backend_flag, model_flag = _resolve_run_spec(parsed.model)
|
||||||
|
|
||||||
|
# Show what we resolved
|
||||||
|
catalog_match = next(
|
||||||
|
(m for m in MODEL_CATALOG if m["name"] == parsed.model),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if catalog_match:
|
||||||
|
print(
|
||||||
|
f"\n Model: {catalog_match['name']} "
|
||||||
|
f"({catalog_match['params']} params, {catalog_match['disk']})",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
if backend_flag:
|
||||||
|
print(f" Backend: {backend_flag}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
best = _best_backend_for_model(catalog_match)
|
||||||
|
print(f" Backend: {best} (auto-detected)", file=sys.stderr)
|
||||||
|
|
||||||
# Auto-pull if needed
|
# Auto-pull if needed
|
||||||
downloaded = _scan_downloaded_models()
|
downloaded = _scan_downloaded_models()
|
||||||
targets = _resolve_pull_target(parsed.model)
|
targets = _resolve_pull_target(parsed.model)
|
||||||
|
|
@ -1198,9 +1282,9 @@ def _probe_backend_state(processor) -> dict:
|
||||||
info["n_audio_tokens_fed"] = transcription._n_audio_tokens_fed
|
info["n_audio_tokens_fed"] = transcription._n_audio_tokens_fed
|
||||||
info["n_text_tokens_received"] = transcription._n_text_tokens_received
|
info["n_text_tokens_received"] = transcription._n_text_tokens_received
|
||||||
info["n_committed_words"] = transcription._n_committed_words
|
info["n_committed_words"] = transcription._n_committed_words
|
||||||
info["pending_audio_samples"] = len(transcription._pending_audio)
|
info["pending_audio_samples"] = transcription._pending_len
|
||||||
with transcription._text_lock:
|
with transcription._text_lock:
|
||||||
info["accumulated_text"] = transcription._accumulated_text
|
info["accumulated_text"] = transcription._get_accumulated_text()
|
||||||
if transcription._generate_error:
|
if transcription._generate_error:
|
||||||
info["generate_error"] = str(transcription._generate_error)
|
info["generate_error"] = str(transcription._generate_error)
|
||||||
# Audio queue depth
|
# Audio queue depth
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,10 @@ class WhisperLiveKitConfig:
|
||||||
nllb_backend: str = "transformers"
|
nllb_backend: str = "transformers"
|
||||||
nllb_size: str = "600M"
|
nllb_size: str = "600M"
|
||||||
|
|
||||||
|
# vLLM Realtime backend
|
||||||
|
vllm_url: str = "ws://localhost:8000/v1/realtime"
|
||||||
|
vllm_model: str = ""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# .en model suffix forces English
|
# .en model suffix forces English
|
||||||
if self.model_size and self.model_size.endswith(".en"):
|
if self.model_size and self.model_size.endswith(".en"):
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,16 @@ class TranscriptionEngine:
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.transcription:
|
if config.transcription:
|
||||||
if config.backend == "voxtral-mlx":
|
if config.backend == "vllm-realtime":
|
||||||
|
from whisperlivekit.vllm_realtime import VLLMRealtimeASR
|
||||||
|
self.tokenizer = None
|
||||||
|
self.asr = VLLMRealtimeASR(
|
||||||
|
vllm_url=config.vllm_url,
|
||||||
|
model_name=config.vllm_model or "Qwen/Qwen3-ASR-1.7B",
|
||||||
|
lan=config.lan,
|
||||||
|
)
|
||||||
|
logger.info("Using vLLM Realtime streaming backend at %s", config.vllm_url)
|
||||||
|
elif config.backend == "voxtral-mlx":
|
||||||
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
|
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.asr = VoxtralMLXASR(**transcription_common_params)
|
self.asr = VoxtralMLXASR(**transcription_common_params)
|
||||||
|
|
@ -112,6 +121,14 @@ class TranscriptionEngine:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
||||||
logger.info("Using Voxtral HF Transformers streaming backend")
|
logger.info("Using Voxtral HF Transformers streaming backend")
|
||||||
|
elif config.backend == "qwen3-simul":
|
||||||
|
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR
|
||||||
|
self.tokenizer = None
|
||||||
|
self.asr = Qwen3SimulStreamingASR(
|
||||||
|
**transcription_common_params,
|
||||||
|
alignment_heads_path=config.custom_alignment_heads,
|
||||||
|
)
|
||||||
|
logger.info("Using Qwen3-ASR backend with SimulStreaming policy")
|
||||||
elif config.backend == "qwen3":
|
elif config.backend == "qwen3":
|
||||||
from whisperlivekit.qwen3_asr import Qwen3ASR
|
from whisperlivekit.qwen3_asr import Qwen3ASR
|
||||||
self.asr = Qwen3ASR(**transcription_common_params)
|
self.asr = Qwen3ASR(**transcription_common_params)
|
||||||
|
|
@ -210,6 +227,12 @@ def online_factory(args, asr, language=None):
|
||||||
asr = SessionASRProxy(asr, language)
|
asr = SessionASRProxy(asr, language)
|
||||||
|
|
||||||
backend = getattr(args, 'backend', None)
|
backend = getattr(args, 'backend', None)
|
||||||
|
if backend == "vllm-realtime":
|
||||||
|
from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor
|
||||||
|
return VLLMRealtimeOnlineProcessor(asr)
|
||||||
|
if backend == "qwen3-simul":
|
||||||
|
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor
|
||||||
|
return Qwen3SimulStreamingOnlineProcessor(asr)
|
||||||
if backend == "voxtral-mlx":
|
if backend == "voxtral-mlx":
|
||||||
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
|
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
|
||||||
return VoxtralMLXOnlineProcessor(asr)
|
return VoxtralMLXOnlineProcessor(asr)
|
||||||
|
|
|
||||||
|
|
@ -147,8 +147,8 @@ def parse_args():
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3"],
|
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"],
|
||||||
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon. Use 'qwen3' for Qwen3-ASR.",
|
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-vac",
|
"--no-vac",
|
||||||
|
|
@ -196,6 +196,22 @@ def parse_args():
|
||||||
default=False,
|
default=False,
|
||||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||||
)
|
)
|
||||||
|
# vLLM Realtime backend arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--vllm-url",
|
||||||
|
type=str,
|
||||||
|
default="ws://localhost:8000/v1/realtime",
|
||||||
|
dest="vllm_url",
|
||||||
|
help="URL of the vLLM realtime WebSocket endpoint.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vllm-model",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
dest="vllm_model",
|
||||||
|
help="Model name to use with vLLM (e.g. Qwen/Qwen3-ASR-1.7B).",
|
||||||
|
)
|
||||||
|
|
||||||
# SimulStreaming-specific arguments
|
# SimulStreaming-specific arguments
|
||||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
@ -11,12 +12,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _patch_transformers_compat():
|
def _patch_transformers_compat():
|
||||||
"""Patch transformers for qwen_asr compatibility.
|
"""Patch transformers for qwen_asr 0.0.6 + transformers >= 5.3 compatibility."""
|
||||||
|
import torch
|
||||||
|
|
||||||
qwen_asr imports ``check_model_inputs`` from ``transformers.utils.generic``,
|
# 1. check_model_inputs was removed
|
||||||
but this decorator hasn't been released yet in any public transformers
|
|
||||||
version. We inject a no-op stub so the import succeeds.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
import transformers.utils.generic as _g
|
import transformers.utils.generic as _g
|
||||||
if not hasattr(_g, "check_model_inputs"):
|
if not hasattr(_g, "check_model_inputs"):
|
||||||
|
|
@ -28,6 +27,63 @@ def _patch_transformers_compat():
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS
|
||||||
|
try:
|
||||||
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
|
if "default" not in ROPE_INIT_FUNCTIONS:
|
||||||
|
def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
|
||||||
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
partial = getattr(config, "partial_rotary_factor", 1.0)
|
||||||
|
dim = int(head_dim * partial)
|
||||||
|
base = config.rope_theta
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
||||||
|
return inv_freq, 1.0
|
||||||
|
ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 3. pad_token_id missing on thinker config
|
||||||
|
try:
|
||||||
|
from qwen_asr.core.transformers_backend.configuration_qwen3_asr import (
|
||||||
|
Qwen3ASRThinkerConfig,
|
||||||
|
)
|
||||||
|
if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"):
|
||||||
|
Qwen3ASRThinkerConfig.pad_token_id = None
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 4. fix_mistral_regex kwarg not accepted by newer transformers
|
||||||
|
try:
|
||||||
|
from transformers.models.auto import processing_auto
|
||||||
|
_orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _patched_ap_from_pretrained(cls, *args, **kwargs):
|
||||||
|
kwargs.pop("fix_mistral_regex", None)
|
||||||
|
return _orig_ap_from_pretrained(cls, *args, **kwargs)
|
||||||
|
|
||||||
|
processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 5. compute_default_rope_parameters missing on RotaryEmbedding
|
||||||
|
try:
|
||||||
|
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
|
||||||
|
Qwen3ASRThinkerTextRotaryEmbedding,
|
||||||
|
)
|
||||||
|
if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"):
|
||||||
|
@staticmethod
|
||||||
|
def _rope_params(config=None, device=None, seq_len=None, **kwargs):
|
||||||
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
partial = getattr(config, "partial_rotary_factor", 1.0)
|
||||||
|
dim = int(head_dim * partial)
|
||||||
|
base = config.rope_theta
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
||||||
|
return inv_freq, 1.0
|
||||||
|
Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _rope_params
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
_patch_transformers_compat()
|
_patch_transformers_compat()
|
||||||
|
|
||||||
|
|
@ -62,6 +118,9 @@ QWEN3_MODEL_MAPPING = {
|
||||||
}
|
}
|
||||||
|
|
||||||
_PUNCTUATION_ENDS = set(".!?。!?;;")
|
_PUNCTUATION_ENDS = set(".!?。!?;;")
|
||||||
|
# Qwen3 raw output starts with "language <Name>" metadata before <asr_text> tag.
|
||||||
|
# When the tag is missing (silence/noise), this metadata leaks as transcription text.
|
||||||
|
_GARBAGE_RE = re.compile(r"^language\s+\S+$", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3ASR(ASRBase):
|
class Qwen3ASR(ASRBase):
|
||||||
|
|
@ -88,8 +147,12 @@ class Qwen3ASR(ASRBase):
|
||||||
else:
|
else:
|
||||||
model_id = "Qwen/Qwen3-ASR-1.7B"
|
model_id = "Qwen/Qwen3-ASR-1.7B"
|
||||||
|
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
if torch.cuda.is_available():
|
||||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
dtype, device = torch.bfloat16, "cuda:0"
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
dtype, device = torch.float32, "mps"
|
||||||
|
else:
|
||||||
|
dtype, device = torch.float32, "cpu"
|
||||||
|
|
||||||
logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})")
|
logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})")
|
||||||
model = Qwen3ASRModel.from_pretrained(
|
model = Qwen3ASRModel.from_pretrained(
|
||||||
|
|
@ -126,17 +189,32 @@ class Qwen3ASR(ASRBase):
|
||||||
result = results[0]
|
result = results[0]
|
||||||
# Stash audio length for timestamp estimation fallback
|
# Stash audio length for timestamp estimation fallback
|
||||||
result._audio_duration = len(audio) / 16000
|
result._audio_duration = len(audio) / 16000
|
||||||
|
logger.info(
|
||||||
|
"Qwen3 result: language=%r text=%r ts=%s",
|
||||||
|
result.language, result.text[:80] if result.text else "",
|
||||||
|
bool(result.time_stamps),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _detected_language(result) -> Optional[str]:
|
def _detected_language(result) -> Optional[str]:
|
||||||
"""Extract Whisper-style language code from Qwen3 result."""
|
"""Extract Whisper-style language code from Qwen3 result."""
|
||||||
lang = getattr(result, 'language', None)
|
lang = getattr(result, 'language', None)
|
||||||
if lang:
|
if not lang or lang.lower() == "none":
|
||||||
return QWEN3_TO_WHISPER_LANGUAGE.get(lang, lang.lower())
|
return None
|
||||||
return None
|
# merge_languages may return comma-separated; take the first
|
||||||
|
first = lang.split(",")[0].strip()
|
||||||
|
if not first or first.lower() == "none":
|
||||||
|
return None
|
||||||
|
return QWEN3_TO_WHISPER_LANGUAGE.get(first, first.lower())
|
||||||
|
|
||||||
def ts_words(self, result) -> List[ASRToken]:
|
def ts_words(self, result) -> List[ASRToken]:
|
||||||
|
# Filter garbage model output (e.g. "language None" for silence/noise)
|
||||||
|
text = (result.text or "").strip()
|
||||||
|
if not text or _GARBAGE_RE.match(text):
|
||||||
|
if text:
|
||||||
|
logger.info("Filtered garbage Qwen3 output: %r", text)
|
||||||
|
return []
|
||||||
detected = self._detected_language(result)
|
detected = self._detected_language(result)
|
||||||
if result.time_stamps:
|
if result.time_stamps:
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
|
||||||
837
whisperlivekit/qwen3_simul.py
Normal file
837
whisperlivekit/qwen3_simul.py
Normal file
|
|
@ -0,0 +1,837 @@
|
||||||
|
"""
|
||||||
|
SimulStreaming-style online processor for Qwen3-ASR.
|
||||||
|
|
||||||
|
Architecture overview
|
||||||
|
---------------------
|
||||||
|
Qwen3-ASR is a decoder-only multimodal model. Audio is encoded by an audio
|
||||||
|
encoder (Whisper-style) into a sequence of embeddings that replace <|audio_pad|>
|
||||||
|
placeholder tokens in the input sequence. The text decoder then uses causal
|
||||||
|
self-attention over the combined audio + text tokens.
|
||||||
|
|
||||||
|
Unlike Whisper (which has explicit cross-attention between decoder and encoder),
|
||||||
|
Qwen3-ASR uses self-attention where generated text tokens attend to earlier
|
||||||
|
audio tokens and previously generated text. This means "alignment heads" here
|
||||||
|
are self-attention heads whose attention over the *audio-token region* tracks
|
||||||
|
the monotonic audio-to-text alignment.
|
||||||
|
|
||||||
|
The border-distance policy works as follows:
|
||||||
|
- After each generated token, extract the attention weights from the
|
||||||
|
selected alignment heads, restricted to the audio-token region
|
||||||
|
- Find which audio frame each head attends to most strongly (argmax)
|
||||||
|
- If the most-attended audio frame is approaching the end of the available
|
||||||
|
audio, pause generation and wait for more audio
|
||||||
|
- If the most-attended frame jumps backward (rewind), discard recent tokens
|
||||||
|
|
||||||
|
This module loads the Qwen3-ASR model *directly* via transformers (not through
|
||||||
|
the qwen_asr package's Qwen3ASRModel wrapper), giving us full control over
|
||||||
|
forward passes, KV caches, and attention extraction.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- A pre-computed alignment heads JSON file (from detect_alignment_heads_qwen3.py)
|
||||||
|
- OR will fall back to all heads in a configurable set of layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen3SimulConfig:
|
||||||
|
"""Configuration for Qwen3 SimulStreaming."""
|
||||||
|
model_id: str = "Qwen/Qwen3-ASR-1.7B"
|
||||||
|
alignment_heads_path: Optional[str] = None
|
||||||
|
language: str = "auto"
|
||||||
|
# Border/rewind thresholds as fraction of audio tokens (not absolute frames).
|
||||||
|
# Qwen3 has ~13 audio tokens/sec vs Whisper's ~50, so absolute thresholds
|
||||||
|
# don't transfer. 0.15 = pause when attention is within last 15% of audio.
|
||||||
|
border_fraction: float = 0.15 # Fraction of audio tokens from end to trigger pause
|
||||||
|
rewind_fraction: float = 0.12 # Max backward jump as fraction of audio tokens
|
||||||
|
audio_min_len: float = 0.5 # Minimum audio length before starting decode
|
||||||
|
audio_max_len: float = 15.0 # Maximum audio buffer length in seconds
|
||||||
|
max_context_tokens: int = 30 # Max committed tokens to include as context
|
||||||
|
init_prompt: Optional[str] = None
|
||||||
|
max_alignment_heads: int = 20 # Use only top N alignment heads
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen3SimulState:
|
||||||
|
"""Per-session mutable state for Qwen3 SimulStreaming."""
|
||||||
|
# Audio
|
||||||
|
audio_buffer: np.ndarray = field(
|
||||||
|
default_factory=lambda: np.array([], dtype=np.float32)
|
||||||
|
)
|
||||||
|
cumulative_time_offset: float = 0.0
|
||||||
|
global_time_offset: float = 0.0
|
||||||
|
speaker: int = -1
|
||||||
|
|
||||||
|
# Decode state
|
||||||
|
last_attend_frame: int = -15
|
||||||
|
generated_tokens: List[int] = field(default_factory=list)
|
||||||
|
committed_text: str = ""
|
||||||
|
committed_word_count: int = 0 # How many words already emitted
|
||||||
|
committed_token_ids: List[int] = field(default_factory=list) # token IDs for prompt context
|
||||||
|
|
||||||
|
# Tracking
|
||||||
|
first_timestamp: Optional[float] = None
|
||||||
|
detected_language: Optional[str] = None
|
||||||
|
last_infer_samples: int = 0 # audio_buffer length at last inference
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3SimulStreamingASR:
|
||||||
|
"""
|
||||||
|
Shared backend for Qwen3-ASR SimulStreaming.
|
||||||
|
|
||||||
|
Loads the model once and is shared across sessions. Each session gets
|
||||||
|
its own Qwen3SimulStreamingOnlineProcessor with independent state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str = None,
|
||||||
|
model_dir: str = None,
|
||||||
|
lan: str = "auto",
|
||||||
|
alignment_heads_path: Optional[str] = None,
|
||||||
|
border_fraction: float = 0.15,
|
||||||
|
min_chunk_size: float = 0.1,
|
||||||
|
warmup_file: Optional[str] = None,
|
||||||
|
model_cache_dir: Optional[str] = None,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
|
direct_english_translation: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
self.warmup_file = warmup_file
|
||||||
|
|
||||||
|
self.cfg = Qwen3SimulConfig(
|
||||||
|
language=lan,
|
||||||
|
alignment_heads_path=alignment_heads_path,
|
||||||
|
border_fraction=border_fraction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model directly via transformers
|
||||||
|
self._load_model(model_size, model_dir, model_cache_dir, model_path)
|
||||||
|
|
||||||
|
# Load alignment heads
|
||||||
|
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
if warmup_file:
|
||||||
|
from whisperlivekit.warmup import load_file
|
||||||
|
audio = load_file(warmup_file)
|
||||||
|
if audio is not None:
|
||||||
|
logger.info("Warming up Qwen3 SimulStreaming model")
|
||||||
|
# Simple warmup: just encode a short audio
|
||||||
|
self._warmup(audio)
|
||||||
|
|
||||||
|
def _load_model(self, model_size, model_dir, model_cache_dir, model_path):
|
||||||
|
"""Load Qwen3-ASR via transformers (SDPA attention for speed)."""
|
||||||
|
from whisperlivekit.qwen3_asr import (
|
||||||
|
QWEN3_MODEL_MAPPING,
|
||||||
|
_patch_transformers_compat,
|
||||||
|
)
|
||||||
|
_patch_transformers_compat()
|
||||||
|
|
||||||
|
from qwen_asr.core.transformers_backend import (
|
||||||
|
Qwen3ASRConfig,
|
||||||
|
Qwen3ASRForConditionalGeneration,
|
||||||
|
Qwen3ASRProcessor,
|
||||||
|
)
|
||||||
|
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||||
|
|
||||||
|
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
||||||
|
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
||||||
|
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
||||||
|
|
||||||
|
if model_dir:
|
||||||
|
model_id = model_dir
|
||||||
|
elif model_path:
|
||||||
|
model_id = model_path
|
||||||
|
elif model_size:
|
||||||
|
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
|
||||||
|
else:
|
||||||
|
model_id = "Qwen/Qwen3-ASR-1.7B"
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
dtype, device = torch.bfloat16, "cuda:0"
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
dtype, device = torch.float32, "mps"
|
||||||
|
else:
|
||||||
|
dtype, device = torch.float32, "cpu"
|
||||||
|
|
||||||
|
logger.info("Loading Qwen3-ASR for SimulStreaming: %s (sdpa attention)", model_id)
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map=device,
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True)
|
||||||
|
|
||||||
|
# Cache model properties
|
||||||
|
thinker = self.model.thinker
|
||||||
|
text_config = thinker.config.text_config
|
||||||
|
self.num_layers = text_config.num_hidden_layers
|
||||||
|
self.num_heads = text_config.num_attention_heads
|
||||||
|
self.num_kv_heads = text_config.num_key_value_heads
|
||||||
|
self.audio_token_id = thinker.config.audio_token_id
|
||||||
|
self.device = next(self.model.parameters()).device
|
||||||
|
self.dtype = next(self.model.parameters()).dtype
|
||||||
|
|
||||||
|
# Cache special token IDs for metadata stripping
|
||||||
|
self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("<asr_text>")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Qwen3-ASR loaded: %d layers x %d heads, device=%s, <asr_text> id=%d",
|
||||||
|
self.num_layers, self.num_heads, self.device, self.asr_text_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_alignment_heads(
|
||||||
|
self, path: Optional[str],
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""Load alignment heads from JSON or use defaults.
|
||||||
|
|
||||||
|
Only loads the top N heads (sorted by TS score) for efficiency.
|
||||||
|
The Qwen3-ASR model has alignment info spread across most heads
|
||||||
|
(decoder-only, no cross-attention), so we pick the strongest ones.
|
||||||
|
"""
|
||||||
|
max_heads = self.cfg.max_alignment_heads
|
||||||
|
|
||||||
|
if path and Path(path).exists():
|
||||||
|
with open(path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
# alignment_heads_compact is pre-sorted by TS score (descending)
|
||||||
|
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
|
||||||
|
heads = all_heads[:max_heads]
|
||||||
|
logger.info(
|
||||||
|
"Loaded top %d alignment heads from %s (of %d total)",
|
||||||
|
len(heads), path, len(all_heads),
|
||||||
|
)
|
||||||
|
return heads
|
||||||
|
|
||||||
|
# Default: use heads from the last quarter of layers
|
||||||
|
default_heads = []
|
||||||
|
start_layer = self.num_layers * 3 // 4
|
||||||
|
for layer in range(start_layer, self.num_layers):
|
||||||
|
for head in range(self.num_heads):
|
||||||
|
default_heads.append((layer, head))
|
||||||
|
logger.warning(
|
||||||
|
"No alignment heads file found. Using default heuristic: "
|
||||||
|
"%d heads from layers %d-%d. Run detect_alignment_heads_qwen3.py "
|
||||||
|
"to find optimal heads.",
|
||||||
|
len(default_heads), start_layer, self.num_layers - 1,
|
||||||
|
)
|
||||||
|
return default_heads[:max_heads]
|
||||||
|
|
||||||
|
def _warmup(self, audio: np.ndarray):
|
||||||
|
"""Run a short inference to warmup the model."""
|
||||||
|
try:
|
||||||
|
audio = audio[:SAMPLE_RATE * 2] # Max 2 seconds
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
|
||||||
|
]
|
||||||
|
text_prompt = self.processor.apply_chat_template(
|
||||||
|
msgs, add_generation_prompt=True, tokenize=False,
|
||||||
|
)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text_prompt],
|
||||||
|
audio=[audio],
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.device).to(self.dtype)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.model.thinker.generate(
|
||||||
|
**inputs, max_new_tokens=5, do_sample=False,
|
||||||
|
)
|
||||||
|
logger.info("Qwen3 SimulStreaming warmup complete")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Warmup failed: %s", e)
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
"""No-op -- SimulStreaming uses the online processor directly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3SimulStreamingOnlineProcessor:
|
||||||
|
"""
|
||||||
|
Per-session online processor for Qwen3-ASR SimulStreaming.
|
||||||
|
|
||||||
|
Implements the same interface as SimulStreamingOnlineProcessor:
|
||||||
|
- insert_audio_chunk(audio, time)
|
||||||
|
- process_iter(is_last=False) -> (List[ASRToken], float)
|
||||||
|
- get_buffer() -> Transcript
|
||||||
|
- start_silence() -> (List[ASRToken], float)
|
||||||
|
- end_silence(duration, offset)
|
||||||
|
- finish() -> (List[ASRToken], float)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
MIN_DURATION_REAL_SILENCE = 5
|
||||||
|
|
||||||
|
def __init__(self, asr: Qwen3SimulStreamingASR, logfile=sys.stderr):
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.end = 0.0
|
||||||
|
self.buffer: List[ASRToken] = []
|
||||||
|
|
||||||
|
# Per-session state
|
||||||
|
self.state = Qwen3SimulState()
|
||||||
|
|
||||||
|
# Build the prompt template once
|
||||||
|
self._build_prompt_template()
|
||||||
|
|
||||||
|
def _build_prompt_template(self):
|
||||||
|
"""Build the base text prompt for Qwen3-ASR."""
|
||||||
|
from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE
|
||||||
|
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
|
||||||
|
]
|
||||||
|
self._base_prompt = self.asr.processor.apply_chat_template(
|
||||||
|
msgs, add_generation_prompt=True, tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add language forcing if configured
|
||||||
|
lan = self.asr.cfg.language
|
||||||
|
if lan and lan != "auto":
|
||||||
|
lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
|
||||||
|
self._base_prompt += f"language {lang_name}<asr_text>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speaker(self):
|
||||||
|
return self.state.speaker
|
||||||
|
|
||||||
|
@speaker.setter
|
||||||
|
def speaker(self, value):
|
||||||
|
self.state.speaker = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_time_offset(self):
|
||||||
|
return self.state.global_time_offset
|
||||||
|
|
||||||
|
@global_time_offset.setter
|
||||||
|
def global_time_offset(self, value):
|
||||||
|
self.state.global_time_offset = value
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
|
"""Append an audio chunk to be processed."""
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
|
||||||
|
|
||||||
|
# Trim audio if too long
|
||||||
|
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
|
||||||
|
if len(self.state.audio_buffer) > max_samples:
|
||||||
|
trim = len(self.state.audio_buffer) - max_samples
|
||||||
|
self.state.audio_buffer = self.state.audio_buffer[trim:]
|
||||||
|
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
|
||||||
|
# Adjust throttle counter so it tracks position within the trimmed buffer
|
||||||
|
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
|
||||||
|
|
||||||
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Handle start of silence -- flush all pending tokens.
|
||||||
|
|
||||||
|
Loops inference until the model produces no new tokens, since a
|
||||||
|
single is_last call may not exhaust all text for the buffered audio.
|
||||||
|
"""
|
||||||
|
all_tokens = []
|
||||||
|
for _ in range(5): # safety limit
|
||||||
|
tokens, processed_upto = self.process_iter(is_last=True)
|
||||||
|
if not tokens:
|
||||||
|
break
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
return all_tokens, self.end
|
||||||
|
|
||||||
|
def end_silence(self, silence_duration: float, offset: float):
|
||||||
|
"""Handle silence period."""
|
||||||
|
self.end += silence_duration
|
||||||
|
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
|
||||||
|
if not long_silence:
|
||||||
|
gap_len = int(self.SAMPLING_RATE * silence_duration)
|
||||||
|
if gap_len > 0:
|
||||||
|
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
||||||
|
self.state.audio_buffer = np.append(
|
||||||
|
self.state.audio_buffer, gap_silence,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Long silence: reset
|
||||||
|
self.state = Qwen3SimulState()
|
||||||
|
self.state.global_time_offset = silence_duration + offset
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
|
"""Handle speaker change event."""
|
||||||
|
self.process_iter(is_last=True)
|
||||||
|
self.state = Qwen3SimulState()
|
||||||
|
self.state.speaker = change_speaker.speaker
|
||||||
|
self.state.global_time_offset = change_speaker.start
|
||||||
|
|
||||||
|
def get_buffer(self) -> Transcript:
|
||||||
|
"""Get the current unvalidated buffer."""
|
||||||
|
return Transcript.from_tokens(tokens=self.buffer, sep='')
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""
|
||||||
|
Process accumulated audio using SimulStreaming with alignment heads.
|
||||||
|
|
||||||
|
This performs a full forward pass (encode audio + greedy decode with
|
||||||
|
attention extraction), applying the border-distance policy to decide
|
||||||
|
when to stop generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (committed ASRToken list, audio processed up to time).
|
||||||
|
"""
|
||||||
|
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
|
||||||
|
if audio_duration < self.asr.cfg.audio_min_len:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# Throttle: skip inference if less than 1s of new audio since last run.
|
||||||
|
# Each inference re-encodes the full buffer, so calling too often wastes
|
||||||
|
# GPU/CPU time and causes lag to spiral.
|
||||||
|
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
|
||||||
|
min_new_seconds = 1.0
|
||||||
|
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
logger.info("Running SimulStreaming inference on %.2fs of audio (%.2fs new)", audio_duration, new_samples / self.SAMPLING_RATE)
|
||||||
|
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||||
|
|
||||||
|
try:
|
||||||
|
timestamped_words = self._infer(is_last)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Qwen3 SimulStreaming inference error: %s", e)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
logger.info("SimulStreaming produced %d words", len(timestamped_words))
|
||||||
|
if not timestamped_words:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
self.buffer = []
|
||||||
|
return timestamped_words, self.end
|
||||||
|
|
||||||
|
def _infer(self, is_last: bool) -> List[ASRToken]:
|
||||||
|
"""Run one inference cycle with alignment-head-based stopping.
|
||||||
|
|
||||||
|
Uses forward hooks on self_attn modules to capture attention weights
|
||||||
|
during generation. The Qwen3-ASR decoder layer discards attention
|
||||||
|
weights (hidden_states, _ = self.self_attn(...)), so output_attentions
|
||||||
|
via generate() would return None. Hooks capture them before discard.
|
||||||
|
"""
|
||||||
|
asr = self.asr
|
||||||
|
state = self.state
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
inputs = asr.processor(
|
||||||
|
text=[self._base_prompt],
|
||||||
|
audio=[state.audio_buffer],
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
inputs = inputs.to(asr.device).to(asr.dtype)
|
||||||
|
|
||||||
|
# Append committed token IDs as context so generate() continues from
|
||||||
|
# where it left off. Cap at max_context_tokens to prevent prompt growth.
|
||||||
|
if state.committed_token_ids:
|
||||||
|
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
|
||||||
|
ctx_ids = torch.tensor(
|
||||||
|
[ctx], dtype=inputs.input_ids.dtype,
|
||||||
|
device=inputs.input_ids.device,
|
||||||
|
)
|
||||||
|
inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1)
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
ctx_mask = torch.ones_like(ctx_ids)
|
||||||
|
inputs["attention_mask"] = torch.cat(
|
||||||
|
[inputs.attention_mask, ctx_mask], dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_len = inputs.input_ids.shape[1]
|
||||||
|
|
||||||
|
# Find audio token range
|
||||||
|
input_ids = inputs.input_ids[0]
|
||||||
|
audio_mask = (input_ids == asr.audio_token_id)
|
||||||
|
audio_positions = audio_mask.nonzero(as_tuple=True)[0]
|
||||||
|
if len(audio_positions) == 0:
|
||||||
|
return []
|
||||||
|
audio_start = audio_positions[0].item()
|
||||||
|
audio_end = audio_positions[-1].item() + 1
|
||||||
|
n_audio_tokens = audio_end - audio_start
|
||||||
|
|
||||||
|
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
|
||||||
|
|
||||||
|
# Install forward hooks to capture alignment attention from Q and K.
|
||||||
|
# With SDPA attention (fast), attn_weights are not returned. Instead,
|
||||||
|
# we hook self_attn to compute Q*K^T attention ONLY for alignment heads
|
||||||
|
# during autoregressive steps (q_len == 1). This is cheap because we
|
||||||
|
# only compute dot products for ~20 heads, not full attention for all.
|
||||||
|
#
|
||||||
|
# Key detail: self_attn is called with ALL keyword arguments from the
|
||||||
|
# decoder layer, so hidden_states/position_embeddings/past_key_values
|
||||||
|
# are all in kwargs, not args.
|
||||||
|
per_step_frames: List[List[int]] = []
|
||||||
|
current_step_frames: List[int] = []
|
||||||
|
|
||||||
|
heads_by_layer: dict = {}
|
||||||
|
for layer_idx, head_idx in asr.alignment_heads:
|
||||||
|
heads_by_layer.setdefault(layer_idx, []).append(head_idx)
|
||||||
|
|
||||||
|
decoder_layers = asr.model.thinker.model.layers
|
||||||
|
num_kv_heads = asr.num_kv_heads
|
||||||
|
num_heads = asr.num_heads
|
||||||
|
gqa_ratio = num_heads // num_kv_heads # GQA group size
|
||||||
|
|
||||||
|
# Import RoPE function used by this model's attention
|
||||||
|
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
def _make_attn_hook(layer_idx):
|
||||||
|
"""Forward hook on self_attn that computes Q*K^T for alignment heads.
|
||||||
|
|
||||||
|
After the forward pass, we recompute Q (with RoPE) for the current
|
||||||
|
token and dot it against the cached K (which already has RoPE) in
|
||||||
|
the audio region. This gives us per-head alignment frames.
|
||||||
|
"""
|
||||||
|
head_indices = heads_by_layer[layer_idx]
|
||||||
|
|
||||||
|
def hook_fn(module, args, kwargs, output):
|
||||||
|
# All arguments are keyword-passed from the decoder layer
|
||||||
|
hidden_states = kwargs.get('hidden_states')
|
||||||
|
if hidden_states is None:
|
||||||
|
hidden_states = args[0] if args else None
|
||||||
|
if hidden_states is None or hidden_states.shape[1] != 1:
|
||||||
|
return # Skip prefill (seq_len > 1)
|
||||||
|
|
||||||
|
position_embeddings = kwargs.get('position_embeddings')
|
||||||
|
if position_embeddings is None and len(args) > 1:
|
||||||
|
position_embeddings = args[1]
|
||||||
|
past_kv = kwargs.get('past_key_values')
|
||||||
|
if position_embeddings is None or past_kv is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Recompute Q with RoPE (cheap: single token through q_proj + RoPE)
|
||||||
|
hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim)
|
||||||
|
q = module.q_norm(
|
||||||
|
module.q_proj(hidden_states).view(hidden_shape)
|
||||||
|
).transpose(1, 2)
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
q, _ = apply_rotary_pos_emb(q, q, cos, sin)
|
||||||
|
|
||||||
|
# K from cache already has RoPE applied
|
||||||
|
cache_layer = past_kv.layers[module.layer_idx]
|
||||||
|
k = cache_layer.keys # (batch, n_kv_heads, kv_len, head_dim)
|
||||||
|
if k is None or audio_end > k.shape[2]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Compute attention scores for alignment heads only
|
||||||
|
for h_idx in head_indices:
|
||||||
|
if h_idx >= q.shape[1]:
|
||||||
|
continue
|
||||||
|
kv_h_idx = h_idx // gqa_ratio
|
||||||
|
q_h = q[0, h_idx, 0] # (head_dim,)
|
||||||
|
k_audio = k[0, kv_h_idx, audio_start:audio_end] # (n_audio, head_dim)
|
||||||
|
scores = torch.matmul(k_audio, q_h) # (n_audio,)
|
||||||
|
frame = scores.argmax().item()
|
||||||
|
current_step_frames.append(frame)
|
||||||
|
|
||||||
|
return hook_fn
|
||||||
|
|
||||||
|
for layer_idx in heads_by_layer:
|
||||||
|
if layer_idx < len(decoder_layers):
|
||||||
|
h = decoder_layers[layer_idx].self_attn.register_forward_hook(
|
||||||
|
_make_attn_hook(layer_idx),
|
||||||
|
with_kwargs=True,
|
||||||
|
)
|
||||||
|
hooks.append(h)
|
||||||
|
|
||||||
|
# Step boundary hook on lm_head to separate per-step frames
|
||||||
|
# and check border-distance stopping criteria in real-time.
|
||||||
|
# This is CRITICAL for performance: instead of generating 200 tokens
|
||||||
|
# then truncating, we stop as soon as attention hits the audio border.
|
||||||
|
# On MPS, each token costs ~50ms, so stopping at 10 tokens vs 200
|
||||||
|
# means ~0.5s vs ~10s inference.
|
||||||
|
last_attend_frame = state.last_attend_frame
|
||||||
|
border_stop_step: Optional[int] = None
|
||||||
|
|
||||||
|
# Compute absolute thresholds from fractional config
|
||||||
|
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
|
||||||
|
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
|
||||||
|
|
||||||
|
def _step_boundary_hook(module, args, output):
|
||||||
|
nonlocal current_step_frames, last_attend_frame, border_stop_step
|
||||||
|
if current_step_frames:
|
||||||
|
per_step_frames.append(current_step_frames)
|
||||||
|
current_step_frames = []
|
||||||
|
|
||||||
|
# Check border distance on each step.
|
||||||
|
# Allow at least 3 steps before checking, so short buffers
|
||||||
|
# can still produce some tokens during streaming.
|
||||||
|
if not is_last and border_stop_step is None and len(per_step_frames) >= 3:
|
||||||
|
latest = per_step_frames[-1]
|
||||||
|
if latest:
|
||||||
|
frames_sorted = sorted(latest)
|
||||||
|
attended = frames_sorted[len(frames_sorted) // 2]
|
||||||
|
|
||||||
|
# Rewind check
|
||||||
|
if last_attend_frame - attended > rewind_threshold:
|
||||||
|
border_stop_step = max(0, len(per_step_frames) - 2)
|
||||||
|
return
|
||||||
|
|
||||||
|
last_attend_frame = attended
|
||||||
|
|
||||||
|
# Border check
|
||||||
|
if (n_audio_tokens - attended) <= border_threshold:
|
||||||
|
border_stop_step = len(per_step_frames) - 1
|
||||||
|
return
|
||||||
|
|
||||||
|
lm_head = asr.model.thinker.lm_head
|
||||||
|
step_hook = lm_head.register_forward_hook(_step_boundary_hook)
|
||||||
|
hooks.append(step_hook)
|
||||||
|
|
||||||
|
# StoppingCriteria that stops generation when border distance is hit
|
||||||
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
|
class BorderStop(StoppingCriteria):
|
||||||
|
def __call__(self, input_ids, scores, **kwargs):
|
||||||
|
return border_stop_step is not None
|
||||||
|
|
||||||
|
stopping = StoppingCriteriaList([BorderStop()])
|
||||||
|
|
||||||
|
# Limit max tokens to what's reasonable for the audio duration.
|
||||||
|
# On MPS, each token costs ~50-100ms, so tight limits are critical.
|
||||||
|
# Speech produces ~4-6 tokens/sec; +5 for metadata prefix tokens.
|
||||||
|
# With is_last, allow slightly more for flushing remaining text.
|
||||||
|
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
|
||||||
|
tokens_per_sec = 6
|
||||||
|
if is_last:
|
||||||
|
max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120)
|
||||||
|
else:
|
||||||
|
max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = asr.model.thinker.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
do_sample=False,
|
||||||
|
stopping_criteria=stopping,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for h in hooks:
|
||||||
|
h.remove()
|
||||||
|
# Flush any remaining frames
|
||||||
|
if current_step_frames:
|
||||||
|
per_step_frames.append(current_step_frames)
|
||||||
|
|
||||||
|
state.last_attend_frame = last_attend_frame
|
||||||
|
|
||||||
|
# Extract generated tokens
|
||||||
|
all_generated = outputs[0, prompt_len:]
|
||||||
|
eos_ids = {151645, 151643}
|
||||||
|
if asr.processor.tokenizer.eos_token_id is not None:
|
||||||
|
eos_ids.add(asr.processor.tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
num_gen = len(all_generated)
|
||||||
|
for i, tid in enumerate(all_generated):
|
||||||
|
if tid.item() in eos_ids:
|
||||||
|
num_gen = i
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_text = asr.processor.tokenizer.decode(all_generated[:num_gen], skip_special_tokens=True)
|
||||||
|
logger.info(
|
||||||
|
"SimulStreaming raw output: %d tokens (stopped at step %s), text=%r",
|
||||||
|
num_gen, border_stop_step, raw_text[:100],
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_gen == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Strip metadata prefix: when language is "auto", the model generates
|
||||||
|
# "language <Name><asr_text>..." before actual transcription text.
|
||||||
|
# Find <asr_text> token and skip everything before it (including itself).
|
||||||
|
asr_text_id = asr.asr_text_token_id
|
||||||
|
metadata_offset = 0
|
||||||
|
for i in range(min(num_gen, 10)): # metadata is at most ~3-4 tokens
|
||||||
|
if all_generated[i].item() == asr_text_id:
|
||||||
|
# Detect language from the metadata prefix before stripping
|
||||||
|
if state.detected_language is None and i > 0:
|
||||||
|
from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE
|
||||||
|
prefix_text = asr.processor.tokenizer.decode(
|
||||||
|
all_generated[:i].tolist(), skip_special_tokens=True,
|
||||||
|
).strip()
|
||||||
|
parts = prefix_text.split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
lang_name = parts[-1]
|
||||||
|
if lang_name.lower() != "none":
|
||||||
|
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
|
||||||
|
lang_name, lang_name.lower(),
|
||||||
|
)
|
||||||
|
logger.info("Auto-detected language: %s", state.detected_language)
|
||||||
|
metadata_offset = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if metadata_offset > 0:
|
||||||
|
logger.info(
|
||||||
|
"Stripping %d metadata prefix tokens (before <asr_text>)",
|
||||||
|
metadata_offset,
|
||||||
|
)
|
||||||
|
all_generated = all_generated[metadata_offset:]
|
||||||
|
num_gen -= metadata_offset
|
||||||
|
per_step_frames = per_step_frames[metadata_offset:]
|
||||||
|
|
||||||
|
if num_gen <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Determine how many tokens to emit based on border stopping
|
||||||
|
step_frames = [f for f in per_step_frames if f]
|
||||||
|
if border_stop_step is not None:
|
||||||
|
emit_up_to = min(border_stop_step, num_gen)
|
||||||
|
else:
|
||||||
|
emit_up_to = num_gen
|
||||||
|
|
||||||
|
# Build timestamped words from the emitted tokens
|
||||||
|
generated_ids = all_generated[:emit_up_to]
|
||||||
|
if len(generated_ids) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
all_words = self._build_timestamped_words(
|
||||||
|
generated_ids, step_frames, emit_up_to,
|
||||||
|
n_audio_tokens, audio_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_words = all_words
|
||||||
|
|
||||||
|
# Update committed word count for space-prefix logic in next batch
|
||||||
|
state.committed_word_count += len(new_words)
|
||||||
|
|
||||||
|
# Append newly emitted token IDs to committed context for next call
|
||||||
|
new_emitted = outputs[0, prompt_len:prompt_len + emit_up_to + metadata_offset]
|
||||||
|
state.committed_token_ids.extend(new_emitted.tolist())
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
def _build_timestamped_words(
|
||||||
|
self,
|
||||||
|
generated_ids: torch.Tensor,
|
||||||
|
step_frames: List[List[int]],
|
||||||
|
emit_up_to: int,
|
||||||
|
n_audio_tokens: int,
|
||||||
|
audio_duration: float,
|
||||||
|
) -> List[ASRToken]:
|
||||||
|
"""Build timestamped ASRToken list from generated tokens and hook-captured frames."""
|
||||||
|
asr = self.asr
|
||||||
|
state = self.state
|
||||||
|
|
||||||
|
# Get per-token attended audio frame (median of alignment head votes)
|
||||||
|
per_token_frame: List[Optional[int]] = []
|
||||||
|
for step in range(emit_up_to):
|
||||||
|
if step < len(step_frames) and step_frames[step]:
|
||||||
|
frames = sorted(step_frames[step])
|
||||||
|
per_token_frame.append(frames[len(frames) // 2])
|
||||||
|
else:
|
||||||
|
per_token_frame.append(None)
|
||||||
|
|
||||||
|
# Decode the full generated sequence at once, then split into words.
|
||||||
|
# This is more robust than per-token Ġ detection, which can fail when
|
||||||
|
# committed context causes the model to generate sub-word continuations.
|
||||||
|
tokenizer = asr.processor.tokenizer
|
||||||
|
full_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
|
||||||
|
text_words = full_text.split()
|
||||||
|
|
||||||
|
# Map each text word to an approximate frame using token-level alignment.
|
||||||
|
# Distribute frames evenly across words (since exact token→word mapping
|
||||||
|
# is imprecise with BPE sub-words anyway).
|
||||||
|
all_frames = [f for f in per_token_frame if f is not None]
|
||||||
|
words = []
|
||||||
|
for wi, word in enumerate(text_words):
|
||||||
|
if all_frames:
|
||||||
|
# Proportionally assign frames to words
|
||||||
|
frac = wi / max(len(text_words), 1)
|
||||||
|
frame_idx = int(frac * len(all_frames))
|
||||||
|
frame_idx = min(frame_idx, len(all_frames) - 1)
|
||||||
|
frame = all_frames[frame_idx]
|
||||||
|
else:
|
||||||
|
frame = None
|
||||||
|
words.append((word, frame))
|
||||||
|
|
||||||
|
# Convert to ASRToken with timestamps
|
||||||
|
tokens = []
|
||||||
|
for i, (text, frame) in enumerate(words):
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if frame is not None and n_audio_tokens > 0:
|
||||||
|
timestamp = (
|
||||||
|
frame / n_audio_tokens * audio_duration
|
||||||
|
+ state.cumulative_time_offset
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
timestamp = (
|
||||||
|
(i / max(len(words), 1)) * audio_duration
|
||||||
|
+ state.cumulative_time_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefix space: first word of the very first batch has no space;
|
||||||
|
# all subsequent words (same batch or later batches) get a space.
|
||||||
|
is_very_first_word = (i == 0 and state.committed_word_count == 0)
|
||||||
|
display_text = text if is_very_first_word else " " + text
|
||||||
|
|
||||||
|
token = ASRToken(
|
||||||
|
start=round(timestamp, 2),
|
||||||
|
end=round(timestamp + 0.1, 2),
|
||||||
|
text=display_text,
|
||||||
|
speaker=state.speaker,
|
||||||
|
detected_language=state.detected_language,
|
||||||
|
).with_offset(state.global_time_offset)
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _median_frame(frames: List[int]) -> Optional[int]:
|
||||||
|
"""Return median of frame list, or None if empty."""
|
||||||
|
if not frames:
|
||||||
|
return None
|
||||||
|
frames_sorted = sorted(frames)
|
||||||
|
return frames_sorted[len(frames_sorted) // 2]
|
||||||
|
|
||||||
|
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
|
||||||
|
"""Warmup the model with a short audio clip."""
|
||||||
|
try:
|
||||||
|
self.state.audio_buffer = audio[:SAMPLE_RATE]
|
||||||
|
self.process_iter(is_last=True)
|
||||||
|
self.state = Qwen3SimulState()
|
||||||
|
logger.info("Qwen3 SimulStreaming online processor warmed up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Warmup failed: %s", e)
|
||||||
|
self.state = Qwen3SimulState()
|
||||||
|
|
||||||
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush remaining audio at end of stream."""
|
||||||
|
all_tokens = []
|
||||||
|
for _ in range(5): # safety limit
|
||||||
|
tokens, _ = self.process_iter(is_last=True)
|
||||||
|
if not tokens:
|
||||||
|
break
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
return all_tokens, self.end
|
||||||
416
whisperlivekit/vllm_realtime.py
Normal file
416
whisperlivekit/vllm_realtime.py
Normal file
|
|
@ -0,0 +1,416 @@
|
||||||
|
"""
|
||||||
|
vLLM Realtime WebSocket streaming backend for WhisperLiveKit.
|
||||||
|
|
||||||
|
Connects to a vLLM server's ``/v1/realtime`` WebSocket endpoint to stream
|
||||||
|
audio and receive transcription deltas. Uses ``websockets.sync.client``
|
||||||
|
for simplicity since ``process_iter`` runs inside ``asyncio.to_thread``.
|
||||||
|
|
||||||
|
Provides ``VLLMRealtimeASR`` (lightweight model holder) and
|
||||||
|
``VLLMRealtimeOnlineProcessor`` (streaming processor) that plug into
|
||||||
|
WhisperLiveKit's audio processing pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMRealtimeASR:
|
||||||
|
"""Lightweight model holder — stores connection info for the vLLM server."""
|
||||||
|
|
||||||
|
sep = " "
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
backend_choice = "vllm-realtime"
|
||||||
|
|
||||||
|
def __init__(self, vllm_url="ws://localhost:8000/v1/realtime",
|
||||||
|
model_name="Qwen/Qwen3-ASR-1.7B", lan="auto", **kwargs):
|
||||||
|
self.vllm_url = vllm_url
|
||||||
|
self.model_name = model_name
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMRealtimeOnlineProcessor:
|
||||||
|
"""
|
||||||
|
Online processor that streams audio to a vLLM Realtime WebSocket.
|
||||||
|
|
||||||
|
Uses a background thread for WebSocket receiving and
|
||||||
|
``websockets.sync.client`` for the sync WebSocket connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
# Minimum audio samples before connecting (0.5s of audio)
|
||||||
|
_MIN_CONNECT_SAMPLES = SAMPLING_RATE // 2
|
||||||
|
|
||||||
|
def __init__(self, asr: VLLMRealtimeASR):
|
||||||
|
self.asr = asr
|
||||||
|
self.end = 0.0
|
||||||
|
self.buffer = []
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
self._reset_state()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"[vllm-realtime] Initialized. url=%s model=%s",
|
||||||
|
asr.vllm_url, asr.model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
self._pending_audio = np.zeros(0, dtype=np.float32)
|
||||||
|
self._ws = None
|
||||||
|
self._recv_thread: Optional[threading.Thread] = None
|
||||||
|
self._connected = False
|
||||||
|
self._done = False
|
||||||
|
self._recv_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
# Text accumulation and word extraction
|
||||||
|
self._accumulated_text = ""
|
||||||
|
self._n_committed_words = 0
|
||||||
|
self._total_audio_duration = 0.0
|
||||||
|
self._global_time_offset = 0.0
|
||||||
|
|
||||||
|
# Lock for text state accessed from both recv thread and main thread
|
||||||
|
self._text_lock = threading.Lock()
|
||||||
|
|
||||||
|
# ── Interface methods ──
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
self._pending_audio = np.append(self._pending_audio, audio)
|
||||||
|
self.audio_buffer = self._pending_audio
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
try:
|
||||||
|
return self._process_iter_inner(is_last)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[vllm-realtime] process_iter exception: %s", e, exc_info=True)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def get_buffer(self) -> Transcript:
|
||||||
|
"""Return all uncommitted text as buffer."""
|
||||||
|
self._drain_deltas()
|
||||||
|
with self._text_lock:
|
||||||
|
text = self._accumulated_text
|
||||||
|
if not text:
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
uncommitted = words[self._n_committed_words:]
|
||||||
|
if uncommitted:
|
||||||
|
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush all pending words when silence starts.
|
||||||
|
|
||||||
|
Sends commit(final=true) to signal end of utterance, waits for
|
||||||
|
transcription.done, flushes all words, then prepares for reconnection
|
||||||
|
on the next utterance.
|
||||||
|
"""
|
||||||
|
if not self._connected or self._done:
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
logger.info("[vllm-realtime] start_silence (not connected): flushed %d words", len(words))
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
# Send any remaining buffered audio
|
||||||
|
self._send_pending_audio()
|
||||||
|
|
||||||
|
# Signal end of stream
|
||||||
|
self._send_commit(final=True)
|
||||||
|
|
||||||
|
# Wait for transcription.done
|
||||||
|
self._wait_for_done(timeout=10.0)
|
||||||
|
|
||||||
|
# Flush all remaining words
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
|
||||||
|
# Close and reset for next utterance
|
||||||
|
self._close_ws()
|
||||||
|
old_offset = self._global_time_offset + self._total_audio_duration
|
||||||
|
self._reset_state()
|
||||||
|
self._global_time_offset = old_offset
|
||||||
|
|
||||||
|
logger.info("[vllm-realtime] start_silence: flushed %d words", len(words))
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
def end_silence(self, silence_duration: float, offset: float):
|
||||||
|
self._global_time_offset += silence_duration
|
||||||
|
self.end += silence_duration
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker):
|
||||||
|
self.start_silence()
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Close connection and flush all remaining words."""
|
||||||
|
if self._connected and not self._done:
|
||||||
|
# Send remaining audio
|
||||||
|
self._send_pending_audio()
|
||||||
|
|
||||||
|
# Signal final commit
|
||||||
|
self._send_commit(final=True)
|
||||||
|
|
||||||
|
# Wait for transcription.done
|
||||||
|
self._wait_for_done(timeout=30.0)
|
||||||
|
|
||||||
|
# Flush all words
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
|
||||||
|
# Close WebSocket
|
||||||
|
self._close_ws()
|
||||||
|
|
||||||
|
logger.info("[vllm-realtime] finish: flushed %d words", len(words))
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
# ── WebSocket connection management ──
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
"""Connect to the vLLM realtime WebSocket and start the receive thread."""
|
||||||
|
from websockets.sync.client import connect
|
||||||
|
|
||||||
|
url = self.asr.vllm_url
|
||||||
|
logger.info("[vllm-realtime] Connecting to %s", url)
|
||||||
|
|
||||||
|
self._ws = connect(url)
|
||||||
|
|
||||||
|
# Send session.update to select model
|
||||||
|
self._ws.send(json.dumps({
|
||||||
|
"type": "session.update",
|
||||||
|
"model": self.asr.model_name,
|
||||||
|
}))
|
||||||
|
|
||||||
|
# Send initial commit(final=false) to start generation
|
||||||
|
self._send_commit(final=False)
|
||||||
|
|
||||||
|
# Start receive thread
|
||||||
|
self._recv_thread = threading.Thread(target=self._recv_loop, daemon=True)
|
||||||
|
self._recv_thread.start()
|
||||||
|
|
||||||
|
self._connected = True
|
||||||
|
logger.info("[vllm-realtime] Connected and started receive thread")
|
||||||
|
|
||||||
|
def _close_ws(self):
|
||||||
|
"""Close the WebSocket connection and join the receive thread."""
|
||||||
|
if self._ws is not None:
|
||||||
|
try:
|
||||||
|
self._ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
if self._recv_thread is not None:
|
||||||
|
self._recv_thread.join(timeout=5.0)
|
||||||
|
self._recv_thread = None
|
||||||
|
|
||||||
|
def _recv_loop(self):
|
||||||
|
"""Background thread: receive messages from the vLLM WebSocket."""
|
||||||
|
try:
|
||||||
|
while not self._done and self._ws is not None:
|
||||||
|
try:
|
||||||
|
raw = self._ws.recv(timeout=0.1)
|
||||||
|
except TimeoutError:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = json.loads(raw)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
msg_type = msg.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "transcription.delta":
|
||||||
|
delta = msg.get("delta", "")
|
||||||
|
if delta:
|
||||||
|
with self._text_lock:
|
||||||
|
self._accumulated_text += delta
|
||||||
|
|
||||||
|
elif msg_type == "transcription.done":
|
||||||
|
done_text = msg.get("text", "")
|
||||||
|
if done_text:
|
||||||
|
with self._text_lock:
|
||||||
|
# Replace accumulated text with final text
|
||||||
|
self._accumulated_text = done_text
|
||||||
|
self._done = True
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[vllm-realtime] recv_loop error: %s", e, exc_info=True)
|
||||||
|
self._recv_error = e
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
# ── Protocol messages ──
|
||||||
|
|
||||||
|
def _send_commit(self, final: bool):
|
||||||
|
"""Send input_audio_buffer.commit message."""
|
||||||
|
if self._ws is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._ws.send(json.dumps({
|
||||||
|
"type": "input_audio_buffer.commit",
|
||||||
|
"final": final,
|
||||||
|
}))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[vllm-realtime] Failed to send commit: %s", e)
|
||||||
|
|
||||||
|
def _send_audio(self, audio: np.ndarray):
|
||||||
|
"""Send audio as a base64-encoded PCM16 append message."""
|
||||||
|
if self._ws is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert float32 [-1, 1] to int16 PCM
|
||||||
|
pcm16 = (audio * 32767).astype(np.int16)
|
||||||
|
audio_bytes = pcm16.tobytes()
|
||||||
|
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._ws.send(json.dumps({
|
||||||
|
"type": "input_audio_buffer.append",
|
||||||
|
"audio": audio_b64,
|
||||||
|
}))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[vllm-realtime] Failed to send audio: %s", e)
|
||||||
|
|
||||||
|
def _send_pending_audio(self):
|
||||||
|
"""Send all pending audio to the vLLM server."""
|
||||||
|
if len(self._pending_audio) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Track total audio duration for timestamp estimation
|
||||||
|
self._total_audio_duration += len(self._pending_audio) / self.SAMPLING_RATE
|
||||||
|
|
||||||
|
# Send in chunks of 0.5s to avoid overwhelming the WebSocket
|
||||||
|
chunk_samples = self.SAMPLING_RATE // 2
|
||||||
|
while len(self._pending_audio) >= chunk_samples:
|
||||||
|
chunk = self._pending_audio[:chunk_samples]
|
||||||
|
self._send_audio(chunk)
|
||||||
|
self._pending_audio = self._pending_audio[chunk_samples:]
|
||||||
|
|
||||||
|
# Send remaining audio if any
|
||||||
|
if len(self._pending_audio) > 0:
|
||||||
|
self._send_audio(self._pending_audio)
|
||||||
|
self._pending_audio = np.zeros(0, dtype=np.float32)
|
||||||
|
|
||||||
|
self.audio_buffer = self._pending_audio
|
||||||
|
|
||||||
|
# ── Receive helpers ──
|
||||||
|
|
||||||
|
def _drain_deltas(self):
|
||||||
|
"""No-op since the recv thread accumulates text directly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _wait_for_done(self, timeout: float = 10.0):
|
||||||
|
"""Wait for transcription.done message from the server."""
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
while not self._done and time.time() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
if not self._done:
|
||||||
|
logger.warning("[vllm-realtime] Timed out waiting for transcription.done")
|
||||||
|
|
||||||
|
# ── Word extraction (same approach as VoxtralHF) ──
|
||||||
|
|
||||||
|
def _time_for_word(self, word_idx: int, n_words_total: int) -> Tuple[float, float]:
|
||||||
|
"""Estimate timestamps by linearly distributing words across audio duration."""
|
||||||
|
duration = max(self._total_audio_duration, 0.001)
|
||||||
|
n_total = max(n_words_total, 1)
|
||||||
|
|
||||||
|
start_time = (word_idx / n_total) * duration + self._global_time_offset
|
||||||
|
end_time = ((word_idx + 1) / n_total) * duration + self._global_time_offset
|
||||||
|
|
||||||
|
return start_time, end_time
|
||||||
|
|
||||||
|
def _extract_new_words(self) -> List[ASRToken]:
|
||||||
|
"""Extract complete words (all but the last, which may still grow)."""
|
||||||
|
with self._text_lock:
|
||||||
|
text = self._accumulated_text
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
new_words: List[ASRToken] = []
|
||||||
|
n_words_total = len(words)
|
||||||
|
|
||||||
|
while len(words) > self._n_committed_words + 1:
|
||||||
|
word = words[self._n_committed_words]
|
||||||
|
start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total)
|
||||||
|
|
||||||
|
text_out = word if self._n_committed_words == 0 else " " + word
|
||||||
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
def _flush_all_pending_words(self) -> List[ASRToken]:
|
||||||
|
"""Flush ALL words including the last partial one."""
|
||||||
|
with self._text_lock:
|
||||||
|
text = self._accumulated_text
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
new_words: List[ASRToken] = []
|
||||||
|
n_words_total = max(len(words), 1)
|
||||||
|
|
||||||
|
while self._n_committed_words < len(words):
|
||||||
|
word = words[self._n_committed_words]
|
||||||
|
start_time, end_time = self._time_for_word(self._n_committed_words, n_words_total)
|
||||||
|
|
||||||
|
text_out = word if self._n_committed_words == 0 else " " + word
|
||||||
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
# ── Core processing ──
|
||||||
|
|
||||||
|
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||||
|
# Connect when we have enough audio buffered
|
||||||
|
if not self._connected:
|
||||||
|
if len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES:
|
||||||
|
self._connect()
|
||||||
|
self._send_pending_audio()
|
||||||
|
else:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# Send any new pending audio
|
||||||
|
if self._connected and not self._done:
|
||||||
|
self._send_pending_audio()
|
||||||
|
|
||||||
|
# If connection closed unexpectedly but new audio arrived, reconnect
|
||||||
|
if self._done and len(self._pending_audio) >= self._MIN_CONNECT_SAMPLES:
|
||||||
|
flush_words = self._flush_all_pending_words()
|
||||||
|
old_offset = self._global_time_offset + self._total_audio_duration
|
||||||
|
self._close_ws()
|
||||||
|
self._reset_state()
|
||||||
|
self._global_time_offset = old_offset
|
||||||
|
self._connect()
|
||||||
|
self._send_pending_audio()
|
||||||
|
return flush_words, self.end
|
||||||
|
|
||||||
|
# Extract complete words
|
||||||
|
new_words = self._extract_new_words()
|
||||||
|
|
||||||
|
if new_words:
|
||||||
|
logger.info(
|
||||||
|
"[vllm-realtime] returning %d words: %s",
|
||||||
|
len(new_words), [w.text for w in new_words],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.buffer = []
|
||||||
|
return new_words, self.end
|
||||||
|
|
@ -102,7 +102,8 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _reset_state(self):
|
def _reset_state(self):
|
||||||
self._pending_audio = np.zeros(0, dtype=np.float32)
|
self._pending_chunks: List[np.ndarray] = []
|
||||||
|
self._pending_len = 0
|
||||||
self._audio_queue: queue.Queue = queue.Queue()
|
self._audio_queue: queue.Queue = queue.Queue()
|
||||||
self._streamer_texts: List[str] = []
|
self._streamer_texts: List[str] = []
|
||||||
self._generate_thread: Optional[threading.Thread] = None
|
self._generate_thread: Optional[threading.Thread] = None
|
||||||
|
|
@ -110,22 +111,63 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
self._generate_finished = False
|
self._generate_finished = False
|
||||||
self._generate_error: Optional[Exception] = None
|
self._generate_error: Optional[Exception] = None
|
||||||
|
|
||||||
# Text accumulation and word extraction
|
# Text accumulation (list of fragments, joined on demand)
|
||||||
self._accumulated_text = ""
|
self._text_fragments: List[str] = []
|
||||||
|
self._text_len = 0
|
||||||
|
# Fragment position tracking for accurate word timestamps:
|
||||||
|
# each entry is (char_offset_in_full_text, audio_tok_pos_consumed)
|
||||||
|
self._fragment_positions: List[Tuple[int, int]] = []
|
||||||
self._n_text_tokens_received = 0
|
self._n_text_tokens_received = 0
|
||||||
self._n_audio_tokens_fed = 0
|
self._n_audio_tokens_fed = 0
|
||||||
|
# Audio tokens actually consumed by the model (tracked inside generator)
|
||||||
|
self._n_audio_tokens_consumed = 0
|
||||||
self._n_committed_words = 0
|
self._n_committed_words = 0
|
||||||
self._global_time_offset = 0.0
|
self._global_time_offset = 0.0
|
||||||
|
|
||||||
|
# Event signalled by the generate thread when it finishes
|
||||||
|
self._generate_done = threading.Event()
|
||||||
|
|
||||||
# Lock for text state accessed from both generate thread and main thread
|
# Lock for text state accessed from both generate thread and main thread
|
||||||
self._text_lock = threading.Lock()
|
self._text_lock = threading.Lock()
|
||||||
|
|
||||||
|
# ── Audio / text helpers ──
|
||||||
|
|
||||||
|
def _get_pending_audio(self) -> np.ndarray:
|
||||||
|
"""Flatten pending audio chunks into a single array."""
|
||||||
|
if not self._pending_chunks:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
if len(self._pending_chunks) == 1:
|
||||||
|
return self._pending_chunks[0]
|
||||||
|
flat = np.concatenate(self._pending_chunks)
|
||||||
|
self._pending_chunks = [flat]
|
||||||
|
return flat
|
||||||
|
|
||||||
|
def _set_pending_audio(self, arr: np.ndarray):
|
||||||
|
"""Replace pending audio with a single array."""
|
||||||
|
if len(arr) == 0:
|
||||||
|
self._pending_chunks = []
|
||||||
|
self._pending_len = 0
|
||||||
|
else:
|
||||||
|
self._pending_chunks = [arr]
|
||||||
|
self._pending_len = len(arr)
|
||||||
|
|
||||||
|
def _get_accumulated_text(self) -> str:
|
||||||
|
"""Get the full accumulated text (joins fragments if needed)."""
|
||||||
|
if not self._text_fragments:
|
||||||
|
return ""
|
||||||
|
if len(self._text_fragments) == 1:
|
||||||
|
return self._text_fragments[0]
|
||||||
|
joined = "".join(self._text_fragments)
|
||||||
|
self._text_fragments = [joined]
|
||||||
|
return joined
|
||||||
|
|
||||||
# ── Interface methods ──
|
# ── Interface methods ──
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
self.end = audio_stream_end_time
|
self.end = audio_stream_end_time
|
||||||
self._pending_audio = np.append(self._pending_audio, audio)
|
self._pending_chunks.append(audio)
|
||||||
self.audio_buffer = self._pending_audio
|
self._pending_len += len(audio)
|
||||||
|
self.audio_buffer = audio # diagnostic only
|
||||||
|
|
||||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
try:
|
try:
|
||||||
|
|
@ -142,7 +184,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
"""
|
"""
|
||||||
self._drain_streamer()
|
self._drain_streamer()
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
text = self._accumulated_text
|
text = self._get_accumulated_text()
|
||||||
if not text:
|
if not text:
|
||||||
return Transcript(start=None, end=None, text="")
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
|
@ -174,16 +216,17 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
# real audio and shouldn't affect word timestamp calculations.
|
# real audio and shouldn't affect word timestamp calculations.
|
||||||
if self._right_pad_samples > 0:
|
if self._right_pad_samples > 0:
|
||||||
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
||||||
self._pending_audio = np.append(self._pending_audio, right_pad)
|
self._pending_chunks.append(right_pad)
|
||||||
|
self._pending_len += len(right_pad)
|
||||||
saved_count = self._n_audio_tokens_fed
|
saved_count = self._n_audio_tokens_fed
|
||||||
self._feed_pending_audio()
|
self._feed_pending_audio()
|
||||||
self._n_audio_tokens_fed = saved_count
|
self._n_audio_tokens_fed = saved_count
|
||||||
|
|
||||||
# Drain in a loop: the model may still be processing right-padding
|
# Drain in a loop: the model may continue producing text tokens after
|
||||||
# chunks after the first drain returns. Keep draining until no new
|
# the audio queue is empty (autoregressive generation). Each iteration
|
||||||
# text appears for two consecutive rounds.
|
# uses an event-driven blocking drain with short timeouts.
|
||||||
all_words: List[ASRToken] = []
|
all_words: List[ASRToken] = []
|
||||||
for _ in range(5): # at most 5 drain+flush cycles
|
for _ in range(5):
|
||||||
self._drain_streamer_blocking(timeout=5.0)
|
self._drain_streamer_blocking(timeout=5.0)
|
||||||
batch = self._flush_all_pending_words()
|
batch = self._flush_all_pending_words()
|
||||||
all_words.extend(batch)
|
all_words.extend(batch)
|
||||||
|
|
@ -208,7 +251,8 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
# Add right-padding so the model can finish decoding
|
# Add right-padding so the model can finish decoding
|
||||||
if self._right_pad_samples > 0:
|
if self._right_pad_samples > 0:
|
||||||
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
||||||
self._pending_audio = np.append(self._pending_audio, right_pad)
|
self._pending_chunks.append(right_pad)
|
||||||
|
self._pending_len += len(right_pad)
|
||||||
|
|
||||||
# Feed remaining audio
|
# Feed remaining audio
|
||||||
if self._generate_started and not self._generate_finished:
|
if self._generate_started and not self._generate_finished:
|
||||||
|
|
@ -218,7 +262,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
# Wait for generate to finish
|
# Wait for generate to finish
|
||||||
if self._generate_thread is not None:
|
if self._generate_thread is not None:
|
||||||
self._generate_thread.join(timeout=30.0)
|
self._generate_thread.join(timeout=30.0)
|
||||||
elif not self._generate_started and len(self._pending_audio) >= self._first_chunk_samples:
|
elif not self._generate_started and self._pending_len >= self._first_chunk_samples:
|
||||||
# Never started but have enough audio — start and immediately finish
|
# Never started but have enough audio — start and immediately finish
|
||||||
self._start_generate_thread()
|
self._start_generate_thread()
|
||||||
self._feed_pending_audio()
|
self._feed_pending_audio()
|
||||||
|
|
@ -242,8 +286,9 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
model = self.asr.model
|
model = self.asr.model
|
||||||
|
|
||||||
# Extract first chunk
|
# Extract first chunk
|
||||||
first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
|
pending = self._get_pending_audio()
|
||||||
self._pending_audio = self._pending_audio[self._first_chunk_samples:]
|
first_chunk_audio = pending[:self._first_chunk_samples]
|
||||||
|
self._set_pending_audio(pending[self._first_chunk_samples:])
|
||||||
# First chunk covers multiple audio tokens
|
# First chunk covers multiple audio tokens
|
||||||
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
|
self._n_audio_tokens_fed += max(1, self._first_chunk_samples // self._chunk_step)
|
||||||
|
|
||||||
|
|
@ -265,11 +310,14 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
audio_queue = self._audio_queue
|
audio_queue = self._audio_queue
|
||||||
|
|
||||||
def input_features_gen():
|
def input_features_gen():
|
||||||
|
# Track audio consumption inside the generator (runs in generate thread)
|
||||||
|
self._n_audio_tokens_consumed = max(1, self._first_chunk_samples // self._chunk_step)
|
||||||
yield first_inputs.input_features
|
yield first_inputs.input_features
|
||||||
while True:
|
while True:
|
||||||
chunk_audio = audio_queue.get()
|
chunk_audio = audio_queue.get()
|
||||||
if chunk_audio is None:
|
if chunk_audio is None:
|
||||||
break
|
break
|
||||||
|
self._n_audio_tokens_consumed += 1
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
chunk_audio,
|
chunk_audio,
|
||||||
is_streaming=True,
|
is_streaming=True,
|
||||||
|
|
@ -298,6 +346,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
self._generate_error = e
|
self._generate_error = e
|
||||||
finally:
|
finally:
|
||||||
self._generate_finished = True
|
self._generate_finished = True
|
||||||
|
self._generate_done.set()
|
||||||
|
|
||||||
self._generate_thread = threading.Thread(target=run_generate, daemon=True)
|
self._generate_thread = threading.Thread(target=run_generate, daemon=True)
|
||||||
self._generate_thread.start()
|
self._generate_thread.start()
|
||||||
|
|
@ -309,13 +358,22 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
chunk_size = self._chunk_samples
|
chunk_size = self._chunk_samples
|
||||||
step_size = self._chunk_step
|
step_size = self._chunk_step
|
||||||
|
|
||||||
while len(self._pending_audio) >= chunk_size:
|
pending = self._get_pending_audio()
|
||||||
chunk = self._pending_audio[:chunk_size]
|
while len(pending) >= chunk_size:
|
||||||
|
chunk = pending[:chunk_size]
|
||||||
self._audio_queue.put(chunk)
|
self._audio_queue.put(chunk)
|
||||||
self._pending_audio = self._pending_audio[step_size:]
|
pending = pending[step_size:]
|
||||||
self._n_audio_tokens_fed += 1
|
self._n_audio_tokens_fed += 1
|
||||||
|
|
||||||
self.audio_buffer = self._pending_audio
|
self._set_pending_audio(pending)
|
||||||
|
self.audio_buffer = pending
|
||||||
|
|
||||||
|
def _append_text_fragment(self, text_fragment: str):
|
||||||
|
"""Append a text fragment with its audio position (must hold _text_lock)."""
|
||||||
|
self._fragment_positions.append((self._text_len, self._n_audio_tokens_consumed))
|
||||||
|
self._text_fragments.append(text_fragment)
|
||||||
|
self._text_len += len(text_fragment)
|
||||||
|
self._n_text_tokens_received += 1
|
||||||
|
|
||||||
def _drain_streamer(self):
|
def _drain_streamer(self):
|
||||||
"""Non-blocking drain of all available text from the streamer."""
|
"""Non-blocking drain of all available text from the streamer."""
|
||||||
|
|
@ -333,19 +391,13 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
break
|
break
|
||||||
if text_fragment:
|
if text_fragment:
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
self._accumulated_text += text_fragment
|
self._append_text_fragment(text_fragment)
|
||||||
self._n_text_tokens_received += 1
|
|
||||||
|
|
||||||
def _drain_streamer_blocking(self, timeout=30.0):
|
def _drain_streamer_blocking(self, timeout=30.0):
|
||||||
"""Blocking drain: wait for the generate thread to process all queued
|
"""Blocking drain: wait for the generate thread to finish producing text.
|
||||||
audio and produce the corresponding text.
|
|
||||||
|
|
||||||
Polls the text queue while the audio queue has items (model still
|
Uses the _generate_done event to know when the model is truly finished.
|
||||||
processing). Once the audio queue is empty, waits for trailing
|
Falls back to text-queue polling with adaptive timeouts.
|
||||||
tokens, then returns.
|
|
||||||
|
|
||||||
This is critical for start_silence(): without it, the non-blocking
|
|
||||||
drain races with the generate thread and the last words get stuck.
|
|
||||||
"""
|
"""
|
||||||
if not self._generate_started or self._generate_finished:
|
if not self._generate_started or self._generate_finished:
|
||||||
self._drain_streamer()
|
self._drain_streamer()
|
||||||
|
|
@ -353,52 +405,101 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
|
|
||||||
text_queue = self._streamer.text_queue
|
text_queue = self._streamer.text_queue
|
||||||
deadline = time.time() + timeout
|
deadline = time.time() + timeout
|
||||||
|
# Count consecutive empty polls to detect when model has caught up
|
||||||
|
empty_streak = 0
|
||||||
|
|
||||||
while time.time() < deadline:
|
while time.time() < deadline:
|
||||||
# Short poll while model is still processing queued audio;
|
remaining = max(deadline - time.time(), 0.01)
|
||||||
# longer wait once the audio queue is empty (trailing tokens).
|
|
||||||
wait = 2.0 if self._audio_queue.empty() else 0.1
|
# If generate thread is done, do a final flush and exit
|
||||||
|
if self._generate_done.is_set() or self._generate_finished:
|
||||||
|
self._drain_streamer()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Adaptive wait: short while audio is queued, longer once queue is empty
|
||||||
|
if self._audio_queue.empty():
|
||||||
|
wait = min(remaining, 0.5)
|
||||||
|
else:
|
||||||
|
wait = min(remaining, 0.1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text_fragment = text_queue.get(timeout=wait)
|
text_fragment = text_queue.get(timeout=wait)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
if self._audio_queue.empty():
|
empty_streak += 1
|
||||||
break # Audio done + no text for 2s → fully caught up
|
# Only exit if audio queue is empty AND we've had enough empty polls
|
||||||
continue # Audio still queued, model still working
|
# This prevents premature exit when the model is slow
|
||||||
|
if self._audio_queue.empty() and empty_streak >= 4:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
empty_streak = 0
|
||||||
if text_fragment is None:
|
if text_fragment is None:
|
||||||
self._generate_finished = True
|
self._generate_finished = True
|
||||||
break
|
break
|
||||||
if text_fragment:
|
if text_fragment:
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
self._accumulated_text += text_fragment
|
self._append_text_fragment(text_fragment)
|
||||||
self._n_text_tokens_received += 1
|
|
||||||
|
|
||||||
# ── Word extraction ──
|
# ── Word extraction ──
|
||||||
|
|
||||||
def _pos_to_time(self, token_position: int) -> float:
|
def _pos_to_time(self, token_position: int) -> float:
|
||||||
"""Convert token position to seconds."""
|
"""Convert audio token position to seconds."""
|
||||||
return token_position * self._seconds_per_token + self._global_time_offset
|
return token_position * self._seconds_per_token + self._global_time_offset
|
||||||
|
|
||||||
|
def _audio_pos_for_char(self, char_idx: int) -> int:
|
||||||
|
"""Look up the audio token position for a character index in the text.
|
||||||
|
|
||||||
|
Uses the fragment position index recorded when text arrives from the
|
||||||
|
generate thread. Returns the audio position of the fragment that
|
||||||
|
contains ``char_idx``, giving much better word timestamps than the
|
||||||
|
old uniform-distribution heuristic.
|
||||||
|
"""
|
||||||
|
if not self._fragment_positions:
|
||||||
|
return 0
|
||||||
|
# _fragment_positions is sorted by char_offset — find the last entry
|
||||||
|
# whose char_offset <= char_idx (the fragment containing this char).
|
||||||
|
pos = 0
|
||||||
|
for offset, audio_tok in self._fragment_positions:
|
||||||
|
if offset > char_idx:
|
||||||
|
break
|
||||||
|
pos = audio_tok
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def _word_timestamps(self, text: str, words: List[str], start_idx: int, end_idx: int) -> List[Tuple[int, int]]:
|
||||||
|
"""Compute (tok_start, tok_end) for words[start_idx:end_idx] using fragment positions."""
|
||||||
|
# Build char offsets for each word
|
||||||
|
result = []
|
||||||
|
char_pos = 0
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
if i > 0:
|
||||||
|
char_pos += 1 # space separator
|
||||||
|
if start_idx <= i < end_idx:
|
||||||
|
tok_start = self._audio_pos_for_char(char_pos)
|
||||||
|
tok_end = self._audio_pos_for_char(char_pos + len(word))
|
||||||
|
result.append((tok_start, tok_end))
|
||||||
|
char_pos += len(word)
|
||||||
|
return result
|
||||||
|
|
||||||
def _extract_new_words(self) -> List[ASRToken]:
|
def _extract_new_words(self) -> List[ASRToken]:
|
||||||
"""Extract complete words (all but the last, which may still be growing)."""
|
"""Extract complete words (all but the last, which may still be growing)."""
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
text = self._accumulated_text
|
text = self._get_accumulated_text()
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
words = text.split()
|
words = text.split()
|
||||||
new_words: List[ASRToken] = []
|
new_words: List[ASRToken] = []
|
||||||
n_words_total = len(words)
|
n_to_commit = len(words) - 1 # keep last word (may still grow)
|
||||||
n_audio_toks = max(self._n_audio_tokens_fed, 1)
|
|
||||||
|
|
||||||
while len(words) > self._n_committed_words + 1:
|
if n_to_commit <= self._n_committed_words:
|
||||||
|
return []
|
||||||
|
|
||||||
|
timestamps = self._word_timestamps(text, words, self._n_committed_words, n_to_commit)
|
||||||
|
|
||||||
|
for tok_start, tok_end in timestamps:
|
||||||
word = words[self._n_committed_words]
|
word = words[self._n_committed_words]
|
||||||
word_idx = self._n_committed_words
|
|
||||||
|
|
||||||
tok_start = int(word_idx / n_words_total * n_audio_toks) if n_words_total > 0 else 0
|
|
||||||
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks) if n_words_total > 0 else 0
|
|
||||||
|
|
||||||
start_time = self._pos_to_time(tok_start)
|
start_time = self._pos_to_time(tok_start)
|
||||||
end_time = self._pos_to_time(tok_end)
|
end_time = self._pos_to_time(max(tok_end, tok_start + 1))
|
||||||
|
|
||||||
text_out = word if self._n_committed_words == 0 else " " + word
|
text_out = word if self._n_committed_words == 0 else " " + word
|
||||||
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
||||||
|
|
@ -409,24 +510,22 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
def _flush_all_pending_words(self) -> List[ASRToken]:
|
def _flush_all_pending_words(self) -> List[ASRToken]:
|
||||||
"""Flush ALL words including the last partial one."""
|
"""Flush ALL words including the last partial one."""
|
||||||
with self._text_lock:
|
with self._text_lock:
|
||||||
text = self._accumulated_text
|
text = self._get_accumulated_text()
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
words = text.split()
|
words = text.split()
|
||||||
new_words: List[ASRToken] = []
|
new_words: List[ASRToken] = []
|
||||||
n_words_total = max(len(words), 1)
|
|
||||||
n_audio_toks = max(self._n_audio_tokens_fed, 1)
|
|
||||||
|
|
||||||
while self._n_committed_words < len(words):
|
if self._n_committed_words >= len(words):
|
||||||
|
return []
|
||||||
|
|
||||||
|
timestamps = self._word_timestamps(text, words, self._n_committed_words, len(words))
|
||||||
|
|
||||||
|
for tok_start, tok_end in timestamps:
|
||||||
word = words[self._n_committed_words]
|
word = words[self._n_committed_words]
|
||||||
word_idx = self._n_committed_words
|
|
||||||
|
|
||||||
tok_start = int(word_idx / n_words_total * n_audio_toks)
|
|
||||||
tok_end = int((word_idx + 1) / n_words_total * n_audio_toks)
|
|
||||||
|
|
||||||
start_time = self._pos_to_time(tok_start)
|
start_time = self._pos_to_time(tok_start)
|
||||||
end_time = self._pos_to_time(tok_end)
|
end_time = self._pos_to_time(max(tok_end, tok_start + 1))
|
||||||
|
|
||||||
text_out = word if self._n_committed_words == 0 else " " + word
|
text_out = word if self._n_committed_words == 0 else " " + word
|
||||||
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
||||||
|
|
@ -439,7 +538,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||||
# Start generate thread when enough audio is buffered
|
# Start generate thread when enough audio is buffered
|
||||||
if not self._generate_started:
|
if not self._generate_started:
|
||||||
if len(self._pending_audio) >= self._first_chunk_samples:
|
if self._pending_len >= self._first_chunk_samples:
|
||||||
self._start_generate_thread()
|
self._start_generate_thread()
|
||||||
self._feed_pending_audio()
|
self._feed_pending_audio()
|
||||||
else:
|
else:
|
||||||
|
|
@ -450,7 +549,7 @@ class VoxtralHFStreamingOnlineProcessor:
|
||||||
self._feed_pending_audio()
|
self._feed_pending_audio()
|
||||||
|
|
||||||
# If generate finished unexpectedly (EOS) but new audio arrived, restart
|
# If generate finished unexpectedly (EOS) but new audio arrived, restart
|
||||||
if self._generate_finished and len(self._pending_audio) >= self._first_chunk_samples:
|
if self._generate_finished and self._pending_len >= self._first_chunk_samples:
|
||||||
self._drain_streamer()
|
self._drain_streamer()
|
||||||
flush_words = self._flush_all_pending_words()
|
flush_words = self._flush_all_pending_words()
|
||||||
# Reset for new utterance
|
# Reset for new utterance
|
||||||
|
|
|
||||||
|
|
@ -91,20 +91,33 @@ def _mel_filters() -> mx.array:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# DFT helpers
|
# DFT helpers (cached — these are constant for a given WINDOW_SIZE)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_CACHED_WINDOW: mx.array | None = None
|
||||||
|
_CACHED_DFT_RE: mx.array | None = None
|
||||||
|
_CACHED_DFT_IM: mx.array | None = None
|
||||||
|
|
||||||
|
|
||||||
def _hann_window() -> mx.array:
|
def _hann_window() -> mx.array:
|
||||||
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
|
global _CACHED_WINDOW
|
||||||
|
if _CACHED_WINDOW is None:
|
||||||
|
_CACHED_WINDOW = mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
|
||||||
|
return _CACHED_WINDOW
|
||||||
|
|
||||||
|
|
||||||
def _dft_matrices():
|
def _dft_matrices():
|
||||||
"""Pre-compute the real / imaginary DFT basis matrices."""
|
"""Return cached real / imaginary DFT basis matrices."""
|
||||||
n_bins = WINDOW_SIZE // 2 + 1
|
global _CACHED_DFT_RE, _CACHED_DFT_IM
|
||||||
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
|
if _CACHED_DFT_RE is None:
|
||||||
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
|
n_bins = WINDOW_SIZE // 2 + 1
|
||||||
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
|
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
|
||||||
return mx.cos(phase), mx.sin(phase)
|
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
|
||||||
|
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
|
||||||
|
_CACHED_DFT_RE = mx.cos(phase)
|
||||||
|
_CACHED_DFT_IM = mx.sin(phase)
|
||||||
|
mx.eval(_CACHED_DFT_RE, _CACHED_DFT_IM)
|
||||||
|
return _CACHED_DFT_RE, _CACHED_DFT_IM
|
||||||
|
|
||||||
|
|
||||||
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
|
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
|
||||||
|
|
|
||||||
|
|
@ -135,8 +135,9 @@ class VoxtralMLXOnlineProcessor:
|
||||||
|
|
||||||
def _reset_state(self):
|
def _reset_state(self):
|
||||||
"""Reset all incremental state for a fresh utterance."""
|
"""Reset all incremental state for a fresh utterance."""
|
||||||
# Audio accumulation
|
# Audio accumulation (list of chunks, concatenated on demand)
|
||||||
self._pending = np.zeros(0, dtype=np.float32)
|
self._pending_chunks: list[np.ndarray] = []
|
||||||
|
self._pending_len = 0
|
||||||
# Mel overlap
|
# Mel overlap
|
||||||
self._mel_overlap: np.ndarray | None = None
|
self._mel_overlap: np.ndarray | None = None
|
||||||
# Encoder incremental state
|
# Encoder incremental state
|
||||||
|
|
@ -167,10 +168,30 @@ class VoxtralMLXOnlineProcessor:
|
||||||
|
|
||||||
# -- audio ingestion --
|
# -- audio ingestion --
|
||||||
|
|
||||||
|
def _get_pending(self) -> np.ndarray:
|
||||||
|
"""Flatten pending chunks into a single array."""
|
||||||
|
if not self._pending_chunks:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
if len(self._pending_chunks) == 1:
|
||||||
|
return self._pending_chunks[0]
|
||||||
|
flat = np.concatenate(self._pending_chunks)
|
||||||
|
self._pending_chunks = [flat]
|
||||||
|
return flat
|
||||||
|
|
||||||
|
def _set_pending(self, arr: np.ndarray):
|
||||||
|
"""Replace pending audio with a single array."""
|
||||||
|
if len(arr) == 0:
|
||||||
|
self._pending_chunks = []
|
||||||
|
self._pending_len = 0
|
||||||
|
else:
|
||||||
|
self._pending_chunks = [arr]
|
||||||
|
self._pending_len = len(arr)
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
self.end = audio_stream_end_time
|
self.end = audio_stream_end_time
|
||||||
self._pending = np.append(self._pending, audio)
|
self._pending_chunks.append(audio)
|
||||||
self.audio_buffer = self._pending
|
self._pending_len += len(audio)
|
||||||
|
self.audio_buffer = audio # diagnostic only
|
||||||
|
|
||||||
# -- core processing --
|
# -- core processing --
|
||||||
|
|
||||||
|
|
@ -231,22 +252,24 @@ class VoxtralMLXOnlineProcessor:
|
||||||
|
|
||||||
def _encode_pending(self):
|
def _encode_pending(self):
|
||||||
"""Feed pending audio through the incremental encoder."""
|
"""Feed pending audio through the incremental encoder."""
|
||||||
available = len(self._pending)
|
if self._pending_len < SAMPLES_PER_TOKEN:
|
||||||
if available < SAMPLES_PER_TOKEN:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
pending = self._get_pending()
|
||||||
|
available = len(pending)
|
||||||
|
|
||||||
if self._first_chunk:
|
if self._first_chunk:
|
||||||
# First chunk: prepend silence for left-padding
|
# First chunk: prepend silence for left-padding
|
||||||
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
||||||
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
|
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
|
||||||
chunk = np.concatenate([left_pad, self._pending[:n_take]])
|
chunk = np.concatenate([left_pad, pending[:n_take]])
|
||||||
self._pending = self._pending[n_take:]
|
self._set_pending(pending[n_take:])
|
||||||
self._samples_encoded += n_take
|
self._samples_encoded += n_take
|
||||||
self._first_chunk = False
|
self._first_chunk = False
|
||||||
else:
|
else:
|
||||||
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
||||||
chunk = self._pending[:n_take]
|
chunk = pending[:n_take]
|
||||||
self._pending = self._pending[n_take:]
|
self._set_pending(pending[n_take:])
|
||||||
self._samples_encoded += n_take
|
self._samples_encoded += n_take
|
||||||
|
|
||||||
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
|
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
|
||||||
|
|
@ -261,11 +284,10 @@ class VoxtralMLXOnlineProcessor:
|
||||||
mx.eval(embeds)
|
mx.eval(embeds)
|
||||||
if self._audio_embeds is not None:
|
if self._audio_embeds is not None:
|
||||||
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
|
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
|
||||||
|
mx.eval(self._audio_embeds)
|
||||||
else:
|
else:
|
||||||
self._audio_embeds = embeds
|
self._audio_embeds = embeds
|
||||||
|
|
||||||
self.audio_buffer = self._pending
|
|
||||||
|
|
||||||
def _do_prefill(self):
|
def _do_prefill(self):
|
||||||
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
|
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
|
||||||
n_dec_layers = len(self._model.decoder.blocks)
|
n_dec_layers = len(self._model.decoder.blocks)
|
||||||
|
|
@ -430,6 +452,55 @@ class VoxtralMLXOnlineProcessor:
|
||||||
return Transcript(start=None, end=None, text="")
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush all pending words when silence starts.
|
||||||
|
|
||||||
|
Adds right-padding silence and forces a full decode pass so the
|
||||||
|
decoder emits tokens for the last words of speech. Without this,
|
||||||
|
the model holds back the final tokens waiting for future context.
|
||||||
|
"""
|
||||||
|
# Align pending audio to SAMPLES_PER_TOKEN boundary
|
||||||
|
remainder = self._pending_len % SAMPLES_PER_TOKEN
|
||||||
|
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
|
||||||
|
|
||||||
|
# Add alignment + right-padding silence to provide future context
|
||||||
|
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
||||||
|
if total_pad > 0:
|
||||||
|
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
|
||||||
|
self._pending_len += total_pad
|
||||||
|
|
||||||
|
# Encode remaining audio (including right-padding)
|
||||||
|
self._encode_pending()
|
||||||
|
|
||||||
|
# Decode everything that's left
|
||||||
|
if self._audio_embeds is not None and self._prefilled:
|
||||||
|
self._decode_positions(self._audio_embeds.shape[0])
|
||||||
|
|
||||||
|
# Flush last token if it wasn't EOS
|
||||||
|
if self._last_token is not None:
|
||||||
|
tid = self._last_token.item()
|
||||||
|
if tid != self._eos_id:
|
||||||
|
text = self._tokenizer.decode(
|
||||||
|
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
last_pos = self._positions_decoded - self._prefix_len
|
||||||
|
if text.lstrip() != text or not self._full_text:
|
||||||
|
if self._current_word_pos is not None:
|
||||||
|
self._word_audio_ends.append(last_pos)
|
||||||
|
self._word_audio_starts.append(last_pos)
|
||||||
|
self._current_word_pos = last_pos
|
||||||
|
elif self._current_word_pos is None:
|
||||||
|
self._word_audio_starts.append(last_pos)
|
||||||
|
self._current_word_pos = last_pos
|
||||||
|
self._full_text += text
|
||||||
|
self._n_text_tokens += 1
|
||||||
|
|
||||||
|
# Close the last word if still open
|
||||||
|
if self._current_word_pos is not None:
|
||||||
|
last_pos = self._positions_decoded - self._prefix_len
|
||||||
|
self._word_audio_ends.append(last_pos)
|
||||||
|
self._current_word_pos = None
|
||||||
|
|
||||||
words = self._flush_all_words()
|
words = self._flush_all_words()
|
||||||
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
|
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
|
||||||
return words, self.end
|
return words, self.end
|
||||||
|
|
@ -448,7 +519,7 @@ class VoxtralMLXOnlineProcessor:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
|
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
|
||||||
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
|
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
|
||||||
len(self._pending),
|
self._pending_len,
|
||||||
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||||
self._samples_encoded,
|
self._samples_encoded,
|
||||||
self._positions_decoded,
|
self._positions_decoded,
|
||||||
|
|
@ -457,7 +528,7 @@ class VoxtralMLXOnlineProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
|
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
|
||||||
remainder = len(self._pending) % SAMPLES_PER_TOKEN
|
remainder = self._pending_len % SAMPLES_PER_TOKEN
|
||||||
if remainder > 0:
|
if remainder > 0:
|
||||||
align_pad = SAMPLES_PER_TOKEN - remainder
|
align_pad = SAMPLES_PER_TOKEN - remainder
|
||||||
else:
|
else:
|
||||||
|
|
@ -466,9 +537,8 @@ class VoxtralMLXOnlineProcessor:
|
||||||
# Add alignment + right-padding silence
|
# Add alignment + right-padding silence
|
||||||
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
||||||
if total_pad > 0:
|
if total_pad > 0:
|
||||||
self._pending = np.append(
|
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
|
||||||
self._pending, np.zeros(total_pad, dtype=np.float32)
|
self._pending_len += total_pad
|
||||||
)
|
|
||||||
|
|
||||||
# Encode remaining audio (including right-padding)
|
# Encode remaining audio (including right-padding)
|
||||||
self._encode_pending()
|
self._encode_pending()
|
||||||
|
|
@ -476,7 +546,7 @@ class VoxtralMLXOnlineProcessor:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
|
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
|
||||||
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||||
len(self._pending),
|
self._pending_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
hit_eos = False
|
hit_eos = False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue