admin管理员组

文章数量:1302379

I am training a diffusion model for an image-to-image translation task. Image size is 128 by 128, and batch size is 8 (due to memory limitations).

This is the model I am using:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super(ResidualBlock, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels, use_attention=use_attention)
        
        # If input and output channels do not match, add a projection layer to match channels
        if in_channels != out_channels:
            self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.proj = None

    def forward(self, x):
        residual = x
        print(f"Input to ResidualBlock x shape: {x.shape}")
        
        # Apply the convolution block
        x = self.conv_block(x)
        print(f"After ConvBlock x shape: {x.shape}")
        
        # If input and output channels do not match, apply the projection
        if self.proj is not None:
            residual = self.proj(residual)
            print(f"Projected residual shape: {residual.shape}")

        # Add the residual connection
        x = x + residual
        print(f"Output from ResidualBlock x shape: {x.shape}")
        return x


# Attention block using the "Self-Attention" mechanism
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height * width)
        attention = torch.bmm(query, key)  # Compute attention map
        attention = F.softmax(attention, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        out = self.gamma * out + x
        return out

# Convolution block (Conv + BatchNorm + LeakyReLU)
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_attention=False):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.use_attention = use_attention
        self.attention = AttentionBlock(out_channels) if use_attention else None

    def forward(self, x):
        residual = x

        # Apply convolution, batchnorm, and activation
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        # Apply attention if needed
        if self.use_attention:
            x = self.attention(x)

        # Add the residual connection
        x = x + residual
        return x

# Downsampling block (Convolution + MaxPooling)
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super(DownBlock, self).__init__()
        self.residual_block = ResidualBlock(in_channels, out_channels, use_attention=use_attention)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        print(f"Input to DownBlock x shape: {x.shape}")
        x = self.residual_block(x)
        x = self.pool(x)  # MaxPooling
        print(f"Output from DownBlock x shape: {x.shape}")
        return x

# Upsampling block (Transposed Convolution + Attention)
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, use_attention=False):
        super(UpBlock, self).__init__()
        self.residual_block = ResidualBlock(in_channels + skip_channels, out_channels, use_attention=use_attention)
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x, skip):
        print(f"Input to UpBlock x shape: {x.shape} and skip shape: {skip.shape}")
        x = self.upconv(x)

        # Concatenate the skip connection
        x = torch.cat([x, skip], dim=1)

        # Pass through the residual block
        x = self.residual_block(x)
        print(f"Output from UpBlock x shape: {x.shape}")
        return x

# UNet with Attention and Residual Blocks
class UNetWithAttention(nn.Module):
    def __init__(self, in_channels, out_channels, base_channels=[64, 128, 256, 512], position_encoding_dim=128, timestep_dim=1, use_attention=True):
        super(UNetWithAttention, self).__init__()

        self.timestep_embed_proj = nn.Linear(position_encoding_dim, base_channels[3])  # Match the bottleneck dimension

        # Downsampling path
        self.down1 = DownBlock(in_channels, base_channels[0], use_attention=use_attention)
        self.down2 = DownBlock(base_channels[0], base_channels[1], use_attention=use_attention)
        self.down3 = DownBlock(base_channels[1], base_channels[2], use_attention=use_attention)
        self.down4 = DownBlock(base_channels[2], base_channels[3], use_attention=use_attention)

        # Bottleneck
        self.bottleneck = ConvBlock(base_channels[3], base_channels[3], use_attention=use_attention)

        # Upsampling path
        self.up1 = UpBlock(base_channels[3], base_channels[2], base_channels[2], use_attention=use_attention)
        self.up2 = UpBlock(base_channels[2], base_channels[1], base_channels[1], use_attention=use_attention)
        self.up3 = UpBlock(base_channels[1], base_channels[0], base_channels[0], use_attention=use_attention)
        self.up4 = ConvBlock(base_channels[0], out_channels, use_attention=False)

    def forward(self, x, t=None):
        # Downsampling
        x1, skip1 = self.down1(x)  # Store skip connection here
        x2, skip2 = self.down2(x1)  # Store skip connection here
        x3, skip3 = self.down3(x2)  # Store skip connection here
        x4, skip4 = self.down4(x3)  # Store skip connection here

        # Bottleneck
        bottleneck = self.bottleneck(x4)

        # Incorporate timestep embedding (add or concatenate it, depending on design)
        if t is not None:
            timestep_embed = t.squeeze(1)
            timestep_embed = self.timestep_embed_proj(timestep_embed)
            _, _, h, w = bottleneck.shape
            timestep_embed = timestep_embed.unsqueeze(-1).unsqueeze(-1)
            timestep_embed = timestep_embed.repeat(1, 1, h, w)
            bottleneck = bottleneck + timestep_embed

        # Upsampling
        x = self.up1(bottleneck, skip4)  # Use skip4 as input to up1
        x = self.up2(x, skip3)  # Use skip3 as input to up2
        x = self.up3(x, skip2)  # Use skip2 as input to up3
        x = self.up4(x)  # No skip connection for the final layer

        # Before returning x, apply an upsampling operation (if needed):
        x = F.interpolate(x, size=(img_size, img_size), mode='bilinear', align_corners=False)

        return x

