diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py index f8e0132bb..8593c1045 100755 --- a/llava/train/llava_trainer.py +++ b/llava/train/llava_trainer.py @@ -2,31 +2,129 @@ import torch import torch.nn as nn import datetime +import math +import sys +import time +import shutil +import numpy as np from accelerate import Accelerator -from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin -from torch.utils.data import Dataset, Sampler, DataLoader +from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin, DistributedType +from accelerate.data_loader import SeedableRandomSampler +from torch.utils.data import Dataset, Sampler, DataLoader, RandomSampler +from torch import distributed as dist +from packaging import version from trl.trainer import DPOTrainer from trl.trainer.utils import DPODataCollatorWithPadding from transformers import Trainer -from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, is_accelerate_available, is_datasets_available, GradientAccumulationPlugin -from transformers.trainer_utils import seed_worker +from transformers.trainer import ( + is_sagemaker_mp_enabled, + get_parameter_names, + has_length, + logger, + is_accelerate_available, + is_datasets_available, + TrainerState, + TRAINER_STATE_NAME, + get_model_param_count, + DebugOption, + DebugUnderflowOverflow, + ParallelMode, + deepspeed_init, + deepspeed_load_checkpoint, + _is_peft_model, +) +from transformers.trainer_utils import seed_worker, speed_metrics, HPSearchBackend, TrainOutput from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf -from transformers.trainer_pt_utils import AcceleratorConfig +from transformers.trainer_pt_utils import AcceleratorConfig, get_dataloader_sampler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.integrations import hp_params from typing import List, Optional from datetime import timedelta if is_accelerate_available(): from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs + import accelerate + + accelerate_version = accelerate.__version__ if is_datasets_available(): import datasets +# Optional TPU/XLA imports +try: + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + from torch_xla.distributed.spmd import XlaShardedTensorCheckpointHandler + from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2 + # from torch_xla.amp import syncfree # Removed unused import + + is_torch_xla_available = lambda: True +except ImportError: + xm = None + met = None + is_torch_xla_available = lambda: False + +# Optional SageMaker imports +try: + import smdistributed.modelparallel.torch as smp +except ImportError: + smp = None + +# Optional function for TPU +try: + from transformers.trainer_pt_utils import tpu_spmd_dataloader +except ImportError: + tpu_spmd_dataloader = None + from llava.utils import rank0_print +def plot_graphs_based_on_log_history(log_history, output_dir, metrics): + """ + Plot graphs based on log history. This is a stub function. + + Args: + log_history: Training log history + output_dir: Output directory for plots + metrics: List of metrics to plot + """ + try: + import matplotlib.pyplot as plt + + if not log_history: + return + + os.makedirs(output_dir, exist_ok=True) + + for metric in metrics: + values = [] + steps = [] + for log in log_history: + if metric in log: + values.append(log[metric]) + steps.append(log.get("step", len(steps))) + + if values: + plt.figure(figsize=(10, 6)) + plt.plot(steps, values) + plt.xlabel("Step") + plt.ylabel(metric) + plt.title(f"{metric} over training") + plt.savefig(os.path.join(output_dir, f"{metric}.png")) + plt.close() + + except ImportError: + # matplotlib not available, skip plotting + pass + except Exception as e: + # Don't fail training if plotting fails + logger.warning(f"Failed to plot graphs: {e}") + + def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -290,10 +388,7 @@ def zo_perturb_parameters(self, scaling_factor=1.0): """ torch.manual_seed(self.zo_random_seed) for name, param in self.trainable_params: - z = torch.normal( - mean=0, std=1, size=param.data.size(), - device=param.device, dtype=param.dtype - ) + z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.device, dtype=param.dtype) param.data += scaling_factor * z * self.zo_eps def zo_forward(self, model, inputs): @@ -385,31 +480,19 @@ def zo_update(self, learning_rate): torch.manual_seed(seed) for name, param in self.trainable_params: - z = torch.normal( - mean=0, std=1, size=param.data.size(), - device=param.device, dtype=param.dtype - ) + z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.device, dtype=param.dtype) if "bias" not in name and "layer_norm" not in name and "layernorm" not in name: param.data -= learning_rate * (grad_sum * z + self.args.weight_decay * param.data) else: param.data -= learning_rate * grad_sum * z - print( - f"Applied MeZO update with aggregated grad estimates " - f"from {len(self.zo_direction_accumulator)} directions " - f"(aggregated to {len(seed_group)} unique seeds)." - f"Seed group: {seed_group}") + print(f"Applied MeZO update with aggregated grad estimates " f"from {len(self.zo_direction_accumulator)} directions " f"(aggregated to {len(seed_group)} unique seeds)." f"Seed group: {seed_group}") self.zo_direction_accumulator = [] self.zo_accumulation_count = 0 self.batch_zo_seeds = None - update_entry = { - "type": "mezo_update", - "global_step": self.state.global_step, - "learning_rate": learning_rate, - "seed_group": seed_group - } + update_entry = {"type": "mezo_update", "global_step": self.state.global_step, "learning_rate": learning_rate, "seed_group": seed_group} self.mezo_update_history.append(update_entry) @@ -423,7 +506,7 @@ def _save_mezo_state(self, output_dir: str): "zo_num_directions": self.zo_num_directions, "trainable_params_names": [name for name, _ in self.trainable_params], "trainable_params_sizes": {name: param.size() for name, param in self.trainable_params}, - "update_history": self.mezo_update_history + "update_history": self.mezo_update_history, } mezo_checkpoint_path = os.path.join(output_dir, "mezo_state.pt") torch.save(mezo_state, mezo_checkpoint_path) @@ -675,9 +758,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): else: super(LLaVATrainer, self)._save(output_dir, state_dict) - def _inner_training_loop( - self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None - ): + def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None): if args is None: args = self.args @@ -720,16 +801,12 @@ def _inner_training_loop( num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps - num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( - args.max_steps % num_update_steps_per_epoch > 0 - ) + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(args.max_steps % num_update_steps_per_epoch > 0) # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size if args.include_tokens_per_second: - num_train_tokens = ( - self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps - ) + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) @@ -746,19 +823,13 @@ def _inner_training_loop( if args.include_tokens_per_second: num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps else: - raise ValueError( - "args.max_steps must be set to a positive value if dataloader does not have a length, was" - f" {args.max_steps}" - ) + raise ValueError("args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}") if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module - raise ValueError( - "Currently --debug underflow_overflow is not supported under DP. Please use DDP" - " (torchrun or torch.distributed.launch (deprecated))." - ) + raise ValueError("Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torchrun or torch.distributed.launch (deprecated)).") else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa @@ -838,9 +909,7 @@ def _inner_training_loop( model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. - model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.lr_scheduler - ) + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.model, self.optimizer, self.lr_scheduler) if self.is_fsdp_enabled: self.model = self.model_wrapped = model @@ -856,9 +925,7 @@ def _inner_training_loop( # ckpt loading if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: - deepspeed_load_checkpoint( - self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) - ) + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)) elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) @@ -889,9 +956,7 @@ def _inner_training_loop( steps_trained_progress_bar = None # Check if continuing training from a checkpoint - if resume_from_checkpoint is not None and os.path.isfile( - os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) - ): + if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) epochs_trained = self.state.global_step // num_update_steps_per_epoch @@ -905,10 +970,7 @@ def _inner_training_loop( logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first" - f" {steps_trained_in_current_epoch} batches in the first epoch." - ) + logger.info(f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch.") # Update the references self.callback_handler.model = self.model @@ -969,11 +1031,7 @@ def _inner_training_loop( if args.past_index >= 0: self._past = None - steps_in_epoch = ( - len(epoch_iterator) - if len_dataloader is not None - else args.max_steps * args.gradient_accumulation_steps - ) + steps_in_epoch = len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: @@ -1001,11 +1059,7 @@ def _inner_training_loop( ) else: input_device = inputs[main_input_name].device - self.state.num_input_tokens_seen += torch.sum( - self.accelerator.gather( - torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) - ) - ).item() + self.state.num_input_tokens_seen += torch.sum(self.accelerator.gather(torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64))).item() if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False @@ -1043,25 +1097,17 @@ def _inner_training_loop( ######################## # Accumulate loss - if ( - args.logging_nan_inf_filter - and not is_torch_xla_available() - and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) - ): + if args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)): # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: if tr_loss.device != tr_loss_step.device: - raise ValueError( - f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" - ) + raise ValueError(f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}") tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) + is_last_step_and_steps_less_than_grad_acc = steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch if ( total_batched_samples % args.gradient_accumulation_steps == 0 @@ -1102,10 +1148,7 @@ def _inner_training_loop( args.max_grad_norm, ) - if ( - is_accelerate_available() - and self.accelerator.distributed_type == DistributedType.DEEPSPEED - ): + if is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED: grad_norm = model.get_global_grad_norm() # In some cases the grad norm may not return a float if hasattr(grad_norm, "item"): @@ -1155,10 +1198,7 @@ def _inner_training_loop( # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: - logger.warning( - "You enabled PyTorch/XLA debug metrics but you don't have a TPU " - "configured. Check your training configuration if this is unexpected." - ) + logger.warning("You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected.") if self.control.should_training_stop: break