admin管理员组

文章数量:1279110

I'm working on a PyTorch project where I have a special parameter that learns conv1d weights. After obtaining these weights, I apply a multiplying factor that can vary for each item in the batch. My goal is to use the groups argument in the convolution operation so that each batch item gets its own unique filter.

I know If I have a time series input with shape B[atch], C[hannles], T[ime], and I create a weight with shape C_out,C,K_size then I could repeat that weight B times to have a B*C_out,C,K_size weight, then using functional conv1d I could apply those weights to the input.

The issue is I would need to fold the batch dim into the channel dim, but im not clear on how to do that. Specifically I'm uncertain about how to order the weight tensor channels. Specifically, if I have weights arranged as either (C × B) or (B × C) (where B is the number of batch items and C is the number of channels per group) and set groups=B, which ordering ensures that the channels remain properly grouped without intermixing between batches?

Additionally, I've had difficulty finding the relevant parts of the PyTorch source code that handle this ordering, and the docs don't seem to specify. Any insights on the correct weight arrangement or pointers to the relevant source code would be greatly appreciated.

Below is just some dummy example code

import torch
import torch.nn.functional as F

# Define input tensor with shape (B, C, T)
B, C, T, C_out, K_size = 2, 3, 5, 4, 3  # Batch, Channels, Time, Output Channels, Kernel Size
x = torch.randn(B, C, T)  # (B, C, T)

# Define a filter with shape (C_out, C, K_size) for a single batch item
base_weight = torch.randn(C_out, C, K_size)  # (C_out, C, K_size)

# Repeat the filter across the batch to create a (B, C_out, C, K_size) weight
weight = base_weight.repeat(B, 1, 1, 1)  # (B, C_out, C, K_size)

# Reshape the weight for grouped convolution (B*C_out, C, K_size)
weight = weight.view(B * C_out, C, K_size)


# Reshape input by folding batch into channels (1, B*C, T)
x = x.view(1, B * C, T)

# OR Reshape input by folding batch into channels (1, C*B, T)
x = x.permute(1,0,2) #B,C,T - >C,B,T

x = x.view(1, C * B, T)

# Apply grouped convolution with groups=B
y = F.conv1d(x, weight, groups=B)


本文标签: