admin管理员组

文章数量:1278793

I'm working on a multi-label classification problem using an autoencoder-based neural network built in PyTorch. The overall idea of my approach is as follows:

I load my dataset from a CSV file, perform cleaning (dropping missing values and unnecessary columns), and select a subset of features. I split the data into features (X) and multi-label targets (y) — where each target is binary (0 or 1). I then apply a group-based train/test split using GroupShuffleSplit to ensure that related samples remain together, which is needed because of the nature of my problem. Finally, I scale the features using StandardScaler.

I couldn't find many examples of this kind of approach, but so far this is what I did:

class MutationDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

#architecture of my model

class AutoencoderClassifier(nn.Module):
    def __init__(self, input_dim, latent_dim=32, dropout_rate=0.2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, latent_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )
        # Classifier head outputs raw logits
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 4)
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        logits = self.classifier(encoded)
        return decoded, logits

My classes are highly imbalanced (with the negative class dominating). To address this, I tried using Focal Loss with additional positive class weighting.

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=3, reduction='mean'):
        """
        Args:
            alpha (float or Tensor): Weighting factor for the rare class. You can also pass a tensor for per-class weighting.
            gamma (float): Focusing parameter. Higher values put more focus on hard misclassified examples.
            reduction (str): 'mean', 'sum', or 'none'
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # Compute the standard BCE loss in a per-element fashion
        BCE_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        # Get probabilities
        pt = torch.exp(-BCE_loss)
        # Compute the focal loss scaling factor
        focal_term = (1 - pt) ** self.gamma
        loss = self.alpha * focal_term * BCE_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

Finally, this is the training of my model:


def train_model(X_train, y_train, X_test, y_test, input_dim, 
                device='cuda', learning_rate=0.001, batch_size=32, epochs=50, dropout_rate=0.2):
    train_dataset = MutationDataset(X_train, y_train)
    test_dataset = MutationDataset(X_test, y_test)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    model = AutoencoderClassifier(input_dim=input_dim, dropout_rate=dropout_rate)
    model.to(device)
    
    # Compute positive class weights per label
    pos_weights = []
    for i in range(y_train.shape[1]):
        pos_count = np.sum(y_train[:, i])
        neg_count = len(y_train) - pos_count
        weight = neg_count / pos_count if pos_count != 0 else 1.0
        pos_weights.append(weight)
    pos_weights = torch.tensor(pos_weights, dtype=torch.float32, device=device)
    print("Pos weights:", pos_weights)
    
    recon_criterion = nn.MSELoss()
    class_criterion = FocalLoss(alpha=1, gamma=2, reduction='mean')
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            optimizer.zero_grad()
            recon, logits = model(batch_x)
            loss = 0.7 * recon_criterion(recon, batch_x) + 0.3 * class_criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}: Loss {total_loss/len(train_loader):.4f}')
    
    return model, test_loader

Needless to say, the performance is quite awful. The model’s loss decreases steadily during training—from around 0.20 at the first epoch down to about 0.03 by the 50th epoch. On the training data, accuracy is relatively high (around 80–91% for most labels). However, the F1 scores vary: while some labels reach moderate F1 scores (e.g., around 0.60 for one label), others are much lower (or even zero)... Test accuracy remains reasonably high for some labels (roughly 77–87%), but F1 scores drop significantly compared to training. For one of the labels, the model fails to predict any positive cases (F1 score of 0 on both training and test sets).

There are a few things I can try and tweak here and there, but it feels like I might be approaching this wrongly on a more general note. What am I missing?

本文标签: pythonAutoencoder for multilabel classification taskStack Overflow