Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] GCNEncoder in rl4co fails to adjust edge_index for batched graphs #227

Open
OceanHWang opened this issue Oct 17, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@OceanHWang
Copy link

Describe the bug

In rl4co/models/nn/graph/gcn.py, the GCNEncoder class processes batched node embeddings without adjusting the edge_index for each graph in the batch. This leads to incorrect message passing in GCNConv, where nodes from different graphs may be incorrectly connected or ignored, resulting in unintended behavior of the model.

Issue in Code:

# In rl4co/models/nn/graph/gcn.py, inside the GCNEncoder class
def forward(self, td):
    # Transfer to embedding space
    init_h = self.init_embedding(td)
    bs, num_nodes, emb_dim = init_h.shape
    # (bs*num_nodes, emb_dim)
    update_node_feature = init_h.reshape(-1, emb_dim)  # Flatten batch
    # shape=(2, num_edges)
    edge_index = self.edge_idx_fn(td, num_nodes)  # Edge index for a single graph

    for layer in self.gcn_layers[:-1]:
        update_node_feature = layer(update_node_feature, edge_index)
        # ...

Problem: edge_index is generated for a single graph and not adjusted for batching.
Consequence: When node features from multiple graphs are concatenated, GCNConv incorrectly processes nodes, only properly updating nodes of the first graph.

To Reproduce

Minimal Example to Reproduce the Bug:

import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from functools import lru_cache

# Function to generate full graph edge index (as in the original code)
@lru_cache(5)
def get_full_graph_edge_index(num_node: int, self_loop=False) -> torch.Tensor:
    adj_matrix = torch.ones(num_node, num_node)
    if not self_loop:
        adj_matrix.fill_diagonal_(0)
    edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0))
    return edge_index

# Parameters
num_nodes_per_graph = 5
num_node_features = 3
batch_size = 2  # Number of graphs in the batch

# Create node features for two graphs
x1 = torch.randn(num_nodes_per_graph, num_node_features)
x2 = torch.randn(num_nodes_per_graph, num_node_features)

# Concatenate node features to simulate batching
x = torch.cat([x1, x2], dim=0)  # Shape: [batch_size * num_nodes_per_graph, num_node_features]

# Generate edge_index using the original function
edge_index_single = get_full_graph_edge_index(num_nodes_per_graph, self_loop=False)

# Use the same edge_index for both graphs without adjusting
edge_index = edge_index_single  # Edge indices from graph1, node indices 0 to 4

# Initialize GCNConv
conv = GCNConv(in_channels=num_node_features, out_channels=2, add_self_loops=False)

# Apply GCNConv without adjusting edge_index
out = conv(x, edge_index)

# Print the output
print("Output without adjusting edge_index:")
print(out)

# Now adjust edge_index for each graph to reflect batching
edge_indices = []
for i in range(batch_size):
    offset = i * num_nodes_per_graph
    adjusted_edge_index = edge_index_single + offset
    edge_indices.append(adjusted_edge_index)
edge_index_adjusted = torch.cat(edge_indices, dim=1)

# Apply GCNConv with adjusted edge_index
out_adjusted = conv(x, edge_index_adjusted)

# Compare outputs for nodes belonging to each graph
print("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] - out_adjusted[:num_nodes_per_graph])

print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] - out_adjusted[num_nodes_per_graph:])

Explanation:

Edge Index Generation: We use the get_full_graph_edge_index function from the original code to generate the edge indices for a fully connected graph without self-loops.
Without Adjusted edge_index:
The edge_index only references nodes 0 to 4.
Nodes 5 to 9 (from the second graph) are not properly connected and thus not processed correctly by GCNConv.
With Adjusted edge_index:
By adding an offset to edge_index for each graph in the batch, nodes are correctly connected within their own graphs.
The outputs for nodes in the second graph are now properly updated based on their graph structure.
Observation:
The differences in outputs for nodes belonging to the second graph indicate that without adjusting edge_index, those nodes were not processed correctly.

Additional context

The issue arises because in the GCNEncoder, the edge_index is generated once for a single graph and used for all graphs in the batch without adjustment. This causes incorrect node indexing when processing batched graphs, leading to unintended behavior.

Reason and Possible fixes

Reason:

The edge_index is not adjusted for batched graphs, causing nodes from different graphs to be incorrectly connected or not processed correctly.
In batched graphs, node indices in edge_index need to be offset for each graph to reflect their positions in the concatenated node feature tensor.

Possible Fix:

Adjust the edge_index for each graph in the batch by adding an offset based on the cumulative number of nodes. Here's a modification to the GCNEncoder class:

