Skip to content

Commit 4c87e8d

Browse files
committed
fix init weights issue for critic/reward model
Signed-off-by: jouw <[email protected]>
1 parent 1344ffd commit 4c87e8d

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,7 @@ def train_rlhf(self, inputs):
236236
value = self.critic_model.forward_value(**batch,
237237
return_value_only=True,
238238
use_cache=False)[:, :-1]
239-
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
240-
start:],
239+
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:, start:],
241240
returns, action_mask[:, start:])
242241
self.critic_model.backward(critic_loss)
243242

applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from huggingface_hub import snapshot_download
1313
from transformers.integrations.deepspeed import HfDeepSpeedConfig
14+
from transformers.modeling_utils import no_init_weights
1415

1516
from dschat.utils.model.reward_model import RewardModel
1617
from dschat.utils.utils import load_state_dict_into_model, print_rank_0
@@ -99,7 +100,8 @@ def create_hf_model(model_class,
99100
dschf = None
100101
if rlhf_training:
101102
# the weight loading is handled by create critic model
102-
model = model_class.from_config(model_config)
103+
with no_init_weights():
104+
model = model_class.from_config(model_config)
103105
else:
104106
model = model_class.from_pretrained(
105107
model_name_or_path,

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,7 @@ def main():
594594
"-------------------------------------------------------------------------------------",
595595
args.global_rank)
596596

597-
if args.enable_tensorboard and torch.distributed.get_rank(
598-
) == 0:
597+
if args.enable_tensorboard and torch.distributed.get_rank() == 0:
599598
writer.add_scalar('reward',
600599
average_reward / inner_iter,
601600
global_step=step)

0 commit comments

Comments
 (0)