diff --git a/whisper_online.py b/whisper_online.py index 26fe6db..266fc72 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -23,15 +23,19 @@ def load_audio_chunk(fname, beg, end): class ASRBase: - # join transcribe words with this character (" " for whisper_timestamped, "" for faster-whisper because it emits the spaces when neeeded) - sep = " " + sep = " " # join transcribe words with this character (" " for whisper_timestamped, + # "" for faster-whisper because it emits the spaces when neeeded) def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None): self.transcribe_kargs = {} self.original_language = lan + self.import_backend() self.model = self.load_model(modelsize, cache_dir, model_dir) + def import_backend(self): + raise NotImplemented("must be implemented in the child class") + def load_model(self, modelsize, cache_dir): raise NotImplemented("must be implemented in the child class") @@ -49,11 +53,14 @@ class ASRBase: class WhisperTimestampedASR(ASRBase): """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. On the other hand, the installation for GPU could be easier. + """ - If used, requires imports: + sep = " " + + def import_backend(self): + global whisper, whisper_timestamped import whisper import whisper_timestamped - """ def load_model(self, modelsize=None, cache_dir=None, model_dir=None): if model_dir is not None: @@ -89,8 +96,12 @@ class FasterWhisperASR(ASRBase): sep = "" + def import_backend(self): + global faster_whisper + import faster_whisper + def load_model(self, modelsize=None, cache_dir=None, model_dir=None): - from faster_whisper import WhisperModel + #from faster_whisper import WhisperModel if model_dir is not None: @@ -465,11 +476,11 @@ if __name__ == "__main__": #asr = WhisperASR(lan=language, modelsize=size) if args.backend == "faster-whisper": - from faster_whisper import WhisperModel + #from faster_whisper import WhisperModel asr_cls = FasterWhisperASR else: - import whisper - import whisper_timestamped + #import whisper + #import whisper_timestamped # from whisper_timestamped_model import WhisperTimestampedASR asr_cls = WhisperTimestampedASR