coreml conversion
This commit is contained in:
parent
1bbbb7903c
commit
4d2ffb24f8
1 changed files with 49 additions and 5 deletions
|
|
@ -11,11 +11,11 @@ from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
from .model import ModelDimensions, Whisper
|
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||||
from .transcribe import transcribe
|
from whisperlivekit.whisper.transcribe import transcribe
|
||||||
from .version import __version__
|
from whisperlivekit.whisper.version import __version__
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
|
@ -417,3 +417,47 @@ def load_model(
|
||||||
model.set_alignment_heads(alignment_heads)
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_encoder_to_coreml(
|
||||||
|
model_name = "base",
|
||||||
|
output_path= "whisper_encoder.mlpackage",
|
||||||
|
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
|
||||||
|
precision = "float16",
|
||||||
|
):
|
||||||
|
|
||||||
|
import coremltools as ct
|
||||||
|
model = load_model(model_name, device="cpu", decoder_only=False)
|
||||||
|
encoder = model.encoder.eval().cpu()
|
||||||
|
|
||||||
|
dummy_input = torch.randn(
|
||||||
|
1,
|
||||||
|
model.dims.n_mels,
|
||||||
|
dummy_frames,
|
||||||
|
dtype=next(encoder.parameters()).dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
traced_encoder = torch.jit.trace(encoder, dummy_input)
|
||||||
|
|
||||||
|
precision_map = {
|
||||||
|
"float16": ct.precision.FLOAT16,
|
||||||
|
"fp16": ct.precision.FLOAT16,
|
||||||
|
"float32": ct.precision.FLOAT32,
|
||||||
|
"fp32": ct.precision.FLOAT32,
|
||||||
|
}
|
||||||
|
coreml_precision = precision_map[precision.lower()]
|
||||||
|
|
||||||
|
mlmodel = ct.convert(
|
||||||
|
traced_encoder,
|
||||||
|
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
|
||||||
|
convert_to= "mlprogram",
|
||||||
|
compute_precision=coreml_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = Path(output_path)
|
||||||
|
mlmodel.save(str(output_path))
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")
|
||||||
Loading…
Reference in a new issue