Skip to content

Commit

Permalink
adapt new version of Yi
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangchengSilent committed Dec 1, 2023
1 parent 1d28c16 commit acad836
Showing 1 changed file with 0 additions and 34 deletions.
34 changes: 0 additions & 34 deletions finetune/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@ def save_hf_format(model, tokenizer, args, sub_folder=""):
print(os.listdir(output_dir))
print(os.getcwd())

source = args.model_name_or_path
target = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
copy(os.path.join(source, "configuration_yi.py"), target)
copy(os.path.join(source, "modeling_yi.py"), target)
copy(os.path.join(source, "tokenization_yi.py"), target)


def get_all_reduce_mean(tensor):
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
Expand Down Expand Up @@ -259,31 +253,3 @@ def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
torch.save(output_state_dict, output_model_file)
del output_state_dict


def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
zero_stage_3 = (zero_stage == 3)
os.makedirs(save_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)

model_to_save = model_ema.module if hasattr(model_ema,
'module') else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, 'ds_id'):
with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v]),
enabled=zero_stage_3):
v_p = v.data.cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p

if global_rank == 0:

torch.save(output_state_dict, output_model_file)
del output_state_dict

0 comments on commit acad836

Please sign in to comment.