Skip to content

Commit

Permalink
[sft] feat: fix sft dataset with latest preprocess code (#49)
Browse files Browse the repository at this point in the history
* api: rename tracking logger to wandb logger type

* [sft] feat: add tests for sft dataset

* refresh dataset

* force refresh

* use ds model for tokenizer

* add option for trainer.val_only

* fix path

* fix lint

* add sft test for cot and raw q&a

* add hf_tokenizer api to patch gemma tokenizer

* fix test
  • Loading branch information
eric-haibin-lin authored Dec 17, 2024
1 parent c7534db commit d60f843
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 101 deletions.
14 changes: 8 additions & 6 deletions examples/data_preprocess/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ def extract_solution(solution_str):
def make_map_fn(split):

def process_fn(example, idx):
question = example.pop('question')
question_raw = example.pop('question')

question = question + ' ' + instruction_following
question = question_raw + ' ' + instruction_following

answer = example.pop('answer')
solution = extract_solution(answer)
answer_raw = example.pop('answer')
solution = extract_solution(answer_raw)
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question
"content": question,
}],
"ability": "math",
"reward_model": {
Expand All @@ -71,7 +71,9 @@ def process_fn(example, idx):
},
"extra_info": {
'split': split,
'index': idx
'index': idx,
'answer': answer_raw,
"question": question_raw,
}
}
return data
Expand Down
19 changes: 11 additions & 8 deletions examples/sft/gsm8k/run_gemma_2b.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Tested in 4 GPUs
# Tested with 2 & 4 GPUs

set -x

Expand All @@ -8,7 +8,7 @@ if [ "$#" -lt 2 ]; then
fi

nproc_per_node=$1
hdfs_path=$2
save_path=$2

# Shift the arguments so $@ refers to the rest
shift 2
Expand All @@ -17,12 +17,15 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.prompt_key=prompt \
data.response_key=answer \
data.micro_batch_size=32 \
data.prompt_key=extra_info \
data.response_key=extra_info \
+data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=8 \
model.partial_pretrain=google/gemma-2b-it \
trainer.default_hdfs_dir=$hdfs_path \
trainer.default_local_dir=$save_path \
trainer.project_name=gsm8k-sft \
trainer.experiment_name=gsm8k-sft-gemma-2b-it \
trainer.total_epochs=3 \
trainer.logger=['console','wandb'] $@
trainer.total_epochs=2 \
trainer.logger=['console','wandb'] \
trainer.default_hdfs_dir=null $@
5 changes: 2 additions & 3 deletions examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,8 @@ def main_task(config):
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)

# instantiate tokenizer
tokenizer = AutoTokenizer.from_pretrained(local_path)
from verl.utils import set_pad_token_id
set_pad_token_id(tokenizer)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)

# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
Expand Down
9 changes: 2 additions & 7 deletions tests/verl/utils/dataset/test_rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,13 @@ def get_gsm8k_data():
local_folder = os.path.expanduser('~/verl-data/gsm8k/')
local_path = os.path.join(local_folder, 'train.parquet')
os.makedirs(local_folder, exist_ok=True)
# import fsspec
# with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout:
# content = fin.read()
# fout.write(content)
return local_path


def test_rl_dataset():
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-coder-1.3b-instruct')
from verl.utils import set_pad_token_id
set_pad_token_id(tokenizer)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct')
local_path = get_gsm8k_data()
dataset = RLHFDataset(parquet_files=local_path, tokenizer=tokenizer, prompt_key='prompt', max_prompt_length=256)

Expand Down
9 changes: 2 additions & 7 deletions tests/verl/utils/dataset/test_rm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os

from transformers import AutoTokenizer
from verl.utils import set_pad_token_id
from verl.utils import hf_tokenizer
from verl.utils.dataset.rm_dataset import RMDataset


Expand All @@ -24,16 +24,11 @@ def get_rm_data():
local_folder = os.path.expanduser('~/verl-data/full_hh_rlhf/rm/')
local_path = os.path.join(local_folder, 'test.parquet')
os.makedirs(local_folder, exist_ok=True)
# import fsspec
# with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout:
# content = fin.read()
# fout.write(content)
return local_path


def test_rm_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
set_pad_token_id(tokenizer)
tokenizer = hf_tokenizer("facebook/opt-1.3b")
local_path = get_rm_data()
dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512)
data = dataset[0]['input_ids']
Expand Down
60 changes: 60 additions & 0 deletions tests/verl/utils/dataset/test_sft_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from transformers import AutoTokenizer
from verl.utils import hf_tokenizer
from verl.utils.dataset.sft_dataset import SFTDataset


def get_gsm8k_data():
# prepare test dataset
url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet"
local_folder = os.path.expanduser('~/verl-data/gsm8k/')
local_path = os.path.join(local_folder, 'train.parquet')
return local_path


def test_sft_cot_dataset():
tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
local_path = get_gsm8k_data()
dataset = SFTDataset(parquet_files=local_path,
tokenizer=tokenizer,
prompt_key='prompt',
prompt_dict_keys=['content'],
response_key='extra_info',
response_dict_keys=['answer'],
max_length=512)

data = dataset[0]['input_ids']
output = tokenizer.batch_decode([data])[0]
assert len(output) > 1
assert type(output) == str


