Skip to content

Commit

Permalink
add comment to custom_collate fn
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfitzek committed Aug 22, 2023
1 parent dfb26df commit f47414f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/rydberggpt/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def custom_collate(batch: List[Batch]) -> Batch:

graph_batch = PyGBatch.from_data_list([b.graph for b in batch])

# NOTE: The graphs, and measurement data are not of the same size. To ensure
# a padded tensor suitable for the neural network, we use the to_dense_batch function. This ensures that our
# data is padded with zeros.
# see: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/to_dense_batch.html
batch = Batch(
graph=graph_batch,
m_onehot=to_dense_batch(
Expand Down

0 comments on commit f47414f

Please sign in to comment.