Skip to content

Commit

Permalink
🧓 Specify and test min versions (#2303)
Browse files Browse the repository at this point in the history
* Add conditional check for LLMBlender availability in test_judges.py

* Fix import issues and update test requirements

* Remove unused imports

* Add require_peft decorator to test cases

* Fix import_utils module to use correct package name for llm_blender

* Found min version and test

* Update Slack notification titles

* Update dependencies versions

* Update GitHub Actions workflow to include setup.py and reorder file paths

* Revert "Update Slack notification titles"

This reverts commit be02a7f.

* Update Slack notification titles

* Remove pull_request branch restriction in tests.yml

* add check code quality back

* Fix PairRMJudge model loading issue
  • Loading branch information
qgallouedec authored Oct 31, 2024
1 parent d57a181 commit 6138439
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 19 deletions.
60 changes: 54 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@ on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
# Run only when relevant files are modified
- "trl/**.py"
- ".github/**.yml"
- "examples/**.py"
- "scripts/**.py"
- ".github/**.yml"
- "tests/**.py"
- "trl/**.py"
- "setup.py"

env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}

jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: recursive
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: 3.12
- uses: pre-commit/[email protected]
with:
extra_args: --all-files

tests:
name: Tests
strategy:
Expand Down Expand Up @@ -49,7 +65,7 @@ jobs:
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the TRL CI with dev dependencies
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with lastest dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

Expand Down Expand Up @@ -81,7 +97,7 @@ jobs:
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the TRL CI with dev dependencies
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

Expand Down Expand Up @@ -110,6 +126,38 @@ jobs:
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the TRL CI with dev dependencies
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} without optional dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

tests_min_versions:
name: Tests with minimum versions
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install accelerate==0.34.0
python -m pip install datasets==2.21.0
python -m pip install transformers==4.46.0
python -m pip install ".[dev]"
- name: Test with pytest
run: |
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with minimum versions
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@
__version__ = "0.12.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)

REQUIRED_PKGS = [
"accelerate",
"datasets",
"accelerate>=0.34.0",
"datasets>=2.21.0",
"rich", # rich shouldn't be a required package for trl, we should remove it from here
"transformers>=4.46.0",
]
Expand Down
23 changes: 13 additions & 10 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import unittest

from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge, is_llm_blender_available
from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge

from .testing_utils import require_llm_blender


class TestJudges(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize once to download the model. This ensures it’s downloaded before running tests, preventing issues
# where concurrent tests attempt to load the model while it’s still downloading.
if is_llm_blender_available():
PairRMJudge()

def _get_prompts_and_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
Expand Down Expand Up @@ -56,9 +50,18 @@ def test_hugging_face_judge(self):
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
self.assertEqual(ranks, [0, 1])

def load_pair_rm_judge(self):
# When using concurrent tests, PairRM may fail to load the model while another job is still downloading.
# This is a workaround to retry loading the model a few times.
for _ in range(5):
try:
return PairRMJudge()
except ValueError:
time.sleep(5)

@require_llm_blender
def test_pair_rm_judge(self):
judge = PairRMJudge()
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
Expand All @@ -67,7 +70,7 @@ def test_pair_rm_judge(self):

@require_llm_blender
def test_pair_rm_judge_return_scores(self):
judge = PairRMJudge()
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_completions()
probs = judge.judge(prompts=prompts, completions=completions, return_scores=True)
self.assertEqual(len(probs), 2)
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
self.data_collator = data_collator
self.eval_dataset = eval_dataset
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47

#########
# calculate various batch sizes
Expand Down

0 comments on commit 6138439

Please sign in to comment.