82 lines
No EOL
3.2 KiB
Python
82 lines
No EOL
3.2 KiB
Python
"""The 1D discrete wavelet transform for PyTorch."""
|
|
|
|
from einops import rearrange
|
|
import pywt
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from typing import Literal
|
|
|
|
|
|
def get_filter_bank(wavelet):
|
|
filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
|
|
if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
|
|
filt = filt[:, 1:]
|
|
return filt
|
|
|
|
class WaveletEncode1d(nn.Module):
|
|
def __init__(self,
|
|
channels,
|
|
levels,
|
|
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
|
|
super().__init__()
|
|
self.wavelet = wavelet
|
|
self.channels = channels
|
|
self.levels = levels
|
|
filt = get_filter_bank(wavelet)
|
|
assert filt.shape[-1] % 2 == 1
|
|
kernel = filt[:2, None]
|
|
kernel = torch.flip(kernel, dims=(-1,))
|
|
index_i = torch.repeat_interleave(torch.arange(2), channels)
|
|
index_j = torch.tile(torch.arange(channels), (2,))
|
|
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
|
|
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
|
|
self.register_buffer("kernel", kernel_final)
|
|
|
|
def forward(self, x):
|
|
for i in range(self.levels):
|
|
low, rest = x[:, : self.channels], x[:, self.channels :]
|
|
pad = self.kernel.shape[-1] // 2
|
|
low = F.pad(low, (pad, pad), "reflect")
|
|
low = F.conv1d(low, self.kernel, stride=2)
|
|
rest = rearrange(
|
|
rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
|
|
)
|
|
x = torch.cat([low, rest], dim=1)
|
|
return x
|
|
|
|
|
|
class WaveletDecode1d(nn.Module):
|
|
def __init__(self,
|
|
channels,
|
|
levels,
|
|
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
|
|
super().__init__()
|
|
self.wavelet = wavelet
|
|
self.channels = channels
|
|
self.levels = levels
|
|
filt = get_filter_bank(wavelet)
|
|
assert filt.shape[-1] % 2 == 1
|
|
kernel = filt[2:, None]
|
|
index_i = torch.repeat_interleave(torch.arange(2), channels)
|
|
index_j = torch.tile(torch.arange(channels), (2,))
|
|
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
|
|
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
|
|
self.register_buffer("kernel", kernel_final)
|
|
|
|
def forward(self, x):
|
|
for i in range(self.levels):
|
|
low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
|
|
pad = self.kernel.shape[-1] // 2 + 2
|
|
low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
|
|
low = F.pad(low, (pad, pad), "reflect")
|
|
low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
|
|
low = F.conv_transpose1d(
|
|
low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
|
|
)
|
|
low = low[..., pad - 1 : -pad]
|
|
rest = rearrange(
|
|
rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
|
|
)
|
|
x = torch.cat([low, rest], dim=1)
|
|
return x |