591 lines
23 KiB
Python
591 lines
23 KiB
Python
"""
|
|
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
|
|
|
|
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
|
|
(streaming processor) that plug into WhisperLiveKit's audio processing
|
|
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
|
|
|
|
Unlike the HuggingFace backend, this runs the full inference loop in-process
|
|
(no background thread / queue) — MLX operations on Apple Silicon are fast
|
|
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
|
|
"""
|
|
|
|
import logging
|
|
import sys
|
|
import time
|
|
from typing import List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
|
|
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
|
from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model
|
|
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
|
|
from whisperlivekit.voxtral_mlx.spectrogram import (
|
|
LEFT_PAD_TOKENS,
|
|
RIGHT_PAD_TOKENS,
|
|
SAMPLES_PER_TOKEN,
|
|
compute_mel_streaming,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Decoder sliding-window size (matches the model's training configuration).
|
|
_DECODER_WINDOW = 8192
|
|
|
|
|
|
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
|
|
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
|
|
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
|
|
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
|
|
return ids, n_delay
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model holder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class VoxtralMLXASR:
|
|
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
|
|
it alive for the lifetime of the server."""
|
|
|
|
sep = " "
|
|
SAMPLING_RATE = 16_000
|
|
|
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
|
self.logfile = logfile
|
|
self.transcribe_kargs = {}
|
|
|
|
lan = kwargs.get("lan", "auto")
|
|
self.original_language = None if lan == "auto" else lan
|
|
|
|
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
|
if not model_path:
|
|
model_size = kwargs.get("model_size", "")
|
|
if model_size and ("/" in model_size or model_size.startswith(".")):
|
|
model_path = model_size
|
|
else:
|
|
model_path = DEFAULT_MODEL_ID
|
|
|
|
t0 = time.time()
|
|
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
|
|
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
|
|
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
|
|
|
|
self.backend_choice = "voxtral-mlx"
|
|
|
|
def transcribe(self, audio):
|
|
pass # all work happens in the online processor
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Online processor
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class VoxtralMLXOnlineProcessor:
|
|
"""Streaming processor that incrementally encodes audio and decodes text
|
|
using the MLX Voxtral model.
|
|
|
|
Lifecycle (called by ``AudioProcessor.transcription_processor``):
|
|
|
|
insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
|
|
... repeat ...
|
|
start_silence() / end_silence()
|
|
finish()
|
|
"""
|
|
|
|
SAMPLING_RATE = 16_000
|
|
|
|
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
|
|
self.asr = asr
|
|
self.logfile = logfile
|
|
self.end = 0.0
|
|
self.buffer: list = []
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
|
|
self._model = asr.model
|
|
self._tokenizer = asr.tokenizer
|
|
|
|
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
|
|
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
|
|
self._prefix_len = len(self._prompt_ids)
|
|
|
|
self._delay_cond = self._model.delay_embedding(
|
|
mx.array([self._n_delay], dtype=mx.float32)
|
|
)
|
|
mx.eval(self._delay_cond)
|
|
|
|
self._prompt_embeds = self._model.decoder.embed(
|
|
mx.array([self._prompt_ids])
|
|
)[0] # [prefix_len, dim]
|
|
mx.eval(self._prompt_embeds)
|
|
|
|
self._eos_id = self._tokenizer.eos_id
|
|
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
|
|
# The streaming model has an inherent delay: text for audio at position P
|
|
# is generated at decoder position P + n_delay. Compensate timestamps.
|
|
self._delay_secs = self._n_delay * self._secs_per_token
|
|
|
|
self._reset_state()
|
|
|
|
# -- state management --
|
|
|
|
def _reset_state(self):
|
|
"""Reset all incremental state for a fresh utterance."""
|
|
# Audio accumulation (list of chunks, concatenated on demand)
|
|
self._pending_chunks: list[np.ndarray] = []
|
|
self._pending_len = 0
|
|
# Mel overlap
|
|
self._mel_overlap: np.ndarray | None = None
|
|
# Encoder incremental state
|
|
self._conv_tail1 = None
|
|
self._conv_tail2 = None
|
|
self._enc_cache = None
|
|
self._ds_remainder = None
|
|
# Audio embeddings not yet decoded
|
|
self._audio_embeds: mx.array | None = None
|
|
# Decoder state
|
|
self._dec_cache: list[SlidingKVCache] | None = None
|
|
self._last_token: mx.array | None = None
|
|
# Bookkeeping
|
|
self._samples_encoded = 0
|
|
self._positions_decoded = 0
|
|
self._prefilled = False
|
|
self._first_chunk = True
|
|
# Text state
|
|
self._full_text = ""
|
|
self._n_text_tokens = 0
|
|
self._n_committed_words = 0
|
|
self._time_offset = 0.0
|
|
# Per-word audio position tracking: decoder position (relative to prefix)
|
|
# where each word in _full_text started and ended
|
|
self._word_audio_starts: list[int] = [] # audio pos where word i started
|
|
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
|
|
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
|
|
|
|
# -- audio ingestion --
|
|
|
|
def _get_pending(self) -> np.ndarray:
|
|
"""Flatten pending chunks into a single array."""
|
|
if not self._pending_chunks:
|
|
return np.zeros(0, dtype=np.float32)
|
|
if len(self._pending_chunks) == 1:
|
|
return self._pending_chunks[0]
|
|
flat = np.concatenate(self._pending_chunks)
|
|
self._pending_chunks = [flat]
|
|
return flat
|
|
|
|
def _set_pending(self, arr: np.ndarray):
|
|
"""Replace pending audio with a single array."""
|
|
if len(arr) == 0:
|
|
self._pending_chunks = []
|
|
self._pending_len = 0
|
|
else:
|
|
self._pending_chunks = [arr]
|
|
self._pending_len = len(arr)
|
|
|
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
|
self.end = audio_stream_end_time
|
|
self._pending_chunks.append(audio)
|
|
self._pending_len += len(audio)
|
|
self.audio_buffer = audio # diagnostic only
|
|
|
|
# -- core processing --
|
|
|
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
|
try:
|
|
return self._step(is_last)
|
|
except Exception as e:
|
|
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
|
|
return [], self.end
|
|
|
|
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
|
# 1. Encode any new audio
|
|
self._encode_pending()
|
|
|
|
if self._audio_embeds is None:
|
|
return [], self.end
|
|
|
|
# 2. Compute how many positions we can safely decode
|
|
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
|
|
n_available = self._audio_embeds.shape[0]
|
|
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
|
|
|
if n_decodable <= 0:
|
|
return [], self.end
|
|
|
|
# 3. Prefill if needed
|
|
if not self._prefilled:
|
|
if self._positions_decoded + n_available < self._prefix_len:
|
|
return [], self.end
|
|
self._do_prefill()
|
|
# Re-check after consuming prefix embeddings
|
|
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
|
|
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
|
|
|
if n_decodable <= 0 or self._audio_embeds is None:
|
|
return [], self.end
|
|
|
|
# 4. Decode available positions
|
|
hit_eos = self._decode_positions(n_decodable)
|
|
|
|
if hit_eos:
|
|
# Flush words, reset for next utterance
|
|
words = self._flush_all_words()
|
|
logger.debug(
|
|
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
|
|
"samples_encoded=%d (%.2fs), text='%s'",
|
|
len(words), self._samples_encoded,
|
|
self._samples_encoded / self.SAMPLING_RATE,
|
|
self._full_text[-60:] if self._full_text else "",
|
|
)
|
|
saved_offset = self._time_offset
|
|
self._reset_state()
|
|
self._time_offset = saved_offset
|
|
return words, self.end
|
|
|
|
# 5. Extract committed words (all but the last, which may still grow)
|
|
return self._extract_committed_words(), self.end
|
|
|
|
def _encode_pending(self):
|
|
"""Feed pending audio through the incremental encoder."""
|
|
if self._pending_len < SAMPLES_PER_TOKEN:
|
|
return
|
|
|
|
pending = self._get_pending()
|
|
available = len(pending)
|
|
|
|
if self._first_chunk:
|
|
# First chunk: prepend silence for left-padding
|
|
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
|
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
|
|
chunk = np.concatenate([left_pad, pending[:n_take]])
|
|
self._set_pending(pending[n_take:])
|
|
self._samples_encoded += n_take
|
|
self._first_chunk = False
|
|
else:
|
|
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
|
chunk = pending[:n_take]
|
|
self._set_pending(pending[n_take:])
|
|
self._samples_encoded += n_take
|
|
|
|
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
|
|
|
|
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
|
|
self._model.encode_incremental(
|
|
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
|
|
)
|
|
)
|
|
|
|
if embeds is not None:
|
|
mx.eval(embeds)
|
|
if self._audio_embeds is not None:
|
|
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
|
|
mx.eval(self._audio_embeds)
|
|
else:
|
|
self._audio_embeds = embeds
|
|
|
|
def _do_prefill(self):
|
|
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
|
|
n_dec_layers = len(self._model.decoder.blocks)
|
|
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
|
|
|
|
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
|
|
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
|
|
|
|
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
|
|
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
|
|
|
|
self._last_token = self._sample(logits)
|
|
mx.async_eval(self._last_token)
|
|
|
|
# Remove consumed prefix embeddings
|
|
self._audio_embeds = self._audio_embeds[self._prefix_len :]
|
|
if self._audio_embeds.shape[0] == 0:
|
|
self._audio_embeds = None
|
|
self._positions_decoded = self._prefix_len
|
|
self._prefilled = True
|
|
|
|
def _decode_positions(self, n: int) -> bool:
|
|
"""Autoregressively decode *n* positions. Returns True on EOS."""
|
|
base_pos = self._positions_decoded # absolute position before this batch
|
|
for i in range(n):
|
|
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
|
|
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
|
|
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
|
|
next_tok = self._sample(logits)
|
|
mx.async_eval(next_tok)
|
|
|
|
token_id = self._last_token.item()
|
|
if token_id == self._eos_id:
|
|
# Close the current word if one is being built
|
|
if self._current_word_pos is not None:
|
|
self._word_audio_ends.append(base_pos + i - self._prefix_len)
|
|
self._current_word_pos = None
|
|
self._trim_embeds(i)
|
|
self._positions_decoded += i
|
|
return True
|
|
|
|
text = self._tokenizer.decode(
|
|
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
|
|
)
|
|
|
|
if text:
|
|
audio_pos = base_pos + i - self._prefix_len
|
|
|
|
# Detect word boundary: new word starts with space or is the very first text
|
|
if text.lstrip() != text or not self._full_text:
|
|
# Close previous word if exists
|
|
if self._current_word_pos is not None:
|
|
self._word_audio_ends.append(audio_pos)
|
|
# Start new word
|
|
self._word_audio_starts.append(audio_pos)
|
|
self._current_word_pos = audio_pos
|
|
elif self._current_word_pos is None:
|
|
# First token of first word (no leading space)
|
|
self._word_audio_starts.append(audio_pos)
|
|
self._current_word_pos = audio_pos
|
|
|
|
self._full_text += text
|
|
self._n_text_tokens += 1
|
|
|
|
if i > 0 and i % 256 == 0:
|
|
mx.clear_cache()
|
|
|
|
self._last_token = next_tok
|
|
|
|
self._positions_decoded += n
|
|
self._trim_embeds(n)
|
|
return False
|
|
|
|
def _trim_embeds(self, n_consumed: int):
|
|
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
|
|
self._audio_embeds = self._audio_embeds[n_consumed:]
|
|
else:
|
|
self._audio_embeds = None
|
|
|
|
def _sample(self, logits: mx.array) -> mx.array:
|
|
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
|
|
|
# -- word extraction --
|
|
|
|
def _audio_pos_to_time(self, pos: int) -> float:
|
|
"""Convert an audio position (relative to prefix end) to seconds."""
|
|
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
|
|
|
|
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
|
|
"""Compute (start, end) time for a word using tracked word positions."""
|
|
starts = self._word_audio_starts
|
|
ends = self._word_audio_ends
|
|
|
|
if not starts:
|
|
return self._time_offset, self._time_offset
|
|
|
|
# Get start position for this word
|
|
if word_idx < len(starts):
|
|
t0 = self._audio_pos_to_time(starts[word_idx])
|
|
else:
|
|
# Fallback: estimate from last known position
|
|
last_pos = ends[-1] if ends else starts[-1]
|
|
t0 = self._audio_pos_to_time(last_pos + 1)
|
|
|
|
# Get end position: use the start of the next word, or the end of this word
|
|
if word_idx + 1 < len(starts):
|
|
t1 = self._audio_pos_to_time(starts[word_idx + 1])
|
|
elif word_idx < len(ends):
|
|
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
|
|
else:
|
|
# Last word, still being built: use last known position + 1 token
|
|
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
|
|
t1 = self._audio_pos_to_time(last_pos + 1)
|
|
|
|
return t0, t1
|
|
|
|
def _extract_committed_words(self) -> List[ASRToken]:
|
|
"""Return complete words (all except the last which may still grow)."""
|
|
if not self._full_text:
|
|
return []
|
|
words = self._full_text.split()
|
|
tokens: List[ASRToken] = []
|
|
n_total = max(len(words), 1)
|
|
|
|
while len(words) > self._n_committed_words + 1:
|
|
w = words[self._n_committed_words]
|
|
idx = self._n_committed_words
|
|
t0, t1 = self._word_time_range(idx, n_total)
|
|
label = w if idx == 0 else " " + w
|
|
tokens.append(ASRToken(start=t0, end=t1, text=label))
|
|
self._n_committed_words += 1
|
|
|
|
return tokens
|
|
|
|
def _flush_all_words(self) -> List[ASRToken]:
|
|
"""Flush every word including the last partial one."""
|
|
if not self._full_text:
|
|
return []
|
|
words = self._full_text.split()
|
|
tokens: List[ASRToken] = []
|
|
n_total = max(len(words), 1)
|
|
|
|
while self._n_committed_words < len(words):
|
|
w = words[self._n_committed_words]
|
|
idx = self._n_committed_words
|
|
t0, t1 = self._word_time_range(idx, n_total)
|
|
label = w if idx == 0 else " " + w
|
|
tokens.append(ASRToken(start=t0, end=t1, text=label))
|
|
self._n_committed_words += 1
|
|
|
|
return tokens
|
|
|
|
# -- interface methods --
|
|
|
|
def get_buffer(self) -> Transcript:
|
|
if not self._full_text:
|
|
return Transcript(start=None, end=None, text="")
|
|
words = self._full_text.split()
|
|
remaining = words[self._n_committed_words :]
|
|
if remaining:
|
|
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
|
|
return Transcript(start=None, end=None, text="")
|
|
|
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
|
"""Flush all pending words when silence starts.
|
|
|
|
Adds right-padding silence and forces a full decode pass so the
|
|
decoder emits tokens for the last words of speech. Without this,
|
|
the model holds back the final tokens waiting for future context.
|
|
"""
|
|
# Align pending audio to SAMPLES_PER_TOKEN boundary
|
|
remainder = self._pending_len % SAMPLES_PER_TOKEN
|
|
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
|
|
|
|
# Add alignment + right-padding silence to provide future context
|
|
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
|
if total_pad > 0:
|
|
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
|
|
self._pending_len += total_pad
|
|
|
|
# Encode remaining audio (including right-padding)
|
|
self._encode_pending()
|
|
|
|
# Decode everything that's left
|
|
if self._audio_embeds is not None and self._prefilled:
|
|
self._decode_positions(self._audio_embeds.shape[0])
|
|
|
|
# Flush last token if it wasn't EOS
|
|
if self._last_token is not None:
|
|
tid = self._last_token.item()
|
|
if tid != self._eos_id:
|
|
text = self._tokenizer.decode(
|
|
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
|
)
|
|
if text:
|
|
last_pos = self._positions_decoded - self._prefix_len
|
|
if text.lstrip() != text or not self._full_text:
|
|
if self._current_word_pos is not None:
|
|
self._word_audio_ends.append(last_pos)
|
|
self._word_audio_starts.append(last_pos)
|
|
self._current_word_pos = last_pos
|
|
elif self._current_word_pos is None:
|
|
self._word_audio_starts.append(last_pos)
|
|
self._current_word_pos = last_pos
|
|
self._full_text += text
|
|
self._n_text_tokens += 1
|
|
|
|
# Close the last word if still open
|
|
if self._current_word_pos is not None:
|
|
last_pos = self._positions_decoded - self._prefix_len
|
|
self._word_audio_ends.append(last_pos)
|
|
self._current_word_pos = None
|
|
|
|
words = self._flush_all_words()
|
|
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
|
|
return words, self.end
|
|
|
|
def end_silence(self, silence_duration: float, offset: float):
|
|
self._time_offset += silence_duration
|
|
self.end += silence_duration
|
|
|
|
def new_speaker(self, change_speaker):
|
|
self.start_silence()
|
|
|
|
def warmup(self, audio, init_prompt=""):
|
|
pass
|
|
|
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
logger.debug(
|
|
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
|
|
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
|
|
self._pending_len,
|
|
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
|
self._samples_encoded,
|
|
self._positions_decoded,
|
|
self._prefilled,
|
|
self._full_text[-80:] if self._full_text else "",
|
|
)
|
|
|
|
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
|
|
remainder = self._pending_len % SAMPLES_PER_TOKEN
|
|
if remainder > 0:
|
|
align_pad = SAMPLES_PER_TOKEN - remainder
|
|
else:
|
|
align_pad = 0
|
|
|
|
# Add alignment + right-padding silence
|
|
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
|
if total_pad > 0:
|
|
self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
|
|
self._pending_len += total_pad
|
|
|
|
# Encode remaining audio (including right-padding)
|
|
self._encode_pending()
|
|
|
|
logger.debug(
|
|
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
|
|
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
|
self._pending_len,
|
|
)
|
|
|
|
hit_eos = False
|
|
|
|
# Decode everything that's left from right-padding
|
|
if self._audio_embeds is not None and self._prefilled:
|
|
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
|
|
logger.debug(
|
|
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
|
|
hit_eos, self._full_text[-80:] if self._full_text else "",
|
|
)
|
|
|
|
# Flush last token if it wasn't EOS
|
|
if self._last_token is not None:
|
|
tid = self._last_token.item()
|
|
if tid != self._eos_id:
|
|
text = self._tokenizer.decode(
|
|
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
|
)
|
|
if text:
|
|
last_pos = self._positions_decoded - self._prefix_len
|
|
# Check if this starts a new word
|
|
if text.lstrip() != text or not self._full_text:
|
|
if self._current_word_pos is not None:
|
|
self._word_audio_ends.append(last_pos)
|
|
self._word_audio_starts.append(last_pos)
|
|
self._current_word_pos = last_pos
|
|
elif self._current_word_pos is None:
|
|
self._word_audio_starts.append(last_pos)
|
|
self._current_word_pos = last_pos
|
|
self._full_text += text
|
|
self._n_text_tokens += 1
|
|
|
|
# Close the last word if still open
|
|
if self._current_word_pos is not None:
|
|
last_pos = self._positions_decoded - self._prefix_len
|
|
self._word_audio_ends.append(last_pos)
|
|
self._current_word_pos = None
|
|
|
|
words = self._flush_all_words()
|
|
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
|
|
return words, self.end
|