admin管理员组文章数量:1418683
Currently working on a classifier using PyWavelets, here is my calculation block:
class WaveletLayer(nn.Module):
def __init__(self):
super(WaveletLayer, self).__init__()
def forward(self, x):
def wavelet_transform(img):
coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
LL, (LH, HL, HH) = coeffs
return (
torch.from_numpy(LL).to(img.device),
torch.from_numpy(LH).to(img.device),
torch.from_numpy(HL).to(img.device),
torch.from_numpy(HH).to(img.device),
)
# Apply wavelet transform to each channel separately
LL, LH, HL, HH = zip(
*[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
)
# Concatenate the results
LL = torch.cat(LL, dim=1)
LH = torch.cat(LH, dim=1)
HL = torch.cat(HL, dim=1)
HH = torch.cat(HH, dim=1)
return torch.cat([LL, LH, HL, HH], dim=1)
The output from this module goes to a resnet block for learning, while doing this I find my CPU clogged and thus slowing down my training process
I am trying to use the GPUs for these calculations.
Currently working on a classifier using PyWavelets, here is my calculation block:
class WaveletLayer(nn.Module):
def __init__(self):
super(WaveletLayer, self).__init__()
def forward(self, x):
def wavelet_transform(img):
coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
LL, (LH, HL, HH) = coeffs
return (
torch.from_numpy(LL).to(img.device),
torch.from_numpy(LH).to(img.device),
torch.from_numpy(HL).to(img.device),
torch.from_numpy(HH).to(img.device),
)
# Apply wavelet transform to each channel separately
LL, LH, HL, HH = zip(
*[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
)
# Concatenate the results
LL = torch.cat(LL, dim=1)
LH = torch.cat(LH, dim=1)
HL = torch.cat(HL, dim=1)
HH = torch.cat(HH, dim=1)
return torch.cat([LL, LH, HL, HH], dim=1)
The output from this module goes to a resnet block for learning, while doing this I find my CPU clogged and thus slowing down my training process
I am trying to use the GPUs for these calculations.
Share Improve this question edited Jan 30 at 8:35 simon 5,6551 gold badge16 silver badges29 bronze badges asked Jan 29 at 13:26 Aryan RajAryan Raj 3479 bronze badges 3 |1 Answer
Reset to default 4Since you only seem to be interested in the Haar wavelet, you can pretty much implement it yourself:
- The high-frequency component of the Haar wavelet along each dimension can be written as a pairwise difference.
- The low-frequency component of the Haar wavelet along each dimension can be written as a pairwise sum.
The following code achieves this in pure PyTorch:
class HaarWaveletLayer(nn.Module):
def l_0(self, t): # sum ("low") along cols
t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
return (t[..., ::2, :] + t[..., 1::2, :])
def l_1(self, t): # sum ("low") along rows
t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
return (t[..., :, ::2] + t[..., :, 1::2])
def h_0(self, t): # diff ("hi") along cols
t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
return (t[..., ::2, :] - t[..., 1::2, :])
def h_1(self, t): # diff ("hi") along rows
t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
return (t[..., :, ::2] - t[..., :, 1::2])
def forward(self, x):
x = .5 * x
l_1 = self.l_1(x)
h_1 = self.h_1(x)
ll = self.l_0(l_1)
lh = self.h_0(l_1)
hl = self.l_0(h_1)
hh = self.h_0(h_1)
return torch.cat([ll, lh, hl, hh], dim=1)
In combination with your given code, you can convince yourself of the equivalence as follows:
t = torch.rand((7, 3, 127, 128)).to("cuda:0")
result_given = WaveletLayer()(t)
result_proposed = HaarWaveletLayer()(t)
# Same result?
assert (result_given - result_proposed).abs().max() < 1e-5
# Time comparison
from timeit import Timer
num_timings = 100
print("time given: ", Timer(lambda: WaveletLayer()(t)).timeit(num_timings))
print("time proposed:", Timer(lambda: HaarWaveletLayer()(t)).timeit(num_timings))
The timing shows a speedup of more than a factor of 10 on my machine.
Notes
- The
t = torch.cat...
parts are only necessary if you want to be able to handle odd-shaped images: In that case, we pad by replicating the last row and column, respectively, mimicking the default padding of PyWavelets. - Multiplying
x
with .5 is done for normalization. Compare this discussion on the Signal Processing Stack Exchange for more details.
本文标签: pythonDoing PyWavelets calculation on GPUStack Overflow
版权声明:本文标题:python - Doing PyWavelets calculation on GPU - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1745295287a2652039.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
pytorch
andtensorflow
, while the code that you show is PyTorch; would a TensorFlow solution also be viable? (2) Consider providing a full example including the instantiation and calling of theWaveletLayer
, so that others can reproduce your current results and compare potential solutions (see minimal, reproducible example). – simon Commented Jan 29 at 13:37