admin管理员组文章数量:1332108
I have created this data:
HeteroData(
user={ x=[100, 16] },
keyword={ x=[321, 16] },
tweet={ x=[1000, 16] },
(user, follow, user)={ edge_index=[2, 291] },
(user, tweetedby, tweet)={ edge_index=[2, 1000] },
(keyword, haskeyword, tweet)={ edge_index=[2, 3752] }
)
And these two varibales based on that:
x_dict:
user: torch.Size([100, 16])
keyword: torch.Size([321, 16])
tweet: torch.Size([1000, 16])
edge_index_dict:
('user', 'follow', 'user'): torch.Size([2, 291])
('user', 'tweetedby', 'tweet'): torch.Size([2, 1000])
('keyword', 'haskeyword', 'tweet'): torch.Size([2, 3752])
My nodes were indexed from 0 to 99 for users, from 100 to 420 for keywords, and next for tweet nodes.
When I want to run this model:
class HeteroGATBinaryClassifier(torch.nn.Module):
def __init__(self, metadata, in_channels, hidden_channels, heads=1):
super().__init__()
self.metadata = metadata # Metadata about node and edge types
# Define GNN layers for each edge type
self.conv1 = HeteroConv({
edge_type: GATConv(in_channels, hidden_channels, heads=heads, add_self_loops = False)
for edge_type in metadata[1]
}, aggr='mean') # Aggregate using mean
self.conv2 = HeteroConv({
edge_type: GATConv(hidden_channels * heads, hidden_channels, heads=heads, add_self_loops = False)
for edge_type in metadata[1]
}, aggr='mean')
# Linear layer for classification (binary output)
self.classifier = Linear(hidden_channels * heads, 1) # Single output for binary classification
def forward(self, x_dict, edge_index_dict, target_node_type):
# First GAT layer
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.elu(x) for key, x in x_dict.items()}
# Second GAT layer
x_dict = self.conv2(x_dict, edge_index_dict)
x_dict = {key: F.elu(x) for key, x in x_dict.items()}
# Apply the classifier only on the target node type
logits = self.classifier(x_dict[target_node_type]) # Logits for target node type
return torch.sigmoid(logits) # Output probabilities for binary classification
But I see the following error:
IndexError: Found indices in 'edge_index' that are larger than 999 (got 1420). Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 1000) in your node feature matrix and try again.
Do you have any idea how to handle it?
I have created this data:
HeteroData(
user={ x=[100, 16] },
keyword={ x=[321, 16] },
tweet={ x=[1000, 16] },
(user, follow, user)={ edge_index=[2, 291] },
(user, tweetedby, tweet)={ edge_index=[2, 1000] },
(keyword, haskeyword, tweet)={ edge_index=[2, 3752] }
)
And these two varibales based on that:
x_dict:
user: torch.Size([100, 16])
keyword: torch.Size([321, 16])
tweet: torch.Size([1000, 16])
edge_index_dict:
('user', 'follow', 'user'): torch.Size([2, 291])
('user', 'tweetedby', 'tweet'): torch.Size([2, 1000])
('keyword', 'haskeyword', 'tweet'): torch.Size([2, 3752])
My nodes were indexed from 0 to 99 for users, from 100 to 420 for keywords, and next for tweet nodes.
When I want to run this model:
class HeteroGATBinaryClassifier(torch.nn.Module):
def __init__(self, metadata, in_channels, hidden_channels, heads=1):
super().__init__()
self.metadata = metadata # Metadata about node and edge types
# Define GNN layers for each edge type
self.conv1 = HeteroConv({
edge_type: GATConv(in_channels, hidden_channels, heads=heads, add_self_loops = False)
for edge_type in metadata[1]
}, aggr='mean') # Aggregate using mean
self.conv2 = HeteroConv({
edge_type: GATConv(hidden_channels * heads, hidden_channels, heads=heads, add_self_loops = False)
for edge_type in metadata[1]
}, aggr='mean')
# Linear layer for classification (binary output)
self.classifier = Linear(hidden_channels * heads, 1) # Single output for binary classification
def forward(self, x_dict, edge_index_dict, target_node_type):
# First GAT layer
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.elu(x) for key, x in x_dict.items()}
# Second GAT layer
x_dict = self.conv2(x_dict, edge_index_dict)
x_dict = {key: F.elu(x) for key, x in x_dict.items()}
# Apply the classifier only on the target node type
logits = self.classifier(x_dict[target_node_type]) # Logits for target node type
return torch.sigmoid(logits) # Output probabilities for binary classification
But I see the following error:
IndexError: Found indices in 'edge_index' that are larger than 999 (got 1420). Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 1000) in your node feature matrix and try again.
Do you have any idea how to handle it?
Share Improve this question edited Mar 10 at 10:16 daqh 14412 bronze badges asked Dec 10, 2024 at 0:42 aliiiiiiiiiiiiiiiiiiiiialiiiiiiiiiiiiiiiiiiiii 3173 silver badges10 bronze badges 2 |1 Answer
Reset to default 0Of course you are having troubles with indices mismatch between node feature matrix and the edge_index.
The edge index must be a tensor with shape (2, number_of_edges)
and with values < num_nodes
.
Each column of the edge index represent and edge and it is used to access the matrix of node features through the convolution process.
Probably, in the program you are running, you have 1000 nodes, and you didn't aligned edge indices correctly because you removed node features without updating the edge index or added nodes to the edge index without updating the node features.
It is very important that indices of edge index are aligned and consistent with node features, if not, you must add an offset to node features or normalize edge indices depending on what is your issue:
I usually do something like this on dim 0
or 1
to normalize src
or dst
of the edge index:
_, edge_index[0] = torch.unique(edge_index[0], return_inverse=True)
本文标签: pythonHow to handle heterogenous GNNStack Overflow
版权声明:本文标题:python - How to handle heterogenous GNN? - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1742222843a2435562.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
edge_index
tensors are supposed to contain tuples of node indices denoting what two nodes are connected by an edge. This means all the edge values should be less than the number of nodes. Your edge index has values outside the node count, ie pointing to nodes that don't exist. Requires more information to say why that is happening – Karl Commented Dec 10, 2024 at 6:40