Skip to content

Commit f5e4014

Browse files
committed
fix: Rename variable to control PG timeout to not refer to NCCL
The variable is not backend specific. In the future, if/when we support other backends, this will become more evidently a problem. Signed-off-by: Ihar Hrachyshka <[email protected]>
1 parent 1532531 commit f5e4014

File tree

5 files changed

+15
-10
lines changed

5 files changed

+15
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,4 +375,4 @@ run_training(
375375

376376
Below is a list of custom environment variables users can set in the training library.
377377

378-
1. `INSTRUCTLAB_NCCL_TIMEOUT_MS`, this environment variable controls the NCCL timeout in milliseconds. Consider increasing if seeing FSDP related NCCL errors.
378+
1. `INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS`, this environment variable controls the process group timeout in milliseconds. Consider increasing if seeing FSDP related errors.

src/instructlab/training/const.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS = "INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS"

src/instructlab/training/main_ds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161

6262
# pylint: disable=no-name-in-module
6363
from instructlab.training.config import DistributedBackend, TorchrunArgs, TrainingArgs
64+
from instructlab.training.const import INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS
6465
from instructlab.training.logger import setup_metric_logger, setup_root_logger
6566
from instructlab.training.multipack_sampler import (
6667
find_packing_max_batch_len_and_grad_accum,
@@ -550,7 +551,7 @@ def train(
550551
# time of writing) default: to cover the unlikely event torch decides to tweak
551552
# the default.
552553
def _get_collective_timeout() -> datetime.timedelta | None:
553-
timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS")
554+
timeout_var = os.getenv(INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS)
554555
if timeout_var is None:
555556
return None
556557

@@ -561,7 +562,7 @@ def _get_collective_timeout() -> datetime.timedelta | None:
561562

562563
if timeout <= 0:
563564
raise ValueError(
564-
f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer."
565+
f"Invalid value for {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS}: {timeout_var}. Must be a positive integer."
565566
)
566567

567568
return datetime.timedelta(milliseconds=timeout)

tests/unit/test_main_ds.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# First Party
99
from instructlab.training import main_ds
10+
from instructlab.training.const import INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS
1011

1112

1213
def test__get_collective_timeout():
@@ -16,7 +17,7 @@ def test__get_collective_timeout():
1617
# Test with custom timeout
1718
timeout = 1234
1819
with mock.patch.dict(
19-
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": str(timeout)}
20+
main_ds.os.environ, {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS: str(timeout)}
2021
):
2122
assert main_ds._get_collective_timeout() == datetime.timedelta(
2223
milliseconds=timeout
@@ -25,15 +26,15 @@ def test__get_collective_timeout():
2526
# Test with invalid timeout (negative)
2627
invalid_timeout = "-100"
2728
with mock.patch.dict(
28-
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout}
29+
main_ds.os.environ, {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS: invalid_timeout}
2930
):
3031
with pytest.raises(ValueError):
3132
main_ds._get_collective_timeout()
3233

3334
# Test with invalid timeout (string)
3435
invalid_timeout = "invalid"
3536
with mock.patch.dict(
36-
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout}
37+
main_ds.os.environ, {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS: invalid_timeout}
3738
):
3839
with pytest.raises(ValueError):
3940
main_ds._get_collective_timeout()

tox.ini

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ basepython = python3.11
1717
[testenv:py3-unit]
1818
description = run unit tests with pytest
1919
passenv =
20-
HF_HOME
21-
INSTRUCTLAB_NCCL_TIMEOUT_MS
20+
HF_HOME
21+
INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS
2222
deps =
2323
pytest
2424
wandb
@@ -33,8 +33,8 @@ commands = {envpython} -m pytest tests/unit {posargs}
3333
[testenv:py3-smoke]
3434
description = run accelerated smoke tests with pytest
3535
passenv =
36-
HF_HOME
37-
INSTRUCTLAB_NCCL_TIMEOUT_MS
36+
HF_HOME
37+
INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS
3838
deps =
3939
pytest
4040
-r requirements-dev.txt

0 commit comments

Comments
 (0)