admin管理员组

文章数量:1418683

Currently working on a classifier using PyWavelets, here is my calculation block:

class WaveletLayer(nn.Module):
    def __init__(self):
        super(WaveletLayer, self).__init__()

    def forward(self, x):
        def wavelet_transform(img):
            coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
            LL, (LH, HL, HH) = coeffs
            return (
                torch.from_numpy(LL).to(img.device),
                torch.from_numpy(LH).to(img.device),
                torch.from_numpy(HL).to(img.device),
                torch.from_numpy(HH).to(img.device),
            )

        # Apply wavelet transform to each channel separately
        LL, LH, HL, HH = zip(
            *[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
        )

        # Concatenate the results
        LL = torch.cat(LL, dim=1)
        LH = torch.cat(LH, dim=1)
        HL = torch.cat(HL, dim=1)
        HH = torch.cat(HH, dim=1)

        return torch.cat([LL, LH, HL, HH], dim=1)

The output from this module goes to a resnet block for learning, while doing this I find my CPU clogged and thus slowing down my training process

I am trying to use the GPUs for these calculations.

Currently working on a classifier using PyWavelets, here is my calculation block:

class WaveletLayer(nn.Module):
    def __init__(self):
        super(WaveletLayer, self).__init__()

    def forward(self, x):
        def wavelet_transform(img):
            coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
            LL, (LH, HL, HH) = coeffs
            return (
                torch.from_numpy(LL).to(img.device),
                torch.from_numpy(LH).to(img.device),
                torch.from_numpy(HL).to(img.device),
                torch.from_numpy(HH).to(img.device),
            )

        # Apply wavelet transform to each channel separately
        LL, LH, HL, HH = zip(
            *[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
        )

        # Concatenate the results
        LL = torch.cat(LL, dim=1)
        LH = torch.cat(LH, dim=1)
        HL = torch.cat(HL, dim=1)
        HH = torch.cat(HH, dim=1)

        return torch.cat([LL, LH, HL, HH], dim=1)

The output from this module goes to a resnet block for learning, while doing this I find my CPU clogged and thus slowing down my training process

I am trying to use the GPUs for these calculations.

Share Improve this question edited Jan 30 at 8:35 simon 5,6551 gold badge16 silver badges29 bronze badges asked Jan 29 at 13:26 Aryan RajAryan Raj 3479 bronze badges 3
  • See this or batch train. – aaa Commented Jan 29 at 13:36
  • 1 (1) You tagged your question with both pytorch and tensorflow, while the code that you show is PyTorch; would a TensorFlow solution also be viable? (2) Consider providing a full example including the instantiation and calling of the WaveletLayer, so that others can reproduce your current results and compare potential solutions (see minimal, reproducible example). – simon Commented Jan 29 at 13:37
  • 1 @simon both are fine to me, I understand both the frameworks, so anyone having knowledge about it in either frameworks can help. – Aryan Raj Commented Jan 29 at 14:27
Add a comment  | 

1 Answer 1

Reset to default 4

Since you only seem to be interested in the Haar wavelet, you can pretty much implement it yourself:

  • The high-frequency component of the Haar wavelet along each dimension can be written as a pairwise difference.
  • The low-frequency component of the Haar wavelet along each dimension can be written as a pairwise sum.

The following code achieves this in pure PyTorch:

class HaarWaveletLayer(nn.Module):
    
    def l_0(self, t):  # sum ("low") along cols
        t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
        return (t[..., ::2, :] + t[..., 1::2, :])
    def l_1(self, t):  # sum ("low") along rows
        t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
        return (t[..., :, ::2] + t[..., :, 1::2])
    def h_0(self, t):  # diff ("hi") along cols
        t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
        return (t[..., ::2, :] - t[..., 1::2, :])
    def h_1(self, t):  # diff ("hi") along rows
        t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
        return (t[..., :, ::2] - t[..., :, 1::2])
    
    def forward(self, x):
        
        x = .5 * x
        l_1 = self.l_1(x)
        h_1 = self.h_1(x)
        ll = self.l_0(l_1)
        lh = self.h_0(l_1)
        hl = self.l_0(h_1)
        hh = self.h_0(h_1)
        
        return torch.cat([ll, lh, hl, hh], dim=1)

In combination with your given code, you can convince yourself of the equivalence as follows:

t = torch.rand((7, 3, 127, 128)).to("cuda:0")
result_given = WaveletLayer()(t)
result_proposed = HaarWaveletLayer()(t)

# Same result?
assert (result_given - result_proposed).abs().max() < 1e-5

# Time comparison
from timeit import Timer
num_timings = 100
print("time given:   ", Timer(lambda: WaveletLayer()(t)).timeit(num_timings))
print("time proposed:", Timer(lambda: HaarWaveletLayer()(t)).timeit(num_timings))

The timing shows a speedup of more than a factor of 10 on my machine.

Notes

  • The t = torch.cat... parts are only necessary if you want to be able to handle odd-shaped images: In that case, we pad by replicating the last row and column, respectively, mimicking the default padding of PyWavelets.
  • Multiplying x with .5 is done for normalization. Compare this discussion on the Signal Processing Stack Exchange for more details.

本文标签: pythonDoing PyWavelets calculation on GPUStack Overflow