Skip to content

Official repository for ALT (ALignment with Textual feedback).

License

Notifications You must be signed in to change notification settings

sauc-abadal/ALT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ALT for LLM Alignment

This repository is the official implementation of Towards Aligning Language Models with Textual Feedback.

ALT diagram

Getting Started

  1. Clone the Repository and navigate to the Project Directory:
   git clone [email protected]:sauc-abadal/ALT.git
   cd ALT
  1. We suggest using conda to setup environments. You need to first replace prefix in tox.yml, gptj_training.yml, and sample.yml with your home path. With conda installed, create the environments with:
conda env create -f tox.yml
conda env create -f gptj_training.yml
conda env create -f sample.yml

Where to find What

The repository is organized as follows. alt contains the building blocks of our ALT method, namely policy definition, vLLM sampling, datapool, training dataset and data collators, and trainer classes. tasks has three separate directories for the tasks we explore in our paper, i.e., toxicity, summarization, and dialogue. The toxicity task has its own separate modules since we reused and modified most of the code from Quark.

  • In alt/models we have the policy.py module, in charge of defining the sample() (for generating completions) and the forward_pass() (for getting logits, logprobs, etc. of generated text) methods of our LLM model, and the reward.py module for the summarization task, where we define the RM trained by CarperAI on the preference dataset collected in Stiennon et al.. The preference TL;DR dataset employed by CarperAI can be found here. Their RM achieves an accuracy of 75.74% on their test set (5,000 randomly drawn samples). We leverage this pre-trained RM for implementing the Quark baseline and our ALTRM model that relies on scores from a Reward Model. Besides, we extended the RM with a 'get_reward()' method for more efficient reward inference.

  • However, even though we define the sample() method in policy.py, we leverage the vLLM library for faster inference of LLMs during the sampling stage of our data collection phase. The script for vLLM sampling can be found in alt/vllm_sampling.py.

  • In alt/trainer there are the two different trainer modules we defined for training Quark and ALT. Each module contains two trainer classes supporting training either with KL-divergence penalty with a reference policy or not.

  • In alt you can find the data_pool.py module with 3 distinct classes, namely ALT_RM_DataPool (used both for Quark and ALTRM), ALT_LMC_DataPool, and ALT_LMU_DataPool. These classes are in charge of storing the sampled generations, along with the prompts and the collected textual feedback for ALTLMC and ALTLMU, or the rewards and associated quantile tokens for Quark and ALTRM. The dataset employed during the training phase of our approach is drawn from this datapool with rejection sampling for controlling the generations' length and for balancing out the different quantiles (for Quark and ALTRM), feedback categories (for ALTLMC), and feedback scores (for ALTLMU).

  • In the alt/training_dataset_and_collator.py module there are the Dataset and DataCollator classes employed during training for both our ALT approach and our Quark-like baseline. These classes are in charge of preppending either the textual feedback or the reward quantile token to the input prompt, for maximum likelihood training on the triplets of (feedback/quantile token, prompt, generation).

Usage

TL;DR Summarization example

In tasks/summarization you will find scripts specific to the summarization task. We tackle the same task as the one outlined in the DPO paper and aim to generate better summaries on the Reddit TL;DR posts dataset. We employ the TLDR dataset from CarperAI hosted at HuggingFace, containing prompts formatted as: "SUBREDDIT: r/... TITLE: ... POST: ... TL;DR:", and labels being the human-written summaries. The dataset contains train, valid, and test splits, ammounting to 117k, 6.45k, and 6.55k samples respectively. We depart our training from the SFT model (GPT-J) trained by CarperAI on the TL;DR summarization dataset in order to adapt the pre-trained GPT-J model to the summarization downstream task. During training, at every iteration, we draw at random (with replacement) 2048 prompts from the train split and we sample multiple generations per prompt. Then, we provide feedback (reward-based or LLM-based) to those generations and perform conditional supervised fine-tuning training. In tasks/summarization/bash_scripts there is an orchestrator script in charge of launching the different bash scripts encompassing the full training and validation pipeline for each iteration. In tasks/summarization/configs there are the configs file that your will need to modify before running any script.

  • Sampling is done with vllm_sampling.py. The script takes in a JSONL file with a single key 'prompt', and outputs a new JSONL file with the key 'generations' added for every prompt. The scripts equally handles the sampling for every method, though you must manually preppend the conditioning feedback for each method accordingly. That is, prompts corresponding to iteration 1 should be unconditioned, but prompts corresponding to iteration 2 onwards should be conditioned on the exemplar feedback, namely 'QUANTILE_TOKEN_0{prompt}' for Quark, 'Excellent. input: {prompt}' for ALTRM, or 'feedback: {3.0-score exemplar feedback} input: {prompt}' for ALTLMU. An example usage for launching sampling is:
python alt/vllm_sampling.py \
    --input_file "$input_prompts_file_train" \
    --output_dir "$output_dir" \
    --split_number "0" \
    --total_splits "1" \
    --model_path "$model_path" \
    --tokenizer_path "$tokenizer_path" \
    --data_split "$data_split_train" \
    --num_generations "$num_generations_train" \
    --temperature "$temperature_train" \
    --top_p "${top_p_train}" \
    --max_new_tokens "${max_new_tokens_train}"
  • Reward-based feedback for Quark or ALTRM is obtained with reward.py. The script takes in a JSONL file with keys 'prompt' and 'generations', and outputs a new JSONL file with the key 'rewards' added for every prompt. An example usage for getting reward-based feedback is:
python tasks/summarization/reward.py \
    --config "$config" \
    --input_sampling_file "$input_sampling_file_train" \
    --output_dir "$output_dir" \
    --split_number "0" \
    --total_splits "1" \
    --num_generations "$num_generations_train" \
    --ALT
  • LLM-based feedback for ALTLMU is obtained with ALT_LMU_feedback.py. The script takes in a JSONL file with keys 'prompt' and 'generations', and outputs a new JSONL file with keys 'analysis', 'feedbacks', and 'scores' added for every prompt. 'feedbacks' are used for conditional training and 'scores' are used for rejection sampling for balancing out the different feedbacks and for getting the high-scoring exemplar feedbacks to conditon on during subsequent sampling iterations. An example usage for getting LLM-based feedback is:
python tasks/summarization/ALT_LMU_feedback.py \
    --config "$config" \
    --input_sampling_file "$input_sampling_file_train" \
    --output_dir "$output_dir" \
    --split_number "0" \
    --total_splits "1" \
    --num_generations "$num_generations_train" \
    --ALT
  • Training is launched by calling '{method}_train_KL.py' or '{method}_train_noKL.py', where method is one of [QUARK, ALT_RM, ALT_LMU], respectively for training including the KL-divergence penalty with the reference policy or not. Our code leverages the DeepSpeed integration within Accelerate for handling the distributed training and taking advantage of features such as the ZeRO optimizer and CPU offloading for better scale and speed. An example usage for launching ALT_RM_train_noKL.py is:
accelerate launch --config_file "$accelerate_config" tasks/summarization/ALT_RM_train_noKL.py \
    --config "$config" \
    --iteration "$iteration" \
    --input_sampling_file "$input_sampling_file_train" \
    --model_path "$model_path" \
    --ds_optimizer \
    --ds_scheduler

Citation

tbd

About

Official repository for ALT (ALignment with Textual feedback).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published