Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bionemo-recipes/recipes/evo2_megatron/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ bionemo-core = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch
nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" } # pragma: allowlist secret

# Megatron Bundle. This points to a version that still supports the deprecated no_weight_decay_cond field until the API for an alternative has been finalized.
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "18ef1b61309dd45bc0535fb7c60064b9d8829a35" } # pragma: allowlist secret
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "18ef1b61309dd45bc0535fb7c60064b9d8829a35", subdirectory = "3rdparty/Megatron-LM" } # pragma: allowlist secret
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "c415f4616340a0431c0eae776d0482ab5cc3770e" } # pragma: allowlist secret
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "c415f4616340a0431c0eae776d0482ab5cc3770e", subdirectory = "3rdparty/Megatron-LM" } # pragma: allowlist secret

[tool.uv.extra-build-dependencies]
warp-lang = ["wheel_stub"]
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import torch
from megatron.bridge.models.model_provider import ModelProviderMixin
from megatron.bridge.models.transformer_config import TransformerConfig
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.config import (
ConfigContainer,
OptimizerConfigOverrideProvider,
OptimizerConfigOverrideProviderContext,
)
from megatron.bridge.training.gpt_step import get_batch_from_iterator
from megatron.bridge.training.losses import masked_next_token_loss
from megatron.bridge.training.state import GlobalState
Expand All @@ -34,24 +38,21 @@
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size
from megatron.core import parallel_state
from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.optimizer import (
ParamGroupOverride,
ParamKey,
ParamPredicate,
)
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config

from bionemo.evo2.models.megatron.hyena.hyena_config import HyenaConfig as _HyenaConfigForFlops

# from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step # FIXME do megatron bridge thing instead of this
from bionemo.evo2.models.megatron.hyena.hyena_layer_specs import get_hyena_stack_spec
from bionemo.evo2.models.megatron.hyena.hyena_model import HyenaModel as MCoreHyenaModel
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond


# from nemo.lightning import get_vocab_size, io, teardown
# from nemo.lightning.base import NEMO_MODELS_CACHE
# from nemo.lightning.io.state import TransformFns
# from nemo.utils import logging


def get_vocab_size(*args, **kwargs):
raise NotImplementedError("FIXME get_vocab_size is not implemented Find it in megatron bridge")

Expand All @@ -60,7 +61,47 @@ def gpt_data_step(*args, **kwargs):
raise NotImplementedError("FIXME gpt_data_step is not implemented Find it in megatron bridge")


# FIXME convert the nemo style configs to megatron bridge style configs
@dataclass
class HyenaOptimizerConfigOverrideProvider(OptimizerConfigOverrideProvider):
"""Hyena-specific optimizer config override provider."""

no_weight_decay_embeddings: bool = False

def build_config_overrides(
self, context: OptimizerConfigOverrideProviderContext
) -> dict[ParamKey, ParamGroupOverride] | None:
"""Build config overrides for weight decay based on scheduler configuration.

This function creates parameter-specific overrides for weight decay behavior.
By default, weight decay is skipped for bias parameters and 1D parameters.
For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.
"""
optimizer_config = context.optimizer_config
config_overrides: dict[ParamKey, ParamGroupOverride] = {}
param_length_1_match = ParamPredicate(name="param_len_1", fn=lambda param: len(param.shape) == 1)
name_tuple: tuple[str, ...] = (
"*.bias",
"*.filter.p",
"*.filter.R",
"*.filter.gamma",
"*.short_conv.short_conv_weight",
)
if self.no_weight_decay_embeddings:
name_tuple += ("*embedding*",)
param_wd_mult_key = ParamKey(
name=name_tuple, # type: ignore
predicate=param_length_1_match,
)

config_overrides[param_wd_mult_key] = ParamGroupOverride(wd_mult=0.0) # type: ignore

if optimizer_config.decoupled_lr is not None:
decoupled_lr_config: ParamGroupOverride = {"max_lr": optimizer_config.decoupled_lr}
decoupled_param_key = ParamKey(attr="is_embedding_or_output_parameter")
if optimizer_config.decoupled_min_lr is not None:
decoupled_lr_config["min_lr"] = optimizer_config.decoupled_min_lr
config_overrides[decoupled_param_key] = decoupled_lr_config
return config_overrides


class HyenaInferenceContext(StaticInferenceContext):
Expand All @@ -75,103 +116,6 @@ def reset(self):
delattr(self, key)


# FIXME convert this to the megatron bridge style config for inference.
# class HyenaModel(GPTModel):
# """This is a wrapper around the MCoreHyenaModel to allow for inference.

