admin管理员组

文章数量:1297014

I have some problem to make a LSTM-VAE for anomalies detection on multivariate signals (no constant duration). I found some informations in this forum and original papers to apply good practices. Even, if I took into account all the basic faults to do not reproduce (data scaling, difference between GPU and CPU calculations, use MSE(reduction='sum'), use KL cyclical annealing, ...), my LSTM-VAE is only able to learn a flat curve. I made a LSTM-VAE with optional bidirectional and layers number. Even if I try differents parameters, it's always the same result : a flat curve. Did you already have this problem ?

Hereunder my model:

class LSTMVAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim, num_layers=1, dropout=0.1, use_residual=False,
             bidirectional=False, device="cpu"):
    super(LSTMVAE, self).__init__()
    self.device = torch.device(device)
    self.use_residual = use_residual
    self.bidirectional = bidirectional

    # Encoder
    self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout,
                                bidirectional=self.bidirectional, batch_first=True)
    self.encoder_layernorm = nn.LayerNorm(hidden_dim * (2 if bidirectional else 1))
    self.fc_mean = nn.Linear(hidden_dim * (2 if bidirectional else 1), latent_dim)
    self.fc_logvar = nn.Linear(hidden_dim * (2 if bidirectional else 1), latent_dim)

    # Decoder
    self.decoder_lstm = nn.LSTM(latent_dim, hidden_dim, num_layers=num_layers, dropout=dropout,
                                bidirectional=self.bidirectional, batch_first=True)
    self.decoder_layernorm = nn.LayerNorm(hidden_dim * (2 if bidirectional else 1))
    self.fc_output = nn.Linear(hidden_dim * (2 if bidirectional else 1), input_dim)

    # Weight initialization
    self._initialize_weights()

    # Move model to device
    self.to(self.device)

def encode(self, x):
    _, (hidden, _) = self.encoder_lstm(x)
    if self.bidirectional:
        # Concatenate the hidden states from both directions
        hidden = torch.cat((hidden[-2], hidden[-1]), dim=-1)  # Concatenate the hidden states from both directions
    else:
        hidden = hidden[-1]  # Get the hidden state of the last layer
    hidden = self.encoder_layernorm(hidden)
    mean = self.fc_mean(hidden)
    logvar = self.fc_logvar(hidden)
    return mean, logvar

def reparameterize(self, mean, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mean + eps * std

def decode(self, z, seq_len):
    z = z.unsqueeze(1).repeat(1, seq_len, 1)
    output, _ = self.decoder_lstm(z)
    if self.bidirectional:
        output = output.view(output.size(0), output.size(1), -1)  # Concatenate the states from both directions
    output = self.decoder_layernorm(output)
    output = self.fc_output(output)
    return output

def forward(self, x):
    seq_len = x.size(1)
    mean, logvar = self.encode(x)
    z = self.reparameterize(mean, logvar)
    recon_x = self.decode(z, seq_len)
    if self.use_residual:
        recon_x = recon_x + x
    return recon_x, mean, logvar

My training loop is the following :

 total_loss = 0
step = 0
for epoch in range(cfg.lstm.epoch):
    model.train()
    for cycle, id_cycle, activity, _ in train_loader:
        cycle = cycle.to(device)
        optimizer.zero_grad()
        recon_cycle, mean, logvar = model(cycle)
        recon_error = nn.MSELoss()(recon_cycle, cycle).item()
        logger.debug(
            f'Cycles: {id_cycle.item()}, activities: {activity}, seq_lens: {cycle.shape[1]}, recon_error: {recon_error}')
        recon_errors_list.append(recon_error)
        loss = loss_function(recon_cycle, cycle, mean, logvar, anneal_function=cfg.lstm.kl_anneal_function,
                             step=step, total_steps=total_steps, cycles=cfg.lstm.kl_anneal_cycles,
                             ratio=cfg.lstm.kl_anneal_ratio)
        loss.backward()
        # Gradient clipping
        if cfg.lstm.clip_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
        # Log gradient statistics if gradient = 0
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                if grad_norm == 0:
                    logger.warning(f"Zero gradient for {name}")
                    logger.warning(f"Parameter stats:")
                    logger.warning(f"  Mean: {param.data.mean().item():.6f}")
                    logger.warning(f"  Std: {param.data.std().item():.6f}")
                    logger.warning(f"  Max: {param.data.max().item():.6f}")
                    logger.warning(f"Parameter requires grad: {param.requires_grad}")
        optimizer.step()
        total_loss += loss.item()
        step += 1
    scheduler.step()

And the final sesults :

I you have some experience with LSTAM-VAE and you have some tricks to implement, thanks by advance !!!

Regards

Rémy

本文标签: Pytorch LSTMVAE not able to learnStack Overflow