admin管理员组

文章数量:1122846

Assume num_classes = batch_size = N

mention_embeddings = [M, dim] 

class_index = [M, 1] # (already sorted by the number of mention in the cluster) . 
# E.g [0,0,0,1,1,2,3,4...] means 3 first mentions have the same class and so on. 

# Note: max(class_index) = N - 1 as we have N classes and M >= N 

Beside that I have:

positive_labels # has shape of [N, 1]
negative_labels # has_shape of [N, 1]  

Here is a little tricky: positive_labels does not have the same shape of mention_embeddings but you can infer positive label of each mention by duplicate indexing (or whatever it's called)

positive_labels = positive_labels[class_index] -> [M, 1]

Each class might have multiple clusters. After detect clusters of each class I have:

cluster_index = [M, 1] # an array indicates which mention index are in the same cluster.
# E.g [0, 1, 0, 2, 3, 3, .... ] means mention at index 0, and 2 are the same cluster,...          

num_clusters = torch.unquie(cluster_index) # -> K cluster. N <= k <= M
cluster_embeddings = torch.zeros((num_clusters,
                                  mention_embeddings.size(1)))
                                  
cluster_embeddings = cluster_embeddings.index_add_(0, cluster_index, mention_embeddings)

# cluster_embeddings shape [K, dim] 

Question: how to create corresponding positive_cluster_labels and negative_cluster_labels which have shape of [K, 1]


My current approach:

Create a mapping between each cluster labels and corresponding set mention indices of that cluster

cluster_label_index_mapping = {
    unique_label: torch.tensor([idx
                                for idx, label in enumerate(cluster_index)
                                if label == unique_label][0])
    for unique_label in torch.unique(cluster_index)}


# => {0: [0,2], A dictionary of cluster label and its mention indices (in array [M, 1])
#    1: [1],
#    2: [3],..   
#    }

Because clusters are created inside a class so they will have same label for positive and negative so I just get one representative mention index. E.g. for class 0 just need to get the corresponding positive and negative index of mention index 0:

 cluster_label_index = torch.cat([item.view(-1, )
                                          for item in cluster_label_index_mapping.values()]
                                 )

But note that: mentions has length of M, original positive, and negative as length of N so I do two times mapping:

 positive_cluster_labels = positive_labels[class_index] 
# (by duplicating some rows => from [N, 1] => [M, 1] 
 positive_cluster_labels = positive_cluster_labels[cluster_label_index] 
# choose K indices from M indices from [M, 1] => [K, 1] 

I'm just worry if it has consistent order between cluster embeddings and their corresponding positive labels and negative labels as well.

I can see that in this approach, I have to name the cluster label by index from 0 -> K-1 so when I compute the cluster embeddings with index_add_ method, it's aggregated by the index in cluster_index 0 -> K-1

Also, this line of code:

 cluster_label_index = torch.cat([item.view(-1, )
                                  for item in cluster_label_index_mapping.values()]
                                )

I'm not sure if unique_label in torch.unique(cluster_index) always gives me the same ordering from 0 -> K-1 (or I have to sort it)

This approach is quite long and not straightforward :) but by using class_index to indicate class of mention (as the number of mentions in each class are not the same so I cannot save it in 2D tensor), I save a lot of memory (instead of padding, and a lot complicated when calculate cluster embeddings) . What do you think? or do you have any other ideas?

本文标签: pythonHow to do bridge mappingStack Overflow