Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable

import torch.nn.functional as F
from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS, HyenaNV1bConfig


@dataclass
class HyenaNV1bConfig2(HyenaNV1bConfig):
"""A parallel friendly version of the HyenaNV1bConfig."""

hidden_size: int = 2048 # 1920
num_groups_hyena: int = 2048 # 1920
num_attention_heads: int = 16 # 15
ffn_hidden_size: int = 5120 # 5120
# Spike-no-more-embedding init by default.
share_embeddings_and_output_weights: bool = False
embedding_init_method_std: float = 1.0
# activation_func_clamp_value: Optional[float] = 7.0
# glu_linear_offset: float = 1.0


# TODO move this to a more permanent location.
HYENA_MODEL_OPTIONS["striped_hyena_1b_nv_parallel"] = HyenaNV1bConfig2
56 changes: 17 additions & 39 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# TODO add back support for slurm resilience.
# import nvidia_resiliency_ext.ptl_resiliency as res_module
import torch
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.enums import Fp8Recipe
from megatron.core.optimizer import OptimizerConfig
Expand Down Expand Up @@ -52,8 +52,8 @@
from bionemo.evo2.models.llama import LLAMA_MODEL_OPTIONS
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings
from bionemo.evo2.models.peft import Evo2LoRA
from bionemo.evo2.run.utils import infer_model_type, patch_eden_tokenizer
from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime
from bionemo.evo2.run.utils import infer_model_type, lookup_activation_func, patch_eden_tokenizer
from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime, _FirstBatchCudaSync
from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings
from bionemo.evo2.utils.logging.callbacks import TEVCallback
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size
Expand Down Expand Up @@ -317,6 +317,14 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
default=False,
help="Add bias to the output layer to enable learning a simple prior.",
)
parser.add_argument(
"--activation-func",
type=str,
default=None,
help="Activation function to use for the FFN layers. See options in "
"https://docs.pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions. The default is "
"inferred from the model config, and is currently 'gelu'. Another good options is 'silu'.",
)
parser.add_argument(
"--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory."
)
Expand Down Expand Up @@ -791,6 +799,9 @@ def train(args: argparse.Namespace) -> nl.Trainer:
}
if args.add_bias_output:
config_modifiers_init["add_bias_output"] = args.add_bias_output
if args.activation_func:
# Override the activation function for the FFN layers.
config_modifiers_init["activation_func"] = lookup_activation_func(args.activation_func)
if args.spike_no_more_embedding_init:
config_modifiers_init["embedding_init_method_std"] = 1.0
# When using spike_no_more_embedding_init, we don't want to share embeddings and outputs.
Expand Down Expand Up @@ -853,27 +864,6 @@ def train(args: argparse.Namespace) -> nl.Trainer:
TEVCallback(),
]

# First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition
# See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details.
class _FirstBatchCudaSync(Callback):
def __init__(self):
self._done = False

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if not self._done and torch.cuda.is_available():
torch.cuda.synchronize()

def on_after_backward(self, trainer, pl_module):
if not self._done and torch.cuda.is_available():
torch.cuda.synchronize()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self._done and torch.cuda.is_available():
torch.cuda.synchronize()
# Unset blocking for subsequent batches
os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
self._done = True

callbacks.append(_FirstBatchCudaSync())

if args.garbage_collect_at_inference:
Expand Down Expand Up @@ -969,7 +959,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
f"-EWD{args.no_weight_decay_embeddings}-SNI{args.spike_no_more_embedding_init}"
f"-OGR{args.overlap_grad_reduce}-OPG{args.overlap_param_gather}"
f"-TVL{args.use_targeted_variance_loss}"
f"-NODES{args.num_nodes}-FP8{args.fp8}"
f"-NODES{args.num_nodes}-FP8{args.fp8_recipe}{args.fp8}"
f"-AF{args.activation_func or 'def'}"
)
if model_type == "mamba":
# Include this setting for mamba models.
Expand Down Expand Up @@ -1103,15 +1094,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
enable_checkpointing=args.create_checkpoint_callback,
)

# Logger setup
nemo_logger.setup(
trainer,
resume_if_exists=True,
)

if auto_resume is not None:
auto_resume.setup(trainer, model)

# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer="adam",
Expand Down Expand Up @@ -1139,12 +1121,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
opt = MegatronOptimizerModule(
opt_config, sched, no_weight_decay_cond=getattr(model_config, "hyena_no_weight_decay_cond_fn", None)
)
opt.connect(model)

# Remove earlier warmup and hook logic; first-batch blocking is sufficient.
llm.train(model, data_module, trainer, log=nemo_logger, resume=auto_resume, optim=opt, tokenizer="data")

# Start training
trainer.fit(model, data_module)
return trainer


Expand Down
16 changes: 15 additions & 1 deletion sub-packages/bionemo-evo2/src/bionemo/evo2/run/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,28 @@
# limitations under the License.
"""Utility functions for Evo2 run functions."""

from typing import Literal
from typing import Callable, Literal

import torch
from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS

from bionemo.evo2.models.llama import LLAMA_MODEL_OPTIONS
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS


def lookup_activation_func(activation_func_name: str) -> Callable:
"""Lookup an activation function by name."""
activation_func = getattr(torch.nn.functional, activation_func_name.lower(), None)
if activation_func is None:
raise ValueError(
f"Invalid activation function: {activation_func_name}. See options in "
"https://docs.pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions, "
"and make sure they are available in the current environment. "
"Recommended options are 'silu', 'gelu' or 'relu'."
)
return activation_func


def patch_eden_tokenizer(tokenizer):
"""Patch the Eden tokenizer to work with the Evo2 tokenizer."""
bos_id, eos_id, sep_id, pad_id = 1, 2, 3, 0
Expand Down
24 changes: 24 additions & 0 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,35 @@
# limitations under the License.

import gc
import os

import torch
from lightning.pytorch import Callback


class _FirstBatchCudaSync(Callback):
# TEMPORARY CALLBACK. Remove once bug is fixed.
# First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition
# See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details.
def __init__(self):
self._done = False

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if not self._done and torch.cuda.is_available():
torch.cuda.synchronize()

def on_after_backward(self, trainer, pl_module):
if not self._done and torch.cuda.is_available():
torch.cuda.synchronize()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self._done and torch.cuda.is_available():
torch.cuda.synchronize()
# Unset blocking for subsequent batches
os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
self._done = True


class GarbageCollectAtInferenceTime(Callback):
"""Callback to clean up CUDA memory before validation to prevent initialization errors."""

Expand Down
Loading