Skip to content

Commit

Permalink
generalize the system
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 1, 2024
1 parent 465de3e commit 7c4ba1a
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 128 deletions.
61 changes: 46 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ from torch import Tensor

from self_rewarding_lm_pytorch import (
SelfRewardingTrainer,
create_mock_dataset
create_mock_dataset,
create_default_paper_config
)

from x_transformers import TransformerWrapper, Decoder
Expand All @@ -41,7 +42,7 @@ transformer = TransformerWrapper(
)
)

sft_train_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1)))
sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1)))
prompt_dataset = create_mock_dataset(100, lambda: 'mock prompt')

def decode_tokens(tokens: Tensor) -> str:
Expand All @@ -53,22 +54,21 @@ def encode_str(seq_str: str) -> Tensor:

trainer = SelfRewardingTrainer(
transformer,
train_sft_dataset = sft_train_dataset,
num_spin_cycles = 0,
num_preference_pairs = [1, 1],
preference_max_seq_len = 64,
prompt_dataset = prompt_dataset,
tokenizer_encode = encode_str,
finetune_configs = create_default_paper_config(
train_sft_dataset = sft_dataset,
self_reward_prompt_dataset = prompt_dataset,
dpo_num_train_steps = 1000
),
tokenizer_decode = decode_tokens,
tokenizer_encode = encode_str,
accelerate_kwargs = dict(
cpu = True
),
dpo_trainer_kwargs = dict(
batch_size = 1
)
)

trainer(overwrite_checkpoints = True)

# checkpoints after each finetuning stage will be saved to ./checkpoints
```

SPIN can either enabled on `SelfRewardingTrainer` with the `spin = True` flag, or trained standalone as shown below
Expand Down Expand Up @@ -123,8 +123,7 @@ from self_rewarding_lm_pytorch import RewardConfig
trainer = SelfRewardingTrainer(
transformer,
...,
num_candidate_responses = 4, # in the paper, they try 4 responses, and pick the max and min rewards for forming the preference pairs
reward_prompt_config = RewardConfig(
self_reward_prompt_config = RewardConfig(
prompt_template = """
Pretty please rate the following user prompt and response
User: {{ prompt }}
Expand All @@ -140,6 +139,39 @@ trainer = SelfRewardingTrainer(
)
```

Finally, if you would like to experiment with arbitrary orders of fine-tuning, you will also have that flexiblity, by passing in `FinetuneConfig` instances into `finetune_configs` as a list

ex. say you want to carry out research on interleaving SPIN, External Rewarding, and Self-Rewarding

This idea originated from <a href="https://github.com/teknium1">Teknium</a> from a private discord channel.

```python

# import the configs

from self_rewarding_lm_pytorch import (
SFTConfig,
SelfRewardDPOConfig,
ExternalRewardDPOConfig,
SelfPlayConfig,
)

trainer = SelfRewardingTrainer(
finetune_configs = [
SFTConfig(...),
SelfPlayConfig(...),
ExternalRewardDPOConfig(...),
SelfRewardDPOConfig(...),
SelfPlayConfig(...),
SelfRewardDPOConfig(...)
]
)

trainer()

# checkpoints after each finetuning stage will be saved to ./checkpoints
```

## Todo

- [x] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that
Expand All @@ -149,12 +181,11 @@ trainer = SelfRewardingTrainer(
- [x] early stopper
- [x] handle break signal if all done on main process
- [x] accept eval module, could be either validation loss or something more sophisticated. returns a scalar tensor or single int / float
- [x] any order of sft, spin, self-rewarding dpo, dpo with external reward model

- [ ] figure out how best to handle different impl of kv cache, for now just do without
- [ ] consider KTO
- [ ] any order of sft, spin, self-rewarding dpo, dpo with external reward model
- [ ] allow for a validation function on the rewards (say reward must be integer, float, in between some range etc)
- [ ] create a variant for both self-rewarding and SPIN where there are no iterations. both create their synthesized data live and reference model is updated with an EMA.

## Citation

Expand Down
10 changes: 10 additions & 0 deletions self_rewarding_lm_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@
)

from self_rewarding_lm_pytorch.mocks import create_mock_dataset

# fine tune configs

from self_rewarding_lm_pytorch.self_rewarding_lm_pytorch import (
SFTConfig,
SelfRewardDPOConfig,
ExternalRewardDPOConfig,
SelfPlayConfig,
create_default_paper_config
)
2 changes: 1 addition & 1 deletion self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def forward(

iter_dl = cycle(train_dataloader)

pbar = tqdm(desc = 'dpo finetuning', total = self.num_train_steps)
pbar = tqdm(desc = 'dpo fine-tuning', total = self.num_train_steps)

set_dropout_(self.model, self.dropout)

Expand Down
Loading

0 comments on commit 7c4ba1a

Please sign in to comment.