260 lines
10 KiB
Python
260 lines
10 KiB
Python
import logging
|
|
import re
|
|
import sys
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
|
|
from whisperlivekit.local_agreement.backends import ASRBase
|
|
from whisperlivekit.timed_objects import ASRToken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _patch_transformers_compat():
|
|
"""Patch transformers for qwen_asr 0.0.6 + transformers >= 5.3 compatibility."""
|
|
import torch
|
|
|
|
# 1. check_model_inputs was removed
|
|
try:
|
|
import transformers.utils.generic as _g
|
|
if not hasattr(_g, "check_model_inputs"):
|
|
def check_model_inputs(*args, **kwargs):
|
|
def decorator(fn):
|
|
return fn
|
|
return decorator
|
|
_g.check_model_inputs = check_model_inputs
|
|
except ImportError:
|
|
pass
|
|
|
|
# 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS
|
|
try:
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
if "default" not in ROPE_INIT_FUNCTIONS:
|
|
def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
|
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
partial = getattr(config, "partial_rotary_factor", 1.0)
|
|
dim = int(head_dim * partial)
|
|
base = config.rope_theta
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
|
return inv_freq, 1.0
|
|
ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters
|
|
except ImportError:
|
|
pass
|
|
|
|
# 3. pad_token_id missing on thinker config
|
|
try:
|
|
from qwen_asr.core.transformers_backend.configuration_qwen3_asr import (
|
|
Qwen3ASRThinkerConfig,
|
|
)
|
|
if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"):
|
|
Qwen3ASRThinkerConfig.pad_token_id = None
|
|
except ImportError:
|
|
pass
|
|
|
|
# 4. fix_mistral_regex kwarg not accepted by newer transformers
|
|
try:
|
|
from transformers.models.auto import processing_auto
|
|
_orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__
|
|
|
|
@classmethod
|
|
def _patched_ap_from_pretrained(cls, *args, **kwargs):
|
|
kwargs.pop("fix_mistral_regex", None)
|
|
return _orig_ap_from_pretrained(cls, *args, **kwargs)
|
|
|
|
processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained
|
|
except Exception:
|
|
pass
|
|
|
|
# 5. compute_default_rope_parameters missing on RotaryEmbedding
|
|
try:
|
|
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
|
|
Qwen3ASRThinkerTextRotaryEmbedding,
|
|
)
|
|
if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"):
|
|
@staticmethod
|
|
def _rope_params(config=None, device=None, seq_len=None, **kwargs):
|
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
partial = getattr(config, "partial_rotary_factor", 1.0)
|
|
dim = int(head_dim * partial)
|
|
base = config.rope_theta
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
|
return inv_freq, 1.0
|
|
Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _rope_params
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
_patch_transformers_compat()
|
|
|
|
# Whisper language codes → Qwen3 canonical language names
|
|
WHISPER_TO_QWEN3_LANGUAGE = {
|
|
"zh": "Chinese", "en": "English", "yue": "Cantonese",
|
|
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
|
|
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
|
|
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
|
|
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
|
|
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
|
|
"pl": "Polish", "cs": "Czech", "fa": "Persian",
|
|
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
|
|
}
|
|
|
|
# Reverse mapping: Qwen3 canonical names → Whisper language codes
|
|
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}
|
|
|
|
# Short convenience names → HuggingFace model IDs
|
|
QWEN3_MODEL_MAPPING = {
|
|
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
|
|
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
|
|
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
|
|
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
|
|
# Whisper-style size aliases (map to closest Qwen3 model)
|
|
"large": "Qwen/Qwen3-ASR-1.7B",
|
|
"large-v3": "Qwen/Qwen3-ASR-1.7B",
|
|
"medium": "Qwen/Qwen3-ASR-1.7B",
|
|
"base": "Qwen/Qwen3-ASR-0.6B",
|
|
"small": "Qwen/Qwen3-ASR-0.6B",
|
|
"tiny": "Qwen/Qwen3-ASR-0.6B",
|
|
}
|
|
|
|
_PUNCTUATION_ENDS = set(".!?。!?;;")
|
|
# Qwen3 raw output starts with "language <Name>" metadata before <asr_text> tag.
|
|
# When the tag is missing (silence/noise), this metadata leaks as transcription text.
|
|
_GARBAGE_RE = re.compile(r"^language\s+\S+$", re.IGNORECASE)
|
|
|
|
|
|
class Qwen3ASR(ASRBase):
|
|
"""Qwen3-ASR backend with ForcedAligner word-level timestamps."""
|
|
|
|
sep = "" # tokens include leading spaces, like faster-whisper
|
|
SAMPLING_RATE = 16000
|
|
|
|
def __init__(self, lan="auto", model_size=None, cache_dir=None,
|
|
model_dir=None, logfile=sys.stderr, **kwargs):
|
|
self.logfile = logfile
|
|
self.transcribe_kargs = {}
|
|
self.original_language = None if lan == "auto" else lan
|
|
self.model = self.load_model(model_size, cache_dir, model_dir)
|
|
|
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
|
import torch
|
|
from qwen_asr import Qwen3ASRModel
|
|
|
|
if model_dir:
|
|
model_id = model_dir
|
|
elif model_size:
|
|
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
|
|
else:
|
|
model_id = "Qwen/Qwen3-ASR-1.7B"
|
|
|
|
if torch.cuda.is_available():
|
|
dtype, device = torch.bfloat16, "cuda:0"
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
dtype, device = torch.float32, "mps"
|
|
else:
|
|
dtype, device = torch.float32, "cpu"
|
|
|
|
logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})")
|
|
model = Qwen3ASRModel.from_pretrained(
|
|
model_id,
|
|
forced_aligner="Qwen/Qwen3-ForcedAligner-0.6B",
|
|
forced_aligner_kwargs=dict(dtype=dtype, device_map=device),
|
|
dtype=dtype,
|
|
device_map=device,
|
|
)
|
|
logger.info("Qwen3-ASR loaded with ForcedAligner")
|
|
return model
|
|
|
|
def _qwen3_language(self) -> Optional[str]:
|
|
if self.original_language is None:
|
|
return None
|
|
return WHISPER_TO_QWEN3_LANGUAGE.get(self.original_language)
|
|
|
|
def transcribe(self, audio: np.ndarray, init_prompt: str = ""):
|
|
try:
|
|
results = self.model.transcribe(
|
|
audio=(audio, 16000),
|
|
language=self._qwen3_language(),
|
|
context=init_prompt or "",
|
|
return_time_stamps=True,
|
|
)
|
|
except Exception:
|
|
logger.warning("Qwen3 timestamp alignment failed, falling back to no timestamps", exc_info=True)
|
|
results = self.model.transcribe(
|
|
audio=(audio, 16000),
|
|
language=self._qwen3_language(),
|
|
context=init_prompt or "",
|
|
return_time_stamps=False,
|
|
)
|
|
result = results[0]
|
|
# Stash audio length for timestamp estimation fallback
|
|
result._audio_duration = len(audio) / 16000
|
|
logger.info(
|
|
"Qwen3 result: language=%r text=%r ts=%s",
|
|
result.language, result.text[:80] if result.text else "",
|
|
bool(result.time_stamps),
|
|
)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _detected_language(result) -> Optional[str]:
|
|
"""Extract Whisper-style language code from Qwen3 result."""
|
|
lang = getattr(result, 'language', None)
|
|
if not lang or lang.lower() == "none":
|
|
return None
|
|
# merge_languages may return comma-separated; take the first
|
|
first = lang.split(",")[0].strip()
|
|
if not first or first.lower() == "none":
|
|
return None
|
|
return QWEN3_TO_WHISPER_LANGUAGE.get(first, first.lower())
|
|
|
|
def ts_words(self, result) -> List[ASRToken]:
|
|
# Filter garbage model output (e.g. "language None" for silence/noise)
|
|
text = (result.text or "").strip()
|
|
if not text or _GARBAGE_RE.match(text):
|
|
if text:
|
|
logger.info("Filtered garbage Qwen3 output: %r", text)
|
|
return []
|
|
detected = self._detected_language(result)
|
|
if result.time_stamps:
|
|
tokens = []
|
|
for i, item in enumerate(result.time_stamps):
|
|
# Prepend space to match faster-whisper convention (tokens carry
|
|
# their own whitespace so ''.join works in Segment.from_tokens)
|
|
text = item.text if i == 0 else " " + item.text
|
|
tokens.append(ASRToken(
|
|
start=item.start_time, end=item.end_time, text=text,
|
|
detected_language=detected,
|
|
))
|
|
return tokens
|
|
# Fallback: estimate timestamps from word count
|
|
if not result.text:
|
|
return []
|
|
words = result.text.split()
|
|
duration = getattr(result, '_audio_duration', 5.0)
|
|
step = duration / max(len(words), 1)
|
|
return [
|
|
ASRToken(
|
|
start=round(i * step, 3), end=round((i + 1) * step, 3),
|
|
text=w if i == 0 else " " + w,
|
|
detected_language=detected,
|
|
)
|
|
for i, w in enumerate(words)
|
|
]
|
|
|
|
def segments_end_ts(self, result) -> List[float]:
|
|
if not result.time_stamps:
|
|
duration = getattr(result, '_audio_duration', 5.0)
|
|
return [duration]
|
|
# Create segment boundaries at punctuation marks
|
|
ends = []
|
|
for item in result.time_stamps:
|
|
if item.text and item.text.rstrip()[-1:] in _PUNCTUATION_ENDS:
|
|
ends.append(item.end_time)
|
|
last_end = result.time_stamps[-1].end_time
|
|
if not ends or ends[-1] != last_end:
|
|
ends.append(last_end)
|
|
return ends
|
|
|
|
def use_vad(self):
|
|
return False
|