fix: handle numpy object_ dtype from ctranslate2 encoder (#337)
This commit is contained in:
parent
4c7706e2cf
commit
b8d9d7d289
1 changed files with 6 additions and 6 deletions
|
|
@ -280,13 +280,13 @@ class AlignAtt(AlignAttBase):
|
|||
if self.device == 'cpu':
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
try:
|
||||
encoder_feature = torch.as_tensor(
|
||||
encoder_feature_ctranslate, device=self.device,
|
||||
)
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError:
|
||||
encoder_feature = torch.as_tensor(
|
||||
np.array(encoder_feature_ctranslate), device=self.device,
|
||||
)
|
||||
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
|
||||
arr = np.array(encoder_feature_ctranslate)
|
||||
if arr.dtype == np.object_:
|
||||
arr = np.array(arr.tolist(), dtype=np.float32)
|
||||
encoder_feature = torch.as_tensor(arr, device=self.device)
|
||||
else:
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments, n_mels=self.model.dims.n_mels,
|
||||
|
|
|
|||
Loading…
Reference in a new issue