diff --git a/.github/workflows/gpu_test.yml b/.github/workflows/dataset.yml similarity index 72% rename from .github/workflows/gpu_test.yml rename to .github/workflows/dataset.yml index 8955c208..45f26bef 100644 --- a/.github/workflows/gpu_test.yml +++ b/.github/workflows/dataset.yml @@ -1,4 +1,4 @@ -name: ray +name: dataset on: # Trigger the workflow on push or pull request, @@ -8,13 +8,13 @@ on: - main paths: - "**/*.py" - - .github/workflows/ray_test.yml + - .github/workflows/dataset.yml pull_request: branches: - main paths: - "**/*.py" - - .github/workflows/ray_test.yml + - .github/workflows/dataset.yml jobs: ray: @@ -30,7 +30,7 @@ jobs: run: | [ ! -d "$HOME/verl-data" ] && git clone --depth 1 https://github.com/eric-haibin-lin/verl-data ~/verl-data pytest -s -x tests/verl - - name: Running ray tests that need 2 GPUs + - name: Running ray test using cupy (move it to L20 when dockerfile ready) run: | cd tests/ray - pytest -s -x test_rvdz.py test_driverfunc_to_worker.py test_data_transfer.py test_colocated_workers.py test_check_worker_alive.py + pytest -s -x test_rvdz.py \ No newline at end of file diff --git a/.github/workflows/e2e_gpu.yml b/.github/workflows/e2e_gpu.yml new file mode 100644 index 00000000..d8760a83 --- /dev/null +++ b/.github/workflows/e2e_gpu.yml @@ -0,0 +1,38 @@ +name: e2e_gpu + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/e2e_gpu.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/e2e_gpu.yml + +jobs: + e2e_gpu: + runs-on: [self-hosted, l20-1] + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install -e .[test] + - name: Running digit completon e2e training tests on 8 L20 GPUs + run: | + bash tests/e2e/run_ray_trainer.sh diff --git a/.github/workflows/ray_test.yml b/.github/workflows/ray_test.yml new file mode 100644 index 00000000..83ec8711 --- /dev/null +++ b/.github/workflows/ray_test.yml @@ -0,0 +1,42 @@ +name: ray + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/ray_test.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/ray_test.yml + +jobs: + ray: + runs-on: [self-hosted, l20-0] + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip install hf_transfer + pip install -e .[test] + pip install --upgrade "ray>=2.40.0" + - name: Running ray tests that need 8 GPUs + run: | + cd tests/ray + pytest -s -x --ignore=test_check_worker_alive.py --ignore=test_rvdz.py . diff --git a/.github/workflows/sanity.yml b/.github/workflows/sanity.yml new file mode 100644 index 00000000..ff76c663 --- /dev/null +++ b/.github/workflows/sanity.yml @@ -0,0 +1,39 @@ +name: sanity + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/sanity.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/sanity.yml + +jobs: + sanity: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e .[test] + - name: Run sanity test + run: | + pytest -s -x tests/sanity + - name: Run untility test + run: | + pytest -s -x tests/utility diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml new file mode 100644 index 00000000..bffdfaea --- /dev/null +++ b/.github/workflows/vllm.yml @@ -0,0 +1,42 @@ +name: vllm + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/vllm.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/vllm.yml + +jobs: + vllm: + runs-on: [self-hosted, l20-0] + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test] + pip3 install vllm==0.5.4 + - name: Running vllm tests on 8 L20 GPUs + run: | + cd tests/rollout + torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py diff --git a/.github/workflows/yapf_format.yml b/.github/workflows/yapf_format.yml index 548df4c4..9260b0ab 100644 --- a/.github/workflows/yapf_format.yml +++ b/.github/workflows/yapf_format.yml @@ -38,7 +38,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install yapf + pip install --upgrade yapf pip install toml==0.10.2 - name: Running yapf run: | diff --git a/requirements.txt b/requirements.txt index aaec5690..b282b34a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ datasets dill hydra-core numpy +pandas pybind11 ray tensordict<0.6 diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/arithmetic_sequence/data/create_dataset.py b/tests/e2e/arithmetic_sequence/data/create_dataset.py new file mode 100644 index 00000000..e023a291 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/data/create_dataset.py @@ -0,0 +1,46 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.e2e.envs.digit_completion import DigitCompletion, generate_ground_truth_response +from torch.utils import data +import os + +if __name__ == '__main__': + simple_task = DigitCompletion(max_number=9, max_diff=9, max_num_in_response=9) + all_prompts = simple_task.get_all_prompts() + + # 21 * 6 * 4 + train_data, test_data = data.random_split(all_prompts, lengths=[0.8, 0.2]) + train_data = list(train_data) + test_data = list(test_data) + + train_data = [[{'role': 'user', 'content': str(item)}] \ + for item in train_data] + test_data = [[{'role': 'user', 'content': str(item)}] \ + for item in test_data] + + print(f'Size of train: {len(train_data)}, size of test: {len(test_data)}') + + train_data = {'prompt': train_data} + test_data = {'prompt': test_data} + + model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__))) + + import pandas as pd + + train_data_frame = pd.DataFrame(train_data) + test_data_frame = pd.DataFrame(test_data) + + train_data_frame.to_parquet(os.path.join(model_folder, 'train.parquet')) + test_data_frame.to_parquet(os.path.join(model_folder, 'test.parquet')) diff --git a/tests/e2e/arithmetic_sequence/data/test.parquet b/tests/e2e/arithmetic_sequence/data/test.parquet new file mode 100644 index 00000000..d0729dc3 Binary files /dev/null and b/tests/e2e/arithmetic_sequence/data/test.parquet differ diff --git a/tests/e2e/arithmetic_sequence/data/train.parquet b/tests/e2e/arithmetic_sequence/data/train.parquet new file mode 100644 index 00000000..0a03a61a Binary files /dev/null and b/tests/e2e/arithmetic_sequence/data/train.parquet differ diff --git a/tests/e2e/arithmetic_sequence/model/config.json b/tests/e2e/arithmetic_sequence/model/config.json new file mode 100644 index 00000000..87944c51 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/model/config.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": null, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 128, + "initializer_range": 0.02, + "intermediate_size": 344, + "max_position_embeddings": 2048, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 4, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pad_token_id": 2, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.3", + "use_cache": true, + "vocab_size": 16 +} diff --git a/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py b/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py new file mode 100644 index 00000000..9a3135f0 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py @@ -0,0 +1,61 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Create a random model and tokenizer for PPO training +""" + +import torch +import os +from transformers import AutoModelForCausalLM, LlamaConfig, AutoTokenizer + +from tests.e2e.envs.digit_completion import CharTokenizer + +tokenizer = CharTokenizer( + characters=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',', ':'], + model_max_length=2048, + chat_template= + "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}" +) + +config = LlamaConfig(vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16, + hidden_size=128, + intermediate_size=344, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id) + +model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) + +model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__))) +os.makedirs(model_folder, exist_ok=True) + +model.save_pretrained(model_folder) + +tokenizer_folder = model_folder +tokenizer.save_pretrained(tokenizer_folder) + +load_tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder) + +chat = [{'role': 'user', 'content': '1,0:2,3'}] + +load_tokenizer.padding_side = 'left' +print( + load_tokenizer.apply_chat_template(chat, + tokenize=True, + add_generation_prompt=True, + max_length=10, + padding='max_length')) diff --git a/tests/e2e/arithmetic_sequence/model/generation_config.json b/tests/e2e/arithmetic_sequence/model/generation_config.json new file mode 100644 index 00000000..578d3750 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/model/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "eos_token_id": 1, + "pad_token_id": 2, + "transformers_version": "4.43.3" +} diff --git a/tests/e2e/arithmetic_sequence/model/model.safetensors b/tests/e2e/arithmetic_sequence/model/model.safetensors new file mode 100644 index 00000000..509e6e97 Binary files /dev/null and b/tests/e2e/arithmetic_sequence/model/model.safetensors differ diff --git a/tests/e2e/arithmetic_sequence/model/tokenizer_config.json b/tests/e2e/arithmetic_sequence/model/tokenizer_config.json new file mode 100644 index 00000000..d01bf75f --- /dev/null +++ b/tests/e2e/arithmetic_sequence/model/tokenizer_config.json @@ -0,0 +1,18 @@ +{ + "char_ords": [ + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 44, + 58 + ], + "model_max_length": 2048, + "chat_template": "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}" +} \ No newline at end of file diff --git a/tests/e2e/arithmetic_sequence/rl/README.md b/tests/e2e/arithmetic_sequence/rl/README.md new file mode 100644 index 00000000..a9b3f9e2 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/rl/README.md @@ -0,0 +1,37 @@ +# Digit completion + +This is an example of solving a digit completion problem. The problem is defined as below: + +The prompt is a sequence of numbers with fixed difference. The agent's goal is to complete the next N numbers. +If the max number is reached, the next number should be modulo with max number. + +For example, +- prompt = [1, 2, 3] +- N = 5 +- max_number = 6 + +The response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]. + +# Environment definition + +The core definition of the task is defined in verl/envs/digit_completion/task.py + +It is highly recommended to take a look at it for better understanding. + + + +# Run experiments + +The users are required to specify the config path and config name (and the relative model config path to the current working directory) + +```bash +# cd examples/arithmetic_sequence/rl + +# Specify the config path and config name (current working dir) +python3 -m verl.trainer.ppo.ray_megatron_train_synchronous --config-path=$(pwd)/config --config-name='ray_megatron' + +# The default relative path of model config is 'config/model_config', if you want to change it, you can rewrite it in ray_megatron.yaml or using: +python3 -m verl.trainer.ppo.ray_megatron_train_synchronous --config-path=$(pwd)/config --config-name='ray_megatron' ++model.base_path=config/model_config + +``` + diff --git a/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml b/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml new file mode 100644 index 00000000..52488979 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml @@ -0,0 +1,132 @@ +data: + tokenizer: null + train_files: ~/verl/tests/e2e/arithmetic_sequence/data/train.parquet + val_files: ~/verl/tests/e2e/arithmetic_sequence/data/test.parquet + prompt_key: prompt + max_prompt_length: 16 + max_response_length: 32 + train_batch_size: 800 + val_batch_size: 200 + return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/verl/tests/e2e/arithmetic_sequence/model + external_lib: tests.e2e.envs.digit_completion + override_config: {} + enable_gradient_checkpointing: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 200 + ppo_micro_batch_size: 200 + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.0 + ppo_epochs: 1 + shuffle: True + optim: + lr: 1e-4 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + grad_offload: False + optimizer_offload: False + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + micro_batch_size: 200 + rollout: + name: hf + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # for xperf_gpt + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.1 + ignore_eos: False + micro_batch_size: 200 + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: 200 + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-3 + model: + path: ~/verl/tests/e2e/arithmetic_sequence/model + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: False + fsdp_config: + param_offload: False + grad_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: 200 + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + + # the following parameters are for backward-compatibility and should be removed + kl_ctrl: + type: fixed + kl_coef: 0.001 + +reward_model: + strategy: fsdp + enable: False + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + offload: False + fsdp_config: + min_num_params: 0 + micro_batch_size: 8 + max_length: null + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.005 + +trainer: + total_epochs: 200 + project_name: verl_examples + experiment_name: arithmetic_sequences + logger: ['console'] + nnodes: 1 + n_gpus_per_node: 1 + save_freq: -1 + test_freq: 1 + critic_warmup: 0 + default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/tests/e2e/arithmetic_sequence/rl/main_trainer.py b/tests/e2e/arithmetic_sequence/rl/main_trainer.py new file mode 100644 index 00000000..6ce87715 --- /dev/null +++ b/tests/e2e/arithmetic_sequence/rl/main_trainer.py @@ -0,0 +1,161 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Using FSDPTrainer +""" +import re +import os +import hydra +import numpy as np +import ray +import torch +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizer, AutoTokenizer + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask +from tests.e2e.envs.digit_completion import CharTokenizer +import pandas as pd + + +def make_reward_function(tokenizer, num_examine): + + def arithmetic_sequence_reward_function(data: DataProto): + from tests.e2e.envs.digit_completion.task import compute_reward + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + for i in range(data.batch.batch_size[0]): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch['prompts'] + + prompt_length = prompt_ids.shape[-1] + + # extract raw prompt + valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + # extract response + response_ids = data_item.batch['responses'] + response_length = response_ids.shape[-1] + response_mask = data.batch['attention_mask'][i][-response_length:] + valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + prompt = tokenizer.decode(valid_prompt_ids) + response = tokenizer.decode(valid_response_ids) + # remove bos and eos + prompt = prompt.replace(tokenizer.sep_token, '') + response = response.replace(tokenizer.eos_token, '') + if i < num_examine: + print(prompt, response) + + reward_output = compute_reward(prompt, response) + dense_reward = reward_output[0].tolist() + ground_truth_response = reward_output[1]['ground_truth_response'] + if len(dense_reward) > 0: + last_reward = dense_reward[-1] + else: + if len(ground_truth_response) == 0: + last_reward = 1 + else: + last_reward = 0 + + # pad to response_length + for _ in range(reward_tensor.shape[-1] - len(dense_reward)): + dense_reward.append(last_reward) + + dense_reward = torch.as_tensor(dense_reward, dtype=torch.float32, device=reward_tensor.device) + reward_tensor[i] = dense_reward * response_mask + + return reward_tensor + + return arithmetic_sequence_reward_function + + +@hydra.main(config_path='config', config_name='ray_trainer', version_base=None) +def main(config): + ray.init( + runtime_env={ + 'env_vars': { + 'MEGATRON_USE_CUDA_TIMER': '0', + 'MEGATRON_START_PROCESS_TIMER': 'False', + 'TOKENIZERS_PARALLELISM': 'true', + 'NCCL_DEBUG': 'WARN' + } + }) + + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + + dp_size = config.trainer.n_gpus_per_node * config.trainer.nnodes + # normalize batch_size + # TODO: move this inside each role + config.actor_rollout_ref.actor.ppo_mini_batch_size //= dp_size + config.actor_rollout_ref.actor.ppo_micro_batch_size //= dp_size + config.critic.ppo_micro_batch_size //= dp_size + config.actor_rollout_ref.rollout.micro_batch_size //= dp_size + + # print the config + # print initial config + print('Config after normalizing batch_size') + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + local_path = os.path.expanduser(local_path) + # instantiate tokenizern + tokenizer = AutoTokenizer.from_pretrained(local_path) + print(f'Tokenizer vocab_size: {tokenizer.vocab_size}') + + # define worker classes + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = 'global_pool' + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + reward_fn = make_reward_function(tokenizer=tokenizer, num_examine=1) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPPOTrainer(config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + reward_fn=reward_fn, + val_reward_fn=reward_fn) + trainer.init_workers() + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/tests/e2e/check_results.py b/tests/e2e/check_results.py new file mode 100644 index 00000000..bd3151f0 --- /dev/null +++ b/tests/e2e/check_results.py @@ -0,0 +1,52 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import numpy as np + + +def extract_reward_from_line(line): + # TODO: this function needs error handling + try: + key_vals = line.split(' - ') + for key_val in key_vals: + key, val = key_val.split(':') + if key == 'critic/rewards/mean': + reward = float(val) + return reward + return -np.inf + except Exception: + return -np.inf + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_file', required=True, type=str) + + args = parser.parse_args() + + with open(args.output_file, 'r') as f: + output = f.read().split('\n') + + best_reward = -np.inf + for line in output: + if line.startswith('step'): + reward = extract_reward_from_line(line) + if reward > best_reward: + best_reward = reward + + print(f'Best reward is {best_reward}') + assert best_reward > 0.2, f'Best reward must be greater than 0.3. best_reward: {best_reward}' + print('Check passes') diff --git a/tests/e2e/envs/__init__.py b/tests/e2e/envs/__init__.py new file mode 100644 index 00000000..7d391479 --- /dev/null +++ b/tests/e2e/envs/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .digit_completion import DigitCompletion + +__all__ = ['DigitCompletion'] \ No newline at end of file diff --git a/tests/e2e/envs/digit_completion/__init__.py b/tests/e2e/envs/digit_completion/__init__.py new file mode 100644 index 00000000..3e3aa76d --- /dev/null +++ b/tests/e2e/envs/digit_completion/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .task import DigitCompletion, generate_ground_truth_response +from .tokenizer import CharTokenizer + +from transformers import AutoTokenizer, LlamaConfig + +AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True) + +__all__ = ['DigitCompletion', 'generate_ground_truth_response', 'CharTokenizer'] \ No newline at end of file diff --git a/tests/e2e/envs/digit_completion/task.py b/tests/e2e/envs/digit_completion/task.py new file mode 100644 index 00000000..7322027c --- /dev/null +++ b/tests/e2e/envs/digit_completion/task.py @@ -0,0 +1,177 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Task and environment definition for digit completion.""" + +import numpy as np + + +class DigitCompletion(object): + """ + The implementation of a simple digit completion task. + The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers. + If the max number is reached, the next number should be modulo with max number. + + For example, + - prompt = [1, 2, 3] + - N = 5 + - max_number = 6 + + the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1] + + Note that the tokenizer is char-level to increase the difficulty. + """ + + def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0): + """ + + Args: + max_number: the maximum number allowed in the arithmetic sequence + max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff] + max_num_in_response: the maximum number in the response + """ + super().__init__() + self.max_number = max_number + self.max_diff = max_diff + self.max_num_in_response = max_num_in_response + assert self.max_num_in_response < 10 + assert self.max_number > 0 + assert self.max_diff > 0 + self.max_number_length = len(str(max_number)) + # {num1},{num2}:{max_num_in_response},{max_number} + self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed + + self.np_rng = np.random.default_rng(seed=seed) + + def __str__(self): + return f'Prompt length: {self.prompt_length}. Response length: {self.response_length}, ' \ + f'Max number: {self.max_number}. Max diff: {self.max_diff}, ' \ + f'Max number in response: {self.max_num_in_response}' + + def get_state(self): + return {'rng': self.np_rng} + + def set_state(self, state): + assert 'rng' in state, 'rng must be inside state' + self.np_rng = state['rng'] + + @property + def prompt_length(self): + return self._prompt_length + + @property + def response_length(self): + # number length + comma length + [EOS] + # The actual number times 1.5 to allow 'U' + return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2 + + def add(self, a, b): + return (a + b) % self.max_number + + def get_all_prompts(self): + all_prompts = [] + for first_num in range(self.max_number + 1): + for diff in range(0, self.max_diff + 1): + second_num = self.add(first_num, diff) + for num_to_complete in range(self.max_num_in_response + 1): + prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}' + all_prompts.append(prompt) + return all_prompts + + def sample_str_prompts(self): + # step 1: sample initial numbers + first_num = self.np_rng.integers(self.max_number + 1) + diff = self.np_rng.integers(self.max_diff + 1) + second_num = self.add(first_num, diff) + num_to_complete = self.np_rng.integers(self.max_num_in_response + 1) + prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}' + return prompt + + def sample_batch_str_prompts(self, batch_size): + str_prompts = [] + for _ in range(batch_size): + str_prompts.append(self.sample_str_prompts()) + return str_prompts + + +def compute_attention_mask(prompts, pad_token_id): + mask = np.ones_like(prompts) + mask[prompts == pad_token_id] = 0 + return mask + + +def compute_position_id_with_mask(mask): + return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None) + + +def generate_ground_truth_response(prompt: str): + """Generate ground truth response given a prompt.""" + num, info = prompt.split(':') + num1, num2 = num.split(',') + max_number, num_to_gen = info.split(',') + num1 = int(num1) + num2 = int(num2) + max_number = int(max_number) + num_to_gen = int(num_to_gen) + diff = (num2 - num1) % max_number + results = [] + last_num = num2 + for _ in range(num_to_gen): + curr = (last_num + diff) % max_number + results.append(str(curr)) + last_num = curr + response = ','.join(results) + return response + + +def compute_reward(prompt: str, response: str, sequence_reward=1.): + """We compute dense reward here so that we can directly train RL without SFT""" + response_length = len(response) + ground_truth_response = generate_ground_truth_response(prompt) + per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS] + + # pad + reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token + # assign reward until mismatches + ground_truth_idx = 0 + for i in range(response_length): + if ground_truth_idx == len(ground_truth_response): + break + + ground_truth_response_token = ground_truth_response[ground_truth_idx] + response_token = response[i] + if ground_truth_response_token == response_token: + reward[i] = per_token_reward + ground_truth_idx += 1 + else: + # no matches + break + + return reward, {'ground_truth_response': ground_truth_response} + + +if __name__ == '__main__': + task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5) + print(task.sample_str_prompts()) + + prompt = '7,8:20,0' + response = '' + print(compute_reward(prompt, response)) + + prompt = '7,8:20,0' + response = 'E000' + print(compute_reward(prompt, response)) + + prompt = '9,10:20,2' + response = '11,12,13' + print(compute_reward(prompt, response)) diff --git a/tests/e2e/envs/digit_completion/tokenizer.py b/tests/e2e/envs/digit_completion/tokenizer.py new file mode 100644 index 00000000..1b8d94ee --- /dev/null +++ b/tests/e2e/envs/digit_completion/tokenizer.py @@ -0,0 +1,158 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py + +CharacterTokenzier for Hugging Face Transformers. + +This is heavily inspired from CanineTokenizer in transformers package. +""" + +import json +import os +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Union + +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer + + +class CharTokenizer(PreTrainedTokenizer): + + def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs): + """Character tokenizer for Hugging Face transformers. + + Args: + characters (Sequence[str]): List of desired characters. Any character which + is not included in this list will be replaced by a special token called + [UNK] with id=6. Following are list of all of the special tokens with + their corresponding ids: + "[CLS]": 0 + "[SEP]": 1 + "[BOS]": 2 + "[MASK]": 3 + "[PAD]": 4 + "[RESERVED]": 5 + "[UNK]": 6 + an id (starting at 7) will be assigned to each character. + + model_max_length (int): Model maximum sequence length. + """ + eos_token_str = 'E' + sep_token_str = 'S' + pad_token_str = 'P' + unk_token_str = 'U' + + self.characters = characters + self.model_max_length = model_max_length + eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False) + sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False) + pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False) + unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False) + + self._vocab_str_to_int = { + sep_token_str: 0, + eos_token_str: 1, + pad_token_str: 2, + unk_token_str: 3, + **{ + ch: i + 4 for i, ch in enumerate(characters) + }, + } + self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} + + super().__init__( + eos_token=eos_token, + sep_token=sep_token, + pad_token=pad_token, + unk_token=unk_token, + add_prefix_space=False, + model_max_length=model_max_length, + **kwargs, + ) + + self.chat_template = chat_template + + @property + def vocab_size(self) -> int: + return len(self._vocab_str_to_int) + + def get_vocab(self): + return self._vocab_str_to_int + + def _tokenize(self, text: str) -> List[str]: + return list(text) + + def _convert_token_to_id(self, token: str) -> int: + return self._vocab_str_to_int.get(token, self._vocab_str_to_int["U"]) + + def _convert_id_to_token(self, index: int) -> str: + return self._vocab_int_to_str[index] + + def convert_tokens_to_string(self, tokens): + return "".join(tokens) + + def build_inputs_with_special_tokens(self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + sep = [self.sep_token_id] + cls = [self.cls_token_id] + result = cls + token_ids_0 + sep + if token_ids_1 is not None: + result += token_ids_1 + sep + return result + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + result = [1] + ([0] * len(token_ids_0)) + [1] + if token_ids_1 is not None: + result += ([0] * len(token_ids_1)) + [1] + return result + + def get_config(self) -> Dict: + return { + "char_ords": [ord(ch) for ch in self.characters], + "model_max_length": self.model_max_length, + "chat_template": self.chat_template + } + + @classmethod + def from_config(cls, config: Dict) -> "DigitCompletionTokenizer": + cfg = {} + cfg["characters"] = [chr(i) for i in config["char_ords"]] + cfg["model_max_length"] = config["model_max_length"] + cfg["chat_template"] = config["chat_template"] + return cls(**cfg) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + cfg_file = Path(save_directory) / "tokenizer_config.json" + cfg = self.get_config() + with open(cfg_file, "w") as f: + json.dump(cfg, f, indent=4) + + @classmethod + def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs): + cfg_file = Path(save_directory) / "tokenizer_config.json" + with open(cfg_file) as f: + cfg = json.load(f) + return cls.from_config(cfg) diff --git a/tests/e2e/run_ray_trainer.sh b/tests/e2e/run_ray_trainer.sh new file mode 100644 index 00000000..51d18fcc --- /dev/null +++ b/tests/e2e/run_ray_trainer.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -e -x + +OUTPUT_FILE="/tmp/output_ray_trainer.txt" + +export PATH=$PATH:~/.local/bin + +rm -rf $OUTPUT_FILE +python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ + data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ + data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ + actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ + critic.model.path=tests/e2e/arithmetic_sequence/model | tee $OUTPUT_FILE; + +python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE +rm -rf $OUTPUT_FILE diff --git a/tests/gpu_utility/test_memory_buffers.py b/tests/gpu_utility/test_memory_buffers.py new file mode 100644 index 00000000..0116c8be --- /dev/null +++ b/tests/gpu_utility/test_memory_buffers.py @@ -0,0 +1,70 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test memory buffers +- We start with two models with the same weights +- We use Memory buffer to make one of the models and then compare the parameters +""" + +import torch +import gc + +from transformers import LlamaModel, LlamaConfig +from verl.utils.memory_buffer import MemoryBufferModuleWrapper + + +def test_memory_buffers(): + llama_config = LlamaConfig(vocab_size=256, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=2, + num_attention_heads=16, + num_key_value_heads=16) + + model = LlamaModel(config=llama_config).cuda() + model_copy = LlamaModel(config=llama_config).cuda() + model_copy.load_state_dict(model.state_dict()) + + model_named_params = dict(model.named_parameters()) + model_copy_named_params = dict(model_copy.named_parameters()) + + norm_factor = 1024**3 + + t_before = torch.cuda.get_device_properties(0).total_memory / norm_factor + r_before = torch.cuda.memory_reserved(0) / norm_factor + a_before = torch.cuda.memory_allocated(0) / norm_factor + + print(f'Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB') + + model_wrapper = MemoryBufferModuleWrapper(model) + + t = torch.cuda.get_device_properties(0).total_memory / norm_factor + r = torch.cuda.memory_reserved(0) / norm_factor + a = torch.cuda.memory_allocated(0) / norm_factor + + gc.collect() + torch.cuda.empty_cache() + + print(f'After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB') + + change_ratio = (a - a_before) / a_before + assert change_ratio < 0.01, f'make sure the allocated change is less than 1%, Got {change_ratio}' + + for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()): + assert name1 == name2 + assert torch.eq(param1.data, param2.data).all(), f'{param1.data}, {param2.data}, {name1}' + + +if __name__ == '__main__': + test_memory_buffers() diff --git a/tests/gpu_utility/test_ops.py b/tests/gpu_utility/test_ops.py new file mode 100644 index 00000000..b7663cfd --- /dev/null +++ b/tests/gpu_utility/test_ops.py @@ -0,0 +1,47 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_flash_attn_cross_entropy(): + from verl.utils.torch_functional import logprobs_from_logits_naive + + from verl.utils.debug import log_gpu_memory_usage + + from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + + import torch + from torch import nn + + log_gpu_memory_usage('At start') + + hidden_states = torch.randn(size=(2048, 5120), device='cuda', requires_grad=True, dtype=torch.bfloat16) + + linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device='cuda', dtype=torch.bfloat16) + + logits = linear(hidden_states) + + # logits = logits.float() + labels = torch.randint(low=0, high=155136, size=(2048,), device='cuda') + + log_gpu_memory_usage('before computation') + # output = checkpoint.checkpoint(logprobs_from_logits, logits, labels, use_reentrant=True) + output = -cross_entropy_loss(logits, labels)[0] + # output = logprobs_from_logits(logits, labels) + log_gpu_memory_usage('After forward') + output.sum().backward() + log_gpu_memory_usage('After backward') + + groundtruth = logprobs_from_logits_naive(logits.float(), labels) + + torch.testing.assert_close(output, groundtruth) diff --git a/tests/gpu_utility/test_torch_functional.py b/tests/gpu_utility/test_torch_functional.py new file mode 100644 index 00000000..292be4de --- /dev/null +++ b/tests/gpu_utility/test_torch_functional.py @@ -0,0 +1,81 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.utils.model import create_random_mask +from flash_attn.bert_padding import unpad_input +import torch + + +def test_log_probs_from_logits_response_rmpad(): + from verl.utils.torch_functional import log_probs_from_logits_response, log_probs_from_logits_response_rmpad + vocab_size = 32000 + batch_size = 2 + prompt_length = 256 + response_length = 256 + + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device='cuda') + attention_mask = create_random_mask(input_ids=input_ids, + max_ratio_of_left_padding=0.2, + max_ratio_of_valid_token=0.8, + min_ratio_of_valid_token=0.6) + + response_mask = attention_mask[:, -response_length:] + + assert torch.all(response_mask[:, 0] == 1) + + logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device='cuda') + logits_rmpad = unpad_input(logits, attention_mask)[0] + + expected_output = log_probs_from_logits_response(input_ids=input_ids, + logits=logits, + response_length=response_length) + actual_output = log_probs_from_logits_response_rmpad(input_ids=input_ids, + attention_mask=attention_mask, + logits_rmpad=logits_rmpad, + response_length=response_length) + + # This should bitwise align as only this operation only contains gather operators + assert torch.all(torch.eq(actual_output * response_mask, expected_output * response_mask)) + + +def test_lr_scheduler(): + from torch import nn + model = nn.Linear(10, 10) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + from verl.utils.torch_functional import get_constant_schedule_with_warmup + constant_lr = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=2) + + lr_lst = [] + + for _ in range(5): + lr_lst.append(constant_lr.get_last_lr()[0]) + constant_lr.step() + + torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.001, 0.001]) + + from verl.utils.torch_functional import get_cosine_schedule_with_warmup + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + cosine_lr = get_cosine_schedule_with_warmup(optimizer=optimizer, + num_warmup_steps=2, + num_training_steps=5, + min_lr_ratio=0.1) + + lr_lst = [] + + for _ in range(5): + lr_lst.append(cosine_lr.get_last_lr()[0]) + cosine_lr.step() + + torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.0007750000000000002, 0.0003250000000000002]) diff --git a/tests/ray/test_data_transfer.py b/tests/ray/test_data_transfer.py index 46b962cd..a17affd3 100644 --- a/tests/ray/test_data_transfer.py +++ b/tests/ray/test_data_transfer.py @@ -52,7 +52,7 @@ def do_nothing(self, data): def test_data_transfer(): ray.init() # construct resource pool - resource_pool = RayResourcePool([2]) + resource_pool = RayResourcePool([8]) cls_with_init = RayClassWithInitArgs(cls=DummyWorker) # construct worker group wg = RayWorkerGroup(resource_pool, cls_with_init) diff --git a/tests/ray/test_remote_api.py b/tests/ray/test_remote_api.py deleted file mode 100644 index b7a64b6e..00000000 --- a/tests/ray/test_remote_api.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from verl.single_controller.remote import remote, RemoteBackend, SharedResourcePool -from verl.single_controller.base.decorator import register, Dispatch -from verl.single_controller.base.worker import Worker - - -@remote(process_on_nodes=[3], use_gpu=True, name_prefix="actor", sharing=SharedResourcePool) -class Actor(Worker): - ... - - -@remote(process_on_nodes=[3], use_gpu=True, name_prefix="critic", sharing=SharedResourcePool) -class Critic(Worker): - ... - - -@remote(process_on_nodes=[2], use_gpu=True, name_prefix="reward", sharing=SharedResourcePool.from_role("actor")) -class Reward(Worker): - ... - - -@remote(process_on_nodes=[2], use_gpu=True, name_prefix="ref", sharing=SharedResourcePool.from_role("actor", "critic")) -class Ref(Worker): - ... - - -@remote(process_on_nodes=[1], use_gpu=True, name_prefix="sec_rm", sharing=SharedResourcePool.from_role("any")) -class SecRM(Worker): - ... - - -def test(): - print("Remote.init_distributed") - remote.init_distributed(backend=RemoteBackend.RAY) - - print("create actor worker group") - actor = Actor() - - print("create critic worker group") - critic = Critic() - - print("create rm worker group") - reward = Reward() - - print("create ref worker group") - ref = Ref() - - print("create sec_rm worker group") - sec_rm = SecRM() - - actor_gpus = actor.execute_all_sync("get_cuda_visible_devices") - critic_gpus = critic.execute_all_sync("get_cuda_visible_devices") - reward_gpus = reward.execute_all_sync("get_cuda_visible_devices") - ref_gpus = ref.execute_all_sync("get_cuda_visible_devices") - sec_rm_gpus = sec_rm.execute_all_sync("get_cuda_visible_devices") - - for gpu in actor_gpus: - assert gpu not in critic_gpus, f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}" - - for gpu in critic_gpus: - assert gpu not in actor_gpus, f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}" - - for gpu in reward_gpus: - assert gpu in actor_gpus, f"actor gpus = {actor_gpus}, reward gpus = {reward_gpus}" - - for gpu in ref_gpus: - assert gpu in actor_gpus + critic_gpus, \ - f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}, ref gpus = {ref_gpus}" - - for gpu in sec_rm_gpus: - assert gpu in actor_gpus + critic_gpus, \ - f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}, sec rm gpus = {sec_rm_gpus}" - - # for ci only - import ray - ray.shutdown() diff --git a/tests/rollout/run_fsdp_vllm.py b/tests/rollout/run_fsdp_vllm.py new file mode 100644 index 00000000..d9e165a9 --- /dev/null +++ b/tests/rollout/run_fsdp_vllm.py @@ -0,0 +1,138 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload +from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType +import torch + +from verl.utils.distributed import initialize_global_process_group +from verl.third_party.vllm import LLM + +from vllm import SamplingParams + + +def main(): + assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example' + local_rank, rank, world_size = initialize_global_process_group() + + local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = os.path.expanduser(local_cache_path) + hdfs_path = 'Qwen/Qwen2-7B-Instruct' + + from verl.utils.fs import copy_local_path_from_hdfs + local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path) + tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) + actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True) + with torch.device("cuda"): + actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) + actor_model.to(torch.bfloat16) + + max_prompt_length = 16 + response_length = 32 + preencode_prompts = [ + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + tokenizer.pad_token = tokenizer.eos_token + prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) + input_ids = prompts['input_ids'] + attention_mask = prompts['attention_mask'] + from verl.utils.torch_functional import pad_sequence_to_length + input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda() + attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda() + + from transformers import GenerationConfig + generation_config = GenerationConfig(do_sample=False) + actor_model.cuda() + output = actor_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=32, + # max_length=max_length, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config=generation_config, + # renormalize_logits=True, + output_scores=False, # this is potentially very large + return_dict_in_generate=True, + use_cache=False) # may OOM when use_cache = True + seq = output.sequences + response = seq[:, max_prompt_length:] + + print(f'hf response: {tokenizer.batch_decode(response)}') + + tensor_model_parallel_size = 4 + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + + mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + fsdp_model = FSDP(actor_model, + use_orig_params=True, + auto_wrap_policy=None, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + cpu_offload=CPUOffload(offload_params=False), + sync_module_states=False, + device_mesh=device_mesh) + + FSDP.set_state_dict_type(fsdp_model, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig()) + + state_dict = fsdp_model.state_dict() + + sampling_params = SamplingParams(temperature=0, + top_p=1, + n=1, + max_tokens=response_length, + logprobs=1, + ignore_eos=True, + detokenize=False) + + print(actor_model_config) + llm = LLM(model=None, + tokenizer=tokenizer, + model_hf_config=actor_model_config, + tensor_parallel_size=tensor_model_parallel_size, + enforce_eager=True, + dtype='bfloat16', + load_format='dummy_dtensor', + gpu_memory_utilization=0.1, + trust_remote_code=True) + + llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor') + + input_ids = input_ids.cuda() + attention_mask = attention_mask.cuda() + idx_list = [] + batch_size = input_ids.shape[0] + + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs + for i in range(batch_size): + idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) + print('start generation') + outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) + vllm_output = outputs[0].cuda() + if torch.distributed.get_rank() == 0: + print(f'hf response: {tokenizer.batch_decode(response)}') + print(f'vllm response: {tokenizer.batch_decode(vllm_output)}') + + +if __name__ == "__main__": + main() diff --git a/tests/rollout/test_vllm_hf_loader.py b/tests/rollout/test_vllm_hf_loader.py new file mode 100644 index 00000000..64e956e0 --- /dev/null +++ b/tests/rollout/test_vllm_hf_loader.py @@ -0,0 +1,174 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import transformers + +from verl.third_party.vllm import LLM, vllm_version +from verl.utils.model import update_model_config +from vllm import SamplingParams +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM + +from transformers import GenerationConfig + +from verl.utils.torch_functional import pad_sequence_to_length +from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs + + +def levenshtein(s1, s2): + m, n = len(s1), len(s2) + # Initialize matrix of zeros + dp = [[0] * (n + 1) for _ in range(m + 1)] + # Initialize first column and first row of the matrix + for i in range(m + 1): + dp[i][0] = i # Deletion from s1 to empty string + for j in range(n + 1): + dp[0][j] = j # Insertion to s1 from empty string + # Compute the Levenshtein distance matrix + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match + dp[i][j] = min( + dp[i - 1][j] + 1, # Deletion + dp[i][j - 1] + 1, # Insertion + dp[i - 1][j - 1] + cost # Substitution + ) + return dp[m][n] + + +def are_lists_similar(a, b): + if len(a) != len(b): + print("The lists are of different lengths.") + return False + + total_length = 0 + total_diff = 0 + + for s1, s2 in zip(a, b): + max_len = max(len(s1), len(s2)) + total_length += max_len + diff = levenshtein(s1, s2) + total_diff += diff + print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") + + percentage_difference = (total_diff / total_length) * 100 + print(f"Total difference: {percentage_difference:.2f}%") + + return percentage_difference <= 10 + + +def test_vllm_with_hf(): + assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.' + + # fill rollout config + max_prompt_length = 16 + max_response_length = 16 + + # Initialize model and token + local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = os.path.expanduser(local_cache_path) + hdfs_path = 'deepseek-ai/deepseek-llm-7b-chat' + from verl.utils.fs import copy_local_path_from_hdfs + local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path) + tokenizer = AutoTokenizer.from_pretrained(local_model_path) + + preencode_prompts = [ + "Who won the Champions League in 2019?", + "The founder of Apple is", + "What's your name", + ] + tokenizer.pad_token = tokenizer.eos_token + prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) + input_ids = prompts['input_ids'] + attention_mask = prompts['attention_mask'] + + input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) + attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) + + actor_model = AutoModelForCausalLM.from_pretrained(local_model_path) + actor_model.to(torch.bfloat16) + + actor_model_config = AutoConfig.from_pretrained(local_model_path) + + temperature = 0 + top_p = 1 + + kwargs = dict(n=1, + temperature=temperature, + top_p=top_p, + max_tokens=max_response_length, + logprobs=1, + ignore_eos=True) + + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + kwargs['detokenize'] = False + sampling_params = SamplingParams(**kwargs) + + tensor_parallel_size = 2 + + llm = LLM(model=actor_model, + tokenizer=tokenizer, + model_hf_config=actor_model_config, + tensor_parallel_size=tensor_parallel_size, + dtype='bfloat16', + gpu_memory_utilization=0.1, + load_format='hf') + + print('start generation') + input_ids = input_ids.cuda() + attention_mask = attention_mask.cuda() + batch_size = input_ids.size(0) + + idx_list = [] + # parse idx from torch.Tensor to List[List[str]] + for i in range(batch_size): + idx_list.append(_pre_process_inputs(tokenizer.pad_token_id, input_ids[i])) + outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) + vllm_output = outputs[0].cuda() + llm.free_cache_engine() + llm = None + import gc + torch.cuda.empty_cache() + gc.collect() + + generation_config = GenerationConfig(do_sample=False) + actor_model.cuda() + output = actor_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_response_length, + # max_length=max_length, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config=generation_config, + # renormalize_logits=True, + output_scores=False, # this is potentially very large + return_dict_in_generate=True, + use_cache=False) # may OOM when use_cache = True + seq = output.sequences + response = seq[:, max_prompt_length:] + + hf_response_tokens = tokenizer.batch_decode(response) + vllm_response_tokens = tokenizer.batch_decode(vllm_output) + + print(f'hf response: {hf_response_tokens}') + print(f'vllm response: {vllm_response_tokens}') + assert are_lists_similar(hf_response_tokens, vllm_response_tokens), \ + f'Strings differ more than 10%:\n' + print('Check Pass') + + +# if __name__ == "__main__": +# test_vllm_with_hf() diff --git a/tests/sanity/test_import.py b/tests/sanity/test_import.py new file mode 100644 index 00000000..2adf63a1 --- /dev/null +++ b/tests/sanity/test_import.py @@ -0,0 +1,23 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_import(): + import verl + print(verl.__version__) + + +def test_single_controller_import(): + import verl.single_controller + print(verl.single_controller.__version__) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 2ca961d5..ad6bab93 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -17,7 +17,7 @@ import os import socket from dataclasses import dataclass -from verl.single_controller.base.decorator import register, Dispatch +from verl.single_controller.base.decorator import register, Dispatch, Execute @dataclass @@ -179,3 +179,8 @@ def rank(self): def execute_with_func_generator(self, func, *args, **kwargs): ret_proto = func(self, *args, **kwargs) return ret_proto + + @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) + def execute_func_rank_zero(self, func, *args, **kwargs): + result = func(*args, **kwargs) + return result \ No newline at end of file diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py index 2f32a913..f146a0ea 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py @@ -196,6 +196,9 @@ def __init__(self, load_config: LoadConfig): raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): if isinstance(actor_model, Dict): return actor_model.items() @@ -241,6 +244,9 @@ def __init__(self, load_config: LoadConfig): raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): # NOTE(shengguangming) Load the weights from the actor model pass diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index b1ef76b5..75c1c9b5 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -335,7 +335,7 @@ def _validate(self): # test_batch = test_batch.to('cuda') # we only do validation on rule-based rm - if test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': + if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': return {} test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])