@@ -139,10 +139,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
139
139
tensor_model_parallel_size = cfg .get ('tensor_model_parallel_size' , 1 ),
140
140
)
141
141
142
- # TODO: Not sure how to use lists of modules with PTL.
143
- # This means we can only use pipeline parallelism without the interleaved schedule.
144
- self .model = build_model (model_provider_func = self .model_provider_func , wrap_with_ddp = False )[0 ]
145
-
146
142
# Prompt tuning initialization
147
143
self .use_soft_prompts = self .cfg .get ('use_soft_prompts' , False )
148
144
@@ -156,12 +152,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
156
152
self .num_prompt_tokens = cfg .get ('num_prompt_tokens' , 100 )
157
153
158
154
if self .cfg .get ('existing_prompt_tags' , None ):
155
+ # Assign prompt tag ids if none were present in the config
156
+ if type (self .cfg .existing_prompt_tags [0 ]) == str :
157
+ existing_prompt_tags = self .cfg .existing_prompt_tags
158
+ num_prompt_tags = len (existing_prompt_tags )
159
+ existing_prompt_tags = [
160
+ (existing_prompt_tags [tag_id ], tag_id + 1 ) for tag_id in range (num_prompt_tags )
161
+ ]
162
+
163
+ with open_dict (self .cfg ):
164
+ self .cfg .existing_prompt_tags = existing_prompt_tags
165
+
159
166
# Fill table with prev tuned prompt tags and their ids
160
167
self .prompt_table = set (self .cfg .existing_prompt_tags )
161
168
162
169
# Get max prompt id from table for starting point of new prompt ids
163
170
self .next_prompt_id = max (self .prompt_table , key = lambda x : x [1 ])[1 ]
164
171
172
+ # TODO: Not sure how to use lists of modules with PTL.
173
+ # This means we can only use pipeline parallelism without the interleaved schedule.
174
+ self .model = build_model (model_provider_func = self .model_provider_func , wrap_with_ddp = False )[0 ]
175
+
165
176
self .setup_optimizer_param_groups ()
166
177
167
178
self .megatron_amp_o2 = cfg .get ('megatron_amp_O2' , False )
0 commit comments