diff --git a/README.md b/README.md
index 97f26c7..d26d040 100644
--- a/README.md
+++ b/README.md
@@ -143,8 +143,74 @@ To start learning, execute the following command.
GPU is required for learning; we have tested on Ubuntu 20.04, CUDA 11.7.
+# Training (w/o Trainer)
+
+We offer `train_ds.py`, a training script independent of Hugging Face's `Trainer` class for more flexible learning configurations.
+For example, the contents of [projects/opt/exp002_ds.yml](projects/opt/exp002_ds.yml) has the following contents:
+
+```yaml
+training_config:
+ per_device_train_batch_size: 2
+ per_device_eval_batch_size: 2
+ gradient_accumulation_steps: 4
+ num_train_epochs: 5
+ dataloader_num_workers: 16
+ learning_rate: 5.0e-5
+ output_dir: ./output/
+ report_to: "wandb"
+ zero_stage: 2
+ precision: "fp16"
+ enable_tensorboard: False
+ seed: 0
+ weight_decay: 0.
+ learning_rate_pretraining_components: 0.
+ num_warmup_steps: 0.
+ optim_betas:
+ - 0.9
+ - 0.95
+ lr_scheduler_type: "cosine"
+ gradient_checkpointing: False
+ cpu_offload: False
+
+
+model_config:
+ pretrained_path: # None or path to model weight
+ model_type: git_llm
+ language_model_name: facebook/opt-125m
+ vision_model_name: openai/clip-vit-base-patch16
+ num_image_with_embedding: 1 # if 1, no img_temporal_embedding
+ max_length: 512
+ keys_to_finetune:
+ - visual_projection
+ - num_image_with_embedding
+ keys_to_freeze: []
+
+ # TODO: support LoRA
+ # use_lora: false
+ # lora:
+ # r: 8
+ # lora_alpha: 32
+ # target_modules:
+ # - q_proj
+ # - k_proj
+ # - v_proj
+ # lora_dropout: 0.01
+ # bias: none
+ # task_type: CAUSAL_LM
+
+dataset_config_path:
+ - ./configs/datasets/m3it_coco.yaml # only coco dataset
+```
+
+To start learning, execute the following command.
+
+```bash
+./scripts/run_ds.sh
+```
+
# Evaluation
+If you have the model trained by ZeRO-3
You can get the pretrained weight form Hugging Face Hub: [turing-motors/heron-chat-git-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ja-stablelm-base-7b-v0)
See also [notebooks](./notebooks).
@@ -197,6 +263,26 @@ with torch.no_grad():
print(processor.tokenizer.batch_decode(out)[0])
```
+If you have a model trained using ZeRO-3, it must be modified as follows:
+
+```diff
+- # prepare a pretrained model
+- model = GitLlamaForCausalLM.from_pretrained(
+- 'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16
+- )
++ from heron.models.utils import load_model, load_pretrained_weight
++ import yaml
++
++ config_file = f"./projects/opt/exp002_ds.yml"
++
++ # get config
++ with open(config_file, "r") as i_:
++ config = yaml.safe_load(i_)
++
++ model = load_model(config["model_config"])
++ model.load_state_dict(torch.load('./output/opt/exp002_ds/epoch-1/pytorch_model.bin'), strict=True)
+```
+
### Pretrained Models
|model|LLM module|adapter|size|
@@ -225,3 +311,4 @@ Released under the [Apache License 2.0](./LICENSE).
- [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text): The main idia of the model is based on original GIT.
- [Llava](https://github.com/haotian-liu/LLaVA): This project is learned a lot from the great Llava project.
- [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM)
+- [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)
diff --git a/configs/datasets/m3it_coco.yaml b/configs/datasets/m3it_coco.yaml
new file mode 100644
index 0000000..6ebc917
--- /dev/null
+++ b/configs/datasets/m3it_coco.yaml
@@ -0,0 +1,3 @@
+dataset_type: m3it
+dataset_names:
+ - coco
diff --git a/docs/README_CN.md b/docs/README_CN.md
index 8f710d5..92daaef 100644
--- a/docs/README_CN.md
+++ b/docs/README_CN.md
@@ -143,6 +143,70 @@ training_config "为训练设置, "model_config "为模型设置,"dataset_conf
学习需要 GPU;我们在 Ubuntu 20.04 和 CUDA 11.7 上对系统进行了测试.
+# 学习方法 (不含 Trainer)
+我们提供 `train_ds.py` ,一个独立于Hugging Face训练师类的训练脚本,用于更灵活的学习配置。例如,[projects/opt/exp002_ds.yml](../projects/opt/exp002_ds.yml) 的内容如下:
+
+```yaml
+training_config:
+ per_device_train_batch_size: 2
+ per_device_eval_batch_size: 2
+ gradient_accumulation_steps: 4
+ num_train_epochs: 5
+ dataloader_num_workers: 16
+ learning_rate: 5.0e-5
+ output_dir: ./output/
+ report_to: "wandb"
+ zero_stage: 2
+ precision: "fp16"
+ enable_tensorboard: False
+ seed: 0
+ weight_decay: 0.
+ learning_rate_pretraining_components: 0.
+ num_warmup_steps: 0.
+ optim_betas:
+ - 0.9
+ - 0.95
+ lr_scheduler_type: "cosine"
+ gradient_checkpointing: False
+ cpu_offload: False
+
+
+model_config:
+ pretrained_path: # None or path to model weight
+ model_type: git_llm
+ language_model_name: facebook/opt-125m
+ vision_model_name: openai/clip-vit-base-patch16
+ num_image_with_embedding: 1 # if 1, no img_temporal_embedding
+ max_length: 512
+ keys_to_finetune:
+ - visual_projection
+ - num_image_with_embedding
+ keys_to_freeze: []
+
+ # TODO: support LoRA
+ # use_lora: false
+ # lora:
+ # r: 8
+ # lora_alpha: 32
+ # target_modules:
+ # - q_proj
+ # - k_proj
+ # - v_proj
+ # lora_dropout: 0.01
+ # bias: none
+ # task_type: CAUSAL_LM
+
+dataset_config_path:
+ - ./configs/datasets/m3it_coco.yaml # only coco dataset
+```
+
+要开始学习, 请执行以下命令.
+
+
+```bash
+./scripts/run_ds.sh
+```
+
# 如何使用
您可以从 Hugging Face Hub 下载训练好的模型:[turing-motors/heron-chat-git-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ja-stablelm-base-7b-v0)
@@ -195,6 +259,26 @@ with torch.no_grad():
print(processor.tokenizer.batch_decode(out))
```
+如果模型是用 ZeRO-3 训练的,请进行以下更改.
+
+```diff
+- # prepare a pretrained model
+- model = GitLlamaForCausalLM.from_pretrained(
+- 'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16
+- )
++ from heron.models.utils import load_model, load_pretrained_weight
++ import yaml
++
++ config_file = f"./projects/opt/exp002_ds.yml"
++
++ # get config
++ with open(config_file, "r") as i_:
++ config = yaml.safe_load(i_)
++
++ model = load_model(config["model_config"])
++ model.load_state_dict(torch.load('./output/opt/exp002_ds/epoch-1/pytorch_model.bin'), strict=True)
+```
+
### 训练有素的模型列表
|model|LLM module|adapter|size|
@@ -222,3 +306,4 @@ print(processor.tokenizer.batch_decode(out))
- [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text)
- [Llava](https://github.com/haotian-liu/LLaVA)
- [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM)
+- [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)
diff --git a/docs/README_JP.md b/docs/README_JP.md
index 437f1d4..09e1d9e 100644
--- a/docs/README_JP.md
+++ b/docs/README_JP.md
@@ -142,6 +142,70 @@ dataset_config_path:
学習にはGPUが必要です。Ubuntu20.04, CUDA11.7で動作確認をしています。
+# 学習方法 (Trainerなし)
+Hugging Faceの `Trainer` クラスに依存しない訓練スクリプト `train_ds.py` を提供しています。
+例えば、[projects/opt/exp002_ds.yml](../projects/opt/exp_002_ds.yml)の内容は次のようになっています。
+
+```yaml
+training_config:
+ per_device_train_batch_size: 2
+ per_device_eval_batch_size: 2
+ gradient_accumulation_steps: 4
+ num_train_epochs: 5
+ dataloader_num_workers: 16
+ learning_rate: 5.0e-5
+ output_dir: ./output/
+ report_to: "wandb"
+ zero_stage: 2
+ precision: "fp16"
+ enable_tensorboard: False
+ seed: 0
+ weight_decay: 0.
+ learning_rate_pretraining_components: 0.
+ num_warmup_steps: 0.
+ optim_betas:
+ - 0.9
+ - 0.95
+ lr_scheduler_type: "cosine"
+ gradient_checkpointing: False
+ cpu_offload: False
+
+
+model_config:
+ pretrained_path: # None or path to model weight
+ model_type: git_llm
+ language_model_name: facebook/opt-125m
+ vision_model_name: openai/clip-vit-base-patch16
+ num_image_with_embedding: 1 # if 1, no img_temporal_embedding
+ max_length: 512
+ keys_to_finetune:
+ - visual_projection
+ - num_image_with_embedding
+ keys_to_freeze: []
+
+ # TODO: support LoRA
+ # use_lora: false
+ # lora:
+ # r: 8
+ # lora_alpha: 32
+ # target_modules:
+ # - q_proj
+ # - k_proj
+ # - v_proj
+ # lora_dropout: 0.01
+ # bias: none
+ # task_type: CAUSAL_LM
+
+dataset_config_path:
+ - ./configs/datasets/m3it_coco.yaml # only coco dataset
+```
+
+学習を開始する場合は、次のコマンドを実行してください。
+
+```bash
+./scripts/run_ds.sh
+```
+
# 利用方法
Hugging Face Hubから学習済みモデルをダウンロードすることができます: [turing-motors/heron-chat-git-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ja-stablelm-base-7b-v0)
@@ -194,6 +258,26 @@ with torch.no_grad():
print(processor.tokenizer.batch_decode(out))
```
+もしZeRO-3で訓練されたモデルならば、推論用コードに次の変更を加えます。
+
+```diff
+- # prepare a pretrained model
+- model = GitLlamaForCausalLM.from_pretrained(
+- 'turing-motors/heron-chat-git-Llama-2-7b-v0', torch_dtype=torch.float16
+- )
++ from heron.models.utils import load_model, load_pretrained_weight
++ import yaml
++
++ config_file = f"./projects/opt/exp002_ds.yml"
++
++ # get config
++ with open(config_file, "r") as i_:
++ config = yaml.safe_load(i_)
++
++ model = load_model(config["model_config"])
++ model.load_state_dict(torch.load('./output/opt/exp002_ds/epoch-1/pytorch_model.bin'), strict=True)
+```
+
### 学習済みモデル一覧
|model|LLM module|adapter|size|
@@ -221,3 +305,4 @@ print(processor.tokenizer.batch_decode(out))
- [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text): モデルの構成方法の着想はGITに基づいています。
- [Llava](https://github.com/haotian-liu/LLaVA): 本ライブラリはLlavaプロジェクトを参考にしています。
- [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM)
+- [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)
diff --git a/heron/models/utils.py b/heron/models/utils.py
index 3aa0a5c..85e489f 100644
--- a/heron/models/utils.py
+++ b/heron/models/utils.py
@@ -191,6 +191,9 @@ def set_trainable_params(
untrainable_list.append(name)
else:
- raise ValueError("either keys_to_freeze or keys_to_finetune should be specified")
+ # Full parameter Tuning
+ for name, p in model.named_parameters():
+ p.requires_grad = True
+ trainable_list.append(name)
return trainable_list, untrainable_list
diff --git a/heron/utils/ds_utils.py b/heron/utils/ds_utils.py
new file mode 100755
index 0000000..ace5212
--- /dev/null
+++ b/heron/utils/ds_utils.py
@@ -0,0 +1,101 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+# Modifications copyright 2023 Turing Inc.
+
+"""
+NOTICE: This code is subject to the terms of the Apache License 2.0.
+
+The code is modified from the original one.
+original code: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-VisualChat/utils/ds_utils.py
+
+Additional contributions by Turing Inc. team
+"""
+
+# DeepSpeed Team
+GLOBAL_BATCH_SIZE = 32
+MICRO_BATCH_SIZE = 4
+
+
+def get_train_ds_config(
+ config,
+ offload,
+ stage=2,
+ enable_hybrid_engine=False,
+ inference_tp_size=1,
+ release_inference_cache=False,
+ pin_parameters=True,
+ tp_gather_partition_size=8,
+ max_out_tokens=512,
+):
+ if config["precision"] == "fp16":
+ enable_fp16 = True
+ enable_bf16 = False
+ elif config["precision"] == "bf16":
+ enable_fp16 = False
+ enable_bf16 = True
+ else:
+ raise ValueError(f"Invalid precision {config['precision']}")
+ device = "cpu" if offload else "none"
+ zero_opt_dict = {
+ "stage": stage,
+ "offload_param": {"device": device},
+ "offload_optimizer": {"device": device},
+ "stage3_param_persistence_threshold": 1e4,
+ "stage3_max_live_parameters": 3e7,
+ "stage3_prefetch_bucket_size": 0,
+ "memory_efficient_linear": False,
+ }
+ output = {
+ "train_batch_size": GLOBAL_BATCH_SIZE,
+ "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
+ "steps_per_print": 10,
+ "zero_optimization": zero_opt_dict,
+ "zero_allow_untested_optimizer": True,
+ "zero_force_ds_cpu_optimizer": False,
+ "fp16": {"enabled": enable_fp16, "loss_scale_window": 100},
+ "bf16": {
+ "enabled": enable_bf16,
+ },
+ "gradient_clipping": 1.0,
+ "prescale_gradients": False,
+ "wall_clock_breakdown": False,
+ "hybrid_engine": {
+ "enabled": enable_hybrid_engine,
+ "max_out_tokens": max_out_tokens,
+ "inference_tp_size": inference_tp_size,
+ "release_inference_cache": release_inference_cache,
+ "pin_parameters": pin_parameters,
+ "tp_gather_partition_size": tp_gather_partition_size,
+ },
+ }
+ if config["enable_tensorboard"]:
+ output.update(
+ {
+ "tensorboard": {
+ "enabled": True,
+ "output_path": config["output_dir"],
+ "job_name": "tb_logging",
+ }
+ }
+ )
+ return output
+
+
+def get_eval_ds_config(offload, stage=0):
+ device = "cpu" if offload else "none"
+ zero_opt_dict = {
+ "stage": stage,
+ "stage3_param_persistence_threshold": 1e4,
+ "offload_param": {"device": device},
+ "memory_efficient_linear": False,
+ }
+ return {
+ "train_batch_size": GLOBAL_BATCH_SIZE,
+ "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
+ "steps_per_print": 10,
+ "zero_optimization": zero_opt_dict,
+ "fp16": {"enabled": True},
+ "gradient_clipping": 1.0,
+ "prescale_gradients": False,
+ "wall_clock_breakdown": False,
+ }
diff --git a/heron/utils/utils.py b/heron/utils/utils.py
new file mode 100644
index 0000000..fa8204a
--- /dev/null
+++ b/heron/utils/utils.py
@@ -0,0 +1,226 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+NOTICE: This code is subject to the terms of the Apache License 2.0.
+
+The code is modified from the original one.
+original code: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-VisualChat/utils/ds_utils.py
+
+Additional contributions by Turing Inc. team
+"""
+
+import json
+
+# DeepSpeed Team
+import os
+import random
+
+import deepspeed
+import numpy as np
+import torch
+from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+from transformers import AutoTokenizer, set_seed
+
+
+def print_rank_0(msg, rank=None):
+ if rank is not None and rank <= 0:
+ print(msg)
+ elif is_rank_0():
+ print(msg)
+
+
+def is_rank_0():
+ """Check whether it is rank 0."""
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == 0:
+ return True
+ else:
+ return False
+ else:
+ return True
+
+
+def get_rank():
+ """Check whether it is rank 0."""
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_rank()
+ else:
+ return 0
+
+
+def to_device(batch, device):
+ output = {}
+ for k, v in batch.items():
+ try:
+ output[k] = v.to(device)
+ except:
+ output[k] = v
+ return output
+
+
+class MovingAverage:
+ def __init__(self):
+ self.count = 0
+ self.total = 0
+ self.mean = 0
+
+ def update(self, num):
+ self.total += num
+ self.count += 1
+ self.mean = self.total / self.count
+
+ return self.mean
+
+
+def set_random_seed(seed):
+ if seed is not None:
+ set_seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_all_reduce_mean(tensor):
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
+ tensor = tensor / torch.distributed.get_world_size()
+ return tensor
+
+
+def get_optimizer_grouped_parameters(
+ model,
+ weight_decay,
+ no_decay_name_list=["bias", "LayerNorm.weight"],
+ small_learning_rate_list=["embed"],
+ small_lr=1e-4,
+):
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if (
+ not any(nd in n for nd in no_decay_name_list)
+ and (not any(nd in n for nd in small_learning_rate_list))
+ and p.requires_grad
+ )
+ ],
+ "weight_decay": weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if (
+ any(nd in n for nd in no_decay_name_list)
+ and (not any(nd in n for nd in small_learning_rate_list))
+ and p.requires_grad
+ )
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if (
+ not any(nd in n for nd in no_decay_name_list)
+ and (any(nd in n for nd in small_learning_rate_list))
+ and p.requires_grad
+ )
+ ],
+ "weight_decay": weight_decay,
+ "lr": small_lr,
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if (
+ any(nd in n for nd in no_decay_name_list)
+ and (any(nd in n for nd in small_learning_rate_list))
+ and p.requires_grad
+ )
+ ],
+ "weight_decay": 0.0,
+ "lr": small_lr,
+ },
+ ]
+ return optimizer_grouped_parameters
+
+
+def _z3_params_to_fetch(param_list):
+ return [
+ p
+ for p in param_list
+ if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
+ ]
+
+
+def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
+ zero_stage_3 = zero_stage == 3
+ with torch.no_grad():
+ for param, param_ema in zip(model.parameters(), model_ema.parameters()):
+ # TODO: use prefiltering for efficiency
+ params_to_fetch = _z3_params_to_fetch([param, param_ema]) if zero_stage_3 else []
+ should_gather_param = len(params_to_fetch) > 0
+ with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
+ data = param.data
+ if device is not None:
+ data = data.to(device)
+ param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
+
+
+def save_hf_format(model, tokenizer, args, sub_folder=""):
+ # used to save huggingface format, so we can use it for hf.from_pretrained
+ model_to_save = model.module if hasattr(model, "module") else model
+ CONFIG_NAME = "config.json"
+ WEIGHTS_NAME = "pytorch_model.bin"
+ output_dir = os.path.join(args.output_dir, sub_folder)
+ os.makedirs(output_dir, exist_ok=True)
+ output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
+ output_config_file = os.path.join(output_dir, CONFIG_NAME)
+ save_dict = model_to_save.state_dict()
+ # for key in list(save_dict.keys()):
+ # if "lora" in key:
+ # del save_dict[key]
+ torch.save(save_dict, output_model_file)
+ try:
+ model_to_save.config.to_json_file(output_config_file)
+ except:
+ args_dict = vars(args)
+ torch.save(args_dict, os.path.join(output_dir, "train_args.pt"))
+ print("config can't be saved")
+ # tokenizer.save_vocabulary(output_dir)
+ tokenizer.save_pretrained(output_dir) # this will save all tokenizer files
+
+
+def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0, sub_folder=""):
+ zero_stage_3 = zero_stage == 3
+ output_dir = os.path.join(save_dir, sub_folder)
+ os.makedirs(output_dir, exist_ok=True)
+ WEIGHTS_NAME = "pytorch_model.bin"
+ output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
+
+ model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
+ if not zero_stage_3:
+ if global_rank == 0:
+ torch.save(model_to_save.state_dict(), output_model_file)
+ else:
+ output_state_dict = {}
+ for k, v in model_to_save.named_parameters(remove_duplicate=False):
+ if hasattr(v, "ds_id"):
+ with deepspeed.zero.GatheredParameters(
+ _z3_params_to_fetch([v]), enabled=zero_stage_3
+ ):
+ v_p = (
+ v.data.clone().detach().cpu()
+ ) # this is a hack to get around the fact that we can't get the data from the param
+ else:
+ v_p = v.cpu()
+ if global_rank == 0 and "lora" not in k:
+ output_state_dict[k] = v_p
+ if global_rank == 0:
+ torch.save(output_state_dict, output_model_file)
+ del output_state_dict
diff --git a/projects/opt/exp002_ds.yml b/projects/opt/exp002_ds.yml
new file mode 100644
index 0000000..0cfdd4a
--- /dev/null
+++ b/projects/opt/exp002_ds.yml
@@ -0,0 +1,49 @@
+training_config:
+ per_device_train_batch_size: 2
+ per_device_eval_batch_size: 2
+ gradient_accumulation_steps: 4
+ num_train_epochs: 3
+ dataloader_num_workers: 16
+ learning_rate: 5.0e-5
+ output_dir: ./output/
+ report_to: "wandb"
+ zero_stage: 3
+ precision: "fp16"
+ enable_tensorboard: False
+ seed: 0
+ weight_decay: 0.
+ learning_rate_pretraining_components: 0.
+ num_warmup_steps: 0.
+ optim_betas:
+ - 0.9
+ - 0.95
+ lr_scheduler_type: "cosine"
+ gradient_checkpointing: False
+ cpu_offload: False
+
+
+model_config:
+ pretrained_path: # None or path to model weight
+ model_type: git_llm
+ language_model_name: facebook/opt-125m
+ vision_model_name: openai/clip-vit-base-patch16
+ num_image_with_embedding: 1 # if 1, no img_temporal_embedding
+ max_length: 512
+ keys_to_finetune: []
+ keys_to_freeze: []
+
+ # TODO: support LoRA
+ # use_lora: false
+ # lora:
+ # r: 8
+ # lora_alpha: 32
+ # target_modules:
+ # - q_proj
+ # - k_proj
+ # - v_proj
+ # lora_dropout: 0.01
+ # bias: none
+ # task_type: CAUSAL_LM
+
+dataset_config_path:
+ - ./configs/datasets/m3it_coco.yaml # only coco dataset
diff --git a/scripts/run_ds.sh b/scripts/run_ds.sh
new file mode 100755
index 0000000..b01678a
--- /dev/null
+++ b/scripts/run_ds.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+export WANDB_PROJECT=heron
+export PROJECT_NAME=opt/exp002_ds
+export WANDB_NAME=$PROJECT_NAME
+
+deepspeed train_ds.py --config_file projects/$PROJECT_NAME.yml
diff --git a/train_ds.py b/train_ds.py
new file mode 100644
index 0000000..7df7417
--- /dev/null
+++ b/train_ds.py
@@ -0,0 +1,308 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+NOTICE: This code is subject to the terms of the Apache License 2.0.
+
+The code is modified from the original one.
+original code: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-VisualChat/training/main.py
+
+Additional contributions by Turing Inc. team
+"""
+
+import math
+import os
+import random
+import sys
+
+import deepspeed
+import fire
+import numpy as np
+import torch
+import yaml
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from tqdm import tqdm
+from transformers import AdamW, AutoTokenizer, SchedulerType, get_scheduler
+from transformers.integrations import HfDeepSpeedConfig
+
+import wandb
+from heron.datasets.utils import get_dataset
+from heron.models.utils import (
+ apply_lora_model,
+ load_model,
+ load_pretrained_weight,
+ set_trainable_params,
+ unload_and_merge_lora,
+)
+from heron.utils.ds_utils import get_train_ds_config
+from heron.utils.utils import (
+ get_all_reduce_mean,
+ get_optimizer_grouped_parameters,
+ print_rank_0,
+ save_zero_three_model,
+ set_random_seed,
+ to_device,
+)
+
+
+def main(config_file: str, local_rank: int = 0):
+ with open(config_file, "r") as i_:
+ config = yaml.safe_load(i_)
+ model_config = config["model_config"]
+ training_config = config["training_config"]
+
+ if os.environ.get("WANDB_NAME") is not None:
+ training_config["output_dir"] = os.path.join(
+ training_config["output_dir"], os.environ["WANDB_NAME"]
+ )
+
+ if local_rank == -1:
+ device = torch.device("cuda")
+ elif local_rank == -100:
+ # for mpirun launcher
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+ deepspeed.init_distributed()
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.cuda.set_device(local_rank)
+ device = torch.device("cuda", local_rank)
+ else:
+ torch.cuda.set_device(local_rank)
+ device = torch.device("cuda", local_rank)
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+ deepspeed.init_distributed()
+
+ training_config["global_rank"] = torch.distributed.get_rank()
+
+ set_random_seed(training_config["seed"])
+
+ # Get configs for initialize DeepSpeed
+ ds_config = get_train_ds_config(
+ training_config,
+ offload=training_config["cpu_offload"],
+ stage=training_config["zero_stage"],
+ )
+ ds_config["train_micro_batch_size_per_gpu"] = training_config["per_device_train_batch_size"]
+ ds_config["train_batch_size"] = (
+ training_config["per_device_train_batch_size"]
+ * torch.distributed.get_world_size()
+ * training_config["gradient_accumulation_steps"]
+ )
+
+ if training_config["zero_stage"] == 3:
+ dschf = HfDeepSpeedConfig(ds_config)
+
+ # Initialization of wandb
+ if os.environ.get("WANDB_NAME") is not None and local_rank == 0:
+ wandb.init(project=os.environ["WANDB_PROJECT"], config=config)
+
+ # Wait for all processes
+ torch.distributed.barrier()
+
+ # Load model
+ model = load_model(model_config)
+
+ # Set trainable params
+ keys_to_finetune = config["model_config"]["keys_to_finetune"]
+ keys_to_freeze = config["model_config"]["keys_to_freeze"]
+ trainable_list, untrainable_list = set_trainable_params(
+ model, keys_to_finetune, keys_to_freeze, train_lora=False
+ )
+ print_rank_0(f"trainable_list {trainable_list}", training_config["global_rank"])
+ print_rank_0(f"untrainable_list {untrainable_list}", training_config["global_rank"])
+
+ print_rank_0(model, training_config["global_rank"])
+
+ # Load datasets
+ train_dataset, eval_dataset = get_dataset(config)
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=training_config["per_device_train_batch_size"],
+ sampler=DistributedSampler(train_dataset, shuffle=True, drop_last=True),
+ num_workers=training_config["dataloader_num_workers"],
+ )
+
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ batch_size=training_config["per_device_eval_batch_size"],
+ sampler=DistributedSampler(eval_dataset, shuffle=False),
+ num_workers=training_config["dataloader_num_workers"],
+ )
+
+ # Split weights in two groups, one with weight decay and the other not.
+ optimizer_grouped_parameters = get_optimizer_grouped_parameters(
+ model,
+ training_config["weight_decay"],
+ small_lr=training_config["learning_rate_pretraining_components"],
+ )
+
+ optimizer = AdamW(
+ optimizer_grouped_parameters,
+ lr=training_config["learning_rate"],
+ betas=tuple(training_config["optim_betas"]),
+ )
+
+ num_update_steps_per_epoch = math.ceil(
+ len(train_dataloader) / training_config["gradient_accumulation_steps"]
+ )
+ if training_config["num_warmup_steps"] <= 1:
+ training_config["num_warmup_steps"] = int(
+ training_config["num_warmup_steps"]
+ * training_config["num_train_epochs"]
+ * num_update_steps_per_epoch
+ )
+ else:
+ training_config["num_warmup_steps"] = int(training_config["num_warmup_steps"])
+
+ lr_scheduler = get_scheduler(
+ name=training_config["lr_scheduler_type"],
+ optimizer=optimizer,
+ num_warmup_steps=training_config["num_warmup_steps"],
+ num_training_steps=training_config["num_train_epochs"] * num_update_steps_per_epoch,
+ )
+
+ model, optimizer, _, lr_scheduler = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ config=ds_config,
+ lr_scheduler=lr_scheduler,
+ dist_init_required=True,
+ )
+
+ start_epoch = 0
+ # let load checkpoint
+ if os.path.exists(os.path.join(training_config["output_dir"], "latest")):
+ _, client_state = model.load_checkpoint(training_config["output_dir"])
+ start_epoch = client_state["epoch"]
+ best_loss = client_state["best_loss"]
+ random.setstate(client_state["random_rng_state"])
+ np.random.set_state(client_state["np_rng_state"])
+ torch.set_rng_state(client_state["torch_rng_state"])
+ torch.cuda.set_rng_state(client_state["torch_cuda_rng_state"])
+
+ if training_config["gradient_checkpointing"]:
+ model.gradient_checkpointing_enable()
+
+ def evaluation(model, eval_dataloader):
+ model.eval()
+ print_rank_0("***** Evaluation *****", training_config["global_rank"])
+ acc_loss = 0
+ progress_bar = tqdm(eval_dataloader, dynamic_ncols=True)
+ for step, batch in enumerate(progress_bar):
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ loss = model(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ pixel_values=batch["pixel_values"].half(),
+ labels=batch["labels"],
+ )[0]
+ acc_loss += loss.float()
+ text = f"step {step}, loss: {loss:.5f} the average_loss: {acc_loss.item()/(step+1)=}"
+ # print_rank_0(text)
+ progress_bar.set_description(text)
+ model.train()
+ ave_loss = acc_loss / (step + 1)
+ print_rank_0(f"the eval average_loss: {ave_loss}", training_config["global_rank"])
+ return ave_loss
+
+ # Train!
+ if start_epoch == 0:
+ print_rank_0("***** Before training *****", training_config["global_rank"])
+ # evaluation(model, eval_dataloader)
+ best_loss = 1e6
+
+ print_rank_0("***** Running training *****", training_config["global_rank"])
+ for epoch in range(start_epoch, training_config["num_train_epochs"]):
+ print_rank_0(
+ f"Beginning of Epoch {epoch+1}/{training_config['num_train_epochs']}, Total Micro Batches {len(train_dataloader)}",
+ training_config["global_rank"],
+ )
+ model.train()
+ acc_loss = 0
+ progress_bar = tqdm(train_dataloader, dynamic_ncols=True)
+ for step, batch in enumerate(progress_bar):
+ batch = to_device(batch, device)
+
+ input_ids = batch["input_ids"]
+ attention_mask = batch["attention_mask"]
+ pixel_values = batch["pixel_values"].half()
+ labels = batch["labels"]
+ loss = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ labels=labels,
+ )[0]
+
+ acc_loss += loss.float()
+ model.backward(loss)
+ # Attention: gradient accumulation in the function
+ model.step()
+
+ # Log to wandb
+ if os.environ.get("WANDB_NAME") is not None and local_rank == 0:
+ now_lr = lr_scheduler.get_lr()[0]
+ wandb.log(
+ {
+ "Train/epoch": epoch,
+ "Train/step": step,
+ "Train/loss": loss,
+ "Train/average_loss": acc_loss / (step + 1),
+ "Train/learning_rate": now_lr,
+ }
+ )
+
+ text = f"step {step}, loss: {loss.detach():.5f} the average_loss: {acc_loss/(step + 1):.5f}"
+ progress_bar.set_description(text)
+
+ model.tput_timer.update_epoch_count()
+ print_rank_0(
+ f"Epoch {epoch+1}, the average_loss: {acc_loss/step}", training_config["global_rank"]
+ )
+ eval_loss = evaluation(model, eval_dataloader)
+
+ if eval_loss < best_loss:
+ best_loss = eval_loss
+
+ # Log to wandb
+ if os.environ.get("WANDB_NAME") is not None and local_rank == 0:
+ wandb.log(
+ {
+ "Eval/loss": eval_loss,
+ }
+ )
+
+ # Save the checkpoint
+ client_state = {
+ "random_rng_state": random.getstate(),
+ "np_rng_state": np.random.get_state(),
+ "torch_rng_state": torch.get_rng_state(),
+ "torch_cuda_rng_state": torch.cuda.get_rng_state(),
+ "epoch": epoch + 1, # start from next epoch
+ "best_loss": best_loss,
+ }
+ model.save_checkpoint(
+ training_config["output_dir"], client_state=client_state
+ ) # save to the latest
+
+ if training_config["zero_stage"] == 3:
+ save_zero_three_model(
+ model,
+ training_config["global_rank"],
+ training_config["output_dir"],
+ zero_stage=training_config["zero_stage"],
+ sub_folder=f"epoch-{epoch}",
+ )
+
+ # TODO: support merging LoRA for ZeRO-3 training
+ if training_config["zero_stage"] != 3:
+ model = model.module
+
+ save_path = os.path.join(training_config["output_dir"], f"epoch_final")
+ model.save_pretrained(save_path)
+
+
+if __name__ == "__main__":
+ fire.Fire(main)