Skip to content

Commit

Permalink
[CLI] trl env for printing system info (#2104)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Sep 24, 2024
1 parent 6859e04 commit 2cad48d
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ body:
id: system-info
attributes:
label: System Info
description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below.
description: Please share your system info with us. You can run the command `trl env` and copy-paste its output below.
placeholder: trl version, transformers version, platform, python version, ...
validations:
required: true
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Once you've confirmed the bug hasn't already been reported, please include the f
To get the OS and software versions automatically, run the following command:

```bash
transformers-cli env
trl env
```

### Do you want a new feature?
Expand Down
52 changes: 52 additions & 0 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Currently supported CLIs are:
- `trl sft`: fine-tune a LLM on a text/instruction dataset
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
- `trl env`: get the system information

## Fine-tuning with the CLI

Expand Down Expand Up @@ -117,3 +118,54 @@ Besides talking to the model there are a few commands you can use:
- **exit**: closes the interface

The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.

## Getting the system information

You can get the system information by running the following command:

```bash
trl env
```

This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.

```txt
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
```

This information are required when reporting an issue.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.14.4"],
"quantization": ["bitsandbytes<=0.41.1"],
"llm_judge": ["openai>=1.23.2", "huggingface_hub>=0.22.2", "llm-blender>=0.0.2"],
"llm_judge": ["openai>=1.23.2", "llm-blender>=0.0.2"],
}
EXTRAS["dev"] = []
for reqs in EXTRAS.values():
Expand Down
5 changes: 5 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ def test_dpo_cli():
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc


def test_env_cli():
output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True)
assert "- Python version: " in output.stdout
8 changes: 7 additions & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"BestOfNSampler",
],
"import_utils": [
"is_deepspeed_available",
"is_diffusers_available",
"is_liger_kernel_available",
"is_llmblender_available",
Expand Down Expand Up @@ -131,7 +132,12 @@
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import is_diffusers_available, is_liger_kernel_available, is_llmblender_available
from .import_utils import (
is_deepspeed_available,
is_diffusers_available,
is_liger_kernel_available,
is_llmblender_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
Expand Down
114 changes: 96 additions & 18 deletions trl/commands/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This file is a copy of trl/examples/scripts/sft.py so that we could
# use it together with rich and the TRL CLI in a more customizable manner.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,41 +12,106 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
import subprocess
import sys
from importlib.metadata import version
from subprocess import CalledProcessError

import torch
from accelerate.commands.config import default_config_file, load_config_from_file
from rich.console import Console
from transformers import is_bitsandbytes_available
from transformers.utils import is_openai_available, is_peft_available

from .. import (
__version__,
is_deepspeed_available,
is_diffusers_available,
is_liger_kernel_available,
is_llmblender_available,
)
from .cli_utils import get_git_commit_hash

SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto"]

SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto", "env"]

def main():

def print_env():
accelerate_config = accelerate_config_str = "not found"

# Get the default from the config file.
if os.path.isfile(default_config_file):
accelerate_config = load_config_from_file(default_config_file).to_dict()

accelerate_config_str = (
"\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()])
if isinstance(accelerate_config, dict)
else accelerate_config
)

commit_hash = get_git_commit_hash("trl")

info = {
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": version("torch"),
"CUDA device": torch.cuda.get_device_name() if torch.cuda.is_available() else "not available",
"Transformers version": version("transformers"),
"Accelerate version": version("accelerate"),
"Accelerate config": accelerate_config_str,
"Datasets version": version("datasets"),
"HF Hub version": version("huggingface_hub"),
"TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__,
"bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed",
"DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed",
"Diffusers version": version("diffusers") if is_diffusers_available() else "not installed",
"Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed",
"LLM-Blender version": version("llm_blender") if is_llmblender_available() else "not installed",
"OpenAI version": version("openai") if is_openai_available() else "not installed",
"PEFT version": version("peft") if is_peft_available() else "not installed",
}

info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()])
print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa


def train(command_name):
console = Console()
# Make sure to import things locally to avoid verbose from third party libs.
with console.status("[bold purple]Welcome! Initializing the TRL CLI..."):
from trl.commands.cli_utils import init_zero_verbose

init_zero_verbose()

command_name = sys.argv[1]
trl_examples_dir = os.path.dirname(__file__)

if command_name not in SUPPORTED_COMMANDS:
raise ValueError(
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}"
)
command = f"accelerate launch {trl_examples_dir}/scripts/{command_name}.py {' '.join(sys.argv[2:])}"

try:
subprocess.run(
command.split(),
text=True,
check=True,
encoding="utf-8",
cwd=os.getcwd(),
env=os.environ.copy(),
)
except (CalledProcessError, ChildProcessError) as exc:
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
raise ValueError("TRL CLI failed! Check the traceback above..") from exc


def chat():
console = Console()
# Make sure to import things locally to avoid verbose from third party libs.
with console.status("[bold purple]Welcome! Initializing the TRL CLI..."):
from trl.commands.cli_utils import init_zero_verbose

init_zero_verbose()
trl_examples_dir = os.path.dirname(__file__)

if command_name == "chat":
command = f"""
python {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
"""
else:
command = f"""
accelerate launch {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
"""
command = f"accelerate launch {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}"

try:
subprocess.run(
Expand All @@ -60,9 +123,24 @@ def main():
env=os.environ.copy(),
)
except (CalledProcessError, ChildProcessError) as exc:
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
console.log("TRL - CHAT failed! See the logs above for further details.")
raise ValueError("TRL CLI failed! Check the traceback above..") from exc


def main():
command_name = sys.argv[1]

if command_name in ["sft", "dpo", "kto"]:
train(command_name)
elif command_name == "chat":
chat()
elif command_name == "env":
print_env()
else:
raise ValueError(
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}"
)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import logging
import os
import subprocess
import sys
from argparse import Namespace
from dataclasses import dataclass, field
Expand Down Expand Up @@ -279,3 +282,26 @@ def set_defaults_with_config(self, **kwargs):
if action.dest in kwargs:
action.default = kwargs[action.dest]
action.required = False


def get_git_commit_hash(package_name):
try:
# Import the package to locate its path
package = importlib.import_module(package_name)
# Get the path to the package using inspect
package_path = os.path.dirname(inspect.getfile(package))

# Navigate up to the Git repository root if the package is inside a subdirectory
git_repo_path = os.path.abspath(os.path.join(package_path, ".."))
git_dir = os.path.join(git_repo_path, ".git")

if os.path.isdir(git_dir):
# Run the git command to get the current commit hash
commit_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=git_repo_path).strip().decode("utf-8")
)
return commit_hash
else:
return None
except Exception as e:
return f"Error: {str(e)}"
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@
_is_python_greater_3_8 = True

# Use same as transformers.utils.import_utils
_deepspeed_available = _is_package_available("deepspeed")
_diffusers_available = _is_package_available("diffusers")
_unsloth_available = _is_package_available("unsloth")
_rich_available = _is_package_available("rich")
_liger_kernel_available = _is_package_available("liger_kernel")
_llmblender_available = _is_package_available("llm_blender")


def is_deepspeed_available() -> bool:
return _deepspeed_available


def is_diffusers_available() -> bool:
return _diffusers_available

Expand Down

0 comments on commit 2cad48d

Please sign in to comment.