Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 124 additions & 84 deletions llava/train/llava_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,129 @@
import torch
import torch.nn as nn
import datetime
import math
import sys
import time
import shutil
import numpy as np

from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin
from torch.utils.data import Dataset, Sampler, DataLoader
from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin, DistributedType
from accelerate.data_loader import SeedableRandomSampler
from torch.utils.data import Dataset, Sampler, DataLoader, RandomSampler
from torch import distributed as dist
from packaging import version

from trl.trainer import DPOTrainer
from trl.trainer.utils import DPODataCollatorWithPadding

from transformers import Trainer
from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, is_accelerate_available, is_datasets_available, GradientAccumulationPlugin
from transformers.trainer_utils import seed_worker
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
logger,
is_accelerate_available,
is_datasets_available,
TrainerState,
TRAINER_STATE_NAME,
get_model_param_count,
DebugOption,
DebugUnderflowOverflow,
ParallelMode,
deepspeed_init,
deepspeed_load_checkpoint,
_is_peft_model,
)
from transformers.trainer_utils import seed_worker, speed_metrics, HPSearchBackend, TrainOutput
from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.trainer_pt_utils import AcceleratorConfig, get_dataloader_sampler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.integrations import hp_params
from typing import List, Optional
from datetime import timedelta

if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs
import accelerate

accelerate_version = accelerate.__version__

if is_datasets_available():
import datasets

# Optional TPU/XLA imports
try:
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
from torch_xla.distributed.spmd import XlaShardedTensorCheckpointHandler
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
# from torch_xla.amp import syncfree # Removed unused import

is_torch_xla_available = lambda: True
except ImportError:
xm = None
met = None
is_torch_xla_available = lambda: False

# Optional SageMaker imports
try:
import smdistributed.modelparallel.torch as smp
except ImportError:
smp = None

# Optional function for TPU
try:
from transformers.trainer_pt_utils import tpu_spmd_dataloader
except ImportError:
tpu_spmd_dataloader = None

from llava.utils import rank0_print


def plot_graphs_based_on_log_history(log_history, output_dir, metrics):
"""
Plot graphs based on log history. This is a stub function.

Args:
log_history: Training log history
output_dir: Output directory for plots
metrics: List of metrics to plot
"""
try:
import matplotlib.pyplot as plt

if not log_history:
return

os.makedirs(output_dir, exist_ok=True)

for metric in metrics:
values = []
steps = []
for log in log_history:
if metric in log:
values.append(log[metric])
steps.append(log.get("step", len(steps)))

if values:
plt.figure(figsize=(10, 6))
plt.plot(steps, values)
plt.xlabel("Step")
plt.ylabel(metric)
plt.title(f"{metric} over training")
plt.savefig(os.path.join(output_dir, f"{metric}.png"))
plt.close()

except ImportError:
# matplotlib not available, skip plotting
pass
except Exception as e:
# Don't fail training if plotting fails
logger.warning(f"Failed to plot graphs: {e}")


