From 1bef748fab064e2fc3beddcbda60fd51cb9612d2 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Fri, 8 Nov 2024 13:26:54 -0800 Subject: [PATCH] [doc][c10d] fixup fsdp tutorial (#1297) Summary: Fix up the FSDP tutorial to get it functional again. 1. Add missing import for load_dataset. 2. Use `checkpoint` instead of `_shard.checkpoint` to get rid of a warning. 3. Add nlp to requirements.txt 4. Get rid of `load_metric` as this function does not exist in new `datasets` module. 5. Add `legacy=False` to get rid of tokenizer warnings. Test Plan: Ran the tutorial as follows and ensured that it ran successfully: ``` torchrun --nnodes=1 --nproc_per_node=2 T5_training.py W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] ***************************************** W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] ***************************************** dict_keys(['train', 'validation', 'test']) Size of train dataset: (157252, 3) Size of Validation dataset: (5599, 3) dict_keys(['train', 'validation', 'test']) Size of train dataset: (157252, 3) Size of Validation dataset: (5599, 3) bFloat16 enabled for mixed precision - using bfSixteen policy ``` --- distributed/FSDP/T5_training.py | 23 +++++------ .../model_checkpointing/checkpoint_handler.py | 16 ++++---- distributed/FSDP/requirements.txt | 1 + distributed/FSDP/summarization_dataset.py | 39 +++++++++---------- distributed/FSDP/utils/train_utils.py | 4 +- 5 files changed, 42 insertions(+), 41 deletions(-) diff --git a/distributed/FSDP/T5_training.py b/distributed/FSDP/T5_training.py index 4ab136eace..762e70c436 100644 --- a/distributed/FSDP/T5_training.py +++ b/distributed/FSDP/T5_training.py @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from transformers.models.t5.modeling_t5 import T5Block +from nlp import load_dataset from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, @@ -86,11 +87,11 @@ def fsdp_main(args): print("Size of train dataset: ", dataset['train'].shape) print("Size of Validation dataset: ", dataset['validation'].shape) - + #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False) - train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) + train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False) - + sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) @@ -107,12 +108,12 @@ def fsdp_main(args): train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs) - + torch.cuda.set_device(local_rank) - + # Set up FSDP parameters mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank) - + # Apply FSDP wrapping to the model model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, @@ -120,7 +121,7 @@ def fsdp_main(args): sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=fsdp_config.limit_all_gathers) - + # Enabling this causes https://github.com/pytorch/examples/issues/1210 if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) @@ -150,7 +151,7 @@ def fsdp_main(args): if args.run_validation: curr_val_loss = validation(model, rank, world_size, val_loader) scheduler.step() - + if rank == 0: print(f"--> epoch {epoch} completed...entering save and stats zone") @@ -170,7 +171,7 @@ def fsdp_main(args): ) if train_config.save_model and curr_val_loss < best_val_loss: - + if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: model_checkpointing.save_model_checkpoint( model, optimizer, rank, fsdp_config, epoch=1 @@ -183,7 +184,7 @@ def fsdp_main(args): if fsdp_config.save_optimizer: model_checkpointing.save_optimizer_checkpoint( model, optimizer, rank, fsdp_config, epoch=1 - ) + ) if curr_val_loss < best_val_loss: best_val_loss = curr_val_loss @@ -212,5 +213,5 @@ def fsdp_main(args): args = parser.parse_args() torch.manual_seed(args.seed) - + fsdp_main(args) diff --git a/distributed/FSDP/model_checkpointing/checkpoint_handler.py b/distributed/FSDP/model_checkpointing/checkpoint_handler.py index 5f6858476f..5d2ea84695 100644 --- a/distributed/FSDP/model_checkpointing/checkpoint_handler.py +++ b/distributed/FSDP/model_checkpointing/checkpoint_handler.py @@ -11,7 +11,7 @@ # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. ) -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, save_state_dict, @@ -24,7 +24,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType -import torch.distributed._shard.checkpoint as dist_cp +import torch.distributed.checkpoint as dist_cp import torch.distributed as dist @@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg, verbose=True): if rank == 0: ck = checkpoint.keys() print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - + dist_cp.load_state_dict( state_dict=checkpoint, storage_reader=reader, @@ -108,7 +108,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True): state_dict=state_dict, storage_writer=distributed_writer, planner=DefaultSavePlanner(), - + ) dist.barrier() t1 = time.perf_counter() @@ -117,7 +117,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True): print( f"Checkpoint Time = {t1-t0:.4f}\n using {cfg.save_using_num_threads=} total threads" ) - + def save_model_checkpoint( model, optimizer, @@ -138,7 +138,7 @@ def save_model_checkpoint( if cfg.verbose: print(f"saving process: rank {rank} done w model state_dict\n") - + if rank == 0: print(f"--> saving model ...") @@ -153,7 +153,7 @@ def save_model_checkpoint( if cfg.verbose: print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") - + def load_model_checkpoint(model, rank, cfg, verbose=True): @@ -299,7 +299,7 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1): StateDictType.LOCAL_STATE_DICT, ): state_dict = model.state_dict() - + # write out distributed checkpoint save_state_dict(state_dict, writer) diff --git a/distributed/FSDP/requirements.txt b/distributed/FSDP/requirements.txt index a59c5bacb2..904bf752db 100644 --- a/distributed/FSDP/requirements.txt +++ b/distributed/FSDP/requirements.txt @@ -3,3 +3,4 @@ datasets tqdm protobuf SentencePiece +nlp diff --git a/distributed/FSDP/summarization_dataset.py b/distributed/FSDP/summarization_dataset.py index 679ea48ec0..b9854e4e7f 100644 --- a/distributed/FSDP/summarization_dataset.py +++ b/distributed/FSDP/summarization_dataset.py @@ -14,8 +14,7 @@ import torch from torch.utils.data import Dataset, DataLoader -from datasets import load_dataset, load_metric - +from nlp import load_dataset from transformers import ( AdamW, @@ -25,7 +24,7 @@ ) class wikihow(Dataset): - def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False): + def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False): self.dataset = load_dataset('wikihow', 'all', data_dir='data/', split=type_path) if num_samples: self.dataset = self.dataset.select(list(range(0, num_samples))) @@ -33,43 +32,43 @@ def __init__(self, tokenizer, type_path, num_samples, input_length, output_lengt self.tokenizer = tokenizer self.output_length = output_length self.print_text = print_text - + def __len__(self): return self.dataset.shape[0] - + def clean_text(self, text): text = text.replace('Example of text:', '') text = text.replace('Example of Summary:', '') text = text.replace('\n','') text = text.replace('``', '') text = text.replace('"', '') - + return text - - + + def convert_to_features(self, example_batch): # Tokenize contexts and questions (as pairs of inputs) - + if self.print_text: print("Input Text: ", self.clean_text(example_batch['text'])) # input_ = self.clean_text(example_batch['text']) + " " # target_ = self.clean_text(example_batch['headline']) + " " - + input_ = self.clean_text(example_batch['text']) target_ = self.clean_text(example_batch['headline']) - - source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length, + + source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length, padding='max_length', truncation=True, return_tensors="pt") - - targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length, + + targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length, padding='max_length', truncation=True, return_tensors="pt") - - + + return source, targets - + def __getitem__(self, index): source, targets = self.convert_to_features(self.dataset[index]) - + source_ids = source["input_ids"].squeeze() target_ids = targets["input_ids"].squeeze() @@ -77,7 +76,7 @@ def __getitem__(self, index): target_mask = targets["attention_mask"].squeeze() return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask} - + def get_dataset(tokenizer, type_path, num_samples, args): - return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length, + return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length, output_length=max_output_length) diff --git a/distributed/FSDP/utils/train_utils.py b/distributed/FSDP/utils/train_utils.py index 24cf239e7c..60e5593ec7 100644 --- a/distributed/FSDP/utils/train_utils.py +++ b/distributed/FSDP/utils/train_utils.py @@ -36,7 +36,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler model.train() local_rank = int(os.environ['LOCAL_RANK']) fsdp_loss = torch.zeros(2).to(local_rank) - + if sampler: sampler.set_epoch(epoch) if rank==0: @@ -98,5 +98,5 @@ def validation(model, rank, world_size, val_loader): def setup_model(model_name): model = T5ForConditionalGeneration.from_pretrained(model_name) - tokenizer = T5Tokenizer.from_pretrained(model_name) + tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False) return model, tokenizer