diff --git a/src/rydberggpt/data/dataclasses.py b/src/rydberggpt/data/dataclasses.py index 2b806462..7d559727 100644 --- a/src/rydberggpt/data/dataclasses.py +++ b/src/rydberggpt/data/dataclasses.py @@ -46,6 +46,10 @@ def custom_collate(batch: List[Batch]) -> Batch: graph_batch = PyGBatch.from_data_list([b.graph for b in batch]) + # NOTE: The graphs, and measurement data are not of the same size. To ensure + # a padded tensor suitable for the neural network, we use the to_dense_batch function. This ensures that our + # data is padded with zeros. + # see: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/to_dense_batch.html batch = Batch( graph=graph_batch, m_onehot=to_dense_batch(