diff --git a/whisper_online.py b/whisper_online.py index c11e53c..53c8417 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -156,6 +156,63 @@ class FasterWhisperASR(ASRBase): def set_translate_task(self): self.transcribe_kargs["task"] = "translate" +class MLXWhisper(ASRBase): + """ + Uses MPX Whisper library as the backend, optimized for Apple Silicon. + Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc + Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx + """ + + sep = " " + + def load_model(self, modelsize=None, model_dir=None): + from mlx_whisper import transcribe + + if model_dir is not None: + logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") + model_size_or_path = model_dir + elif modelsize is not None: + logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.") + model_size_or_path = modelsize + elif modelsize == None: + logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.") + model_size_or_path = "mlx-community/whisper-large-v3-mlx" + + self.model_size_or_path = model_size_or_path + return transcribe + + def transcribe(self, audio, init_prompt=""): + segments = self.model( + audio, + language=self.original_language, + initial_prompt=init_prompt, + word_timestamps=True, + condition_on_previous_text=True, + path_or_hf_repo=self.model_size_or_path, + **self.transcribe_kargs + ) + return segments.get("segments", []) + + + def ts_words(self, segments): + """ + Extract timestamped words from transcription segments and skips words with high no-speech probability. + """ + return [ + (word["start"], word["end"], word["word"]) + for segment in segments + for word in segment.get("words", []) + if segment.get("no_speech_prob", 0) <= 0.9 + ] + + def segments_end_ts(self, res): + return [s['end'] for s in res] + + def use_vad(self): + self.transcribe_kargs["vad_filter"] = True + + def set_translate_task(self): + self.transcribe_kargs["task"] = "translate" class OpenaiApiASR(ASRBase): """Uses OpenAI's Whisper API for audio transcription.""" @@ -660,7 +717,7 @@ def add_shared_args(parser): parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.") parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.") parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.") - parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.') + parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.') parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.') parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.') parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.') @@ -679,6 +736,8 @@ def asr_factory(args, logfile=sys.stderr): else: if backend == "faster-whisper": asr_cls = FasterWhisperASR + elif backend == "mlx-whisper": + asr_cls = MLXWhisper else: asr_cls = WhisperTimestampedASR