📃 [ReST-MCTS*]
[GitHub]
[Website]
We develop a reinforced self-training approach, called ReST-MCTS*, based on integrating process reward guidance with tree search MCTS* for collecting higher-quality reasoning traces as well as per-step value to train policy and reward models. ReST-MCTS* circumvents the per-step manual annotation typically used to train process rewards by tree-search-based reinforcement learning: Given oracle final correct answers, ReST-MCTS* is able to infer the correct process rewards by estimating the probability this step can help lead to the correct answer. These inferred rewards serve dual purposes: they act as value targets for further refining the process reward model and also facilitate the selection of high-quality traces for policy model self-training.
We summary the key differences between existing self-improvement methods and our approach. Train refers to whether to train a reward model.
Considering the different dependency versions of transformers
for Mistral (or Llama) and SciGLM, you should install different environments through miniconda and install corresponding required packages by:
running Mistral (or Llama)
pip install -r requirements_mistral.txt
or running SciGLM
pip install -r requirements_sciglm.txt
Note that for some models on huggingface like the GLM series, you may need to install specific versions of transformers
.
The Python version for running GLM is 3.11. The Python version for running Mistral or Llama is 3.12.
To run MCTS* search, you should implement a policy as well as a process reward model (value model).
You can download initial checkpoint and directly set these models by providing the model paths in the file models/model.py
, substituting INFERENCE_MODEL_DIR
, VALUE_BASE_MODEL_DIR
and VALUE_MODEL_STATE_DICT
.
INFERENCE_MODEL_DIR
is the local path to the policy model, model could be [Llama3-8B-Instruct], [Mistral-7B: MetaMATH], and [SciGLM-6B].
VALUE_BASE_MODEL_DIR
is the local path to the value model. Considering the different dependency versions of transformers
, Mistral-7B is adopted as the backbone of the value model when the policy model is [Llama3-8B-Instruct] or [Mistral-7B: MetaMATH]. When the policy model is [SciGLM-6B], we use [ChatGLM3-6B] as the backbone of the value model.
Aiming to gather value train data for science, we integrate questions of a lean science dataset
You can download [$D_{V_0}$] and put them in PRM/data
to train Mistral-7B as the initial process reward model and obtain VALUE_MODEL_STATE_DICT
.
We also provide PRM/train_VM_chatglm.py
and PRM/train_VM_mistral.py
.
The experimental settings are as follows:
For ChatGLM3-6B, learning rate (lr) is 2e-5, the number of epochs is 2 or 3, and batch size is 3.
For Mistral, learning rate (lr) is 3e-6, the number of epochs is 2 or 3, and batch size is 3.
We now only provide the implementation of the llama
, glm
and mistral
as policy, with glm
and mistral
as value model in models/model.py
.
If you are trying with other models, you can refer to our implementation and modify relevant codes to implement the corresponding models.
Once you've implemented the policy and value model, you should modify the LOCAL_INFERENCE_IDX
and LOCAL_VALUE_IDX
in models/model.py
to the corresponding model index.
Before running search for evaluation or generation, you have to make sure your target question dataset is in the correct format. The data file should be a json file with items in the following format:
{
"content": "Calculate the sum of the first 10 prime numbers.",
"answer": "129"
}
The content
entry is required, serving as the question. While the answer
entry is optional, it is used for evaluation.
The implementation of MCTS* search can be found in MCTS
. We provide a search interface in MCTS/task.py
. To run MCTS* search for a single question, you can refer to the following script:
from MCTS.task import *
question = "Calculate the sum of the first 10 prime numbers."
task = MCTS_Task(question, 'llama', 'local', lang='en')
output = task.run()
print(output['solution'])
For evaluation of MCTS* on benchmarks, you can refer to evaluate.py
, setting the parameter --mode
to "mcts". You should specify the benchmark name and the exact file (subset) you want to evaluate. A simple demonstration is provided below:
python evaluate.py \
--task_name "scibench" \
--file "thermo" \
--propose_method "gpt" \
--value_method "local" \
--mode "mcts" \
--evaluate "scibench" \
--iteration_limit 50 \
--use_reflection "simple" \
--branch 3
You can also refer to the MCTS/args.md
for more details on the search parameters.
Given question set
Download policy data (positive samples) for training 1st policy model (Llama3-8b-Instruct): [Hugging Face]
Download PRM data (positive and negative samples) for training 1st reward model (Mistral-7B: MetaMATH): [Hugging Face]
Download the trained policy model: [Hugging Face]
For our methods:
Regarding Llama3-8B-Instruct and Mistral-7B: MetaMATH, we use the default repo of [MAmmoTH] to train the policy model and evaluate.
Regarding SciGLM-6B, we use the default repo of [SciGLM] to train the policy model and evaluate.
We also implement self-rewarding as our baseline in ./self_train/self_train_dpo.py.
Self-training Results:
Accuracy of Different Verifiers:
Accuracy of Different Searches (we also provide the plot code in figures/plot_math_self_training.py
):
If you find our work helpful, please kindly cite our paper:
@article{zhang2024rest,
title={ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search},
author={Zhang, Dan and Zhoubian, Sining and Hu, Ziniu and Yue, Yisong and Dong, Yuxiao and Tang, Jie},
journal={arXiv preprint arXiv:2406.03816},
year={2024}
}