I use this to initialize the model.

unet = UNetWithAttention(in_channels=6, out_channels=3,
                         base_channels=[64, 128, 256, 512],
                         position_encoding_dim=position_encoding_dim,
                         timestep_dim=1,
                         use_attention=True)  # Provide value here
unet.to(device)

This is my timestep encoding function.

def positional_encoding(t, enc_dim):
    """Encode position information with a sinusoid."""
    inv_freq = 1.0 / (10000 ** (torch.arange(0, enc_dim, 2).float() / enc_dim)).to(t.device)
    pos_enc_a = torch.sin(t.unsqueeze(-1) * inv_freq)  # Use broadcasting
    pos_enc_b = torch.cos(t.unsqueeze(-1) * inv_freq)  # Use broadcasting
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
    return pos_enc

I use this to prepare the data, and my position_encoding_dim is 256.

def prepare_data(input_image, target_image, device=device):
    batch_size = input_image.shape[0]
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    t = torch.randint(low=0, high=noise_steps, size=(batch_size,)).to(device)
    x_t, noise = diffusion.forward_diffusion(target_image, t)
    x_t = torch.cat((input_image, x_t), dim=1)
    t = positional_encoding(t.unsqueeze(1), position_encoding_dim)

    return x_t.to(device), t.to(device), noise.to(device)

I run this code to train the model.

for epoch in range(epochs):
    start_time = time.time()
    num_batches = len(train_loader)
    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "_" * 10)

    unet.train()
    running_loss = 0.0

    for batch_idx, (input_images, target_images) in enumerate(train_loader, start=0):
        x_t, t, noise = prepare_data(input_images, target_images)
        outputs = unet(x=x_t, t=t)
        optimizer.zero_grad()
        loss = criterion(outputs, noise)
        loss.backward()
        optimizer.step()

I keep getting this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-25-1c932665f6dc> in <cell line: 0>()
      9     for batch_idx, (input_images, target_images) in enumerate(train_loader, start=0):
     10         x_t, t, noise = prepare_data(input_images, target_images)
---> 11         outputs = unet(x=x_t, t=t)
     12         optimizer.zero_grad()
     13         loss = criterion(outputs, noise)

11 frames
<ipython-input-16-2b9fa24bdc2f> in forward(self, x)
     73 
     74         # Add the residual connection
---> 75         x = x + residual
     76         return x
     77 

RuntimeError: The size of tensor a (64) must match the size of tensor b (6) at non-singleton dimension 1

I don't know how to fix it.

I added print statements to get the output of the model at different points so I can find where the problem is, but I have not been able to pinpoint it still. How can I fix it, please?

I am training a diffusion model for an image-to-image translation task. Image size is 128 by 128, and batch size is 8 (due to memory limitations).

This is the model I am using:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super(ResidualBlock, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels, use_attention=use_attention)
        
        # If input and output channels do not match, add a projection layer to match channels
        if in_channels != out_channels:
            self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.proj = None

    def forward(self, x):
        residual = x
        print(f"Input to ResidualBlock x shape: {x.shape}")
        
        # Apply the convolution block
        x = self.conv_block(x)
        print(f"After ConvBlock x shape: {x.shape}")
        
        # If input and output channels do not match, apply the projection
        if self.proj is not None:
            residual = self.proj(residual)
            print(f"Projected residual shape: {residual.shape}")

        # Add the residual connection
        x = x + residual
        print(f"Output from ResidualBlock x shape: {x.shape}")
        return x


