condition on encoder_feature_ctranslate type
This commit is contained in:
parent
c7b3bb5e58
commit
1f7798c7c1
1 changed files with 2 additions and 0 deletions
|
|
@ -409,6 +409,8 @@ class PaddedAlignAttWhisper:
|
|||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||
if type(encoder_feature_ctranslate).__module__ == 'ctranslate2._ext':
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
else:
|
||||
# mel + padding to 30s
|
||||
|
|
|
|||
Loading…
Reference in a new issue