def forward(self, td):
    init_h = self.init_embedding(td)
    bs, num_nodes, emb_dim = init_h.shape
    update_node_feature = init_h.reshape(-1, emb_dim)

    # Original edge_index for a single graph
    edge_index_single = self.edge_idx_fn(td, num_nodes)

    # Adjust edge_index for batching
    edge_indices = []
    for i in range(bs):
        offset = i * num_nodes
        adjusted_edge_index = edge_index_single + offset
        edge_indices.append(adjusted_edge_index)
    edge_index = torch.cat(edge_indices, dim=1).to(td.device)

    # Proceed with GCN layers using the adjusted edge_index
    for layer in self.gcn_layers[:-1]:
        update_node_feature = layer(update_node_feature, edge_index)
        update_node_feature = F.relu(update_node_feature)
        update_node_feature = F.dropout(
            update_node_feature, training=self.training, p=self.dropout
        )

    # Last layer without activation and dropout
    update_node_feature = self.gcn_layers[-1](update_node_feature, edge_index)

    # De-batch the graph
    update_node_feature = update_node_feature.view(bs, num_nodes, emb_dim)

    # Residual connection
    if self.residual:
        update_node_feature = update_node_feature + init_h

    return update_node_feature, init_h

Explanation:

Adjustment of edge_index:
For each graph in the batch, we add an offset to the edge_index so that node indices correctly map to their positions in the concatenated update_node_feature.
Benefit:
This ensures that nodes are only connected within their own graph, preventing unintended cross-graph connections.
GCNConv can now process each graph correctly within the batch.

Checklist

  • [√] I have checked that there is no similar issue in the repo (required)
  • [√] I have provided a minimal working example to reproduce the bug (required)
@OceanHWang OceanHWang added the bug Something isn't working label Oct 17, 2024
@OceanHWang OceanHWang changed the title [BUG] [BUG] GCNEncoder in rl4co fails to adjust edge_index for batched graphs Oct 17, 2024
@fedebotu
Copy link
Member

Thanks for reporting!

It seems that you are indeed correct - the GCN encoder should be adjusted based on the application, for example, this one. The default value should be changed to do message passing properly (cc: @cbhua @LTluttmann )

PS: here is another fix with Batch using the same indexing

import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from functools import lru_cache

# Function to generate full graph edge index (as in the original code)
@lru_cache(5)
def get_full_graph_edge_index(num_node: int, self_loop=False) -> torch.Tensor:
    adj_matrix = torch.ones(num_node, num_node)
    if not self_loop:
        adj_matrix.fill_diagonal_(0)
    edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0))
    return edge_index

# Parameters
num_nodes_per_graph = 5
num_node_features = 3
batch_size = 2  # Number of graphs in the batch

# Create node features for two graphs
x1 = torch.randn(num_nodes_per_graph, num_node_features)
x2 = torch.randn(num_nodes_per_graph, num_node_features)

# Concatenate node features to simulate batching
x = torch.cat([x1, x2], dim=0)  # Shape: [batch_size * num_nodes_per_graph, num_node_features]

# Generate edge_index using the original function
edge_index_single = get_full_graph_edge_index(num_nodes_per_graph, self_loop=False)

# Use the same edge_index for both graphs without adjusting
edge_index = edge_index_single  # Edge indices from graph1, node indices 0 to 4

# Initialize GCNConv
conv = GCNConv(in_channels=num_node_features, out_channels=2, add_self_loops=False)

# Apply GCNConv without adjusting edge_index
# print(x.shape, edge_index.shape)
out = conv(x, edge_index)

# Print the output
print("Output without adjusting edge_index:")
print(out)

# Now adjust edge_index for each graph to reflect batching
edge_indices = []
for i in range(batch_size):
    offset = i * num_nodes_per_graph
    adjusted_edge_index = edge_index_single + offset
    edge_indices.append(adjusted_edge_index)
edge_index_adjusted = torch.cat(edge_indices, dim=1)

# Apply GCNConv with adjusted edge_index
out_adjusted = conv(x, edge_index_adjusted)

# Create Data objects for each graph
data1 = Data(x=x1, edge_index=edge_index_single)
data2 = Data(x=x2, edge_index=edge_index_single)

# Batch the graphs
batch = Batch.from_data_list([data1, data2])

# Apply GCNConv on batched data
out_batched = conv(batch.x, batch.edge_index)

# Compare outputs for nodes belonging to each graph

print("\n"+20*"="+"Fix 1"+20*"=")

print("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] - out_adjusted[:num_nodes_per_graph])

print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] - out_adjusted[num_nodes_per_graph:])


print("\n"+20*"="+"Fix 2"+20*"=")

print("\nDifference in outputs for nodes belonging to the first graph:")
print(out[:num_nodes_per_graph] - out_batched[:num_nodes_per_graph])

print("\nDifference in outputs for nodes belonging to the second graph:")
print(out[num_nodes_per_graph:] - out_batched[num_nodes_per_graph:])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants