Skip to content

Commit

Permalink
Merge pull request #45 from Meituan-AutoML/train_w_lora
Browse files Browse the repository at this point in the history
fix v1&&v2 lora conflict
  • Loading branch information
sxu1997 authored Apr 15, 2024
2 parents 19d7684 + 7f5c81f commit 688fdec
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mobilevlm/model/mobilevlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def initialize_vision_modules(self, model_args, fsdp=None):
# Build Vision-Projector
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
Expand Down
8 changes: 7 additions & 1 deletion mobilevlm/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']

for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
Expand Down Expand Up @@ -854,7 +858,8 @@ def make_inputs_require_grad(module, input, output):
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
model.config.vision_tower_type = training_args.vision_tower_type = model_args.vision_tower_type

model.requires_grad_(True)
if not training_args.lora_enable:
model.requires_grad_(True)

if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
Expand Down Expand Up @@ -910,6 +915,7 @@ def make_inputs_require_grad(module, input, output):
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
print('non_lora_trainable...', non_lora_state_dict.keys())
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
Expand Down
51 changes: 51 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,57 @@ case ${TASK} in
cd ${WORK_DIR}
OUTPUT_DIR=$3
bash scripts/benchmark.sh ${OUTPUT_DIR}
;;
"finetune.lora")
echo ">>> Start Visual-Instruction Tuning with LoRA..."
cd ${WORK_DIR}
LANGUAGE_MODEL=$3
VISION_MODEL=$4
OUTPUT_DIR=$5
OUTPUT_DIR_PT=${OUTPUT_DIR}/mobilevlm_v2-1.pretrain
OUTPUT_DIR_FT=${OUTPUT_DIR}/mobilevlm_v2-2.finetune-lora
mkdir -p ${OUTPUT_DIR_FT}
declare -A DS_CONF
deepspeed mobilevlm/train/train_mem.py \
--deepspeed scripts/deepspeed/zero3.json \
--lora_enable True --lora_r 128 --lora_alpha 256 \
--learning_rate 2e-4 \
--model_name_or_path ${OUTPUT_DIR_PT} \
--version v1 \
--data_path data/finetune_data/MobileVLM_V2_FT_Mix2M.json \
--image_folder data/finetune_data \
--vision_tower ${VISION_MODEL} \
--vision_tower_type clip \
--mm_projector_type ldpnetv2 \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ${OUTPUT_DIR_FT} \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to none \
2>&1 | tee -a ${OUTPUT_DIR_FT}/log.txt &&
python3 scripts/mergelora.py ${OUTPUT_DIR_PT} ${OUTPUT_DIR}/mobilevlm_v2-2.finetune-lora ${OUTPUT_DIR}/mobilevlm_v2-2.finetune \
2>&1 | tee -a ${OUTPUT_DIR_FT}/log.txt &&
echo "Done."
;;
*)
echo "error with ${DATASET_ID}"
Expand Down
4 changes: 3 additions & 1 deletion run_v1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ case ${TASK} in
mkdir -p ${OUTPUT_DIR_FT}
declare -A DS_CONF
DS_CONF=([mobilevlm1.7b]=zero2 [mobilevlm3b]=zero3)
declare -A LR_CONF
LR_CONF=([mobilevlm1.7b]=1e-4 [mobilevlm3b]=2e-4)
deepspeed mobilevlm/train/train_mem.py \
--deepspeed scripts/deepspeed/${DS_CONF[${ARCH}]}.json \
--lora_enable True --lora_r 128 --lora_alpha 256 \
--learning_rate 2e-4 \
--learning_rate ${LR_CONF[${ARCH}]} \
--model_name_or_path ${LANGUAGE_MODEL} \
--version v1 \
--data_path data/finetune_data/llava_v1_5_mix665k.json \
Expand Down

0 comments on commit 688fdec

Please sign in to comment.