From 7c4ba1a86190b57561d8a1b7ed6e3dbf69e23bec Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 1 Feb 2024 09:29:09 -0800 Subject: [PATCH] generalize the system --- README.md | 61 +++- self_rewarding_lm_pytorch/__init__.py | 10 + self_rewarding_lm_pytorch/dpo.py | 2 +- .../self_rewarding_lm_pytorch.py | 316 ++++++++++++------ setup.py | 2 +- 5 files changed, 263 insertions(+), 128 deletions(-) diff --git a/README.md b/README.md index e87dbe5..94db031 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,8 @@ from torch import Tensor from self_rewarding_lm_pytorch import ( SelfRewardingTrainer, - create_mock_dataset + create_mock_dataset, + create_default_paper_config ) from x_transformers import TransformerWrapper, Decoder @@ -41,7 +42,7 @@ transformer = TransformerWrapper( ) ) -sft_train_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1))) +sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1))) prompt_dataset = create_mock_dataset(100, lambda: 'mock prompt') def decode_tokens(tokens: Tensor) -> str: @@ -53,22 +54,21 @@ def encode_str(seq_str: str) -> Tensor: trainer = SelfRewardingTrainer( transformer, - train_sft_dataset = sft_train_dataset, - num_spin_cycles = 0, - num_preference_pairs = [1, 1], - preference_max_seq_len = 64, - prompt_dataset = prompt_dataset, - tokenizer_encode = encode_str, + finetune_configs = create_default_paper_config( + train_sft_dataset = sft_dataset, + self_reward_prompt_dataset = prompt_dataset, + dpo_num_train_steps = 1000 + ), tokenizer_decode = decode_tokens, + tokenizer_encode = encode_str, accelerate_kwargs = dict( cpu = True - ), - dpo_trainer_kwargs = dict( - batch_size = 1 ) ) trainer(overwrite_checkpoints = True) + +# checkpoints after each finetuning stage will be saved to ./checkpoints ``` SPIN can either enabled on `SelfRewardingTrainer` with the `spin = True` flag, or trained standalone as shown below @@ -123,8 +123,7 @@ from self_rewarding_lm_pytorch import RewardConfig trainer = SelfRewardingTrainer( transformer, ..., - num_candidate_responses = 4, # in the paper, they try 4 responses, and pick the max and min rewards for forming the preference pairs - reward_prompt_config = RewardConfig( + self_reward_prompt_config = RewardConfig( prompt_template = """ Pretty please rate the following user prompt and response User: {{ prompt }} @@ -140,6 +139,39 @@ trainer = SelfRewardingTrainer( ) ``` +Finally, if you would like to experiment with arbitrary orders of fine-tuning, you will also have that flexiblity, by passing in `FinetuneConfig` instances into `finetune_configs` as a list + +ex. say you want to carry out research on interleaving SPIN, External Rewarding, and Self-Rewarding + +This idea originated from Teknium from a private discord channel. + +```python + +# import the configs + +from self_rewarding_lm_pytorch import ( + SFTConfig, + SelfRewardDPOConfig, + ExternalRewardDPOConfig, + SelfPlayConfig, +) + +trainer = SelfRewardingTrainer( + finetune_configs = [ + SFTConfig(...), + SelfPlayConfig(...), + ExternalRewardDPOConfig(...), + SelfRewardDPOConfig(...), + SelfPlayConfig(...), + SelfRewardDPOConfig(...) + ] +) + +trainer() + +# checkpoints after each finetuning stage will be saved to ./checkpoints +``` + ## Todo - [x] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that @@ -149,12 +181,11 @@ trainer = SelfRewardingTrainer( - [x] early stopper - [x] handle break signal if all done on main process - [x] accept eval module, could be either validation loss or something more sophisticated. returns a scalar tensor or single int / float +- [x] any order of sft, spin, self-rewarding dpo, dpo with external reward model - [ ] figure out how best to handle different impl of kv cache, for now just do without - [ ] consider KTO -- [ ] any order of sft, spin, self-rewarding dpo, dpo with external reward model - [ ] allow for a validation function on the rewards (say reward must be integer, float, in between some range etc) -- [ ] create a variant for both self-rewarding and SPIN where there are no iterations. both create their synthesized data live and reference model is updated with an EMA. ## Citation diff --git a/self_rewarding_lm_pytorch/__init__.py b/self_rewarding_lm_pytorch/__init__.py index 061b885..b84716b 100644 --- a/self_rewarding_lm_pytorch/__init__.py +++ b/self_rewarding_lm_pytorch/__init__.py @@ -14,3 +14,13 @@ ) from self_rewarding_lm_pytorch.mocks import create_mock_dataset + +# fine tune configs + +from self_rewarding_lm_pytorch.self_rewarding_lm_pytorch import ( + SFTConfig, + SelfRewardDPOConfig, + ExternalRewardDPOConfig, + SelfPlayConfig, + create_default_paper_config +) diff --git a/self_rewarding_lm_pytorch/dpo.py b/self_rewarding_lm_pytorch/dpo.py index 4486001..dab3c67 100644 --- a/self_rewarding_lm_pytorch/dpo.py +++ b/self_rewarding_lm_pytorch/dpo.py @@ -455,7 +455,7 @@ def forward( iter_dl = cycle(train_dataloader) - pbar = tqdm(desc = 'dpo finetuning', total = self.num_train_steps) + pbar = tqdm(desc = 'dpo fine-tuning', total = self.num_train_steps) set_dropout_(self.model, self.dropout) diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index 774b51a..87637f0 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -3,7 +3,7 @@ from random import randrange from copy import deepcopy from pathlib import Path -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import wraps from textwrap import dedent @@ -209,13 +209,15 @@ def init(self): # config, allowing for different types of reward prompting # colocate with functions for extracting the response and reward -REWARD_PROMPT_CONFIG = dict( +SELF_REWARD_PROMPT_CONFIG = dict( default = RewardConfig( prompt_template = DEFAULT_LLM_AS_JUDGE_PROMPT, reward_regex_template = DEFAULT_REWARD_REGEX_TEMPLATE ) ) +default_is_valid_reward_pair = lambda preferred_reward, unpreferred_reward: (preferred_reward != unpreferred_reward).all() + @beartype def default_pick_paired_rewards_fn(rewards: Tensor): is_nan_mask = torch.isnan(rewards) @@ -614,6 +616,102 @@ def forward(self) -> DPODataset: return DPODataset(**self.dpo_dataset_kwargs) +# fine tuning configs + +class FinetuneConfig: + pass + +@dataclass +class SFTConfig(FinetuneConfig): + train_dataset: Dataset + valid_dataset: Optional[Dataset] = None + dropout: float = 0.1 + trainer_kwargs: dict = field(default_factory = dict) + +@dataclass +class SelfRewardDPOConfig(FinetuneConfig): + prompt_dataset: Dataset + num_generated_preference_pairs: int + dpo_beta: float = 0.1 + max_seq_len: int = 1024 + self_reward_config_keyname: str = 'default' + is_valid_reward_pair: Callable[[Tensor, Tensor], bool] = default_is_valid_reward_pair + is_picked_pair_reward_fn: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn + dropout: float = 0.1 + early_stopper_eval_module: Optional[Module] = None + num_train_steps: Optional[Module] = None + num_candidate_responses: int = 4 + num_sampled_reward_responses: int = 3 + gen_temperature: float = 0.7 + gen_filter_fn: Callable = top_p, + gen_filter_kwargs: dict = field(default_factory = dict) + eval_temperature: float = 0.7 + eval_filter_fn: Callable = top_p, + eval_filter_kwargs: dict = field(default_factory = dict) + trainer_kwargs: dict = field(default_factory = dict) + reward_generator_kwargs: dict = field(default_factory = dict) + +@dataclass +class ExternalRewardDPOConfig(FinetuneConfig): + reward_model: Module + dpo_beta: float = 0.1 + max_seq_len: int = 1024 + dropout: float = 0.1 + trainer_kwargs: dict = field(default_factory = dict) + reward_generator_kwargs: dict = field(default_factory = dict) + +@dataclass +class SelfPlayConfig(FinetuneConfig): + train_dataset: Dataset + valid_dataset: Optional[Dataset] = None + max_seq_len: int = 1024 + spin_λ: float = 0.1 + dropout: float = 0.1 + gen_temperature: float = 0.7 + gen_filter_fn: Callable = top_p + gen_filter_kwargs: dict = field(default_factory = dict) + trainer_kwargs: dict = field(default_factory = dict) + +# generated default config for paper + +def create_default_paper_config( + *, + train_sft_dataset: Dataset, + self_reward_prompt_dataset: Union[Dataset, Tuple[Dataset, Dataset]], + valid_sft_dataset: Optional[Dataset] = None, + num_generated_preference_pairs = (3964, 6942), + early_stopper_eval_module: Optional[Module] = None, + dpo_num_train_steps: Optional[int] = None, + sft_config: dict = dict(), + self_reward_config: dict = dict() + +) -> List[FinetuneConfig]: + + prompt_dataset_iter1, prompt_dataset_iter2 = cast_tuple(self_reward_prompt_dataset, 2, validate = True) + num_generated_iter1, num_generated_iter2 = num_generated_preference_pairs + + return [ + SFTConfig( + train_dataset = train_sft_dataset, + valid_dataset = valid_sft_dataset, + **sft_config + ), + SelfRewardDPOConfig( + num_generated_preference_pairs = num_generated_iter1, + prompt_dataset = prompt_dataset_iter1, + num_train_steps = dpo_num_train_steps, + early_stopper_eval_module = early_stopper_eval_module, + **self_reward_config + ), + SelfRewardDPOConfig( + num_generated_preference_pairs = num_generated_iter2, + prompt_dataset = prompt_dataset_iter2, + num_train_steps = dpo_num_train_steps, + early_stopper_eval_module = early_stopper_eval_module, + **self_reward_config + ) + ] + # self-rewarding trainer class class SelfRewardingTrainer(Module): @@ -622,38 +720,10 @@ def __init__( self, model: Module, *, + finetune_configs: List[FinetuneConfig], tokenizer_encode: Callable[[str], TensorType['seq', int]], tokenizer_decode: Callable[[TensorType['seq', int]], str], - prompt_dataset: Union[Tuple[Dataset, ...], Dataset], - train_sft_dataset: Optional[Union[Tuple[Dataset], Dataset]] = None, - valid_sft_dataset: Optional[Dataset] = None, - initial_sft: bool = True, - dpo_beta = 0.1, - num_spin_cycles = 0, - spin_λ = 0.1, - preference_max_seq_len: int = 1024, - self_reward_num_iterations = 2, - reward_prompt_config: Union[RewardConfig, Dict[str, RewardConfig]] = REWARD_PROMPT_CONFIG, - reward_iteration_type = [ - 'default', - 'default' - ], - reward_model: Tuple[Optional[Module], ...] = (None, None), - is_valid_reward_pair: Optional[Callable[[Tensor, Tensor], bool]] = lambda preferred_reward, unpreferred_reward: (preferred_reward != unpreferred_reward).all(), - pick_paired_rewards: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn, - num_preference_pairs: List[int] = [ - 3964, - 6942 - ], - reward_generator_kwargs: dict = dict( - num_candidate_responses = 4, - gen_temperature = 0.7, - gen_nucleus_p = 0.9, - eval_temperature = 0.7, - eval_nucleus_p = 0.9 - ), - early_stopper_eval_module: Optional[Module] = None, - dropout: float = 0.1, + self_reward_prompt_config: Union[RewardConfig, Dict[str, RewardConfig]] = SELF_REWARD_PROMPT_CONFIG, pad_id: int = -1, checkpoints_folder: str = './checkpoints', accelerate_kwargs: dict = dict(), @@ -664,14 +734,8 @@ def __init__( ): super().__init__() - if isinstance(reward_prompt_config, RewardConfig): - reward_prompt_config = dict(default = reward_prompt_config) - - assert all([key in reward_prompt_config for key in reward_iteration_type]), f'reward prompt must be one of {reward_prompt_config.keys()}' - - # prompts need to be pre-generated. in paper, it seems to be coming from llama 70B chat - - prompt_dataset = cast_tuple(prompt_dataset, self_reward_num_iterations, validate = True) + if isinstance(self_reward_prompt_config, RewardConfig): + self_reward_prompt_config = dict(default = self_reward_prompt_config) # model and accelerator @@ -681,90 +745,120 @@ def __init__( # all trainers - self.trainers = [] + self.trainers: List[Tuple[str, Callable]] = [] - # sft + # config -> trainers - if initial_sft: - assert exists(train_sft_dataset) + for ind, config in enumerate(finetune_configs): + finetune_stage = ind + 1 - sft_trainer = SFTTrainer( - model, - accelerator = self.accelerator, - dropout = dropout, - train_dataset = train_sft_dataset, - valid_dataset = valid_sft_dataset, - **sft_trainer_kwargs - ) + if isinstance(config, SFTConfig): + trainer = SFTTrainer( + self.model, + accelerator = self.accelerator, + dropout = config.dropout, + train_dataset = config.train_dataset, + valid_dataset = config.valid_dataset, + **config.trainer_kwargs + ) - self.trainers.append(('sft', sft_trainer)) - - # spin - - for _ in range(num_spin_cycles): - assert exists(train_sft_dataset) - - spin_trainer = SPINTrainer( - self.model, - accelerator = self.accelerator, - dropout = dropout, - train_sft_dataset = train_sft_dataset, - valid_sft_dataset = valid_sft_dataset, - max_seq_len = spin_trainer_kwargs.pop('max_seq_len', preference_max_seq_len), - pad_id = pad_id, - spin_kwargs = { - 'λ': spin_λ, - **spin_kwargs - } - ) + self.trainers.append(('sft', trainer)) - self.trainers.append(('spin', spin_trainer)) + elif isinstance(config, SelfRewardDPOConfig): - # external reward model (not in paper, but just allow for it) + assert exists(config.early_stopper_eval_module) ^ exists(config.num_train_steps), 'either a validation module is passed in for early stopping, or a max number of training steps is specified' - reward_model = cast_tuple(reward_model, self_reward_num_iterations, validate = True) + assert config.self_reward_config_keyname in self_reward_prompt_config, f'reward prompt must be one of {self_reward_prompt_config.keys()}' - reward_models = [self.accelerator.prepare(model) if exists(model) else None for model in reward_model] + self_reward_config = self_reward_prompt_config[config.self_reward_config_keyname] - # self-reward related + self_reward_dataset_generator = DPODatasetGenerator( + model = model, + prompt_dataset = config.prompt_dataset, + reward_config = self_reward_config, + num_preference_pairs = config.num_generated_preference_pairs, + preference_max_seq_len = config.max_seq_len, + tokenizer_encode = tokenizer_encode, + tokenizer_decode = tokenizer_decode, + is_valid_reward_pair = config.is_valid_reward_pair, + pick_paired_rewards = config.is_picked_pair_reward_fn, + **config.reward_generator_kwargs + ) - self.reward_prompt_configs = [reward_prompt_config[key] for key in reward_iteration_type] - self.self_reward_num_iterations = self_reward_num_iterations + trainer = DPOTrainer( + dpo = model, + accelerator = self.accelerator, + dataset_generator = self_reward_dataset_generator, + dropout = config.dropout, + early_stopper_eval_module = config.early_stopper_eval_module, + early_stopper_kwargs = dict( + early_stop_checkpoint_folder = f'./early-stop-checkpoint.{finetune_stage}', + ), + dpo_kwargs = dict( + beta = config.dpo_beta, + pad_id = pad_id + ), + **config.trainer_kwargs + ) - for ind, reward_config, one_prompt_dataset, one_stage_num_preference_pairs, one_reward_model in zip(range(self_reward_num_iterations), self.reward_prompt_configs, prompt_dataset, num_preference_pairs, reward_model): - dpo_iteration = ind + 1 + self.trainers.append(('dpo', trainer)) - self_reward_dataset_generator = DPODatasetGenerator( - model = model, - prompt_dataset = one_prompt_dataset, - reward_config = reward_config, - reward_model = one_reward_model, - num_preference_pairs = one_stage_num_preference_pairs, - preference_max_seq_len = preference_max_seq_len, - tokenizer_encode = tokenizer_encode, - tokenizer_decode = tokenizer_decode, - is_valid_reward_pair = is_valid_reward_pair, - pick_paired_rewards = pick_paired_rewards, - **reward_generator_kwargs - ) + elif isinstance(config, ExternalRewardDPOConfig): - trainer = DPOTrainer( - dpo = model, - accelerator = self.accelerator, - dataset_generator = self_reward_dataset_generator, - dropout = dropout, - early_stopper_eval_module = early_stopper_eval_module, - early_stopper_kwargs = dict( - early_stop_checkpoint_folder = f'./early-stop-checkpoint.{dpo_iteration}', - ), - dpo_kwargs = dict( - beta = dpo_beta, - pad_id = pad_id - ), - **dpo_trainer_kwargs - ) + reward_model = self.accelerator.prepare(config.reward_model) + + self_reward_dataset_generator = DPODatasetGenerator( + model = model, + prompt_dataset = config.prompt_dataset, + reward_model = reward_model, + num_preference_pairs = config.num_generated_preference_pairs, + preference_max_seq_len = config.max_seq_len, + tokenizer_encode = tokenizer_encode, + tokenizer_decode = tokenizer_decode, + is_valid_reward_pair = config.is_valid_reward_pair, + pick_paired_rewards = config.is_pick_paired_rewards, + **reward_generator_kwargs + ) + + trainer = DPOTrainer( + dpo = model, + accelerator = self.accelerator, + dataset_generator = self_reward_dataset_generator, + dropout = dropout, + early_stopper_eval_module = config.early_stopper_eval_module, + early_stopper_kwargs = dict( + early_stop_checkpoint_folder = f'./early-stop-checkpoint.{dpo_iteration}', + ), + dpo_kwargs = dict( + beta = config.dpo_beta, + pad_id = pad_id + ), + **dpo_trainer_kwargs + ) + + self.trainers.append(('dpo', trainer)) + + elif isinstance(config, SelfPlayConfig): + trainer = SPINTrainer( + self.model, + accelerator = self.accelerator, + dropout = config.dropout, + train_sft_dataset = config.train_dataset, + valid_sft_dataset = config.valid_dataset, + max_seq_len = config.max_seq_len, + pad_id = pad_id, + spin_kwargs = { + 'λ': config.spin_λ, + **spin_kwargs + } + ) + + self.trainers.append(('spin', trainer)) + + else: + raise ValueError(f'you did not write out the logic for your custom trainer from your custom finetune config') - self.trainers.append(('dpo', trainer)) + assert len(self.trainers) == len(finetune_configs) # checkpoints folder @@ -806,6 +900,6 @@ def forward( finetuning_stage = ind + 1 trainer() - self.save(f'{finetuning_stage}.{trainer_type}.ckpt.pt') + self.save(f'{finetuning_stage}.{trainer_type}.ckpt.pt', overwrite = overwrite_checkpoints) self.print(f'self-reward training done') diff --git a/setup.py b/setup.py index 36e59f1..0b6d25e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'self-rewarding-lm-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.2.0', license='MIT', description = 'Self Rewarding LM - Pytorch', author = 'Phil Wang',