fixes #261
Co-authored-by: yosagi <11404771+yosagi@users.noreply.github.com>"
This commit is contained in:
parent
0c5365e7c6
commit
416dce7975
2 changed files with 47 additions and 4 deletions
|
|
@ -167,7 +167,10 @@ class PaddedAlignAttWhisper:
|
||||||
self.inference.kv_cache = self.kv_cache
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
|
||||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||||
|
|
||||||
|
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
|
||||||
def remove_hooks(self):
|
def remove_hooks(self):
|
||||||
for hook in self.l_hooks:
|
for hook in self.l_hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
@ -261,6 +264,7 @@ class PaddedAlignAttWhisper:
|
||||||
self.segments = []
|
self.segments = []
|
||||||
self.log_segments += 1
|
self.log_segments += 1
|
||||||
|
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
if self.always_fire: return True
|
if self.always_fire: return True
|
||||||
|
|
@ -562,6 +566,12 @@ class PaddedAlignAttWhisper:
|
||||||
|
|
||||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||||
|
|
||||||
|
# Prepend pending tokens from previous chunk if any
|
||||||
|
if self.pending_incomplete_tokens:
|
||||||
|
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||||
|
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||||
|
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||||
|
|
||||||
if fire_detected or is_last: #or punctuation_stop:
|
if fire_detected or is_last: #or punctuation_stop:
|
||||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
|
|
@ -590,7 +600,14 @@ class PaddedAlignAttWhisper:
|
||||||
|
|
||||||
timestamped_words = []
|
timestamped_words = []
|
||||||
timestamp_idx = 0
|
timestamp_idx = 0
|
||||||
|
replacement_char = "\ufffd"
|
||||||
for word, word_tokens in zip(split_words, split_tokens):
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
# Skip words containing incomplete UTF-8 from client output
|
||||||
|
if replacement_char in word:
|
||||||
|
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
except:
|
except:
|
||||||
|
|
@ -608,5 +625,11 @@ class PaddedAlignAttWhisper:
|
||||||
self.global_time_offset
|
self.global_time_offset
|
||||||
)
|
)
|
||||||
timestamped_words.append(timestamp_entry)
|
timestamped_words.append(timestamp_entry)
|
||||||
|
|
||||||
return timestamped_words
|
# Hold incomplete tokens for next chunk
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
if split_words and replacement_char in split_words[-1]:
|
||||||
|
self.pending_incomplete_tokens = split_tokens[-1]
|
||||||
|
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ class TokenBuffer:
|
||||||
self.prefix_token_ids = prefix_token_ids
|
self.prefix_token_ids = prefix_token_ids
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
def as_token_ids(self, tokenizer=None):
|
||||||
|
|
||||||
|
|
@ -64,7 +65,26 @@ class TokenBuffer:
|
||||||
def append_token_ids(self, token_ids):
|
def append_token_ids(self, token_ids):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
self.text += self.tokenizer.decode(token_ids)
|
|
||||||
|
all_tokens = self.pending_token_ids + token_ids
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(all_tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
if replacement_char in decoded:
|
||||||
|
if len(all_tokens) > 1:
|
||||||
|
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||||
|
|
||||||
|
if replacement_char not in decoded_partial:
|
||||||
|
self.text += decoded_partial
|
||||||
|
self.pending_token_ids = [all_tokens[-1]]
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.text += decoded
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
def as_split_word_tokens(self):
|
def as_split_word_tokens(self):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue