@@ -92,6 +92,7 @@ def spmd(
9292 h : str = "gpu.small" ,
9393 j : str = "1x1" ,
9494 env : Optional [Dict [str , str ]] = None ,
95+ metadata : Optional [Dict [str , str ]] = None ,
9596 max_retries : int = 0 ,
9697 mounts : Optional [List [str ]] = None ,
9798 debug : bool = False ,
@@ -131,6 +132,7 @@ def spmd(
131132 h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
132133 j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133134 env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
134136 max_retries: the number of scheduler retries allowed
135137 rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
136138 Only takes effect when running multi-node. When running single node, this parameter
@@ -153,6 +155,7 @@ def spmd(
153155 h = h ,
154156 j = str (StructuredJArgument .parse_from (h , j )),
155157 env = env ,
158+ metadata = metadata ,
156159 max_retries = max_retries ,
157160 mounts = mounts ,
158161 debug = debug ,
@@ -171,6 +174,7 @@ def ddp(
171174 memMB : int = 1024 ,
172175 j : str = "1x2" ,
173176 env : Optional [Dict [str , str ]] = None ,
177+ metadata : Optional [Dict [str , str ]] = None ,
174178 max_retries : int = 0 ,
175179 rdzv_port : int = 29500 ,
176180 rdzv_backend : str = "c10d" ,
@@ -203,6 +207,7 @@ def ddp(
203207 h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
204208 j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
205209 env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
210+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
206211 max_retries: the number of scheduler retries allowed
207212 rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
208213 Only takes effect when running multi-node. When running single node, this parameter
@@ -238,8 +243,8 @@ def ddp(
238243 # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
239244 rdzv_endpoint = _noquote (f"$${{{ macros .rank0_env } :=localhost}}:{ rdzv_port } " )
240245
241- if env is None :
242- env = {}
246+ env = env or {}
247+ metadata = metadata or {}
243248
244249 argname = StructuredNameArgument .parse_from (
245250 name = name ,
@@ -292,6 +297,7 @@ def ddp(
292297 resource = specs .resource (cpu = cpu , gpu = gpu , memMB = memMB , h = h ),
293298 args = ["-c" , _args_join (cmd )],
294299 env = env ,
300+ metadata = metadata ,
295301 port_map = {
296302 "c10d" : rdzv_port ,
297303 },
0 commit comments