# Our model follows the same API as the GPTModel, but the megatron model class is different so we need to handle the inference wrapper slightly differently.
# """

# def get_inference_wrapper(
# self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=None
# ) -> torch.Tensor:
# """Gets the inference wrapper for the Hyena model.

# Args:
# params_dtype: The data type for model parameters
# inference_batch_times_seqlen_threshold: Threshold for batch size * sequence length during inference
# inference_max_seq_length: Maximum sequence length for inference

# Returns:
# GPTInferenceWrapper: The inference wrapper for the model

# Raises:
# ValueError: If MCoreHyenaModel instance not found or vocab size cannot be determined
# """
# # This is to get the MCore model required in GPTInferenceWrapper.
# mcore_model = self.module
# while mcore_model:
# if type(mcore_model) is MCoreHyenaModel:
# break
# mcore_model = getattr(mcore_model, "module", None)
# if mcore_model is None or type(mcore_model) is not MCoreHyenaModel:
# raise ValueError("Exact MCoreHyenaModel instance not found in the model structure.")

# vocab_size = None
# if self.tokenizer is not None:
# vocab_size = self.tokenizer.vocab_size
# elif hasattr(self.config, "vocab_size"):
# vocab_size = self.config.vocab_size
# else:
# raise ValueError(
# "Unable to find vocab size."
# " Either pass in a tokenizer with vocab size, or set vocab size in the model config"
# )

# inference_wrapper_config = InferenceWrapperConfig(
# hidden_size=mcore_model.config.hidden_size,
# params_dtype=params_dtype,
# inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
# padded_vocab_size=vocab_size,
# inference_max_seq_length=inference_max_seq_length,
# inference_max_requests=1,
# )

# inference_context = HyenaInferenceContext.from_config(inference_wrapper_config)
# model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config, inference_context)
# return model_inference_wrapper

# def forward(
# self,
# input_ids: torch.Tensor,
# position_ids: torch.Tensor,
# attention_mask: Optional[torch.Tensor] = None,
# labels: Optional[torch.Tensor] = None,
# decoder_input: Optional[torch.Tensor] = None,
# loss_mask: Optional[torch.Tensor] = None,
# inference_context=None,
# packed_seq_params=None,
# ) -> torch.Tensor:
# """Forward pass of the Hyena model.

# Args:
# input_ids: Input token IDs
# position_ids: Position IDs for input tokens
# attention_mask: Optional attention mask
# labels: Optional labels for loss computation
# decoder_input: Optional decoder input
# loss_mask: Optional loss mask
# inference_context: Optional inference parameters
# packed_seq_params: Optional parameters for packed sequences


# Returns:
# torch.Tensor: Output tensor from the model
# """
# extra_kwargs = {"packed_seq_params": packed_seq_params} if packed_seq_params is not None else {}
# output_tensor = self.module(
# input_ids,
# position_ids,
# attention_mask,
# decoder_input=decoder_input,
# labels=labels,
# inference_context=inference_context,
# loss_mask=loss_mask,
# **extra_kwargs,
# )
# return output_tensor


