Skip to content

Commit 256236f

Browse files
vadam5ericharper
andauthored
Prompt tuning bug fix (#3780)
* Making updated code backwards compatible with previous prompt tuned models Signed-off-by: Virginia Adams <[email protected]> * Fixed backward compatiablity bug Signed-off-by: Virginia Adams <[email protected]> * Removed random import Signed-off-by: Virginia Adams <[email protected]> Co-authored-by: Eric Harper <[email protected]>
1 parent afba754 commit 256236f

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
139139
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
140140
)
141141

142-
# TODO: Not sure how to use lists of modules with PTL.
143-
# This means we can only use pipeline parallelism without the interleaved schedule.
144-
self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0]
145-
146142
# Prompt tuning initialization
147143
self.use_soft_prompts = self.cfg.get('use_soft_prompts', False)
148144

@@ -156,12 +152,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
156152
self.num_prompt_tokens = cfg.get('num_prompt_tokens', 100)
157153

158154
if self.cfg.get('existing_prompt_tags', None):
155+
# Assign prompt tag ids if none were present in the config
156+
if type(self.cfg.existing_prompt_tags[0]) == str:
157+
existing_prompt_tags = self.cfg.existing_prompt_tags
158+
num_prompt_tags = len(existing_prompt_tags)
159+
existing_prompt_tags = [
160+
(existing_prompt_tags[tag_id], tag_id + 1) for tag_id in range(num_prompt_tags)
161+
]
162+
163+
with open_dict(self.cfg):
164+
self.cfg.existing_prompt_tags = existing_prompt_tags
165+
159166
# Fill table with prev tuned prompt tags and their ids
160167
self.prompt_table = set(self.cfg.existing_prompt_tags)
161168

162169
# Get max prompt id from table for starting point of new prompt ids
163170
self.next_prompt_id = max(self.prompt_table, key=lambda x: x[1])[1]
164171

172+
# TODO: Not sure how to use lists of modules with PTL.
173+
# This means we can only use pipeline parallelism without the interleaved schedule.
174+
self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0]
175+
165176
self.setup_optimizer_param_groups()
166177

167178
self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False)

0 commit comments

Comments
 (0)