From 966c3adc207e3d82eff1bfd74f4070340567e6cc Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 18 Oct 2024 12:26:48 -0700 Subject: [PATCH] Convert arg to str for template substitution Signed-off-by: Hemil Desai --- src/nemo_run/core/execution/base.py | 2 +- test/core/execution/test_base.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nemo_run/core/execution/base.py b/src/nemo_run/core/execution/base.py index aa11e73..cadd328 100644 --- a/src/nemo_run/core/execution/base.py +++ b/src/nemo_run/core/execution/base.py @@ -115,7 +115,7 @@ def substitute(self, arg: str) -> str: """ substitute applies the values to the template arg. """ - return Template(arg).safe_substitute(**asdict(self)) + return Template(str(arg)).safe_substitute(**asdict(self)) @runtime_checkable diff --git a/test/core/execution/test_base.py b/test/core/execution/test_base.py index 1dc23d0..588e101 100644 --- a/test/core/execution/test_base.py +++ b/test/core/execution/test_base.py @@ -16,6 +16,8 @@ import fiddle as fdl import pytest +from torchx.specs import Role + from nemo_run.config import Config from nemo_run.core.execution.base import ( Executor, @@ -25,7 +27,6 @@ Torchrun, ) from nemo_run.core.execution.slurm import SlurmExecutor -from torchx.specs import Role class TestExecutorMacros: @@ -59,6 +60,8 @@ def test_substitute(self): assert macros.substitute("${head_node_ip_var}") == "192.168.0.1" assert macros.substitute("${nproc_per_node_var}") == "4" + assert macros.substitute(1) == "1" + def test_group_host(self): macros = SlurmExecutor(account="a").macro_values() assert macros