def get_batch(
data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection
) -> tuple[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from bionemo.evo2.models.evo2_provider import (
Hyena1bModelProvider,
HyenaModelProvider,
HyenaOptimizerConfigOverrideProvider,
)
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond


class Evo2CommonKwargs(TypedDict, total=False):
Expand Down Expand Up @@ -87,7 +87,8 @@ class Evo2CommonKwargs(TypedDict, total=False):
# Precision / overlap configs
precision_config: MixedPrecisionConfig | str | None
comm_overlap_config: CommOverlapConfig | None
pad_eod_loss_mask: bool = False
pad_eod_loss_mask: bool
no_weight_decay_embeddings: bool


def evo2_1b_pretrain_config(**user_kwargs: Unpack[Evo2CommonKwargs]) -> ConfigContainer:
Expand All @@ -107,6 +108,7 @@ def evo2_1b_pretrain_config(**user_kwargs: Unpack[Evo2CommonKwargs]) -> ConfigCo
"pipeline_model_parallel_size": 1,
"sequence_parallel": False,
"precision_config": "bf16_mixed",
"no_weight_decay_embeddings": False,
}
kwargs: Evo2CommonKwargs = {**recommended, **user_kwargs}
return _evo2_common(**kwargs)
Expand Down Expand Up @@ -155,6 +157,7 @@ def _evo2_common(
# TODO fp8 etc
precision_config: MixedPrecisionConfig | str | None = "bf16_mixed",
comm_overlap_config: CommOverlapConfig | None = None,
no_weight_decay_embeddings: bool = False,
pad_eod_loss_mask: bool = False,
) -> ConfigContainer:
"""Create a pre-training configuration for Mamba 2.x models.
Expand Down Expand Up @@ -229,7 +232,6 @@ def _evo2_common(
max_lr=lr,
min_lr=min_lr,
)
scheduler.no_weight_decay_cond_type = hyena_no_weight_decay_cond

cfg = ConfigContainer(
model=model_cfg,
Expand All @@ -241,6 +243,9 @@ def _evo2_common(
micro_batch_size=micro_batch_size,
),
optimizer=opt_config,
optimizer_config_override_provider=HyenaOptimizerConfigOverrideProvider(
no_weight_decay_embeddings=no_weight_decay_embeddings,
),
scheduler=scheduler,
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH
from bionemo.evo2.models.evo2_provider import HYENA_MODEL_OPTIONS, hyena_forward_step
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond_with_embeddings
from bionemo.evo2.recipes.evo2 import evo2_1b_pretrain_config as pretrain_config


Expand Down Expand Up @@ -759,6 +758,9 @@ def train(args: argparse.Namespace) -> None:
recipe_kwargs["dir"] = args.result_dir
recipe_kwargs["name"] = args.experiment_name

if args.no_weight_decay_embeddings:
recipe_kwargs["no_weight_decay_embeddings"] = True

# 2. Generate Base Configuration
cfg: ConfigContainer = pretrain_config(**recipe_kwargs)

Expand Down Expand Up @@ -828,9 +830,6 @@ def train(args: argparse.Namespace) -> None:
else:
if args.activation_checkpoint_recompute_num_layers is not None:
cfg.model.recompute_num_layers = args.activation_checkpoint_recompute_num_layers

if args.no_weight_decay_embeddings:
cfg.scheduler.no_weight_decay_cond_type = hyena_no_weight_decay_cond_with_embeddings
# Optimizer
if args.wd is not None:
cfg.optimizer.weight_decay = args.wd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,10 @@

import pytest
import torch
from megatron.bridge.training.utils.weight_decay_utils import get_no_weight_decay_cond
from megatron.bridge.training.config import OptimizerConfig, OptimizerConfigOverrideProviderContext, SchedulerConfig
from megatron.core.optimizer import _get_param_groups

from bionemo.evo2.models.evo2_provider import HyenaNVTestModelProvider
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond


def test_no_weight_decay_cond_fn():
"""Verify that the get_no_weight_decay_cond function returns our lambda properly."""
assert (
get_no_weight_decay_cond(hyena_no_weight_decay_cond, default_skip_embedding_weight_decay=False)
is hyena_no_weight_decay_cond
)
from bionemo.evo2.models.evo2_provider import HyenaNVTestModelProvider, HyenaOptimizerConfigOverrideProvider


class _FakePGCollection:
Expand Down Expand Up @@ -92,27 +83,35 @@ def test_weight_decay_conditions():
config.finalize()
assert config.init_method is not None
model = config.provide(pre_process=True, post_process=True)
optimizer_config_override_provider = HyenaOptimizerConfigOverrideProvider(
no_weight_decay_embeddings=False,
)
optimizer_config = OptimizerConfig(
optimizer="adam",
lr=1.0,
weight_decay=1.0,
)
scheduler_config = SchedulerConfig(
lr_decay_style="linear",
lr_decay_iters=1000,
lr_decay_samples=1000000,
)
hyena_config_overrides = optimizer_config_override_provider.build_config_overrides(
context=OptimizerConfigOverrideProviderContext(
model=model,
optimizer_config=optimizer_config,
scheduler_config=scheduler_config,
)
)
param_groups = _get_param_groups(
model_chunks=[model],
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0,
lr=1.0,
min_lr=0.0,
decoupled_lr=None,
decoupled_min_lr=None,
default_skip_embedding_weight_decay=False,
config=optimizer_config,
config_overrides=None, # default config overrides
)
param_groups2 = _get_param_groups(
model_chunks=[model],
no_weight_decay_cond=hyena_no_weight_decay_cond,
scale_lr_cond=None,
lr_mult=1.0,
lr=1.0,
min_lr=0.0,
decoupled_lr=None,
decoupled_min_lr=None,
default_skip_embedding_weight_decay=False,
config=optimizer_config,
config_overrides=hyena_config_overrides,
)
assert len(param_groups2) == len(param_groups)
assert len(param_groups2) == 2
Expand Down
Loading
Loading