Skip to content

Commit

Permalink
Merge branch 'trintamaki/multimodal-eval-dataset' into 'main'
Browse files Browse the repository at this point in the history
Use torch dataloader in multimodal evaluation

See merge request ADLR/megatron-lm!2110
  • Loading branch information
jon-barker committed Sep 23, 2024
2 parents 0fd4617 + ede39b8 commit 2065c35
Show file tree
Hide file tree
Showing 2 changed files with 483 additions and 248 deletions.
10 changes: 5 additions & 5 deletions examples/multimodal/dataloader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from dataset_helpers import TaskEncoder, print_error_handler

from megatron.core import mpu
from megatron.core import parallel_state
from megatron.energon import (
LimitDataset,
RepeatDataset,
Expand Down Expand Up @@ -71,9 +71,9 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
worker_debug_path = None
worker_log_level = 0

rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()

worker_config = WorkerConfig(
rank=rank,
Expand All @@ -88,7 +88,7 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
train_dataloader = get_savable_loader(train_ds, worker_config=worker_config)
if args.load is not None:
if getattr(args, "dataloader_save", None):
dp_rank = mpu.get_data_parallel_rank()
dp_rank = parallel_state.get_data_parallel_rank()
data_save_name = get_checkpoint_name(
args.dataloader_save,
args.iteration,
Expand Down
Loading

0 comments on commit 2065c35

Please sign in to comment.