def test_sft_dataset():
tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
local_path = get_gsm8k_data()
dataset = SFTDataset(parquet_files=local_path,
tokenizer=tokenizer,
prompt_key='extra_info',
prompt_dict_keys=['question'],
response_key='extra_info',
response_dict_keys=['answer'],
max_length=512)

data = dataset[0]['input_ids']
output = tokenizer.batch_decode([data])[0]
assert len(output) > 1
assert type(output) == str
18 changes: 11 additions & 7 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def __init__(self, config, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
# build tokenizer first
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
self.tokenizer = AutoTokenizer.from_pretrained(local_model_path,
trust_remote_code=self.config.model.trust_remote_code)
from verl.utils import set_pad_token_id
set_pad_token_id(self.tokenizer)
from verl.utils import hf_tokenizer
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
if self.config.data.chat_template is not None:
raise ValueError('Apply Chat template from config is not supported yet.')

Expand All @@ -77,6 +75,8 @@ def __init__(self, config, device_mesh: DeviceMesh):
self._build_model_optimizer()

# TODO: add checkpoint manager
if self.device_mesh.get_rank() == 0:
print(self.config)

def _normalize_config_bsz(self):
dp_size = self.device_mesh.size()
Expand All @@ -95,13 +95,17 @@ def _build_dataloader(self):
self.train_dataset = SFTDataset(parquet_files=config.data.train_files,
tokenizer=self.tokenizer,
prompt_key=config.data.prompt_key,
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
response_key=config.data.response_key,
response_dict_keys=config.data.get('response_dict_keys', None),
max_length=config.data.max_length,
truncation=config.data.truncation)
self.val_dataset = SFTDataset(parquet_files=config.data.val_files,
tokenizer=self.tokenizer,
prompt_key=config.data.prompt_key,
prompt_dict_keys=config.data.get('prompt_dict_keys', None),
response_key=config.data.response_key,
response_dict_keys=config.data.get('response_dict_keys', None),
max_length=config.data.max_length,
truncation=config.data.truncation)

Expand Down Expand Up @@ -292,10 +296,11 @@ def save_checkpoint(self, step):
# save huggingface model
if self.device_mesh.get_rank() == 0:
os.makedirs(path, exist_ok=True)
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir)
self.model.save_pretrained(path, state_dict=state_dict)
self.tokenizer.save_pretrained(path)
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir)
if self.config.trainer.default_hdfs_dir:
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
torch.distributed.barrier()

def fit(self):
Expand Down Expand Up @@ -349,7 +354,6 @@ def main(config):
local_rank, rank, world_size = initialize_global_process_group()

device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))

trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh)
trainer.fit()

Expand Down
5 changes: 2 additions & 3 deletions verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def main(config):
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
local_path = copy_local_path_from_hdfs(config.model.path)
tokenizer = AutoTokenizer.from_pretrained(local_path)
from verl.utils import set_pad_token_id
set_pad_token_id(tokenizer)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)

if config.rollout.temperature == 0.:
assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'
Expand Down
7 changes: 4 additions & 3 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _select_rm_score_fn(data_source):


class RewardManager():
"""The reward manager.
"""

def __init__(self, tokenizer, num_examine) -> None:
self.tokenizer = tokenizer
Expand Down Expand Up @@ -112,9 +114,8 @@ def main_task(config):
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)

# instantiate tokenizer
tokenizer = AutoTokenizer.from_pretrained(local_path)
from verl.utils import set_pad_token_id
set_pad_token_id(tokenizer)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)

# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def fit(self):
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=global_steps)
if self.config.trainer.get('val_only', False):
return

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
Expand Down Expand Up @@ -527,3 +530,4 @@ def fit(self):
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=global_steps)
20 changes: 6 additions & 14 deletions verl/trainer/ppo/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from verl.utils.import_utils import import_external_libs
from verl.utils.debug import log_gpu_memory_usage
import verl.utils.hdfs_io as hdfs_io
from verl.utils import set_pad_token_id
from verl.utils import hf_tokenizer

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
Expand Down Expand Up @@ -107,8 +107,7 @@ def _build_model_optimizer(self,

# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=trust_remote_code)
set_pad_token_id(self.tokenizer)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
Expand Down Expand Up @@ -467,9 +466,7 @@ def _build_critic_model_optimizer(self, config):
from transformers import AutoTokenizer

tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
trust_remote_code=config.model.get('trust_remote_code', False))
set_pad_token_id(self.tokenizer)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))

from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
Expand Down Expand Up @@ -673,14 +670,9 @@ def _build_model(self, config):
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_local_path_from_hdfs(config.model.input_tokenizer)
self.input_tokenizer = AutoTokenizer.from_pretrained(input_tokenizer_local_path,
trust_remote_code=config.model.get(
'trust_remote_code', False))
self.tokenizer = AutoTokenizer.from_pretrained(local_path,
trust_remote_code=config.model.get(
'trust_remote_code', False))
set_pad_token_id(self.tokenizer)
set_pad_token_id(self.input_tokenizer)
self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path,
trust_remote_code=config.model.get('trust_remote_code', False))
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False))

trust_remote_code = config.model.get('trust_remote_code', False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
Expand Down
Loading

0 comments on commit d60f843

Please sign in to comment.