Skip to content

Linzwcs/AFT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

From Drafts to Answers: Unlocking LLM Potential via Aggregation Fine-Tuning

     

(Framework of aggregation fine-tuning and propose-and-aggregate inference.)

This repository contains the official PyTorch implementation of the paper: From Drafts to Answers: Unlocking LLM Potential via Aggregation Fine-Tuning. In this work, we introduce Aggregation Fine-Tuning (AFT), a supervised fine-tuning paradigm where the model learns to synthesize multiple draft responses, referred to as prposals, into a single, refined answer, termed aggregation. An AFT model, fine-tuned from Llama3.1-8B-Base with only 64K data, achieves a 41.3% LC win rate on AlpacaEval 2, surpassing significantly larger LLMs such as Llama3.1-405B-Instruct and GPT-4.

πŸ”” News

πŸš€ Quick Start

Install

To install the inference framework, follow the steps below:

  1. Create and activate a new environment:

    conda create -n AFT python=3.11
    conda activate AFT
  2. Install PyTorch based on your device configuration:

    Our device uses CUDA 12.1, so install PyTorch with the following command:

    pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 
  3. Clone the repository, navigate to the directory, and install our inference framework:

    git clone [email protected]:Linzwcs/AFT.git
    cd AFT
    pip install -e .

Inference

We provide a sample inference code for propose-and-aggregate in inference.py. You can execute it by running the following command:

  python inference.py \
        --config ./configs/default.yaml \
        --input_file ./data/sample.jsonl \
        --output_file ./output/output.jsonl \
        --batch_size 16

The detailed meanings of the keys in the config file are illustrated below:

  model_name: <path to AFT model> # path to AFT model 
  
  # The proposal_params and aggregation_params are instances of SamplingParams 
  # and will be sent to vllm.chat().
  generation_params:
    temperature: 0.7
    top_p: 0.95
    max_tokens: 4096
    n: 5  # number of propsals per step
    final_layer_temperature: 0.7
    final_layer_top_p: 1

  vllm_seed: 2024 # vllm backend seed


  num_aggregation: 2  # Number of aggregation layers

Datasets

We consider two types of training datasets, differing in the proposal type, i.e., off-policy proposals and on-policy proposals. Off-policy proposals are derived from existing preference alignment datasets, such as UltraFeedback, where each query is accompanied by multiple responses generated by models different from the one being fine-tuned. On-policy proposals are obtained by leveraging in-context learning (ICL) with demonstrations, prompting the base LLM to generate multiple responses for a given query. These two types of base models, combined with the two datasets, form our four AFT models.

Datasets to be Released:

  1. Linzwcs/AFT-off-policy
  2. Linzwcs/AFT-on-policy-llama
  3. Linzwcs/AFT-on-policy-mistral

Our datasets will be released soon.

Models

We build our models based on two model families: the Llama-3.1-8B and Mistral-7B-v0.1. We train the models on both on-policy and off-policy training data, yielding altogether four AFT models, as listed below:

Released Models:

  1. Linzwcs/Llama-AFT-Off-Policy
  2. Linzwcs/Llama-AFT-On-Policy
  3. Linzwcs/Mistral-AFT-Off-Policy
  4. Linzwcs/Mistral-AFT-On-Policy

You can use these models by setting the model_name in the configuration file to the corresponding model names. Please refer to our paper for training details.

πŸ” Benchmark Performance

We evaluate our models on two benchmark datasets: MT-Bench and AlpacaEval 2, and the results are presented below:

MT-Bench AlpacaEval
Model 1st turn 2nd turn Avg. LC(%) WR(%)
Mistral-7B-v0.1-Base
SFT 6.6 6.1 6.4 6.7 6.1
AFT-off-policy 7.7 6.3 7.0 19.8 20.0
w/ Agg. 8.0 7.0 7.5 33.8 47.8
AFT-on-policy 7.5 6.4 6.9 23.4 24.9
w/ Agg. 8.3 7.0 7.6 30.7 48.4
Llama3.1-8B-Base
SFT 7.3 6.2 6.8 8.0 7.3
AFT-off-policy 7.7 6.9 7.3 20.3 19.6
w/ Agg. 8.3 7.6 7.9 40.3 47.8
AFT-on-policy 7.9 6.9 7.4 21.5 21.8
w/ Agg. 8.5 7.6 8.1 41.3 51.3

Acknowledgement

This project is mainly motivated and supported by several existing works:

  1. Mixture-of-Agents Enhances Large Language Model Capabilities: we draw inspiration from MoA to propose aggregation leanring.

  2. vLLM: we construct the generation pipeline based on vLLM.

πŸ“ Citation

@misc{li2025draftsanswersunlockingllm,
      title={From Drafts to Answers: Unlocking LLM Potential via Aggregation Fine-Tuning}, 
      author={Yafu Li and Zhilin Wang and Tingchen Fu and Ganqu Cui and Sen Yang and Yu Cheng},
      year={2025},
      eprint={2501.11877},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2501.11877}, 
}

License

The code and model weights are licensed under LICENSE.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages