Skip to content

Commit

Permalink
Update finetune code to the latest version of Yi (#222)
Browse files Browse the repository at this point in the history
* adapt new version of Yi

* adapt new version of Yi

* apply autoflake

---------

Co-authored-by: Jun Tian <[email protected]>
  • Loading branch information
jiangchengSilent and findmyway authored Dec 2, 2023
1 parent 231bf09 commit 831d2e5
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 35 deletions.
1 change: 1 addition & 0 deletions finetune/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_hf_model(
from_tf=bool(".ckpt" in model_name_or_path),
config=model_config,
trust_remote_code=True,
use_flash_attention_2=True,
)
else:
model = model_class.from_pretrained(
Expand Down
35 changes: 0 additions & 35 deletions finetune/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
from shutil import copy

import deepspeed
import torch
Expand Down Expand Up @@ -90,12 +89,6 @@ def save_hf_format(model, tokenizer, args, sub_folder=""):
print(os.listdir(output_dir))
print(os.getcwd())

source = args.model_name_or_path
target = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
copy(os.path.join(source, "configuration_yi.py"), target)
copy(os.path.join(source, "modeling_yi.py"), target)
copy(os.path.join(source, "tokenization_yi.py"), target)


def get_all_reduce_mean(tensor):
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
Expand Down Expand Up @@ -258,31 +251,3 @@ def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict


def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
zero_stage_3 = zero_stage == 3
os.makedirs(save_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)

model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, "ds_id"):
with deepspeed.zero.GatheredParameters(
_z3_params_to_fetch([v]), enabled=zero_stage_3
):
v_p = v.data.cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p

if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict

0 comments on commit 831d2e5

Please sign in to comment.