You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In rl4co/models/nn/graph/gcn.py, the GCNEncoder class processes batched node embeddings without adjusting the edge_index for each graph in the batch. This leads to incorrect message passing in GCNConv, where nodes from different graphs may be incorrectly connected or ignored, resulting in unintended behavior of the model.
Issue in Code:
# In rl4co/models/nn/graph/gcn.py, inside the GCNEncoder class
def forward(self, td):
# Transfer to embedding space
init_h = self.init_embedding(td)
bs, num_nodes, emb_dim = init_h.shape
# (bs*num_nodes, emb_dim)
update_node_feature = init_h.reshape(-1, emb_dim) # Flatten batch
# shape=(2, num_edges)
edge_index = self.edge_idx_fn(td, num_nodes) # Edge index for a single graph
for layer in self.gcn_layers[:-1]:
update_node_feature = layer(update_node_feature, edge_index)
# ...
Problem: edge_index is generated for a single graph and not adjusted for batching. Consequence: When node features from multiple graphs are concatenated, GCNConv incorrectly processes nodes, only properly updating nodes of the first graph.
To Reproduce
Minimal Example to Reproduce the Bug:
importtorchfromtorch_geometric.nnimportGCNConvfromtorch_geometric.dataimportDatafromfunctoolsimportlru_cache# Function to generate full graph edge index (as in the original code)@lru_cache(5)defget_full_graph_edge_index(num_node: int, self_loop=False) ->torch.Tensor:
adj_matrix=torch.ones(num_node, num_node)
ifnotself_loop:
adj_matrix.fill_diagonal_(0)
edge_index=torch.permute(torch.nonzero(adj_matrix), (1, 0))
returnedge_index# Parametersnum_nodes_per_graph=5num_node_features=3batch_size=2# Number of graphs in the batch# Create node features for two graphsx1=torch.randn(num_nodes_per_graph, num_node_features)
x2=torch.randn(num_nodes_per_graph, num_node_features)
# Concatenate node features to simulate batchingx=torch.cat([x1, x2], dim=0) # Shape: [batch_size * num_nodes_per_graph, num_node_features]# Generate edge_index using the original functionedge_index_single=get_full_graph_edge_index(num_nodes_per_graph, self_loop=False)
# Use the same edge_index for both graphs without adjustingedge_index=edge_index_single# Edge indices from graph1, node indices 0 to 4# Initialize GCNConvconv=GCNConv(in_channels=num_node_features, out_channels=2, add_self_loops=False)
# Apply GCNConv without adjusting edge_indexout=conv(x, edge_index)
# Print the outputprint("Output without adjusting edge_index:")
print(out)
# Now adjust edge_index for each graph to reflect batchingedge_indices= []
foriinrange(batch_size):
offset=i*num_nodes_per_graphadjusted_edge_index=edge_index_single+offsetedge_indices.append(adjusted_edge_index)
edge_index_adjusted=torch.cat(edge_indices, dim=1)
# Apply GCNConv with adjusted edge_indexout_adjusted=conv(x, edge_index_adjusted)
# Compare outputs for nodes belonging to each graphprint("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] -out_adjusted[:num_nodes_per_graph])
print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] -out_adjusted[num_nodes_per_graph:])
Explanation:
Edge Index Generation: We use the get_full_graph_edge_index function from the original code to generate the edge indices for a fully connected graph without self-loops. Without Adjusted edge_index:
The edge_index only references nodes 0 to 4.
Nodes 5 to 9 (from the second graph) are not properly connected and thus not processed correctly by GCNConv. With Adjusted edge_index:
By adding an offset to edge_index for each graph in the batch, nodes are correctly connected within their own graphs.
The outputs for nodes in the second graph are now properly updated based on their graph structure. Observation:
The differences in outputs for nodes belonging to the second graph indicate that without adjusting edge_index, those nodes were not processed correctly.
Additional context
The issue arises because in the GCNEncoder, the edge_index is generated once for a single graph and used for all graphs in the batch without adjustment. This causes incorrect node indexing when processing batched graphs, leading to unintended behavior.
Reason and Possible fixes
Reason:
The edge_index is not adjusted for batched graphs, causing nodes from different graphs to be incorrectly connected or not processed correctly.
In batched graphs, node indices in edge_index need to be offset for each graph to reflect their positions in the concatenated node feature tensor.
Possible Fix:
Adjust the edge_index for each graph in the batch by adding an offset based on the cumulative number of nodes. Here's a modification to the GCNEncoder class:
defforward(self, td):
init_h=self.init_embedding(td)
bs, num_nodes, emb_dim=init_h.shapeupdate_node_feature=init_h.reshape(-1, emb_dim)
# Original edge_index for a single graphedge_index_single=self.edge_idx_fn(td, num_nodes)
# Adjust edge_index for batchingedge_indices= []
foriinrange(bs):
offset=i*num_nodesadjusted_edge_index=edge_index_single+offsetedge_indices.append(adjusted_edge_index)
edge_index=torch.cat(edge_indices, dim=1).to(td.device)
# Proceed with GCN layers using the adjusted edge_indexforlayerinself.gcn_layers[:-1]:
update_node_feature=layer(update_node_feature, edge_index)
update_node_feature=F.relu(update_node_feature)
update_node_feature=F.dropout(
update_node_feature, training=self.training, p=self.dropout
)
# Last layer without activation and dropoutupdate_node_feature=self.gcn_layers[-1](update_node_feature, edge_index)
# De-batch the graphupdate_node_feature=update_node_feature.view(bs, num_nodes, emb_dim)
# Residual connectionifself.residual:
update_node_feature=update_node_feature+init_hreturnupdate_node_feature, init_h
Explanation:
Adjustment of edge_index:
For each graph in the batch, we add an offset to the edge_index so that node indices correctly map to their positions in the concatenated update_node_feature. Benefit:
This ensures that nodes are only connected within their own graph, preventing unintended cross-graph connections.
GCNConv can now process each graph correctly within the batch.
Checklist
[√] I have checked that there is no similar issue in the repo (required)
[√] I have provided a minimal working example to reproduce the bug (required)
The text was updated successfully, but these errors were encountered:
It seems that you are indeed correct - the GCN encoder should be adjusted based on the application, for example, this one. The default value should be changed to do message passing properly (cc: @cbhua@LTluttmann )
PS: here is another fix with Batch using the same indexing
importtorchfromtorch_geometric.nnimportGCNConvfromtorch_geometric.dataimportData, Batchfromfunctoolsimportlru_cache# Function to generate full graph edge index (as in the original code)@lru_cache(5)defget_full_graph_edge_index(num_node: int, self_loop=False) ->torch.Tensor:
adj_matrix=torch.ones(num_node, num_node)
ifnotself_loop:
adj_matrix.fill_diagonal_(0)
edge_index=torch.permute(torch.nonzero(adj_matrix), (1, 0))
returnedge_index# Parametersnum_nodes_per_graph=5num_node_features=3batch_size=2# Number of graphs in the batch# Create node features for two graphsx1=torch.randn(num_nodes_per_graph, num_node_features)
x2=torch.randn(num_nodes_per_graph, num_node_features)
# Concatenate node features to simulate batchingx=torch.cat([x1, x2], dim=0) # Shape: [batch_size * num_nodes_per_graph, num_node_features]# Generate edge_index using the original functionedge_index_single=get_full_graph_edge_index(num_nodes_per_graph, self_loop=False)
# Use the same edge_index for both graphs without adjustingedge_index=edge_index_single# Edge indices from graph1, node indices 0 to 4# Initialize GCNConvconv=GCNConv(in_channels=num_node_features, out_channels=2, add_self_loops=False)
# Apply GCNConv without adjusting edge_index# print(x.shape, edge_index.shape)out=conv(x, edge_index)
# Print the outputprint("Output without adjusting edge_index:")
print(out)
# Now adjust edge_index for each graph to reflect batchingedge_indices= []
foriinrange(batch_size):
offset=i*num_nodes_per_graphadjusted_edge_index=edge_index_single+offsetedge_indices.append(adjusted_edge_index)
edge_index_adjusted=torch.cat(edge_indices, dim=1)
# Apply GCNConv with adjusted edge_indexout_adjusted=conv(x, edge_index_adjusted)
# Create Data objects for each graphdata1=Data(x=x1, edge_index=edge_index_single)
data2=Data(x=x2, edge_index=edge_index_single)
# Batch the graphsbatch=Batch.from_data_list([data1, data2])
# Apply GCNConv on batched dataout_batched=conv(batch.x, batch.edge_index)
# Compare outputs for nodes belonging to each graphprint("\n"+20*"="+"Fix 1"+20*"=")
print("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] -out_adjusted[:num_nodes_per_graph])
print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] -out_adjusted[num_nodes_per_graph:])
print("\n"+20*"="+"Fix 2"+20*"=")
print("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] -out_batched[:num_nodes_per_graph])
print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] -out_batched[num_nodes_per_graph:])
Describe the bug
In
rl4co/models/nn/graph/gcn.py
, theGCNEncoder
class processes batched node embeddings without adjusting the edge_index for each graph in the batch. This leads to incorrect message passing inGCNConv
, where nodes from different graphs may be incorrectly connected or ignored, resulting in unintended behavior of the model.Issue in Code:
Problem: edge_index is generated for a single graph and not adjusted for batching.
Consequence: When node features from multiple graphs are concatenated, GCNConv incorrectly processes nodes, only properly updating nodes of the first graph.
To Reproduce
Minimal Example to Reproduce the Bug:
Explanation:
Edge Index Generation: We use the
get_full_graph_edge_index
function from the original code to generate the edge indices for a fully connected graph without self-loops.Without Adjusted edge_index:
The
edge_index
only references nodes 0 to 4.Nodes 5 to 9 (from the second graph) are not properly connected and thus not processed correctly by
GCNConv
.With Adjusted edge_index:
By adding an offset to
edge_index
for each graph in the batch, nodes are correctly connected within their own graphs.The outputs for nodes in the second graph are now properly updated based on their graph structure.
Observation:
The differences in outputs for nodes belonging to the second graph indicate that without adjusting edge_index, those nodes were not processed correctly.
Additional context
The issue arises because in the
GCNEncoder
, theedge_index
is generated once for a single graph and used for all graphs in the batch without adjustment. This causes incorrect node indexing when processing batched graphs, leading to unintended behavior.Reason and Possible fixes
Reason:
The
edge_index
is not adjusted for batched graphs, causing nodes from different graphs to be incorrectly connected or not processed correctly.In batched graphs, node indices in edge_index need to be offset for each graph to reflect their positions in the concatenated node feature tensor.
Possible Fix:
Adjust the
edge_index
for each graph in the batch by adding an offset based on the cumulative number of nodes. Here's a modification to the GCNEncoder class:Explanation:
Adjustment of edge_index:
For each graph in the batch, we add an offset to the edge_index so that node indices correctly map to their positions in the concatenated update_node_feature.
Benefit:
This ensures that nodes are only connected within their own graph, preventing unintended cross-graph connections.
GCNConv can now process each graph correctly within the batch.
Checklist
The text was updated successfully, but these errors were encountered: