Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support separate value model #1624

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ class RLHFArguments:
default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
)
ppo_use_separate_value_model: Optional[bool] = field(
default=False,
metadata={"help": "Use a separate value model which does not share parameters with policy."}
)
value_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the value model."}
)


@dataclass
Expand Down
12 changes: 11 additions & 1 deletion src/llmtuner/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import torch
from tqdm import tqdm
from types import MethodType
from typing import TYPE_CHECKING, List, Optional, Tuple

from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
Expand Down Expand Up @@ -296,7 +297,16 @@ def batched_forward_pass(
attention_mask = input_kwargs["attention_mask"]

with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
unwrapped_model = self.accelerator.unwrap_model(model)
if "value" in unwrapped_model.pretrained_model.peft_config:
# this model has a separate value model and policy model
unwrapped_model.pretrained_model.set_adapter("value")
_, _, values = model(**input_kwargs)
unwrapped_model.pretrained_model.set_adapter("default")
logits, _, _ = model(**input_kwargs)
else:
# this model has a shared value model and policy model
logits, _, values = model(**input_kwargs)

if values.size(0) != input_ids.size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
Expand Down
19 changes: 19 additions & 0 deletions src/llmtuner/train/ppo/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py

import math
import os
from peft import TaskType, LoraConfig
from trl import PPOConfig
import torch
from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
Expand Down Expand Up @@ -29,6 +32,22 @@ def run_ppo(
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
if finetuning_args.ppo_use_separate_value_model:
if finetuning_args.value_model is not None:
model.pretrained_model.load_adapter(finetuning_args.value_model, "value", is_trainable=True)
state_dict = torch.load(os.path.join(finetuning_args.value_model, "pytorch_model.bin"), map_location="cpu")
model.load_state_dict(state_dict, strict=False)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target,
modules_to_save=finetuning_args.additional_target
)
model.pretrained_model.add_adapter("value", lora_config)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")

tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
Expand Down