Skip to content

Commit 7fc2b7a

Browse files
JobDefinition env parameter to set pip index url and host
1 parent 745772e commit 7fc2b7a

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

src/codeflare_sdk/job/jobs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from torchx.schedulers.ray_scheduler import RayScheduler
2323
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2424

25+
from ..utils.generate_yaml import update_pip_requirements
26+
2527

2628
if TYPE_CHECKING:
2729
from ..cluster.cluster import Cluster
@@ -90,6 +92,12 @@ def __init__(
9092
)
9193
self.image = image
9294
self.workspace = workspace
95+
if "PIP_TRUSTED_HOST" in self.env or "PIP_INDEX_URL" in self.env:
96+
update_pip_requirements(self)
97+
else:
98+
self.env.setdefault("PIP_TRUSTED_HOST", "pypi.org")
99+
self.env.setdefault("PIP_INDEX_URL", "https://pypi.org/simple")
100+
update_pip_requirements(self)
93101

94102
def _dry_run(self, cluster: "Cluster"):
95103
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sys
2222
import os
2323
import argparse
24+
from pathlib import Path
2425
import uuid
2526
from kubernetes import client, config
2627
from .kube_api_helpers import _kube_api_error_handling
@@ -689,3 +690,38 @@ def generate_appwrapper(
689690
else:
690691
write_user_appwrapper(user_yaml, outfile)
691692
return outfile
693+
694+
695+
def update_pip_requirements(self):
696+
pip_trusted_host = self.env.get("PIP_TRUSTED_HOST")
697+
pip_index_url = self.env.get("PIP_INDEX_URL")
698+
requirements_path = Path("requirements.txt")
699+
700+
if requirements_path.exists():
701+
with requirements_path.open("r") as file:
702+
requirements = file.readlines()
703+
704+
# Check and replace or add --trusted-host and --index-url
705+
trusted_host = f"--trusted-host {pip_trusted_host}\n"
706+
index_url = f"--index-url {pip_index_url}\n"
707+
modified_requirements = []
708+
709+
for line in requirements:
710+
if line.startswith("--trusted-host"):
711+
modified_requirements.append(trusted_host)
712+
trusted_host = None
713+
elif line.startswith("--index-url"):
714+
modified_requirements.append(index_url)
715+
index_url = None
716+
else:
717+
modified_requirements.append(line)
718+
719+
# Append the lines if they were not replaced
720+
if trusted_host:
721+
modified_requirements.insert(0, trusted_host)
722+
if index_url:
723+
modified_requirements.insert(0, index_url)
724+
725+
# Write back the modified requirements
726+
with requirements_path.open("w") as file:
727+
file.writelines(modified_requirements)

tests/unit_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2091,7 +2091,11 @@ def test_DDPJobDefinition_creation():
20912091
assert ddp.memMB == 1024
20922092
assert ddp.h == None
20932093
assert ddp.j == "2x1"
2094-
assert ddp.env == {"test": "test"}
2094+
assert ddp.env == {
2095+
"PIP_TRUSTED_HOST": "pypi.org",
2096+
"PIP_INDEX_URL": "https://pypi.org/simple",
2097+
"test": "test",
2098+
}
20952099
assert ddp.max_retries == 0
20962100
assert ddp.mounts == []
20972101
assert ddp.rdzv_port == 29500

0 commit comments

Comments
 (0)