Skip to content

Commit f82b896

Browse files
committed
x
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent a7d4fd1 commit f82b896

6 files changed

Lines changed: 2 additions & 101 deletions

File tree

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
import warnings
1919
from collections import OrderedDict
2020
from contextlib import nullcontext
21-
<<<<<<< HEAD
2221
from typing import ClassVar, ContextManager, Unpack
23-
=======
24-
from typing import ClassVar, Unpack
25-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
2622

2723
import torch
2824
import torch.nn as nn
@@ -346,14 +342,6 @@ def forward(
346342
if te_rope_emb.dtype != torch.float32:
347343
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
348344

349-
<<<<<<< HEAD
350-
with self.get_autocast_context(None, outer=True):
351-
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
352-
if output_hidden_states:
353-
all_hidden_states = (*all_hidden_states, hidden_states)
354-
355-
with self.get_autocast_context(layer_idx):
356-
=======
357345
# Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
358346
# by get_layer_autocast(), which nests inside this context.
359347
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
@@ -362,7 +350,6 @@ def forward(
362350
all_hidden_states = (*all_hidden_states, hidden_states)
363351

364352
with self.get_layer_autocast(layer_number):
365-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
366353
hidden_states = decoder_layer(
367354
hidden_states,
368355
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,8 @@ fp8_config:
4141
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
4242
fp8_format: "HYBRID"
4343
fp8_recipe_kwargs: {}
44-
45-
fp4_config:
46-
enabled: false
47-
fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
48-
fp4_format: "E2M1"
49-
fp4_recipe_kwargs: {}
44+
quantized_model_init_kwargs:
45+
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.
5046

5147
fp4_config:
5248
enabled: false

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
import warnings
1919
from collections import OrderedDict
2020
from contextlib import nullcontext
21-
<<<<<<< HEAD
2221
from typing import ClassVar, ContextManager, Unpack
23-
=======
24-
from typing import ClassVar, Unpack
25-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
2622

2723
import torch
2824
import torch.nn as nn
@@ -346,14 +342,6 @@ def forward(
346342
if te_rope_emb.dtype != torch.float32:
347343
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
348344

349-
<<<<<<< HEAD
350-
with self.get_autocast_context(None, outer=True):
351-
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
352-
if output_hidden_states:
353-
all_hidden_states = (*all_hidden_states, hidden_states)
354-
355-
with self.get_autocast_context(layer_idx):
356-
=======
357345
# Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
358346
# by get_layer_autocast(), which nests inside this context.
359347
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
@@ -362,7 +350,6 @@ def forward(
362350
all_hidden_states = (*all_hidden_states, hidden_states)
363351

364352
with self.get_layer_autocast(layer_number):
365-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
366353
hidden_states = decoder_layer(
367354
hidden_states,
368355
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,

bionemo-recipes/recipes/llama3_native_te/train_ddp.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,27 +72,6 @@ def main(args: DictConfig) -> float | None:
7272
# Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2.
7373
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
7474

75-
<<<<<<< HEAD
76-
# --- Model Configuration ---
77-
# Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
78-
fp8_recipe = None
79-
if args.fp8_config.enabled:
80-
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
81-
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
82-
)
83-
84-
fp4_recipe = None
85-
if args.fp4_config.enabled:
86-
fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
87-
88-
# --- Model Initialization ---
89-
if args.use_te:
90-
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
91-
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
92-
else:
93-
config = LlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
94-
model = LlamaForCausalLM(config)
95-
=======
9675
if args.use_te:
9776
config_class = NVLlamaConfig
9877
model_class = NVLlamaForCausalLM
@@ -141,7 +120,6 @@ def main(args: DictConfig) -> float | None:
141120
recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
142121
):
143122
model = model_class(config)
144-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
145123

146124
logger.info("Initialized Model:\n%s", model)
147125

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,29 +74,6 @@ def main(args: DictConfig) -> float | None:
7474

7575
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
7676

77-
<<<<<<< HEAD
78-
# --- Model Configuration ---
79-
# Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
80-
fp8_recipe = None
81-
if args.fp8_config.enabled:
82-
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
83-
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
84-
)
85-
86-
fp4_recipe = None
87-
if args.fp4_config.enabled:
88-
fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
89-
90-
# --- Model Initialization ---
91-
if args.use_te:
92-
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
93-
with torch.device("meta") if args.use_meta_device else nullcontext():
94-
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
95-
else:
96-
config = LlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
97-
with torch.device("meta") if args.use_meta_device else nullcontext():
98-
model = LlamaForCausalLM(config)
99-
=======
10077
if args.use_te:
10178
config_class = NVLlamaConfig
10279
model_class = NVLlamaForCausalLM
@@ -152,7 +129,6 @@ def main(args: DictConfig) -> float | None:
152129
),
153130
):
154131
model = model_class(config)
155-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
156132

157133
logger.info("Initialized Model:\n%s", model)
158134

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,33 +81,11 @@ def main(args: DictConfig) -> float | None:
8181
logger.info("Created device mesh: %s", device_mesh)
8282

8383
# --- Model Configuration ---
84-
<<<<<<< HEAD
85-
<<<<<<< HEAD
86-
# Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
87-
fp8_recipe = None
88-
if args.fp8_config.enabled:
89-
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
90-
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
91-
)
92-
93-
fp4_recipe = None
94-
if args.fp4_config.enabled:
95-
fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
96-
97-
# --- Model Initialization ---
98-
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
99-
100-
with torch.device("meta") if args.use_meta_device else nullcontext():
101-
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
102-
=======
103-
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
104-
=======
10584
config = NVLlamaConfig.from_pretrained(
10685
args.config_name_or_path,
10786
dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16,
10887
**args.config_kwargs,
10988
)
110-
>>>>>>> 80e4897e (fixed quant stats init and adds fp32 master weights)
11189

11290
# Resolve layer-wise quantization assignments and store on config.
11391
layer_precision = resolve_layer_precision(
@@ -150,7 +128,6 @@ def main(args: DictConfig) -> float | None:
150128
),
151129
):
152130
model = NVLlamaForCausalLM(config)
153-
>>>>>>> 4067915d (adds llama3 MXFP8 NVFP4)
154131

155132
logger.info("Initialized Model:\n%s", model)
156133

0 commit comments

Comments
 (0)