logfile reviewed, whisper_timestamped loading module and vad
PR #10, issues #9, #30
This commit is contained in:
parent
bd0d848e7f
commit
8f32dea5ca
1 changed files with 33 additions and 20 deletions
|
|
@ -26,12 +26,15 @@ class ASRBase:
|
|||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when neeeded)
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = lan
|
||||
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
|
||||
def load_model(self, modelsize, cache_dir):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
|
||||
|
|
@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase):
|
|||
sep = " "
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
global whisper_timestamped # has to be global as it is used at each `transcribe` call
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
from whisper_timestamped import transcribe_timestamped
|
||||
self.transcribe_timestamped = transcribe_timestamped
|
||||
if model_dir is not None:
|
||||
print("ignoring model_dir, not implemented",file=self.logfile)
|
||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
|
||||
result = self.transcribe_timestamped(self.model,
|
||||
audio, language=self.original_language,
|
||||
initial_prompt=init_prompt, verbose=None,
|
||||
condition_on_previous_text=True, **self.transcribe_kargs)
|
||||
return result
|
||||
|
||||
def ts_words(self,r):
|
||||
|
|
@ -74,7 +80,12 @@ class WhisperTimestampedASR(ASRBase):
|
|||
return [s["end"] for s in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
raise NotImplemented("Feature use_vad is not implemented for whisper_timestamped backend.")
|
||||
self.transcribe_kargs["vad"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
|
|
@ -135,7 +146,6 @@ class FasterWhisperASR(ASRBase):
|
|||
class HypothesisBuffer:
|
||||
|
||||
def __init__(self, logfile=sys.stderr):
|
||||
"""output: where to store the log. Leave it unchanged to print to terminal."""
|
||||
self.commited_in_buffer = []
|
||||
self.buffer = []
|
||||
self.new = []
|
||||
|
|
@ -205,7 +215,7 @@ class OnlineASRProcessor:
|
|||
def __init__(self, asr, tokenizer, logfile=sys.stderr):
|
||||
"""asr: WhisperASR object
|
||||
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
|
||||
output: where to store the log. Leave it unchanged to print to terminal.
|
||||
logfile: where to store the log.
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenizer = tokenizer
|
||||
|
|
@ -468,21 +478,24 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
|
||||
logfile = sys.stderr
|
||||
|
||||
if args.offline and args.comp_unaware:
|
||||
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=sys.stderr)
|
||||
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
|
||||
sys.exit(1)
|
||||
|
||||
audio_path = args.audio_path
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
||||
print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
|
||||
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
||||
|
||||
size = args.model
|
||||
language = args.lan
|
||||
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
|
||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
||||
|
||||
if args.backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
|
|
@ -499,15 +512,15 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=sys.stderr)
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
|
||||
min_chunk = args.min_chunk_size
|
||||
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))
|
||||
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile)
|
||||
|
||||
|
||||
# load the audio into the LRU cache before we start the timer
|
||||
|
|
@ -529,10 +542,10 @@ if __name__ == "__main__":
|
|||
if now is None:
|
||||
now = time.time()-start
|
||||
if o[0] is not None:
|
||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=sys.stderr,flush=True)
|
||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
|
||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
||||
else:
|
||||
print(o,file=sys.stderr,flush=True)
|
||||
print(o,file=logfile,flush=True)
|
||||
|
||||
if args.offline: ## offline mode processing (for testing/debugging)
|
||||
a = load_audio(audio_path)
|
||||
|
|
@ -540,7 +553,7 @@ if __name__ == "__main__":
|
|||
try:
|
||||
o = online.process_iter()
|
||||
except AssertionError:
|
||||
print("assertion error",file=sys.stderr)
|
||||
print("assertion error",file=logfile)
|
||||
pass
|
||||
else:
|
||||
output_transcript(o)
|
||||
|
|
@ -553,12 +566,12 @@ if __name__ == "__main__":
|
|||
try:
|
||||
o = online.process_iter()
|
||||
except AssertionError:
|
||||
print("assertion error",file=sys.stderr)
|
||||
print("assertion error",file=logfile)
|
||||
pass
|
||||
else:
|
||||
output_transcript(o, now=end)
|
||||
|
||||
print(f"## last processed {end:.2f}s",file=sys.stderr,flush=True)
|
||||
print(f"## last processed {end:.2f}s",file=logfile,flush=True)
|
||||
|
||||
beg = end
|
||||
end += min_chunk
|
||||
|
|
@ -580,12 +593,12 @@ if __name__ == "__main__":
|
|||
try:
|
||||
o = online.process_iter()
|
||||
except AssertionError:
|
||||
print("assertion error",file=sys.stderr)
|
||||
print("assertion error",file=logfile)
|
||||
pass
|
||||
else:
|
||||
output_transcript(o)
|
||||
now = time.time() - start
|
||||
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr,flush=True)
|
||||
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
|
||||
|
||||
if end >= duration:
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in a new issue