logfile reviewed, whisper_timestamped loading module and vad

PR #10, issues #9, #30
This commit is contained in:
Dominik Macháček 2023-11-28 12:14:54 +01:00
parent bd0d848e7f
commit 8f32dea5ca

View file

@ -26,12 +26,15 @@ class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when neeeded)
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan
self.model = self.load_model(modelsize, cache_dir, model_dir)
def load_model(self, modelsize, cache_dir):
raise NotImplemented("must be implemented in the child class")
@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase):
sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
global whisper_timestamped # has to be global as it is used at each `transcribe` call
import whisper
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
print("ignoring model_dir, not implemented",file=self.logfile)
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
result = self.transcribe_timestamped(self.model,
audio, language=self.original_language,
initial_prompt=init_prompt, verbose=None,
condition_on_previous_text=True, **self.transcribe_kargs)
return result
def ts_words(self,r):
@ -74,7 +80,12 @@ class WhisperTimestampedASR(ASRBase):
return [s["end"] for s in res["segments"]]
def use_vad(self):
raise NotImplemented("Feature use_vad is not implemented for whisper_timestamped backend.")
self.transcribe_kargs["vad"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class FasterWhisperASR(ASRBase):
@ -135,7 +146,6 @@ class FasterWhisperASR(ASRBase):
class HypothesisBuffer:
def __init__(self, logfile=sys.stderr):
"""output: where to store the log. Leave it unchanged to print to terminal."""
self.commited_in_buffer = []
self.buffer = []
self.new = []
@ -205,7 +215,7 @@ class OnlineASRProcessor:
def __init__(self, asr, tokenizer, logfile=sys.stderr):
"""asr: WhisperASR object
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
output: where to store the log. Leave it unchanged to print to terminal.
logfile: where to store the log.
"""
self.asr = asr
self.tokenizer = tokenizer
@ -468,21 +478,24 @@ if __name__ == "__main__":
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
args = parser.parse_args()
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = sys.stderr
if args.offline and args.comp_unaware:
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=sys.stderr)
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
sys.exit(1)
audio_path = args.audio_path
SAMPLING_RATE = 16000
duration = len(load_audio(audio_path))/SAMPLING_RATE
print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
size = args.model
language = args.lan
t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
if args.backend == "faster-whisper":
asr_cls = FasterWhisperASR
@ -499,15 +512,15 @@ if __name__ == "__main__":
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
if args.vad:
print("setting VAD filter",file=sys.stderr)
print("setting VAD filter",file=logfile)
asr.use_vad()
min_chunk = args.min_chunk_size
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile)
# load the audio into the LRU cache before we start the timer
@ -529,10 +542,10 @@ if __name__ == "__main__":
if now is None:
now = time.time()-start
if o[0] is not None:
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=sys.stderr,flush=True)
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
else:
print(o,file=sys.stderr,flush=True)
print(o,file=logfile,flush=True)
if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
@ -540,7 +553,7 @@ if __name__ == "__main__":
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=sys.stderr)
print("assertion error",file=logfile)
pass
else:
output_transcript(o)
@ -553,12 +566,12 @@ if __name__ == "__main__":
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=sys.stderr)
print("assertion error",file=logfile)
pass
else:
output_transcript(o, now=end)
print(f"## last processed {end:.2f}s",file=sys.stderr,flush=True)
print(f"## last processed {end:.2f}s",file=logfile,flush=True)
beg = end
end += min_chunk
@ -580,12 +593,12 @@ if __name__ == "__main__":
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=sys.stderr)
print("assertion error",file=logfile)
pass
else:
output_transcript(o)
now = time.time() - start
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr,flush=True)
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
if end >= duration:
break