Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions nemo_deploy/llm/inference/inference_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
get_default_load_sharded_strategy,
)
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.inference.config import InferenceConfig
from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext
from megatron.core.inference.contexts.static_context import StaticInferenceContext
from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
Expand Down Expand Up @@ -229,18 +232,27 @@ def setup_megatron_model_and_tokenizer_for_inference(
torch_distributed_init(dist_config)
model_config, mlm_args = load_model_config(checkpoint_path)

# MLA models require cache_mla_latents=True for the dynamic inference backend.
# The checkpoint may have saved it as False (training default), but inference
# with the dynamic engine always needs it enabled.
if hasattr(model_config, "cache_mla_latents"):
model_config.cache_mla_latents = True

# Convert attention_backend from string to enum if needed
if hasattr(model_config, "attention_backend") and isinstance(model_config.attention_backend, str):
if model_config.attention_backend == "AttnBackend.fused":
model_config.attention_backend = AttnBackend.fused
elif model_config.attention_backend == "AttnBackend.flash":
model_config.attention_backend = AttnBackend.flash
elif model_config.attention_backend == "AttnBackend.unfused":
model_config.attention_backend = AttnBackend.unfused
elif model_config.attention_backend == "AttnBackend.local":
model_config.attention_backend = AttnBackend.local
elif model_config.attention_backend == "AttnBackend.auto":
if hasattr(model_config, "attention_backend"):
if model_config.attention_backend is None:
# Deserialization of the AttnBackend enum failed (e.g. Hydra _target_ dict
# not reconstructed); fall back to auto so the engine can pick the best backend.
model_config.attention_backend = AttnBackend.auto
elif isinstance(model_config.attention_backend, str):
_str_to_attn_backend = {
"AttnBackend.fused": AttnBackend.fused,
"AttnBackend.flash": AttnBackend.flash,
"AttnBackend.unfused": AttnBackend.unfused,
"AttnBackend.local": AttnBackend.local,
"AttnBackend.auto": AttnBackend.auto,
}
model_config.attention_backend = _str_to_attn_backend.get(model_config.attention_backend, AttnBackend.auto)

if tensor_model_parallel_size is not None:
model_config.tensor_model_parallel_size = tensor_model_parallel_size
Expand Down Expand Up @@ -524,6 +536,35 @@ def create_mcore_engine(
buffer_size_gb=buffer_size_gb,
)

# MCoreEngine (StaticInferenceEngine) initialises its DynamicInferenceContext with
# block_size_tokens=256. MLA models require block_size_tokens=64 (Flash MLA), so the
# init silently fails and the engine falls back to legacy static batching — which is
# incompatible with cache_mla_latents=True. Detect that fallback and redo the dynamic
# engine setup with the correct block size.
if getattr(model.config, "cache_mla_latents", False) and mcore_engine.legacy:
LOGGER.info(
"MCoreEngine fell back to legacy static engine for MLA model; "
"reinitialising DynamicInferenceEngine with block_size_tokens=64."
)
dynamic_context = DynamicInferenceContext(
model_config=model.config,
inference_config=InferenceConfig(
max_sequence_length=inference_max_seq_length,
buffer_size_gb=buffer_size_gb,
max_requests=max_batch_size,
num_cuda_graphs=1,
block_size_tokens=64, # Flash MLA requirement
unified_memory_level=0,
),
)
mcore_engine.controller.inference_wrapped_model.inference_context = dynamic_context
mcore_engine.controller.inference_wrapped_model.prep_model_for_inference()
mcore_engine.controller._init_dynamic_sampling_tensors()
mcore_engine.dynamic_engine = DynamicInferenceEngine(
controller=mcore_engine.controller, context=dynamic_context
)
mcore_engine.legacy = False

# Wrap the engine to ensure cleanup
wrapped_engine = MCoreEngineWithCleanup(mcore_engine, model_inference_wrapper, tokenizer)

Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,14 @@ override-dependencies = [
"flash-linear-attention>=0.3.0,<0.4.dev0",
"patchelf; sys_platform=='never'",
"nvidia-resiliency-ext>=0.3.0,<0.6.0",
"transformer-engine[pytorch,core_cu13]>=2.12.0a0,<2.15.0; sys_platform != 'darwin'",
"transformer-engine-cu13>=2.12.0a0,<2.15.0; sys_platform != 'darwin'",
"transformer-engine-cu12; sys_platform == 'never'",
# The custom-built TE in the container already includes the torch extension natively.
# Installing transformer-engine-torch from PyPI creates a dist-info that triggers TE's
# sanity check requiring the base package to also be a PyPI wheel, which fails for
# source/custom builds. Since the .so is already present, skip the PyPI package.
"transformer-engine-torch; sys_platform == 'never'",
"mamba-ssm>=2.3.0,<2.4.0",
"transformers>=5.0.0",
"transformers==5.2.0",
"protobuf~=6.33.5",
"opencv-python-headless; sys_platform == 'never'",
"cryptography>=43.0.0,<47",
Expand Down
Loading
Loading