# Attention block using the "Self-Attention" mechanism
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height * width)
        attention = torch.bmm(query, key)  # Compute attention map
        attention = F.softmax(attention, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        out = self.gamma * out + x
        return out

# Convolution block (Conv + BatchNorm + LeakyReLU)
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_attention=False):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.use_attention = use_attention
        self.attention = AttentionBlock(out_channels) if use_attention else None

    def forward(self, x):
        residual = x

        # Apply convolution, batchnorm, and activation
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        # Apply attention if needed
        if self.use_attention:
            x = self.attention(x)

        # Add the residual connection
        x = x + residual
        return x

# Downsampling block (Convolution + MaxPooling)
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super(DownBlock, self).__init__()
        self.residual_block = ResidualBlock(in_channels, out_channels, use_attention=use_attention)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        print(f"Input to DownBlock x shape: {x.shape}")
        x = self.residual_block(x)
        x = self.pool(x)  # MaxPooling
        print(f"Output from DownBlock x shape: {x.shape}")
        return x

# Upsampling block (Transposed Convolution + Attention)
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, use_attention=False):
        super(UpBlock, self).__init__()
        self.residual_block = ResidualBlock(in_channels + skip_channels, out_channels, use_attention=use_attention)
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x, skip):
        print(f"Input to UpBlock x shape: {x.shape} and skip shape: {skip.shape}")
        x = self.upconv(x)

        # Concatenate the skip connection
        x = torch.cat([x, skip], dim=1)

        # Pass through the residual block
        x = self.residual_block(x)
        print(f"Output from UpBlock x shape: {x.shape}")
        return x

# UNet with Attention and Residual Blocks
class UNetWithAttention(nn.Module):
    def __init__(self, in_channels, out_channels, base_channels=[64, 128, 256, 512], position_encoding_dim=128, timestep_dim=1, use_attention=True):
        super(UNetWithAttention, self).__init__()

        self.timestep_embed_proj = nn.Linear(position_encoding_dim, base_channels[3])  # Match the bottleneck dimension

        # Downsampling path
        self.down1 = DownBlock(in_channels, base_channels[0], use_attention=use_attention)
        self.down2 = DownBlock(base_channels[0], base_channels[1], use_attention=use_attention)
        self.down3 = DownBlock(base_channels[1], base_channels[2], use_attention=use_attention)
        self.down4 = DownBlock(base_channels[2], base_channels[3], use_attention=use_attention)

        # Bottleneck
        self.bottleneck = ConvBlock(base_channels[3], base_channels[3], use_attention=use_attention)

        # Upsampling path
        self.up1 = UpBlock(base_channels[3], base_channels[2], base_channels[2], use_attention=use_attention)
        self.up2 = UpBlock(base_channels[2], base_channels[1], base_channels[1], use_attention=use_attention)
        self.up3 = UpBlock(base_channels[1], base_channels[0], base_channels[0], use_attention=use_attention)
        self.up4 = ConvBlock(base_channels[0], out_channels, use_attention=False)

    def forward(self, x, t=None):
        # Downsampling
        x1, skip1 = self.down1(x)  # Store skip connection here
        x2, skip2 = self.down2(x1)  # Store skip connection here
        x3, skip3 = self.down3(x2)  # Store skip connection here
        x4, skip4 = self.down4(x3)  # Store skip connection here

        # Bottleneck
        bottleneck = self.bottleneck(x4)

        # Incorporate timestep embedding (add or concatenate it, depending on design)
        if t is not None:
            timestep_embed = t.squeeze(1)
            timestep_embed = self.timestep_embed_proj(timestep_embed)
            _, _, h, w = bottleneck.shape
            timestep_embed = timestep_embed.unsqueeze(-1).unsqueeze(-1)
            timestep_embed = timestep_embed.repeat(1, 1, h, w)
            bottleneck = bottleneck + timestep_embed

        # Upsampling
        x = self.up1(bottleneck, skip4)  # Use skip4 as input to up1
        x = self.up2(x, skip3)  # Use skip3 as input to up2
        x = self.up3(x, skip2)  # Use skip2 as input to up3
        x = self.up4(x)  # No skip connection for the final layer

        # Before returning x, apply an upsampling operation (if needed):
        x = F.interpolate(x, size=(img_size, img_size), mode='bilinear', align_corners=False)

        return x

