-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The provided qkv memory layout is not supported! When using RoPE #681
Comments
I solved the problem by hard-code... |
Hi @1049451037 , could you provide more details about how to run the job please? Currently I can start the job but am stuck at
Could you send me your complete |
I don't see any problem in your log. You run the server. Just need to start a client at: https://github.com/NVIDIA/Megatron-LM/blob/main/tools/text_generation_cli.py
|
Same issue. Any solutions? |
I tested with TE main (8255f87) and Megatron-LM main (8957468). I'm not seeing the issue above. Let me know if I'm not using the same run script as you have. Thanks.
|
You don't have the problem if you just run the example. Because the example inference does not use MCORE model. It just use the legacy model as you can see in model_provider. |
@cyanguwa You may replace the model provider in text generation server with this to reproduce the error: from megatron.core.models.gpt import GPTModel
import megatron.model
from megatron.training import get_model
from megatron.arguments import core_transformer_config_from_args
from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
import torch
from megatron.core.transformer.spec_utils import import_module
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building GPT model ...')
config = core_transformer_config_from_args(get_args())
if args.use_mcore_models:
print("building megatron core model!!!!!!!!!!!!!!")
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=False,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
)
else:
print("building megatron legacy model!!!!!!!!!!!!!!")
assert False, "Never do this!"
assert(args.context_parallel_size == 1), "Context parallelism is only supported with Megatron Core!"
model = megatron.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=False,
pre_process=pre_process,
post_process=post_process
)
return model |
We're aware of this bug and will push a fix to MCore. # In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride()) |
No. This won't fix the bug. It makes inference normal, but makes training fail. (The loss of training cannot converge) |
My solution for now is just adding the |
Maybe it is the |
|
I believe the MCore issue is fixed now, is that correct @yaox12? Can we close this issue? |
Yes, the issue is fixed in the latest main branch of Megatron-LM. |
I think the problem has ever been solved before. But now it occurs again.
How to solve it? I have tested both stable branch and main branch. None of them work.
#544
#455
I just run the official text generation example of Megatron-LM by adding
--position-embedding-type rope
and--no-position-embedding
args:https://github.com/NVIDIA/Megatron-LM/blob/main/examples/run_text_generation_server_345M.sh
And got the error
The provided qkv memory layout is not supported!
Moreover, I use the mcore version model instead of legacy model, so you should change it in text_generation_server.py to reproduce the error.
The text was updated successfully, but these errors were encountered: