Skip to content

Commit

Permalink
use chunked dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfitzek committed Aug 22, 2023
1 parent 106826f commit 09275bd
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.strategies import DDPStrategy

from rydberggpt.data.loading.rydberg_dataset import get_rydberg_dataloader
from rydberggpt.data.loading.rydberg_dataset_chunked import get_chunked_dataloader
from rydberggpt.models.rydberg_encoder_decoder import get_rydberg_graph_encoder_decoder
from rydberggpt.training.callbacks.module_info_callback import ModelInfoCallback
from rydberggpt.training.callbacks.stop_on_loss_threshold_callback import (
Expand Down Expand Up @@ -43,7 +44,7 @@ def main(config_path: str, config_name: str, dataset_path: str):
config.device = device

# https://lightning.ai/docs/pytorch/stable/data/datamodule.html
train_loader, val_loader = get_rydberg_dataloader(
train_loader, val_loader = get_chunked_dataloader(
config.batch_size,
test_size=0.2,
num_workers=config.num_workers,
Expand Down

0 comments on commit 09275bd

Please sign in to comment.