I use this to initialize the model.

unet = UNetWithAttention(in_channels=6, out_channels=3,
                         base_channels=[64, 128, 256, 512],
                         position_encoding_dim=position_encoding_dim,
                         timestep_dim=1,
                         use_attention=True)  # Provide value here
unet.to(device)

This is my timestep encoding function.

def positional_encoding(t, enc_dim):
    """Encode position information with a sinusoid."""
    inv_freq = 1.0 / (10000 ** (torch.arange(0, enc_dim, 2).float() / enc_dim)).to(t.device)
    pos_enc_a = torch.sin(t.unsqueeze(-1) * inv_freq)  # Use broadcasting
    pos_enc_b = torch.cos(t.unsqueeze(-1) * inv_freq)  # Use broadcasting
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
    return pos_enc

I use this to prepare the data, and my position_encoding_dim is 256.

def prepare_data(input_image, target_image, device=device):
    batch_size = input_image.shape[0]
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    t = torch.randint(low=0, high=noise_steps, size=(batch_size,)).to(device)
    x_t, noise = diffusion.forward_diffusion(target_image, t)
    x_t = torch.cat((input_image, x_t), dim=1)
    t = positional_encoding(t.unsqueeze(1), position_encoding_dim)

    return x_t.to(device), t.to(device), noise.to(device)

I run this code to train the model.

for epoch in range(epochs):
    start_time = time.time()
    num_batches = len(train_loader)
    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "_" * 10)

    unet.train()
    running_loss = 0.0

    for batch_idx, (input_images, target_images) in enumerate(train_loader, start=0):
        x_t, t, noise = prepare_data(input_images, target_images)
        outputs = unet(x=x_t, t=t)
        optimizer.zero_grad()
        loss = criterion(outputs, noise)
        loss.backward()
        optimizer.step()

I keep getting this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-25-1c932665f6dc> in <cell line: 0>()
      9     for batch_idx, (input_images, target_images) in enumerate(train_loader, start=0):
     10         x_t, t, noise = prepare_data(input_images, target_images)
---> 11         outputs = unet(x=x_t, t=t)
     12         optimizer.zero_grad()
     13         loss = criterion(outputs, noise)

11 frames
<ipython-input-16-2b9fa24bdc2f> in forward(self, x)
     73 
     74         # Add the residual connection
---> 75         x = x + residual
     76         return x
     77 

RuntimeError: The size of tensor a (64) must match the size of tensor b (6) at non-singleton dimension 1

I don't know how to fix it.

I added print statements to get the output of the model at different points so I can find where the problem is, but I have not been able to pinpoint it still. How can I fix it, please?

Share Improve this question edited Feb 14 at 10:05 teetee.py asked Feb 11 at 6:48 teetee.pyteetee.py 113 bronze badges 2
  • 2 Please add the full error traceback to your question. – Klaus D. Commented Feb 11 at 7:04
  • I have added the full traceback. Thanks. – teetee.py Commented Feb 14 at 10:05
Add a comment  | 

2 Answers 2

Reset to default 0

maybe you def prepare_data() does not working ? like the code " x_t, noise = diffusion.forward_diffusion(target_image, t)" diffusion function not normal operation, then x_t variable is not a good values. Lead to later matrix size mapping error.

To use a residual network block, the input tensor has to be the same shape as the output otherwise you won't be able to add them as you are seeing. In your case the convolutional components in your first ConvBlock change the number of channels from 6 to 64, which gives the error you see.

Within, e.g., the ResNet available in torchvision, for each residual layer, if the input and output channels sizes are not going to match, it adds a downsample (in this case another convolutional layer with a 1x1 kernel size) to the block to make the shape of the input match that of the output (see here). You'll need to add similar "downsampling" (although, despite the name you'd actually be upsampling the input)!

本文标签: