diff --git a/ramalama/model.py b/ramalama/model.py index 11ab530d..8ef6bbc9 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -142,9 +142,15 @@ def setup_container(self, args): if os.path.exists("/dev/kfd"): conman_args += ["--device", "/dev/kfd"] + env_vars = {k: v for k, v in os.environ.items() if k.startswith(("ASAHI_", "CUDA_", "HIP_", "HSA_"))} + gpu_type, gpu_num = get_gpu() - if gpu_type == "HIP_VISIBLE_DEVICES" or gpu_type == "ASAHI_VISIBLE_DEVICES": - conman_args += ["-e", f"{gpu_type}={gpu_num}"] + if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}: + env_vars[gpu_type] = str(gpu_num) + + for k, v in env_vars.items(): + conman_args += ["-e", f"{k}={v}"] + return conman_args def run_container(self, args, shortnames): diff --git a/test/system/030-run.bats b/test/system/030-run.bats index 5ad8cab6..469903a4 100755 --- a/test/system/030-run.bats +++ b/test/system/030-run.bats @@ -42,6 +42,23 @@ load helpers fi } +@test "ramalama --dryrun run ensure env vars are respected" { + skip_if_nocontainer + model=tiny + + ASAHI_VISIBLE_DEVICES=99 run_ramalama --dryrun run ${model} + is "$output" ".*-e ASAHI_VISIBLE_DEVICES=99" "ensure ASAHI_VISIBLE_DEVICES is set from environment" + + CUDA_LAUNCH_BLOCKING=1 run_ramalama --dryrun run ${model} + is "$output" ".*-e CUDA_LAUNCH_BLOCKING=1" "ensure CUDA_LAUNCH_BLOCKING is set from environment" + + HIP_VISIBLE_DEVICES=99 run_ramalama --dryrun run ${model} + is "$output" ".*-e HIP_VISIBLE_DEVICES=99" "ensure HIP_VISIBLE_DEVICES is set from environment" + + HSA_OVERRIDE_GFX_VERSION=0.0.0 run_ramalama --dryrun run ${model} + is "$output" ".*-e HSA_OVERRIDE_GFX_VERSION=0.0.0" "ensure HSA_OVERRIDE_GFX_VERSION is set from environment" +} + @test "ramalama run tiny with prompt" { skip_if_notlocal run_ramalama run --name foobar tiny "Write a 1 line poem"