Merge pull request #26 from SilasK/fix-sentencesegmenter
Improve logging stil trying to fix sentence segmenter
This commit is contained in:
commit
c3d72cae7c
2 changed files with 71 additions and 48 deletions
|
|
@ -69,6 +69,7 @@ class HypothesisBuffer:
|
|||
return commit
|
||||
|
||||
def pop_commited(self, time):
|
||||
"Remove (from the beginning) of commited_in_buffer all the words that are finished before `time`"
|
||||
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
|
||||
self.commited_in_buffer.pop(0)
|
||||
|
||||
|
|
@ -110,6 +111,15 @@ class OnlineASRProcessor:
|
|||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
|
||||
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||
if self.buffer_trimming_sec <= 0:
|
||||
raise ValueError("buffer_trimming_sec must be positive")
|
||||
elif self.buffer_trimming_sec > 30:
|
||||
logger.warning(
|
||||
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
||||
)
|
||||
|
||||
def init(self, offset=None):
|
||||
"""run this when starting or restarting processing"""
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
|
|
@ -160,51 +170,52 @@ class OnlineASRProcessor:
|
|||
|
||||
# transform to [(beg,end,"word1"), ...]
|
||||
tsw = self.asr.ts_words(res)
|
||||
|
||||
# insert into HypothesisBuffer
|
||||
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
|
||||
o = self.transcript_buffer.flush()
|
||||
# Completed words
|
||||
self.commited.extend(o)
|
||||
completed = self.to_flush(o)
|
||||
completed = self.concatenate_tsw(o) # This will be returned at the end of the function
|
||||
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
|
||||
the_rest = self.to_flush(self.transcript_buffer.complete())
|
||||
## The rest is incomplete
|
||||
the_rest = self.concatenate_tsw(self.transcript_buffer.complete())
|
||||
logger.debug(f"INCOMPLETE: {the_rest[2]}")
|
||||
|
||||
# there is a newly confirmed text
|
||||
|
||||
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
|
||||
if (
|
||||
len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
|
||||
): # longer than this
|
||||
if self.buffer_trimming_way == "sentence":
|
||||
|
||||
self.chunk_completed_sentence(self.commited)
|
||||
|
||||
|
||||
|
||||
# TODO: new words in `completed` should not be reterned unless they form a sentence
|
||||
# TODO: only complete sentences should go to completed
|
||||
|
||||
|
||||
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec :
|
||||
|
||||
logger.debug("chunking sentence")
|
||||
self.chunk_completed_sentence()
|
||||
|
||||
|
||||
else:
|
||||
logger.debug("not enough audio to trim as a sentence")
|
||||
|
||||
if self.buffer_trimming_way == "segment":
|
||||
s = self.buffer_trimming_sec # trim the completed segments longer than s,
|
||||
else:
|
||||
s = 30 # if the audio buffer is longer than 30s, trim it
|
||||
|
||||
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
||||
if self.buffer_trimming_way == "sentence":
|
||||
logger.warning(f"Chunck segment after {self.buffer_trimming_sec} seconds!"
|
||||
" Even if no sentence was found!"
|
||||
)
|
||||
|
||||
|
||||
self.chunk_completed_segment(res)
|
||||
|
||||
|
||||
|
||||
# alternative: on any word
|
||||
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
||||
# let's find commited word that is less
|
||||
# k = len(self.commited)-1
|
||||
# while k>0 and self.commited[k][1] > l:
|
||||
# k -= 1
|
||||
# t = self.commited[k][1]
|
||||
logger.debug("chunking segment")
|
||||
# self.chunk_at(t)
|
||||
# alternative: on any word
|
||||
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
||||
# let's find commited word that is less
|
||||
# k = len(self.commited)-1
|
||||
# while k>0 and self.commited[k][1] > l:
|
||||
# k -= 1
|
||||
# t = self.commited[k][1]
|
||||
# self.chunk_at(t)
|
||||
|
||||
logger.debug(
|
||||
f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
|
||||
)
|
||||
return self.to_flush(o)
|
||||
|
||||
return completed
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
if self.commited == []:
|
||||
|
|
@ -214,7 +225,6 @@ class OnlineASRProcessor:
|
|||
sents = self.words_to_sentences(self.commited)
|
||||
|
||||
|
||||
|
||||
if len(sents) < 2:
|
||||
logger.debug(f"[Sentence-segmentation] no sentence segmented.")
|
||||
return
|
||||
|
|
@ -228,7 +238,7 @@ class OnlineASRProcessor:
|
|||
# we will continue with audio processing at this timestamp
|
||||
chunk_at = sents[-2][1]
|
||||
|
||||
logger.debug(f"[Sentence-segmentation]: sentence will be chunked at {chunk_at:2.2f}")
|
||||
|
||||
self.chunk_at(chunk_at)
|
||||
|
||||
def chunk_completed_segment(self, res):
|
||||
|
|
@ -239,7 +249,9 @@ class OnlineASRProcessor:
|
|||
|
||||
t = self.commited[-1][1]
|
||||
|
||||
if len(ends) > 1:
|
||||
if len(ends) <= 1:
|
||||
logger.debug(f"--- not enough segments to chunk (<=1 words)")
|
||||
else:
|
||||
|
||||
e = ends[-2] + self.buffer_time_offset
|
||||
while len(ends) > 2 and e > t:
|
||||
|
|
@ -250,16 +262,26 @@ class OnlineASRProcessor:
|
|||
self.chunk_at(e)
|
||||
else:
|
||||
logger.debug(f"--- last segment not within commited area")
|
||||
else:
|
||||
logger.debug(f"--- not enough segments to chunk")
|
||||
|
||||
|
||||
def chunk_at(self, time):
|
||||
"""trims the hypothesis and audio buffer at "time" """
|
||||
logger.debug(f"chunking at {time:2.2f}s")
|
||||
|
||||
logger.debug(
|
||||
f"len of audio buffer before chunking is: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
|
||||
)
|
||||
|
||||
|
||||
self.transcript_buffer.pop_commited(time)
|
||||
cut_seconds = time - self.buffer_time_offset
|
||||
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
|
||||
self.buffer_time_offset = time
|
||||
|
||||
logger.debug(
|
||||
f"len of audio buffer is now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
|
||||
)
|
||||
|
||||
def words_to_sentences(self, words):
|
||||
"""Uses self.tokenize for sentence segmentation of words.
|
||||
Returns: [(beg,end,"sentence 1"),...]
|
||||
|
|
@ -292,14 +314,14 @@ class OnlineASRProcessor:
|
|||
Returns: the same format as self.process_iter()
|
||||
"""
|
||||
o = self.transcript_buffer.complete()
|
||||
f = self.to_flush(o)
|
||||
f = self.concatenate_tsw(o)
|
||||
logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
|
||||
self.buffer_time_offset += len(self.audio_buffer) / 16000
|
||||
return f
|
||||
|
||||
def to_flush(
|
||||
def concatenate_tsw(
|
||||
self,
|
||||
sents,
|
||||
tsw,
|
||||
sep=None,
|
||||
offset=0,
|
||||
):
|
||||
|
|
@ -308,13 +330,14 @@ class OnlineASRProcessor:
|
|||
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
|
||||
if sep is None:
|
||||
sep = self.asr.sep
|
||||
t = sep.join(s[2] for s in sents)
|
||||
if len(sents) == 0:
|
||||
|
||||
t = sep.join(s[2] for s in tsw)
|
||||
if len(tsw) == 0:
|
||||
b = None
|
||||
e = None
|
||||
else:
|
||||
b = offset + sents[0][0]
|
||||
e = offset + sents[-1][1]
|
||||
b = offset + tsw[0][0]
|
||||
e = offset + tsw[-1][1]
|
||||
return (b, e, t)
|
||||
|
||||
|
||||
|
|
@ -415,7 +438,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
|||
ret = self.online.process_iter()
|
||||
return ret
|
||||
else:
|
||||
print("no online update, only VAD", self.status, file=self.logfile)
|
||||
logger.debug("no online update, only VAD")
|
||||
return (None, None, "")
|
||||
|
||||
def finish(self):
|
||||
|
|
|
|||
|
|
@ -154,13 +154,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||
|
||||
full_transcription += trans
|
||||
if args.vac:
|
||||
buffer = online.online.to_flush(
|
||||
buffer = online.online.concatenate_tsw(
|
||||
online.online.transcript_buffer.buffer
|
||||
)[
|
||||
2
|
||||
] # We need to access the underlying online object to get the buffer
|
||||
else:
|
||||
buffer = online.to_flush(online.transcript_buffer.buffer)[2]
|
||||
buffer = online.concatenate_tsw(online.transcript_buffer.buffer)[2]
|
||||
if (
|
||||
buffer in full_transcription
|
||||
): # With VAC, the buffer is not updated until the next chunk is processed
|
||||
|
|
|
|||
Loading…
Reference in a new issue