Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CGPO] Mixture of judges #2159

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
013aae4
base judge
gaetanlop Oct 3, 2024
0ea5a48
adding mixture of judges
gaetanlop Oct 3, 2024
517cfb0
update doc
gaetanlop Oct 3, 2024
9e5ed12
update doc
gaetanlop Oct 3, 2024
3406e53
formatting
gaetanlop Oct 3, 2024
568d2b9
fix small typo in doc
gaetanlop Oct 3, 2024
466292e
fix randomcontraintjudge
gaetanlop Oct 3, 2024
a3d90df
Merge branch 'main' into cgpo_mixture_of_judges
qgallouedec Oct 4, 2024
3f0b8b0
replace arxiv by hf papers
gaetanlop Oct 4, 2024
8995ab4
formatting
gaetanlop Oct 4, 2024
896259e
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 4, 2024
ef1feb0
fix naming in __init__
gaetanlop Oct 4, 2024
3da4a06
run precommi
gaetanlop Oct 4, 2024
765768b
adding gold answers to judges
gaetanlop Oct 7, 2024
8aaaaa1
cgpo llm judges
gaetanlop Oct 7, 2024
cfc84ed
fix init
gaetanlop Oct 7, 2024
a1e8eeb
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 7, 2024
6898285
output type
gaetanlop Oct 7, 2024
f5639a1
adjust booleans in test
gaetanlop Oct 7, 2024
289b855
adapt moj doc
gaetanlop Oct 7, 2024
308e743
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 9, 2024
2c6de87
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 11, 2024
dedc859
renaming and removing factuality and safety judges
gaetanlop Oct 11, 2024
ba0fffb
fix typo in import
gaetanlop Oct 11, 2024
226de82
fix small typo in naming
gaetanlop Oct 11, 2024
5626cd4
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 14, 2024
567b798
formatting
gaetanlop Oct 14, 2024
1c33494
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 21, 2024
64c9de8
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 24, 2024
559cd1b
Update trl/trainer/judges.py
gaetanlop Oct 24, 2024
2c29ef5
update parameter name
gaetanlop Oct 25, 2024
bd1bed8
update tests
gaetanlop Oct 25, 2024
21e3ccd
update doc
gaetanlop Oct 25, 2024
43d6cca
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Oct 29, 2024
9eca0f8
Update trl/trainer/judges.py
gaetanlop Oct 29, 2024
d5b32f0
Update doc
gaetanlop Oct 29, 2024
ac88c63
fix alltruejudge type
gaetanlop Oct 29, 2024
999154b
Merge branch 'main' into cgpo_mixture_of_judges
gaetanlop Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/judges.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,18 @@ judge.judge(
) # Outputs: [0, 1]
```

## AllTrueJudge

[[autodoc]] AllTrueJudge

## BaseJudge

[[autodoc]] BaseJudge

## BaseBinaryJudge

[[autodoc]] BaseBinaryJudge

## BaseRankJudge

[[autodoc]] BaseRankJudge
Expand All @@ -58,6 +66,10 @@ judge.judge(

[[autodoc]] BasePairwiseJudge

## RandomBinaryJudge

[[autodoc]] RandomBinaryJudge

## RandomRankJudge

[[autodoc]] RandomRankJudge
Expand Down
41 changes: 34 additions & 7 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,52 @@

import unittest

from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge, is_llmblender_available
from trl import (
AllTrueJudge,
HfPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
is_llmblender_available,
)


class TestJudges(unittest.TestCase):
def _get_prompts_and_completions(self):
def _get_prompts_and_pairwise_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
return prompts, completions

def _get_prompts_and_single_completions(self):
prompts = ["What's the capital of France?", "What's the color of the sky?"]
completions = ["Marseille", "blue"]
return prompts, completions

def test_all_true_judge(self):
moj = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
prompts, completions = self._get_prompts_and_single_completions()
judgements = moj.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {True, False} for judgement in judgements))

def test_random_constraint_judge(self):
judge = RandomBinaryJudge()
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_pairwise_judge(self):
judge = RandomPairwiseJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))

def test_random_rank_judge(self):
judge = RandomRankJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, list) for rank in ranks))
Expand All @@ -41,7 +68,7 @@ def test_random_rank_judge(self):
@unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HfPairwiseJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
Expand All @@ -50,7 +77,7 @@ def test_hugging_face_judge(self):
@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
def test_pair_rm_judge(self):
judge = PairRMJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
Expand All @@ -59,7 +86,7 @@ def test_pair_rm_judge(self):
@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
def test_pair_rm_judge_return_scores(self):
judge = PairRMJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
probs = judge.judge(prompts=prompts, completions=completions, return_scores=True)
self.assertEqual(len(probs), 2)
self.assertTrue(all(isinstance(prob, float) for prob in probs))
Expand Down
6 changes: 6 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
"trainer": [
"AlignPropConfig",
"AlignPropTrainer",
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"BCOConfig",
Expand Down Expand Up @@ -82,6 +84,7 @@
"PPOTrainer",
"PPOv2Config",
"PPOv2Trainer",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
"RewardConfig",
Expand Down Expand Up @@ -146,6 +149,8 @@
from .trainer import (
AlignPropConfig,
AlignPropTrainer,
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
Expand Down Expand Up @@ -178,6 +183,7 @@
PPOTrainer,
PPOv2Config,
PPOv2Trainer,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
RewardConfig,
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
"gkd_trainer": ["GKDTrainer"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"judges": [
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"HfPairwiseJudge",
"OpenAIPairwiseJudge",
"PairRMJudge",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
],
Expand Down Expand Up @@ -98,12 +101,15 @@
from .gkd_trainer import GKDTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .judges import (
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
HfPairwiseJudge,
OpenAIPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)
Expand Down
79 changes: 79 additions & 0 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,53 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order:
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class BaseBinaryJudge(BaseJudge):
"""
Base class for binary judges.
"""

@abstractmethod
def judge(
self,
prompts: List[str],
completions: List[str],
gold_completions: Optional[List[str]] = None,
shuffle_order: bool = True,
) -> List[int]:
"""
Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.

This base class should be used to implement binary evaluations as done in section 4.1.4 of the CGPO paper (https://arxiv.org/pdf/2409.20370).
gaetanlop marked this conversation as resolved.
Show resolved Hide resolved
It is relevant for assessing whether or not a prompt completion pair satisfies a specific contraint.

Args:
prompts (`List[str]`): List of prompts.
completions (`List[str]`): List of completions.
gold_completions (`List[str]`, `optional`): List of gold completions if it exists.
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.

Returns:
List[int]: A list of binary labels:
- 1 indicates that the completion satisfies the evaluated constraint.
- 0 indicates that the completion does not satisfy the evaluated constraint.

Note:
If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed.
For instance, this could occur if the underlying language model or rule based contraint returned an invalid answer.
In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling.
"""
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""

def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1]) for _ in range(len(prompts))]


class RandomRankJudge(BaseRankJudge):
"""
Random rank, for testing purposes.
Expand Down Expand Up @@ -361,3 +408,35 @@ def get_rank(prompt, candidates):

# Return the ranks
return ranks


class AllTrueJudge(BaseBinaryJudge):
"""
Unify the decision of multiple BaseBinaryJudge.

This class returns False if it fails on any of the binary judges (ie a judge returns 0 or -1) and returns True otherwise.

It is an implementation of the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370)

Args:
judges (`List[BaseBinaryJudge]`): A list of [`BaseBinaryJudge`].
"""
gaetanlop marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, judges: List[BaseBinaryJudge]):
self.judges = judges

def judge(
self,
prompts: List[str],
completions: List[str],
gold_completions: Optional[List[str]] = None,
shuffle_order: bool = True,
) -> List[bool]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should return a list of int to be consistent with the super class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!

all_binary_judgments = [
judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
]

return [
True if all(all_binary_judgment == 1 for all_binary_judgment in binary_judgments) else False
for binary_judgments in zip(*all_binary_judgments)
]