@@ -92,6 +92,7 @@ def spmd(
9292 h : str = "gpu.small" ,
9393 j : str = "1x1" ,
9494 env : Optional [Dict [str , str ]] = None ,
95+ metadata : Optional [Dict [str , str ]] = None ,
9596 max_retries : int = 0 ,
9697 mounts : Optional [List [str ]] = None ,
9798 debug : bool = False ,
@@ -131,6 +132,7 @@ def spmd(
131132 h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
132133 j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133134 env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
134136 max_retries: the number of scheduler retries allowed
135137 mounts: (for docker based runs only) mounts to mount into the worker environment/container
136138 (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
@@ -150,6 +152,7 @@ def spmd(
150152 h = h ,
151153 j = str (StructuredJArgument .parse_from (h , j )),
152154 env = env ,
155+ metadata = metadata ,
153156 max_retries = max_retries ,
154157 mounts = mounts ,
155158 debug = debug ,
@@ -168,6 +171,7 @@ def ddp(
168171 memMB : int = 1024 ,
169172 j : str = "1x2" ,
170173 env : Optional [Dict [str , str ]] = None ,
174+ metadata : Optional [Dict [str , str ]] = None ,
171175 max_retries : int = 0 ,
172176 rdzv_port : int = 29500 ,
173177 rdzv_backend : str = "c10d" ,
@@ -201,6 +205,7 @@ def ddp(
201205 h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
202206 j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
203207 env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
208+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
204209 max_retries: the number of scheduler retries allowed
205210 rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
206211 Only takes effect when running multi-node. When running single node, this parameter
@@ -237,8 +242,8 @@ def ddp(
237242 # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
238243 rdzv_endpoint = _noquote (f"$${{{ macros .rank0_env } :=localhost}}:{ rdzv_port } " )
239244
240- if env is None :
241- env = {}
245+ env = env or {}
246+ metadata = metadata or {}
242247
243248 argname = StructuredNameArgument .parse_from (
244249 name = name ,
@@ -299,6 +304,7 @@ def ddp(
299304 mounts = specs .parse_mounts (mounts ) if mounts else [],
300305 )
301306 ],
307+ metadata = metadata ,
302308 )
303309
304310
0 commit comments