def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
Expand Down Expand Up @@ -290,10 +388,7 @@ def zo_perturb_parameters(self, scaling_factor=1.0):
"""
torch.manual_seed(self.zo_random_seed)
for name, param in self.trainable_params:
z = torch.normal(
mean=0, std=1, size=param.data.size(),
device=param.device, dtype=param.dtype
)
z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.device, dtype=param.dtype)
param.data += scaling_factor * z * self.zo_eps

def zo_forward(self, model, inputs):
Expand Down Expand Up @@ -385,31 +480,19 @@ def zo_update(self, learning_rate):
torch.manual_seed(seed)

for name, param in self.trainable_params:
z = torch.normal(
mean=0, std=1, size=param.data.size(),
device=param.device, dtype=param.dtype
)
z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.device, dtype=param.dtype)
if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
param.data -= learning_rate * (grad_sum * z + self.args.weight_decay * param.data)
else:
param.data -= learning_rate * grad_sum * z

print(
f"Applied MeZO update with aggregated grad estimates "
f"from {len(self.zo_direction_accumulator)} directions "
f"(aggregated to {len(seed_group)} unique seeds)."
f"Seed group: {seed_group}")
print(f"Applied MeZO update with aggregated grad estimates " f"from {len(self.zo_direction_accumulator)} directions " f"(aggregated to {len(seed_group)} unique seeds)." f"Seed group: {seed_group}")

self.zo_direction_accumulator = []
self.zo_accumulation_count = 0
self.batch_zo_seeds = None

update_entry = {
"type": "mezo_update",
"global_step": self.state.global_step,
"learning_rate": learning_rate,
"seed_group": seed_group
}
update_entry = {"type": "mezo_update", "global_step": self.state.global_step, "learning_rate": learning_rate, "seed_group": seed_group}

self.mezo_update_history.append(update_entry)

Expand All @@ -423,7 +506,7 @@ def _save_mezo_state(self, output_dir: str):
"zo_num_directions": self.zo_num_directions,
"trainable_params_names": [name for name, _ in self.trainable_params],
"trainable_params_sizes": {name: param.size() for name, param in self.trainable_params},
"update_history": self.mezo_update_history
"update_history": self.mezo_update_history,
}
mezo_checkpoint_path = os.path.join(output_dir, "mezo_state.pt")
torch.save(mezo_state, mezo_checkpoint_path)
Expand Down Expand Up @@ -675,9 +758,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None):
if args is None:
args = self.args

Expand Down Expand Up @@ -720,16 +801,12 @@ def _inner_training_loop(
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(args.max_steps % num_update_steps_per_epoch > 0)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
)
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
Expand All @@ -746,19 +823,13 @@ def _inner_training_loop(
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
raise ValueError("args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}")

if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
" (torchrun or torch.distributed.launch (deprecated))."
)
raise ValueError("Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torchrun or torch.distributed.launch (deprecated)).")
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

Expand Down Expand Up @@ -838,9 +909,7 @@ def _inner_training_loop(
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.model, self.optimizer, self.lr_scheduler)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
Expand All @@ -856,9 +925,7 @@ def _inner_training_loop(
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
)
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model))
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

Expand Down Expand Up @@ -889,9 +956,7 @@ def _inner_training_loop(
steps_trained_progress_bar = None

# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
epochs_trained = self.state.global_step // num_update_steps_per_epoch
Expand All @@ -905,10 +970,7 @@ def _inner_training_loop(
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
logger.info(f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch.")

# Update the references
self.callback_handler.model = self.model
Expand Down Expand Up @@ -969,11 +1031,7 @@ def _inner_training_loop(
if args.past_index >= 0:
self._past = None

steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
steps_in_epoch = len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
Expand Down Expand Up @@ -1001,11 +1059,7 @@ def _inner_training_loop(
)
else:
input_device = inputs[main_input_name].device
self.state.num_input_tokens_seen += torch.sum(
self.accelerator.gather(
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
)
).item()
self.state.num_input_tokens_seen += torch.sum(self.accelerator.gather(torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64))).item()
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
Expand Down Expand Up @@ -1043,25 +1097,17 @@ def _inner_training_loop(
########################

# Accumulate loss
if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
if args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
)
raise ValueError(f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}")
tr_loss += tr_loss_step

self.current_flos += float(self.floating_point_ops(inputs))

is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
is_last_step_and_steps_less_than_grad_acc = steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch

if (
total_batched_samples % args.gradient_accumulation_steps == 0
Expand Down Expand Up @@ -1102,10 +1148,7 @@ def _inner_training_loop(
args.max_grad_norm,
)

if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
if is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED:
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
Expand Down Expand Up @@ -1155,10 +1198,7 @@ def _inner_training_loop(
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
logger.warning("You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected.")
if self.control.should_training_stop:
break

Expand Down