diff --git a/mobilevlm/model/mobilevlm.py b/mobilevlm/model/mobilevlm.py index 307f78c..e6af382 100644 --- a/mobilevlm/model/mobilevlm.py +++ b/mobilevlm/model/mobilevlm.py @@ -64,6 +64,9 @@ def initialize_vision_modules(self, model_args, fsdp=None): # Build Vision-Projector if getattr(self, 'mm_projector', None) is None: self.mm_projector = build_vision_projector(self.config) + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): diff --git a/mobilevlm/train/train.py b/mobilevlm/train/train.py index 53f69ee..9e9d1bb 100644 --- a/mobilevlm/train/train.py +++ b/mobilevlm/train/train.py @@ -142,7 +142,11 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) @@ -854,7 +858,8 @@ def make_inputs_require_grad(module, input, output): model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter model.config.vision_tower_type = training_args.vision_tower_type = model_args.vision_tower_type - model.requires_grad_(True) + if not training_args.lora_enable: + model.requires_grad_(True) if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) @@ -910,6 +915,7 @@ def make_inputs_require_grad(module, input, output): if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) + print('non_lora_trainable...', non_lora_state_dict.keys()) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) diff --git a/run.sh b/run.sh index 2301606..569f509 100644 --- a/run.sh +++ b/run.sh @@ -113,6 +113,57 @@ case ${TASK} in cd ${WORK_DIR} OUTPUT_DIR=$3 bash scripts/benchmark.sh ${OUTPUT_DIR} + ;; + "finetune.lora") + echo ">>> Start Visual-Instruction Tuning with LoRA..." + cd ${WORK_DIR} + LANGUAGE_MODEL=$3 + VISION_MODEL=$4 + OUTPUT_DIR=$5 + OUTPUT_DIR_PT=${OUTPUT_DIR}/mobilevlm_v2-1.pretrain + OUTPUT_DIR_FT=${OUTPUT_DIR}/mobilevlm_v2-2.finetune-lora + mkdir -p ${OUTPUT_DIR_FT} + declare -A DS_CONF + deepspeed mobilevlm/train/train_mem.py \ + --deepspeed scripts/deepspeed/zero3.json \ + --lora_enable True --lora_r 128 --lora_alpha 256 \ + --learning_rate 2e-4 \ + --model_name_or_path ${OUTPUT_DIR_PT} \ + --version v1 \ + --data_path data/finetune_data/MobileVLM_V2_FT_Mix2M.json \ + --image_folder data/finetune_data \ + --vision_tower ${VISION_MODEL} \ + --vision_tower_type clip \ + --mm_projector_type ldpnetv2 \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ${OUTPUT_DIR_FT} \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to none \ + 2>&1 | tee -a ${OUTPUT_DIR_FT}/log.txt && + python3 scripts/mergelora.py ${OUTPUT_DIR_PT} ${OUTPUT_DIR}/mobilevlm_v2-2.finetune-lora ${OUTPUT_DIR}/mobilevlm_v2-2.finetune \ + 2>&1 | tee -a ${OUTPUT_DIR_FT}/log.txt && + echo "Done." ;; *) echo "error with ${DATASET_ID}" diff --git a/run_v1.sh b/run_v1.sh index cf1de12..3650f77 100644 --- a/run_v1.sh +++ b/run_v1.sh @@ -130,10 +130,12 @@ case ${TASK} in mkdir -p ${OUTPUT_DIR_FT} declare -A DS_CONF DS_CONF=([mobilevlm1.7b]=zero2 [mobilevlm3b]=zero3) + declare -A LR_CONF + LR_CONF=([mobilevlm1.7b]=1e-4 [mobilevlm3b]=2e-4) deepspeed mobilevlm/train/train_mem.py \ --deepspeed scripts/deepspeed/${DS_CONF[${ARCH}]}.json \ --lora_enable True --lora_r 128 --lora_alpha 256 \ - --learning_rate 2e-4 \ + --learning_rate ${LR_CONF[${ARCH}]} \ --model_name_or_path ${LANGUAGE_MODEL} \ --version v1 \ --data_path data/finetune_data/llava_v1_5_mix665k.json \