diff --git a/README.md b/README.md index 97f26c7..d26d040 100644 --- a/README.md +++ b/README.md @@ -143,8 +143,74 @@ To start learning, execute the following command. GPU is required for learning; we have tested on Ubuntu 20.04, CUDA 11.7. +# Training (w/o Trainer) + +We offer `train_ds.py`, a training script independent of Hugging Face's `Trainer` class for more flexible learning configurations. +For example, the contents of [projects/opt/exp002_ds.yml](projects/opt/exp002_ds.yml) has the following contents: + +```yaml +training_config: + per_device_train_batch_size: 2 + per_device_eval_batch_size: 2 + gradient_accumulation_steps: 4 + num_train_epochs: 5 + dataloader_num_workers: 16 + learning_rate: 5.0e-5 + output_dir: ./output/ + report_to: "wandb" + zero_stage: 2 + precision: "fp16" + enable_tensorboard: False + seed: 0 + weight_decay: 0. + learning_rate_pretraining_components: 0. + num_warmup_steps: 0. + optim_betas: + - 0.9 + - 0.95 + lr_scheduler_type: "cosine" + gradient_checkpointing: False + cpu_offload: False + + +model_config: + pretrained_path: # None or path to model weight + model_type: git_llm + language_model_name: facebook/opt-125m + vision_model_name: openai/clip-vit-base-patch16 + num_image_with_embedding: 1 # if 1, no img_temporal_embedding + max_length: 512 + keys_to_finetune: + - visual_projection + - num_image_with_embedding + keys_to_freeze: [] + + # TODO: support LoRA + # use_lora: false + # lora: + # r: 8 + # lora_alpha: 32 + # target_modules: + # - q_proj + # - k_proj + # - v_proj + # lora_dropout: 0.01 + # bias: none + # task_type: CAUSAL_LM + +dataset_config_path: + - ./configs/datasets/m3it_coco.yaml # only coco dataset +``` + +To start learning, execute the following command. + +```bash +./scripts/run_ds.sh +``` + # Evaluation +If you have the model trained by ZeRO-3 You can get the pretrained weight form Hugging Face Hub: [turing-motors/heron-chat-git-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ja-stablelm-base-7b-v0)
See also [notebooks](./notebooks). @@ -197,6 +263,26 @@ with torch.no_grad(): print(processor.tokenizer.batch_decode(out)[0]) ``` +If you have a model trained using ZeRO-3, it must be modified as follows: + +```diff +- # prepare a pretrained model +- model = GitLlamaForCausalLM.from_pretrained( +- 'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16 +- ) ++ from heron.models.utils import load_model, load_pretrained_weight ++ import yaml ++ ++ config_file = f"./projects/opt/exp002_ds.yml" ++ ++ # get config ++ with open(config_file, "r") as i_: ++ config = yaml.safe_load(i_) ++ ++ model = load_model(config["model_config"]) ++ model.load_state_dict(torch.load('./output/opt/exp002_ds/epoch-1/pytorch_model.bin'), strict=True) +``` + ### Pretrained Models |model|LLM module|adapter|size| @@ -225,3 +311,4 @@ Released under the [Apache License 2.0](./LICENSE). - [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text): The main idia of the model is based on original GIT. - [Llava](https://github.com/haotian-liu/LLaVA): This project is learned a lot from the great Llava project. - [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM) +- [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) diff --git a/configs/datasets/m3it_coco.yaml b/configs/datasets/m3it_coco.yaml new file mode 100644 index 0000000..6ebc917 --- /dev/null +++ b/configs/datasets/m3it_coco.yaml @@ -0,0 +1,3 @@ +dataset_type: m3it +dataset_names: + - coco diff --git a/docs/README_CN.md b/docs/README_CN.md index 8f710d5..92daaef 100644 --- a/docs/README_CN.md +++ b/docs/README_CN.md @@ -143,6 +143,70 @@ training_config "为训练设置, "model_config "为模型设置,"dataset_conf 学习需要 GPU;我们在 Ubuntu 20.04 和 CUDA 11.7 上对系统进行了测试. +# 学习方法 (不含 Trainer) +我们提供 `train_ds.py` ,一个独立于Hugging Face训练师类的训练脚本,用于更灵活的学习配置。例如,[projects/opt/exp002_ds.yml](../projects/opt/exp002_ds.yml) 的内容如下: + +```yaml +training_config: + per_device_train_batch_size: 2 + per_device_eval_batch_size: 2 + gradient_accumulation_steps: 4 + num_train_epochs: 5 + dataloader_num_workers: 16 + learning_rate: 5.0e-5 + output_dir: ./output/ + report_to: "wandb" + zero_stage: 2 + precision: "fp16" + enable_tensorboard: False + seed: 0 + weight_decay: 0. + learning_rate_pretraining_components: 0. + num_warmup_steps: 0. + optim_betas: + - 0.9 + - 0.95 + lr_scheduler_type: "cosine" + gradient_checkpointing: False + cpu_offload: False + + +model_config: + pretrained_path: # None or path to model weight + model_type: git_llm + language_model_name: facebook/opt-125m + vision_model_name: openai/clip-vit-base-patch16 + num_image_with_embedding: 1 # if 1, no img_temporal_embedding + max_length: 512 + keys_to_finetune: + - visual_projection + - num_image_with_embedding + keys_to_freeze: [] + + # TODO: support LoRA + # use_lora: false + # lora: + # r: 8 + # lora_alpha: 32 + # target_modules: + # - q_proj + # - k_proj + # - v_proj + # lora_dropout: 0.01 + # bias: none + # task_type: CAUSAL_LM + +dataset_config_path: + - ./configs/datasets/m3it_coco.yaml # only coco dataset +``` + +要开始学习, 请执行以下命令. + + +```bash +./scripts/run_ds.sh +``` + # 如何使用 您可以从 Hugging Face Hub 下载训练好的模型:[turing-motors/heron-chat-git-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ja-stablelm-base-7b-v0)
@@ -195,6 +259,26 @@ with torch.no_grad(): print(processor.tokenizer.batch_decode(out)) ``` +如果模型是用 ZeRO-3 训练的,请进行以下更改. + +```diff +- # prepare a pretrained model +- model = GitLlamaForCausalLM.from_pretrained( +- 'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16 +- ) ++ from heron.models.utils import load_model, load_pretrained_weight ++ import yaml ++ ++ config_file = f"./projects/opt/exp002_ds.yml" ++ ++ # get config ++ with open(config_file, "r") as i_: ++ config = yaml.safe_load(i_) ++ ++ model = load_model(config["model_config"]) ++ model.load_state_dict(torch.load('./output/opt/exp002_ds/epoch-1/pytorch_model.bin'), strict=True) +``` + ### 训练有素的模型列表 |model|LLM module|adapter|size| @@ -222,3 +306,4 @@ print(processor.tokenizer.batch_decode(out)) - [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text) - [Llava](https://github.com/haotian-liu/LLaVA) - [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM) +- [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) diff --git a/docs/README_JP.md b/docs/README_JP.md index 437f1d4..09e1d9e 100644 --- a/docs/README_JP.md +++ b/docs/README_JP.md @@ -142,6 +142,70 @@ dataset_config_path: 学習にはGPUが必要です。Ubuntu20.04, CUDA11.7で動作確認をしています。 +# 学習方法 (Trainerなし) +Hugging Faceの `Trainer` クラスに依存しない訓練スクリプト `train_ds.py` を提供しています。
+例えば、[projects/opt/exp002_ds.yml](../projects/opt/exp_002_ds.yml)の内容は次のようになっています。 + +```yaml +training_config: + per_device_train_batch_size: 2 + per_device_eval_batch_size: 2 + gradient_accumulation_steps: 4 + num_train_epochs: 5 + dataloader_num_workers: 16 + learning_rate: 5.0e-5 + output_dir: ./output/ + report_to: "wandb" + zero_stage: 2 + precision: "fp16" + enable_tensorboard: False + seed: 0 + weight_decay: 0. + learning_rate_pretraining_components: 0. + num_warmup_steps: 0. + optim_betas: + - 0.9 + - 0.95 + lr_scheduler_type: "cosine" + gradient_checkpointing: False + cpu_offload: False + + +model_config: + pretrained_path: # None or path to model weight + model_type: git_llm + language_model_name: facebook/opt-125m + vision_model_name: openai/clip-vit-base-patch16 + num_image_with_embedding: 1 # if 1, no img_temporal_embedding + max_length: 512 + keys_to_finetune: + - visual_projection + - num_image_with_embedding + keys_to_freeze: [] + + # TODO: support LoRA + # use_lora: false + # lora: + # r: 8 + # lora_alpha: 32 + # target_modules: + # - q_proj + # - k_proj + # - v_proj + # lora_dropout: 0.01 + # bias: none + # task_type: CAUSAL_LM + +dataset_config_path: + - ./configs/datasets/m3it_coco.yaml # only coco dataset +``` + +学習を開始する場合は、次のコマンドを実行してください。 + +```bash +./scripts/run_ds.sh +``` + # 利用方法 Hugging Face Hubから学習済みモデルをダウンロードすることができます: [turing-motors/heron-chat-git-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ja-stablelm-base-7b-v0)
@@ -194,6 +258,26 @@ with torch.no_grad(): print(processor.tokenizer.batch_decode(out)) ``` +もしZeRO-3で訓練されたモデルならば、推論用コードに次の変更を加えます。 + +```diff +- # prepare a pretrained model +- model = GitLlamaForCausalLM.from_pretrained( +- 'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16 +- ) ++ from heron.models.utils import load_model, load_pretrained_weight ++ import yaml ++ ++ config_file = f"./projects/opt/exp002_ds.yml" ++ ++ # get config ++ with open(config_file, "r") as i_: ++ config = yaml.safe_load(i_) ++ ++ model = load_model(config["model_config"]) ++ model.load_state_dict(torch.load('./output/opt/exp002_ds/epoch-1/pytorch_model.bin'), strict=True) +``` + ### 学習済みモデル一覧 |model|LLM module|adapter|size| @@ -221,3 +305,4 @@ print(processor.tokenizer.batch_decode(out)) - [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text): モデルの構成方法の着想はGITに基づいています。 - [Llava](https://github.com/haotian-liu/LLaVA): 本ライブラリはLlavaプロジェクトを参考にしています。 - [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM) +- [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) diff --git a/heron/models/utils.py b/heron/models/utils.py index 3aa0a5c..85e489f 100644 --- a/heron/models/utils.py +++ b/heron/models/utils.py @@ -191,6 +191,9 @@ def set_trainable_params( untrainable_list.append(name) else: - raise ValueError("either keys_to_freeze or keys_to_finetune should be specified") + # Full parameter Tuning + for name, p in model.named_parameters(): + p.requires_grad = True + trainable_list.append(name) return trainable_list, untrainable_list diff --git a/heron/utils/ds_utils.py b/heron/utils/ds_utils.py new file mode 100755 index 0000000..ace5212 --- /dev/null +++ b/heron/utils/ds_utils.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +# Modifications copyright 2023 Turing Inc. + +""" +NOTICE: This code is subject to the terms of the Apache License 2.0. + +The code is modified from the original one. +original code: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-VisualChat/utils/ds_utils.py + +Additional contributions by Turing Inc. team +""" + +# DeepSpeed Team +GLOBAL_BATCH_SIZE = 32 +MICRO_BATCH_SIZE = 4 + + +def get_train_ds_config( + config, + offload, + stage=2, + enable_hybrid_engine=False, + inference_tp_size=1, + release_inference_cache=False, + pin_parameters=True, + tp_gather_partition_size=8, + max_out_tokens=512, +): + if config["precision"] == "fp16": + enable_fp16 = True + enable_bf16 = False + elif config["precision"] == "bf16": + enable_fp16 = False + enable_bf16 = True + else: + raise ValueError(f"Invalid precision {config['precision']}") + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": {"device": device}, + "stage3_param_persistence_threshold": 1e4, + "stage3_max_live_parameters": 3e7, + "stage3_prefetch_bucket_size": 0, + "memory_efficient_linear": False, + } + output = { + "train_batch_size": GLOBAL_BATCH_SIZE, + "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, + "steps_per_print": 10, + "zero_optimization": zero_opt_dict, + "zero_allow_untested_optimizer": True, + "zero_force_ds_cpu_optimizer": False, + "fp16": {"enabled": enable_fp16, "loss_scale_window": 100}, + "bf16": { + "enabled": enable_bf16, + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "hybrid_engine": { + "enabled": enable_hybrid_engine, + "max_out_tokens": max_out_tokens, + "inference_tp_size": inference_tp_size, + "release_inference_cache": release_inference_cache, + "pin_parameters": pin_parameters, + "tp_gather_partition_size": tp_gather_partition_size, + }, + } + if config["enable_tensorboard"]: + output.update( + { + "tensorboard": { + "enabled": True, + "output_path": config["output_dir"], + "job_name": "tb_logging", + } + } + ) + return output + + +def get_eval_ds_config(offload, stage=0): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": 1e4, + "offload_param": {"device": device}, + "memory_efficient_linear": False, + } + return { + "train_batch_size": GLOBAL_BATCH_SIZE, + "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, + "steps_per_print": 10, + "zero_optimization": zero_opt_dict, + "fp16": {"enabled": True}, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } diff --git a/heron/utils/utils.py b/heron/utils/utils.py new file mode 100644 index 0000000..fa8204a --- /dev/null +++ b/heron/utils/utils.py @@ -0,0 +1,226 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +NOTICE: This code is subject to the terms of the Apache License 2.0. + +The code is modified from the original one. +original code: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-VisualChat/utils/ds_utils.py + +Additional contributions by Turing Inc. team +""" + +import json + +# DeepSpeed Team +import os +import random + +import deepspeed +import numpy as np +import torch +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from transformers import AutoTokenizer, set_seed + + +def print_rank_0(msg, rank=None): + if rank is not None and rank <= 0: + print(msg) + elif is_rank_0(): + print(msg) + + +def is_rank_0(): + """Check whether it is rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + return True + else: + return False + else: + return True + + +def get_rank(): + """Check whether it is rank 0.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def to_device(batch, device): + output = {} + for k, v in batch.items(): + try: + output[k] = v.to(device) + except: + output[k] = v + return output + + +class MovingAverage: + def __init__(self): + self.count = 0 + self.total = 0 + self.mean = 0 + + def update(self, num): + self.total += num + self.count += 1 + self.mean = self.total / self.count + + return self.mean + + +def set_random_seed(seed): + if seed is not None: + set_seed(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_all_reduce_mean(tensor): + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + tensor = tensor / torch.distributed.get_world_size() + return tensor + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "LayerNorm.weight"], + small_learning_rate_list=["embed"], + small_lr=1e-4, +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if ( + not any(nd in n for nd in no_decay_name_list) + and (not any(nd in n for nd in small_learning_rate_list)) + and p.requires_grad + ) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if ( + any(nd in n for nd in no_decay_name_list) + and (not any(nd in n for nd in small_learning_rate_list)) + and p.requires_grad + ) + ], + "weight_decay": 0.0, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if ( + not any(nd in n for nd in no_decay_name_list) + and (any(nd in n for nd in small_learning_rate_list)) + and p.requires_grad + ) + ], + "weight_decay": weight_decay, + "lr": small_lr, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if ( + any(nd in n for nd in no_decay_name_list) + and (any(nd in n for nd in small_learning_rate_list)) + and p.requires_grad + ) + ], + "weight_decay": 0.0, + "lr": small_lr, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [ + p + for p in param_list + if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE + ] + + +def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0): + zero_stage_3 = zero_stage == 3 + with torch.no_grad(): + for param, param_ema in zip(model.parameters(), model_ema.parameters()): + # TODO: use prefiltering for efficiency + params_to_fetch = _z3_params_to_fetch([param, param_ema]) if zero_stage_3 else [] + should_gather_param = len(params_to_fetch) > 0 + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param): + data = param.data + if device is not None: + data = data.to(device) + param_ema.data.copy_(torch.lerp(data, param_ema.data, beta)) + + +def save_hf_format(model, tokenizer, args, sub_folder=""): + # used to save huggingface format, so we can use it for hf.from_pretrained + model_to_save = model.module if hasattr(model, "module") else model + CONFIG_NAME = "config.json" + WEIGHTS_NAME = "pytorch_model.bin" + output_dir = os.path.join(args.output_dir, sub_folder) + os.makedirs(output_dir, exist_ok=True) + output_model_file = os.path.join(output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(output_dir, CONFIG_NAME) + save_dict = model_to_save.state_dict() + # for key in list(save_dict.keys()): + # if "lora" in key: + # del save_dict[key] + torch.save(save_dict, output_model_file) + try: + model_to_save.config.to_json_file(output_config_file) + except: + args_dict = vars(args) + torch.save(args_dict, os.path.join(output_dir, "train_args.pt")) + print("config can't be saved") + # tokenizer.save_vocabulary(output_dir) + tokenizer.save_pretrained(output_dir) # this will save all tokenizer files + + +def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0, sub_folder=""): + zero_stage_3 = zero_stage == 3 + output_dir = os.path.join(save_dir, sub_folder) + os.makedirs(output_dir, exist_ok=True) + WEIGHTS_NAME = "pytorch_model.bin" + output_model_file = os.path.join(output_dir, WEIGHTS_NAME) + + model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema + if not zero_stage_3: + if global_rank == 0: + torch.save(model_to_save.state_dict(), output_model_file) + else: + output_state_dict = {} + for k, v in model_to_save.named_parameters(remove_duplicate=False): + if hasattr(v, "ds_id"): + with deepspeed.zero.GatheredParameters( + _z3_params_to_fetch([v]), enabled=zero_stage_3 + ): + v_p = ( + v.data.clone().detach().cpu() + ) # this is a hack to get around the fact that we can't get the data from the param + else: + v_p = v.cpu() + if global_rank == 0 and "lora" not in k: + output_state_dict[k] = v_p + if global_rank == 0: + torch.save(output_state_dict, output_model_file) + del output_state_dict diff --git a/projects/opt/exp002_ds.yml b/projects/opt/exp002_ds.yml new file mode 100644 index 0000000..0cfdd4a --- /dev/null +++ b/projects/opt/exp002_ds.yml @@ -0,0 +1,49 @@ +training_config: + per_device_train_batch_size: 2 + per_device_eval_batch_size: 2 + gradient_accumulation_steps: 4 + num_train_epochs: 3 + dataloader_num_workers: 16 + learning_rate: 5.0e-5 + output_dir: ./output/ + report_to: "wandb" + zero_stage: 3 + precision: "fp16" + enable_tensorboard: False + seed: 0 + weight_decay: 0. + learning_rate_pretraining_components: 0. + num_warmup_steps: 0. + optim_betas: + - 0.9 + - 0.95 + lr_scheduler_type: "cosine" + gradient_checkpointing: False + cpu_offload: False + + +model_config: + pretrained_path: # None or path to model weight + model_type: git_llm + language_model_name: facebook/opt-125m + vision_model_name: openai/clip-vit-base-patch16 + num_image_with_embedding: 1 # if 1, no img_temporal_embedding + max_length: 512 + keys_to_finetune: [] + keys_to_freeze: [] + + # TODO: support LoRA + # use_lora: false + # lora: + # r: 8 + # lora_alpha: 32 + # target_modules: + # - q_proj + # - k_proj + # - v_proj + # lora_dropout: 0.01 + # bias: none + # task_type: CAUSAL_LM + +dataset_config_path: + - ./configs/datasets/m3it_coco.yaml # only coco dataset diff --git a/scripts/run_ds.sh b/scripts/run_ds.sh new file mode 100755 index 0000000..b01678a --- /dev/null +++ b/scripts/run_ds.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export WANDB_PROJECT=heron +export PROJECT_NAME=opt/exp002_ds +export WANDB_NAME=$PROJECT_NAME + +deepspeed train_ds.py --config_file projects/$PROJECT_NAME.yml diff --git a/train_ds.py b/train_ds.py new file mode 100644 index 0000000..7df7417 --- /dev/null +++ b/train_ds.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +NOTICE: This code is subject to the terms of the Apache License 2.0. + +The code is modified from the original one. +original code: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-VisualChat/training/main.py + +Additional contributions by Turing Inc. team +""" + +import math +import os +import random +import sys + +import deepspeed +import fire +import numpy as np +import torch +import yaml +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm +from transformers import AdamW, AutoTokenizer, SchedulerType, get_scheduler +from transformers.integrations import HfDeepSpeedConfig + +import wandb +from heron.datasets.utils import get_dataset +from heron.models.utils import ( + apply_lora_model, + load_model, + load_pretrained_weight, + set_trainable_params, + unload_and_merge_lora, +) +from heron.utils.ds_utils import get_train_ds_config +from heron.utils.utils import ( + get_all_reduce_mean, + get_optimizer_grouped_parameters, + print_rank_0, + save_zero_three_model, + set_random_seed, + to_device, +) + + +def main(config_file: str, local_rank: int = 0): + with open(config_file, "r") as i_: + config = yaml.safe_load(i_) + model_config = config["model_config"] + training_config = config["training_config"] + + if os.environ.get("WANDB_NAME") is not None: + training_config["output_dir"] = os.path.join( + training_config["output_dir"], os.environ["WANDB_NAME"] + ) + + if local_rank == -1: + device = torch.device("cuda") + elif local_rank == -100: + # for mpirun launcher + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + deepspeed.init_distributed() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + else: + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + deepspeed.init_distributed() + + training_config["global_rank"] = torch.distributed.get_rank() + + set_random_seed(training_config["seed"]) + + # Get configs for initialize DeepSpeed + ds_config = get_train_ds_config( + training_config, + offload=training_config["cpu_offload"], + stage=training_config["zero_stage"], + ) + ds_config["train_micro_batch_size_per_gpu"] = training_config["per_device_train_batch_size"] + ds_config["train_batch_size"] = ( + training_config["per_device_train_batch_size"] + * torch.distributed.get_world_size() + * training_config["gradient_accumulation_steps"] + ) + + if training_config["zero_stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + + # Initialization of wandb + if os.environ.get("WANDB_NAME") is not None and local_rank == 0: + wandb.init(project=os.environ["WANDB_PROJECT"], config=config) + + # Wait for all processes + torch.distributed.barrier() + + # Load model + model = load_model(model_config) + + # Set trainable params + keys_to_finetune = config["model_config"]["keys_to_finetune"] + keys_to_freeze = config["model_config"]["keys_to_freeze"] + trainable_list, untrainable_list = set_trainable_params( + model, keys_to_finetune, keys_to_freeze, train_lora=False + ) + print_rank_0(f"trainable_list {trainable_list}", training_config["global_rank"]) + print_rank_0(f"untrainable_list {untrainable_list}", training_config["global_rank"]) + + print_rank_0(model, training_config["global_rank"]) + + # Load datasets + train_dataset, eval_dataset = get_dataset(config) + + train_dataloader = DataLoader( + train_dataset, + batch_size=training_config["per_device_train_batch_size"], + sampler=DistributedSampler(train_dataset, shuffle=True, drop_last=True), + num_workers=training_config["dataloader_num_workers"], + ) + + eval_dataloader = DataLoader( + eval_dataset, + batch_size=training_config["per_device_eval_batch_size"], + sampler=DistributedSampler(eval_dataset, shuffle=False), + num_workers=training_config["dataloader_num_workers"], + ) + + # Split weights in two groups, one with weight decay and the other not. + optimizer_grouped_parameters = get_optimizer_grouped_parameters( + model, + training_config["weight_decay"], + small_lr=training_config["learning_rate_pretraining_components"], + ) + + optimizer = AdamW( + optimizer_grouped_parameters, + lr=training_config["learning_rate"], + betas=tuple(training_config["optim_betas"]), + ) + + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / training_config["gradient_accumulation_steps"] + ) + if training_config["num_warmup_steps"] <= 1: + training_config["num_warmup_steps"] = int( + training_config["num_warmup_steps"] + * training_config["num_train_epochs"] + * num_update_steps_per_epoch + ) + else: + training_config["num_warmup_steps"] = int(training_config["num_warmup_steps"]) + + lr_scheduler = get_scheduler( + name=training_config["lr_scheduler_type"], + optimizer=optimizer, + num_warmup_steps=training_config["num_warmup_steps"], + num_training_steps=training_config["num_train_epochs"] * num_update_steps_per_epoch, + ) + + model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + lr_scheduler=lr_scheduler, + dist_init_required=True, + ) + + start_epoch = 0 + # let load checkpoint + if os.path.exists(os.path.join(training_config["output_dir"], "latest")): + _, client_state = model.load_checkpoint(training_config["output_dir"]) + start_epoch = client_state["epoch"] + best_loss = client_state["best_loss"] + random.setstate(client_state["random_rng_state"]) + np.random.set_state(client_state["np_rng_state"]) + torch.set_rng_state(client_state["torch_rng_state"]) + torch.cuda.set_rng_state(client_state["torch_cuda_rng_state"]) + + if training_config["gradient_checkpointing"]: + model.gradient_checkpointing_enable() + + def evaluation(model, eval_dataloader): + model.eval() + print_rank_0("***** Evaluation *****", training_config["global_rank"]) + acc_loss = 0 + progress_bar = tqdm(eval_dataloader, dynamic_ncols=True) + for step, batch in enumerate(progress_bar): + with torch.no_grad(): + batch = to_device(batch, device) + loss = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"].half(), + labels=batch["labels"], + )[0] + acc_loss += loss.float() + text = f"step {step}, loss: {loss:.5f} the average_loss: {acc_loss.item()/(step+1)=}" + # print_rank_0(text) + progress_bar.set_description(text) + model.train() + ave_loss = acc_loss / (step + 1) + print_rank_0(f"the eval average_loss: {ave_loss}", training_config["global_rank"]) + return ave_loss + + # Train! + if start_epoch == 0: + print_rank_0("***** Before training *****", training_config["global_rank"]) + # evaluation(model, eval_dataloader) + best_loss = 1e6 + + print_rank_0("***** Running training *****", training_config["global_rank"]) + for epoch in range(start_epoch, training_config["num_train_epochs"]): + print_rank_0( + f"Beginning of Epoch {epoch+1}/{training_config['num_train_epochs']}, Total Micro Batches {len(train_dataloader)}", + training_config["global_rank"], + ) + model.train() + acc_loss = 0 + progress_bar = tqdm(train_dataloader, dynamic_ncols=True) + for step, batch in enumerate(progress_bar): + batch = to_device(batch, device) + + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + pixel_values = batch["pixel_values"].half() + labels = batch["labels"] + loss = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + labels=labels, + )[0] + + acc_loss += loss.float() + model.backward(loss) + # Attention: gradient accumulation in the function + model.step() + + # Log to wandb + if os.environ.get("WANDB_NAME") is not None and local_rank == 0: + now_lr = lr_scheduler.get_lr()[0] + wandb.log( + { + "Train/epoch": epoch, + "Train/step": step, + "Train/loss": loss, + "Train/average_loss": acc_loss / (step + 1), + "Train/learning_rate": now_lr, + } + ) + + text = f"step {step}, loss: {loss.detach():.5f} the average_loss: {acc_loss/(step + 1):.5f}" + progress_bar.set_description(text) + + model.tput_timer.update_epoch_count() + print_rank_0( + f"Epoch {epoch+1}, the average_loss: {acc_loss/step}", training_config["global_rank"] + ) + eval_loss = evaluation(model, eval_dataloader) + + if eval_loss < best_loss: + best_loss = eval_loss + + # Log to wandb + if os.environ.get("WANDB_NAME") is not None and local_rank == 0: + wandb.log( + { + "Eval/loss": eval_loss, + } + ) + + # Save the checkpoint + client_state = { + "random_rng_state": random.getstate(), + "np_rng_state": np.random.get_state(), + "torch_rng_state": torch.get_rng_state(), + "torch_cuda_rng_state": torch.cuda.get_rng_state(), + "epoch": epoch + 1, # start from next epoch + "best_loss": best_loss, + } + model.save_checkpoint( + training_config["output_dir"], client_state=client_state + ) # save to the latest + + if training_config["zero_stage"] == 3: + save_zero_three_model( + model, + training_config["global_rank"], + training_config["output_dir"], + zero_stage=training_config["zero_stage"], + sub_folder=f"epoch-{epoch}", + ) + + # TODO: support merging LoRA for ZeRO-3 training + if training_config["zero_stage"] != 3: + model = model.module + + save_path = os.path.join(training_config["output_dir"], f"epoch_final") + model.save_pretrained(save_path) + + +if __name__ == "__main__": + fire.Fire(main)