Upgrade SimulStreaming Whisper core from version 20230918 to 20250625
This commit is contained in:
parent
9dcfb38967
commit
8e056cbdf2
11 changed files with 2085 additions and 641 deletions
2
setup.py
2
setup.py
|
|
@ -1,7 +1,7 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
setup(
|
setup(
|
||||||
name="whisperlivekit",
|
name="whisperlivekit",
|
||||||
version="0.2.2",
|
version="0.2.4.dev0",
|
||||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
@ -15,8 +14,6 @@ from .model import ModelDimensions, Whisper
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
|
|
@ -74,7 +71,6 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
)
|
)
|
||||||
|
|
||||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
logger.info(f'Downloading model weights to {download_target}')
|
|
||||||
with tqdm(
|
with tqdm(
|
||||||
total=int(source.info().get("Content-Length")),
|
total=int(source.info().get("Content-Length")),
|
||||||
ncols=80,
|
ncols=80,
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,9 @@ def detect_language(
|
||||||
list of dictionaries containing the probability distribution over all languages.
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
"""
|
"""
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = get_tokenizer(model.is_multilingual)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
tokenizer.language is None
|
tokenizer.language is None
|
||||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
|
|
@ -111,9 +113,6 @@ class DecodingOptions:
|
||||||
# implementation details
|
# implementation details
|
||||||
fp16: bool = True # use fp16 for most of the calculation
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
# streaming
|
|
||||||
add_sot: Optional[bool] = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DecodingResult:
|
class DecodingResult:
|
||||||
|
|
@ -513,19 +512,17 @@ class DecodingTask:
|
||||||
logit_filters: List[LogitFilter]
|
logit_filters: List[LogitFilter]
|
||||||
|
|
||||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||||
self.options: DecodingOptions = self._verify_options(options)
|
self.model = model
|
||||||
if self.options.fp16:
|
|
||||||
self.model = model.half()
|
|
||||||
else:
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
language = options.language or "en"
|
language = options.language or "en"
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
model.is_multilingual, language=language, task=options.task
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
)
|
)
|
||||||
self.tokenizer: Tokenizer = tokenizer
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
# print(self.options)
|
|
||||||
|
|
||||||
self.n_group: int = options.beam_size or options.best_of or 1
|
self.n_group: int = options.beam_size or options.best_of or 1
|
||||||
self.n_ctx: int = model.dims.n_text_ctx
|
self.n_ctx: int = model.dims.n_text_ctx
|
||||||
|
|
@ -589,7 +586,7 @@ class DecodingTask:
|
||||||
|
|
||||||
def _get_initial_tokens(self) -> Tuple[int]:
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
tokens = list(self.sot_sequence)
|
tokens = list(self.sot_sequence)
|
||||||
# print("prefix", prefix)
|
|
||||||
if prefix := self.options.prefix:
|
if prefix := self.options.prefix:
|
||||||
prefix_tokens = (
|
prefix_tokens = (
|
||||||
self.tokenizer.encode(" " + prefix.strip())
|
self.tokenizer.encode(" " + prefix.strip())
|
||||||
|
|
@ -607,15 +604,12 @@ class DecodingTask:
|
||||||
if isinstance(prompt, str)
|
if isinstance(prompt, str)
|
||||||
else prompt
|
else prompt
|
||||||
)
|
)
|
||||||
# if self.options.add_sot:
|
|
||||||
tokens = (
|
tokens = (
|
||||||
[self.tokenizer.sot_prev]
|
[self.tokenizer.sot_prev]
|
||||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||||
+ tokens
|
+ tokens
|
||||||
)
|
)
|
||||||
#else:
|
|
||||||
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
|
|
||||||
# print("return", tokens)
|
|
||||||
return tuple(tokens)
|
return tuple(tokens)
|
||||||
|
|
||||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||||
|
|
@ -663,7 +657,7 @@ class DecodingTask:
|
||||||
if audio_features.dtype != (
|
if audio_features.dtype != (
|
||||||
torch.float16 if self.options.fp16 else torch.float32
|
torch.float16 if self.options.fp16 else torch.float32
|
||||||
):
|
):
|
||||||
raise TypeError(
|
return TypeError(
|
||||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -689,10 +683,9 @@ class DecodingTask:
|
||||||
no_speech_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in range(self.sample_len): # 最多循环448次
|
for i in range(self.sample_len):
|
||||||
# print("in decode main loop", i , tokens[0].tolist())
|
|
||||||
logits = self.inference.logits(tokens, audio_features)
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
# print(logits)
|
|
||||||
if (
|
if (
|
||||||
i == 0 and self.tokenizer.no_speech is not None
|
i == 0 and self.tokenizer.no_speech is not None
|
||||||
): # save no_speech_probs
|
): # save no_speech_probs
|
||||||
|
|
@ -724,7 +717,7 @@ class DecodingTask:
|
||||||
|
|
||||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||||
# print("initial_tokens", self.initial_tokens)
|
|
||||||
# detect language if requested, overwriting the language token
|
# detect language if requested, overwriting the language token
|
||||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
if self.options.task == "lang_id":
|
if self.options.task == "lang_id":
|
||||||
|
|
|
||||||
|
|
@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
||||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
"""
|
"""
|
||||||
return "".join(
|
return "".join(
|
||||||
c
|
(
|
||||||
if c in keep
|
c
|
||||||
else ADDITIONAL_DIACRITICS[c]
|
if c in keep
|
||||||
if c in ADDITIONAL_DIACRITICS
|
else (
|
||||||
else ""
|
ADDITIONAL_DIACRITICS[c]
|
||||||
if unicodedata.category(c) == "Mn"
|
if c in ADDITIONAL_DIACRITICS
|
||||||
else " "
|
else (
|
||||||
if unicodedata.category(c)[0] in "MSP"
|
""
|
||||||
else c
|
if unicodedata.category(c) == "Mn"
|
||||||
|
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
for c in unicodedata.normalize("NFKD", s)
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
||||||
|
|
||||||
@numba.jit(nopython=True)
|
@numba.jit(nopython=True)
|
||||||
def backtrace(trace: np.ndarray):
|
def backtrace(trace: np.ndarray):
|
||||||
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
|
i = trace.shape[0] - 1
|
||||||
j = trace.shape[1] - 1 # j=M
|
j = trace.shape[1] - 1
|
||||||
# 边界点其实无意义?
|
|
||||||
trace[0, :] = 2
|
trace[0, :] = 2
|
||||||
trace[:, 0] = 1
|
trace[:, 0] = 1
|
||||||
|
|
||||||
|
|
@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
|
||||||
@numba.jit(nopython=True, parallel=True)
|
@numba.jit(nopython=True, parallel=True)
|
||||||
def dtw_cpu(x: np.ndarray):
|
def dtw_cpu(x: np.ndarray):
|
||||||
N, M = x.shape
|
N, M = x.shape
|
||||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
|
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
|
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||||
|
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
for j in range(1, M + 1):
|
for j in range(1, M + 1):
|
||||||
|
|
@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||||
x_skew = x_skew.T.contiguous()
|
x_skew = x_skew.T.contiguous()
|
||||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
cost = cost.cuda()
|
cost = cost.to(x.device)
|
||||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||||
|
|
||||||
dtw_kernel[(1,)](
|
dtw_kernel[(1,)](
|
||||||
|
|
@ -192,21 +191,19 @@ def find_alignment(
|
||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 进行前传,获得token概率
|
from .model import disable_sdpa
|
||||||
with torch.no_grad():
|
|
||||||
|
with torch.no_grad(), disable_sdpa():
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||||
text_token_probs = text_token_probs.tolist()
|
text_token_probs = text_token_probs.tolist()
|
||||||
|
|
||||||
# 移除钩子
|
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
# print(model.alignment_heads)
|
|
||||||
# exit(0)
|
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = (weights * qk_scale).softmax(dim=-1)
|
weights = (weights * qk_scale).softmax(dim=-1)
|
||||||
|
|
@ -215,18 +212,9 @@ def find_alignment(
|
||||||
weights = median_filter(weights, medfilt_width)
|
weights = median_filter(weights, medfilt_width)
|
||||||
|
|
||||||
matrix = weights.mean(axis=0)
|
matrix = weights.mean(axis=0)
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
text_indices, time_indices = dtw(-matrix)
|
text_indices, time_indices = dtw(-matrix)
|
||||||
|
|
||||||
print("num_frames", num_frames)
|
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
print("text_indices", text_indices)
|
|
||||||
print("time", time_indices)
|
|
||||||
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
|
|
||||||
print("eot", tokenizer.eot)
|
|
||||||
|
|
||||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||||
if len(word_tokens) <= 1:
|
if len(word_tokens) <= 1:
|
||||||
# return on eot only
|
# return on eot only
|
||||||
|
|
@ -238,9 +226,7 @@ def find_alignment(
|
||||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||||
|
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
# print("jumps", jumps, jumps.shape)
|
|
||||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||||
# print("jump_times", jump_times)
|
|
||||||
start_times = jump_times[word_boundaries[:-1]]
|
start_times = jump_times[word_boundaries[:-1]]
|
||||||
end_times = jump_times[word_boundaries[1:]]
|
end_times = jump_times[word_boundaries[1:]]
|
||||||
word_probabilities = [
|
word_probabilities = [
|
||||||
|
|
@ -315,6 +301,7 @@ def add_word_timestamps(
|
||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
|
median_duration = min(0.7, float(median_duration))
|
||||||
max_duration = median_duration * 2
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
|
|
||||||
|
|
@ -1,501 +0,0 @@
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from whisper.audio import (
|
|
||||||
FRAMES_PER_SECOND,
|
|
||||||
HOP_LENGTH,
|
|
||||||
N_FRAMES,
|
|
||||||
N_SAMPLES,
|
|
||||||
SAMPLE_RATE,
|
|
||||||
log_mel_spectrogram,
|
|
||||||
pad_or_trim,
|
|
||||||
)
|
|
||||||
from whisper.decoding import DecodingOptions, DecodingResult
|
|
||||||
from whisper.timing import add_word_timestamps
|
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|
||||||
from whisper.utils import (
|
|
||||||
exact_div,
|
|
||||||
format_timestamp,
|
|
||||||
get_writer,
|
|
||||||
make_safe,
|
|
||||||
optional_float,
|
|
||||||
optional_int,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from whisper.model import Whisper
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
|
||||||
model: "Whisper",
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
|
||||||
*,
|
|
||||||
verbose: Optional[bool] = None,
|
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
|
||||||
condition_on_previous_text: bool = True,
|
|
||||||
initial_prompt: Optional[str] = None,
|
|
||||||
word_timestamps: bool = False,
|
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
|
||||||
**decode_options,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribe an audio file using Whisper
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model: Whisper
|
|
||||||
The Whisper model instance
|
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
|
||||||
The path to the audio file to open, or the audio waveform
|
|
||||||
|
|
||||||
verbose: bool
|
|
||||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
|
||||||
If False, displays minimal details. If None, does not display anything
|
|
||||||
|
|
||||||
temperature: Union[float, Tuple[float, ...]]
|
|
||||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
|
||||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
||||||
|
|
||||||
compression_ratio_threshold: float
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
logprob_threshold: float
|
|
||||||
If the average log probability over sampled tokens is below this value, treat as failed
|
|
||||||
|
|
||||||
no_speech_threshold: float
|
|
||||||
If the no_speech probability is higher than this value AND the average log probability
|
|
||||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
||||||
|
|
||||||
condition_on_previous_text: bool
|
|
||||||
if True, the previous output of the model is provided as a prompt for the next window;
|
|
||||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
||||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
||||||
|
|
||||||
word_timestamps: bool
|
|
||||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
|
||||||
and include the timestamps for each word in each segment.
|
|
||||||
|
|
||||||
prepend_punctuations: str
|
|
||||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
|
||||||
|
|
||||||
append_punctuations: str
|
|
||||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
|
||||||
|
|
||||||
initial_prompt: Optional[str]
|
|
||||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
|
||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
|
||||||
to make it more likely to predict those word correctly.
|
|
||||||
|
|
||||||
decode_options: dict
|
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
||||||
"""
|
|
||||||
# print("HACKED")
|
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
|
||||||
if model.device == torch.device("cpu"):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
|
||||||
if dtype == torch.float16:
|
|
||||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
decode_options["fp16"] = False
|
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
|
||||||
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
|
|
||||||
# mel = pad_or_trim(mel, 3000)
|
|
||||||
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
|
|
||||||
|
|
||||||
# 判断语种
|
|
||||||
if decode_options.get("language", None) is None:
|
|
||||||
# 如果是单语种模型,直接设成英文
|
|
||||||
if not model.is_multilingual:
|
|
||||||
decode_options["language"] = "en"
|
|
||||||
# 否则需要前传一次
|
|
||||||
else:
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
||||||
)
|
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
|
||||||
# print(mel_segment.shape)
|
|
||||||
_, probs = model.detect_language(mel_segment)
|
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
|
||||||
if verbose is not None:
|
|
||||||
print(
|
|
||||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
language: str = decode_options["language"]
|
|
||||||
task: str = decode_options.get("task", "transcribe")
|
|
||||||
# 输出编码器
|
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
|
||||||
|
|
||||||
# 词级别时间戳
|
|
||||||
if word_timestamps and task == "translate":
|
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
|
||||||
|
|
||||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
|
||||||
temperatures = (
|
|
||||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
|
||||||
)
|
|
||||||
decode_result = None
|
|
||||||
|
|
||||||
for t in temperatures:
|
|
||||||
kwargs = {**decode_options}
|
|
||||||
if t > 0:
|
|
||||||
# disable beam_size and patience when t > 0
|
|
||||||
kwargs.pop("beam_size", None)
|
|
||||||
kwargs.pop("patience", None)
|
|
||||||
else:
|
|
||||||
# disable best_of when t == 0
|
|
||||||
kwargs.pop("best_of", None)
|
|
||||||
|
|
||||||
options = DecodingOptions(**kwargs, temperature=t)
|
|
||||||
decode_result = model.decode(segment, options)
|
|
||||||
|
|
||||||
# 几种解码可能失败的情况。这些情况下会重复解码
|
|
||||||
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
|
|
||||||
needs_fallback = False
|
|
||||||
if (
|
|
||||||
compression_ratio_threshold is not None
|
|
||||||
and decode_result.compression_ratio > compression_ratio_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # too repetitive
|
|
||||||
if (
|
|
||||||
logprob_threshold is not None
|
|
||||||
and decode_result.avg_logprob < logprob_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # average log probability is too low
|
|
||||||
if (
|
|
||||||
no_speech_threshold is not None
|
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = False # silence
|
|
||||||
if not needs_fallback:
|
|
||||||
break
|
|
||||||
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
|
|
||||||
# t,
|
|
||||||
# decode_result.compression_ratio, compression_ratio_threshold,
|
|
||||||
# -decode_result.avg_logprob, -logprob_threshold,
|
|
||||||
# decode_result.no_speech_prob, no_speech_threshold
|
|
||||||
# ))
|
|
||||||
|
|
||||||
return decode_result
|
|
||||||
|
|
||||||
seek = 0
|
|
||||||
input_stride = exact_div(
|
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
|
||||||
) # mel frames per output token: 2
|
|
||||||
# 这里output token指的应该是CNN输出的那个东西
|
|
||||||
|
|
||||||
time_precision = (
|
|
||||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
|
||||||
) # time per output token: 0.02 (seconds)
|
|
||||||
all_tokens = []
|
|
||||||
all_segments = []
|
|
||||||
prompt_reset_since = 0
|
|
||||||
|
|
||||||
if initial_prompt is not None:
|
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
|
||||||
else:
|
|
||||||
initial_prompt_tokens = []
|
|
||||||
|
|
||||||
def new_segment(
|
|
||||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
|
||||||
):
|
|
||||||
tokens = tokens.tolist()
|
|
||||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
|
||||||
return {
|
|
||||||
"seek": seek,
|
|
||||||
"start": start,
|
|
||||||
"end": end,
|
|
||||||
"text": tokenizer.decode(text_tokens),
|
|
||||||
"tokens": tokens,
|
|
||||||
"temperature": result.temperature,
|
|
||||||
"avg_logprob": result.avg_logprob,
|
|
||||||
"compression_ratio": result.compression_ratio,
|
|
||||||
"no_speech_prob": result.no_speech_prob,
|
|
||||||
}
|
|
||||||
|
|
||||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
|
||||||
with tqdm.tqdm(
|
|
||||||
total=content_frames, unit="frames", disable=verbose is not False
|
|
||||||
) as pbar:
|
|
||||||
last_speech_timestamp = 0.0
|
|
||||||
while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
|
|
||||||
# print("seek segments", seek, content_frames)
|
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
|
|
||||||
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
|
|
||||||
mel_segment = mel[:, seek:]
|
|
||||||
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
|
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
|
|
||||||
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
|
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
|
||||||
tokens = torch.tensor(result.tokens)
|
|
||||||
|
|
||||||
# 跳过静音部分
|
|
||||||
if no_speech_threshold is not None:
|
|
||||||
# no voice activity check
|
|
||||||
should_skip = result.no_speech_prob > no_speech_threshold
|
|
||||||
if (
|
|
||||||
logprob_threshold is not None
|
|
||||||
and result.avg_logprob > logprob_threshold
|
|
||||||
):
|
|
||||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
|
||||||
should_skip = False
|
|
||||||
|
|
||||||
if should_skip:
|
|
||||||
seek += segment_size # fast-forward to the next segment boundary
|
|
||||||
continue
|
|
||||||
|
|
||||||
previous_seek = seek
|
|
||||||
current_segments = []
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
|
|
||||||
timestamp_tokens[-1] = False
|
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
|
|
||||||
|
|
||||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
||||||
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
|
|
||||||
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
|
|
||||||
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
|
|
||||||
# 多个的话指向第二个 那如果有三个怎么办?
|
|
||||||
# 否则是个0维tensor
|
|
||||||
|
|
||||||
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
|
|
||||||
if len(consecutive) > 0:
|
|
||||||
# if the output contains two consecutive timestamp tokens
|
|
||||||
slices = consecutive.tolist()
|
|
||||||
if single_timestamp_ending:
|
|
||||||
slices.append(len(tokens)) # 把最后一段的结尾也加进去
|
|
||||||
# print("many sentenses", consecutive)
|
|
||||||
last_slice = 0
|
|
||||||
for current_slice in slices:
|
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
|
||||||
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
|
|
||||||
start_timestamp_pos = (
|
|
||||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
end_timestamp_pos = (
|
|
||||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
# 获取一个新的语音段
|
|
||||||
current_segments.append(
|
|
||||||
new_segment(
|
|
||||||
start=time_offset + start_timestamp_pos * time_precision,
|
|
||||||
end=time_offset + end_timestamp_pos * time_precision,
|
|
||||||
tokens=sliced_tokens,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
last_slice = current_slice
|
|
||||||
|
|
||||||
if single_timestamp_ending:
|
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
|
||||||
seek += segment_size
|
|
||||||
else:
|
|
||||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
|
||||||
# 如果语音尚未结束,那么seek变为上一个结束的语段的位置
|
|
||||||
# 换句话说就是针对30s长的chunk的语音设计的
|
|
||||||
last_timestamp_pos = (
|
|
||||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
seek += last_timestamp_pos * input_stride
|
|
||||||
else:
|
|
||||||
duration = segment_duration
|
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
||||||
# print(timestamps)
|
|
||||||
if (
|
|
||||||
len(timestamps) > 0
|
|
||||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
|
||||||
):
|
|
||||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
|
||||||
# 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
|
|
||||||
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
|
|
||||||
last_timestamp_pos = (
|
|
||||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
duration = last_timestamp_pos * time_precision
|
|
||||||
|
|
||||||
current_segments.append(
|
|
||||||
new_segment(
|
|
||||||
start=time_offset,
|
|
||||||
end=time_offset + duration,
|
|
||||||
tokens=tokens,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seek += segment_size
|
|
||||||
|
|
||||||
# 每个token有自己的时间戳
|
|
||||||
if word_timestamps:
|
|
||||||
add_word_timestamps(
|
|
||||||
segments=current_segments,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
mel=mel_segment,
|
|
||||||
num_frames=segment_size,
|
|
||||||
prepend_punctuations=prepend_punctuations,
|
|
||||||
append_punctuations=append_punctuations,
|
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
|
||||||
)
|
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
|
||||||
]
|
|
||||||
if len(word_end_timestamps) > 0:
|
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
|
||||||
)
|
|
||||||
if seek_shift > 0:
|
|
||||||
seek = previous_seek + seek_shift
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
for segment in current_segments:
|
|
||||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
|
||||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
|
||||||
print(make_safe(line))
|
|
||||||
|
|
||||||
# if a segment is instantaneous or does not contain text, clear it
|
|
||||||
for i, segment in enumerate(current_segments):
|
|
||||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
|
||||||
segment["text"] = ""
|
|
||||||
segment["tokens"] = []
|
|
||||||
segment["words"] = []
|
|
||||||
|
|
||||||
# 更新结果
|
|
||||||
all_segments.extend(
|
|
||||||
[
|
|
||||||
{"id": i, **segment}
|
|
||||||
for i, segment in enumerate(
|
|
||||||
current_segments, start=len(all_segments)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
all_tokens.extend(
|
|
||||||
[token for segment in current_segments for token in segment["tokens"]]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not condition_on_previous_text or result.temperature > 0.5:
|
|
||||||
# do not feed the prompt tokens if a high temperature was used
|
|
||||||
prompt_reset_since = len(all_tokens)
|
|
||||||
|
|
||||||
# update progress bar
|
|
||||||
pbar.update(min(content_frames, seek) - previous_seek)
|
|
||||||
|
|
||||||
# print("太长了")
|
|
||||||
# break
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
|
||||||
segments=all_segments,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
|
||||||
from . import available_models
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
|
||||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
|
||||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
||||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
|
||||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
|
||||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
|
||||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
|
||||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
|
||||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
|
||||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
|
||||||
model_name: str = args.pop("model")
|
|
||||||
model_dir: str = args.pop("model_dir")
|
|
||||||
output_dir: str = args.pop("output_dir")
|
|
||||||
output_format: str = args.pop("output_format")
|
|
||||||
device: str = args.pop("device")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
|
||||||
if args["language"] is not None:
|
|
||||||
warnings.warn(
|
|
||||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
|
||||||
)
|
|
||||||
args["language"] = "en"
|
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
|
||||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
|
||||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
|
||||||
else:
|
|
||||||
temperature = [temperature]
|
|
||||||
|
|
||||||
if (threads := args.pop("threads")) > 0:
|
|
||||||
torch.set_num_threads(threads)
|
|
||||||
|
|
||||||
from . import load_model
|
|
||||||
|
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
|
||||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
|
||||||
if not args["word_timestamps"]:
|
|
||||||
for option in word_options:
|
|
||||||
if args[option]:
|
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
|
||||||
for audio_path in args.pop("audio"):
|
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
||||||
writer(result, audio_path, writer_args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (
|
from .utils import (
|
||||||
exact_div,
|
exact_div,
|
||||||
format_timestamp,
|
format_timestamp,
|
||||||
|
get_end,
|
||||||
get_writer,
|
get_writer,
|
||||||
make_safe,
|
make_safe,
|
||||||
optional_float,
|
optional_float,
|
||||||
|
|
@ -44,9 +46,12 @@ def transcribe(
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
|
carry_initial_prompt: bool = False,
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -98,15 +103,27 @@ def transcribe(
|
||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
to make it more likely to predict those word correctly.
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
|
carry_initial_prompt: bool
|
||||||
|
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||||
|
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||||
|
left-sliced to make space.
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
clip_timestamps: Union[str, List[float]]
|
||||||
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||||
|
The last end timestamp defaults to the end of the file.
|
||||||
|
|
||||||
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
|
when a possible hallucination is detected
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
"""
|
"""
|
||||||
# print("transcribe")
|
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||||
if model.device == torch.device("cpu"):
|
if model.device == torch.device("cpu"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
@ -119,8 +136,9 @@ def transcribe(
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
|
|
@ -131,7 +149,6 @@ def transcribe(
|
||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||||
)
|
)
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
# print(mel_segment.shape)
|
|
||||||
_, probs = model.detect_language(mel_segment)
|
_, probs = model.detect_language(mel_segment)
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
decode_options["language"] = max(probs, key=probs.get)
|
||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
|
|
@ -141,7 +158,25 @@ def transcribe(
|
||||||
|
|
||||||
language: str = decode_options["language"]
|
language: str = decode_options["language"]
|
||||||
task: str = decode_options.get("task", "transcribe")
|
task: str = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(clip_timestamps, str):
|
||||||
|
clip_timestamps = [
|
||||||
|
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||||
|
]
|
||||||
|
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||||
|
if len(seek_points) == 0:
|
||||||
|
seek_points.append(0)
|
||||||
|
if len(seek_points) % 2 == 1:
|
||||||
|
seek_points.append(content_frames)
|
||||||
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|
@ -179,6 +214,8 @@ def transcribe(
|
||||||
if (
|
if (
|
||||||
no_speech_threshold is not None
|
no_speech_threshold is not None
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
and decode_result.no_speech_prob > no_speech_threshold
|
||||||
|
and logprob_threshold is not None
|
||||||
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = False # silence
|
needs_fallback = False # silence
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
|
|
@ -186,7 +223,8 @@ def transcribe(
|
||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
seek = 0
|
clip_idx = 0
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
|
|
@ -197,9 +235,11 @@ def transcribe(
|
||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||||
else:
|
else:
|
||||||
initial_prompt_tokens = []
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
|
|
@ -225,16 +265,33 @@ def transcribe(
|
||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
while seek < content_frames:
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
|
# A later commit should turn this into a simpler nested loop.
|
||||||
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||||
|
# while seek < seek_clip_end
|
||||||
|
while clip_idx < len(seek_clips):
|
||||||
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||||
|
if seek < seek_clip_start:
|
||||||
|
seek = seek_clip_start
|
||||||
|
if seek >= seek_clip_end:
|
||||||
|
clip_idx += 1
|
||||||
|
if clip_idx < len(seek_clips):
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
continue
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment_size = min(N_FRAMES, content_frames - seek)
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||||
|
mel_segment = mel[:, seek : seek + segment_size]
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
# print("melshape", mel_segment.shape)
|
if carry_initial_prompt:
|
||||||
|
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||||
|
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||||
|
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||||
|
else:
|
||||||
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
|
|
@ -255,6 +312,30 @@ def transcribe(
|
||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
|
# anomalous words are very long/short/improbable
|
||||||
|
def word_anomaly_score(word: dict) -> float:
|
||||||
|
probability = word.get("probability", 0.0)
|
||||||
|
duration = word["end"] - word["start"]
|
||||||
|
score = 0.0
|
||||||
|
if probability < 0.15:
|
||||||
|
score += 1.0
|
||||||
|
if duration < 0.133:
|
||||||
|
score += (0.133 - duration) * 15
|
||||||
|
if duration > 2.0:
|
||||||
|
score += duration - 2.0
|
||||||
|
return score
|
||||||
|
|
||||||
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||||
|
if segment is None or not segment["words"]:
|
||||||
|
return False
|
||||||
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||||
|
words = words[:8]
|
||||||
|
score = sum(word_anomaly_score(w) for w in words)
|
||||||
|
return score >= 3 or score + 0.01 >= len(words)
|
||||||
|
|
||||||
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||||
|
return next((s for s in segments if s["words"]), None)
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
|
|
@ -317,9 +398,7 @@ def transcribe(
|
||||||
)
|
)
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
# print("word_timestamps, ", word_timestamps)
|
|
||||||
if word_timestamps:
|
if word_timestamps:
|
||||||
# print("=========run timestamps here=========")
|
|
||||||
add_word_timestamps(
|
add_word_timestamps(
|
||||||
segments=current_segments,
|
segments=current_segments,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -330,17 +409,71 @@ def transcribe(
|
||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
)
|
)
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
if not single_timestamp_ending:
|
||||||
]
|
last_word_end = get_end(current_segments)
|
||||||
if len(word_end_timestamps) > 0:
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
# skip silence before possible hallucinations
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
if hallucination_silence_threshold is not None:
|
||||||
)
|
threshold = hallucination_silence_threshold
|
||||||
if seek_shift > 0:
|
if not single_timestamp_ending:
|
||||||
seek = previous_seek + seek_shift
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
remaining_duration = window_end_time - last_word_end
|
||||||
|
if remaining_duration > threshold:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
else:
|
||||||
|
seek = previous_seek + segment_size
|
||||||
|
|
||||||
|
# if first segment might be a hallucination, skip leading silence
|
||||||
|
first_segment = next_words_segment(current_segments)
|
||||||
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||||
|
gap = first_segment["start"] - time_offset
|
||||||
|
if gap > threshold:
|
||||||
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip silence before any possible hallucination that is surrounded
|
||||||
|
# by silence or more hallucinations
|
||||||
|
hal_last_end = last_speech_timestamp
|
||||||
|
for si in range(len(current_segments)):
|
||||||
|
segment = current_segments[si]
|
||||||
|
if not segment["words"]:
|
||||||
|
continue
|
||||||
|
if is_segment_anomaly(segment):
|
||||||
|
next_segment = next_words_segment(
|
||||||
|
current_segments[si + 1 :]
|
||||||
|
)
|
||||||
|
if next_segment is not None:
|
||||||
|
hal_next_start = next_segment["words"][0]["start"]
|
||||||
|
else:
|
||||||
|
hal_next_start = time_offset + segment_duration
|
||||||
|
silence_before = (
|
||||||
|
segment["start"] - hal_last_end > threshold
|
||||||
|
or segment["start"] < threshold
|
||||||
|
or segment["start"] - time_offset < 2.0
|
||||||
|
)
|
||||||
|
silence_after = (
|
||||||
|
hal_next_start - segment["end"] > threshold
|
||||||
|
or is_segment_anomaly(next_segment)
|
||||||
|
or window_end_time - segment["end"] < 2.0
|
||||||
|
)
|
||||||
|
if silence_before and silence_after:
|
||||||
|
seek = round(
|
||||||
|
max(time_offset + 1, segment["start"])
|
||||||
|
* FRAMES_PER_SECOND
|
||||||
|
)
|
||||||
|
if content_duration - segment["end"] < threshold:
|
||||||
|
seek = content_frames
|
||||||
|
current_segments[si:] = []
|
||||||
|
break
|
||||||
|
hal_last_end = segment["end"]
|
||||||
|
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None:
|
||||||
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
|
|
@ -384,10 +517,17 @@ def transcribe(
|
||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
|
def valid_model_name(name):
|
||||||
|
if name in available_models() or os.path.exists(name):
|
||||||
|
return name
|
||||||
|
raise ValueError(
|
||||||
|
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
|
|
@ -405,6 +545,8 @@ def cli():
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||||
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
|
|
@ -418,7 +560,10 @@ def cli():
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||||
|
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
|
@ -450,17 +595,28 @@ def cli():
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
word_options = [
|
||||||
|
"highlight_words",
|
||||||
|
"max_line_count",
|
||||||
|
"max_line_width",
|
||||||
|
"max_words_per_line",
|
||||||
|
]
|
||||||
if not args["word_timestamps"]:
|
if not args["word_timestamps"]:
|
||||||
for option in word_options:
|
for option in word_options:
|
||||||
if args[option]:
|
if args[option]:
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
parser.error(f"--{option} requires --word_timestamps True")
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
if args["max_line_count"] and not args["max_line_width"]:
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||||
|
if args["max_words_per_line"] and args["max_line_width"]:
|
||||||
|
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
try:
|
||||||
writer(result, audio_path, writer_args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
writer(result, audio_path, **writer_args)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
||||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
kernel = triton.JITFunction(kernel.fn)
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
kernel.src = kernel.src.replace(
|
new_kernel = kernel.src.replace(
|
||||||
" LOAD_ALL_ROWS_HERE",
|
" LOAD_ALL_ROWS_HERE",
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
|
|
@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace(
|
|
||||||
|
new_kernel = new_kernel.replace(
|
||||||
" BUBBLESORT_HERE",
|
" BUBBLESORT_HERE",
|
||||||
"\n\n".join(
|
"\n\n".join(
|
||||||
[
|
[
|
||||||
|
|
@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
|
||||||
|
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
|
|
||||||
|
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||||
|
kernel._unsafe_update_src(new_kernel)
|
||||||
|
kernel.hash = None
|
||||||
|
else:
|
||||||
|
kernel.src = new_kernel
|
||||||
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
|
|
@ -68,13 +68,29 @@ def format_timestamp(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
|
segments[0]["start"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
|
segments[-1]["end"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
def __init__(self, output_dir: str):
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
def __call__(
|
||||||
|
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
audio_basename = os.path.splitext(audio_basename)[0]
|
||||||
output_path = os.path.join(
|
output_path = os.path.join(
|
||||||
|
|
@ -82,16 +98,20 @@ class ResultWriter:
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options)
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class WriteTXT(ResultWriter):
|
class WriteTXT(ResultWriter):
|
||||||
extension: str = "txt"
|
extension: str = "txt"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
|
|
@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
|
||||||
always_include_hours: bool
|
always_include_hours: bool
|
||||||
decimal_marker: str
|
decimal_marker: str
|
||||||
|
|
||||||
def iterate_result(self, result: dict, options: dict):
|
def iterate_result(
|
||||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
self,
|
||||||
max_line_count: Optional[int] = options["max_line_count"]
|
result: dict,
|
||||||
highlight_words: bool = options["highlight_words"]
|
options: Optional[dict] = None,
|
||||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
*,
|
||||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
max_line_width: Optional[int] = None,
|
||||||
|
max_line_count: Optional[int] = None,
|
||||||
|
highlight_words: bool = False,
|
||||||
|
max_words_per_line: Optional[int] = None,
|
||||||
|
):
|
||||||
|
options = options or {}
|
||||||
|
max_line_width = max_line_width or options.get("max_line_width")
|
||||||
|
max_line_count = max_line_count or options.get("max_line_count")
|
||||||
|
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||||
|
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||||
|
preserve_segments = max_line_count is None or max_line_width is None
|
||||||
|
max_line_width = max_line_width or 1000
|
||||||
|
max_words_per_line = max_words_per_line or 1000
|
||||||
|
|
||||||
def iterate_subtitles():
|
def iterate_subtitles():
|
||||||
line_len = 0
|
line_len = 0
|
||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: List[dict] = []
|
||||||
last = result["segments"][0]["words"][0]["start"]
|
last: float = get_start(result["segments"]) or 0.0
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
chunk_index = 0
|
||||||
timing = original_timing.copy()
|
words_count = max_words_per_line
|
||||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
while chunk_index < len(segment["words"]):
|
||||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
remaining_words = len(segment["words"]) - chunk_index
|
||||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
words_count = remaining_words
|
||||||
# line continuation
|
for i, original_timing in enumerate(
|
||||||
line_len += len(timing["word"])
|
segment["words"][chunk_index : chunk_index + words_count]
|
||||||
else:
|
):
|
||||||
# new line
|
timing = original_timing.copy()
|
||||||
timing["word"] = timing["word"].strip()
|
long_pause = (
|
||||||
|
not preserve_segments and timing["start"] - last > 3.0
|
||||||
|
)
|
||||||
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
if (
|
if (
|
||||||
len(subtitle) > 0
|
line_len > 0
|
||||||
and max_line_count is not None
|
and has_room
|
||||||
and (long_pause or line_count >= max_line_count)
|
and not long_pause
|
||||||
or seg_break
|
and not seg_break
|
||||||
):
|
):
|
||||||
# subtitle break
|
# line continuation
|
||||||
yield subtitle
|
line_len += len(timing["word"])
|
||||||
subtitle = []
|
else:
|
||||||
line_count = 1
|
# new line
|
||||||
elif line_len > 0:
|
timing["word"] = timing["word"].strip()
|
||||||
# line break
|
if (
|
||||||
line_count += 1
|
len(subtitle) > 0
|
||||||
timing["word"] = "\n" + timing["word"]
|
and max_line_count is not None
|
||||||
line_len = len(timing["word"].strip())
|
and (long_pause or line_count >= max_line_count)
|
||||||
subtitle.append(timing)
|
or seg_break
|
||||||
last = timing["start"]
|
):
|
||||||
|
# subtitle break
|
||||||
|
yield subtitle
|
||||||
|
subtitle = []
|
||||||
|
line_count = 1
|
||||||
|
elif line_len > 0:
|
||||||
|
# line break
|
||||||
|
line_count += 1
|
||||||
|
timing["word"] = "\n" + timing["word"]
|
||||||
|
line_len = len(timing["word"].strip())
|
||||||
|
subtitle.append(timing)
|
||||||
|
last = timing["start"]
|
||||||
|
chunk_index += max_words_per_line
|
||||||
if len(subtitle) > 0:
|
if len(subtitle) > 0:
|
||||||
yield subtitle
|
yield subtitle
|
||||||
|
|
||||||
|
|
@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
|
||||||
|
|
||||||
yield start, end, "".join(
|
yield start, end, "".join(
|
||||||
[
|
[
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
(
|
||||||
if j == i
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
else word
|
if j == i
|
||||||
|
else word
|
||||||
|
)
|
||||||
for j, word in enumerate(all_words)
|
for j, word in enumerate(all_words)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
|
||||||
always_include_hours: bool = False
|
always_include_hours: bool = False
|
||||||
decimal_marker: str = "."
|
decimal_marker: str = "."
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
print("WEBVTT\n", file=file)
|
print("WEBVTT\n", file=file)
|
||||||
for start, end, text in self.iterate_result(result, options):
|
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
|
||||||
always_include_hours: bool = True
|
always_include_hours: bool = True
|
||||||
decimal_marker: str = ","
|
decimal_marker: str = ","
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for i, (start, end, text) in enumerate(
|
for i, (start, end, text) in enumerate(
|
||||||
self.iterate_result(result, options), start=1
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
):
|
):
|
||||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
|
||||||
|
|
||||||
extension: str = "tsv"
|
extension: str = "tsv"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
print("start", "end", "text", sep="\t", file=file)
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
|
|
@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
|
||||||
class WriteJSON(ResultWriter):
|
class WriteJSON(ResultWriter):
|
||||||
extension: str = "json"
|
extension: str = "json"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
json.dump(result, file)
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -249,9 +307,11 @@ def get_writer(
|
||||||
if output_format == "all":
|
if output_format == "all":
|
||||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
def write_all(result: dict, file: TextIO, options: dict):
|
def write_all(
|
||||||
|
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for writer in all_writers:
|
for writer in all_writers:
|
||||||
writer(result, file, options)
|
writer(result, file, options, **kwargs)
|
||||||
|
|
||||||
return write_all
|
return write_all
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__version__ = "20230918"
|
__version__ = "20250625"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue