Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama2-70b - remat - MLP Layer doesn't recompute mlpwo #1229

Open
sopiko99 opened this issue Feb 3, 2025 · 1 comment
Open

Llama2-70b - remat - MLP Layer doesn't recompute mlpwo #1229

sopiko99 opened this issue Feb 3, 2025 · 1 comment
Assignees

Comments

@sopiko99
Copy link

sopiko99 commented Feb 3, 2025

Hi,
I've been trying to find out why there is no recomputation in backward pass for MLP block mlpwo. I have tested different remat policies, but mlpwo doesn't appear in the profiler traces. I've tried explicitly defining custom policy with key mlpwo: 'remat' but still no luck. Checked also xla dumps, but mlpwo remat wasn't present, but others - wi_0 or wi_1 - were there.

base_config: "base.yml"
model_name: "llama2-7b"
enable_checkpointing: False
attention: "cudnn_flash_te"

dcn_data_parallelism: -1
dcn_fsdp_parallelism: 1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
remat_policy: "full"
param_scan_axis: 1
use_iota_embed: True
scan_layers: True
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False
dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
per_device_batch_size: 6 
max_target_length: 8192
learning_rate: 3.e-5

XLA Flags used:
--xla_gpu_enable_cublaslt=False --xla_gpu_graph_level=0 --xla_gpu_autotune_level=5 --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=137438953472 --xla_gpu_enable_all_gather_combine_by_dim=FALSE

Platform: AMD GPU - MI300

  • rocm6.2.0
  • jax0.4.35
  • py3.10.14

Example screenshot of rematted_computation/mlp recomputing only wi_0 and wi_1:

Image

What would you suggest as the next step of investigating this issue?

Please, let me know if more information is needed.

Looking forward to your answer 👾

@gobbleturk
Copy link
Collaborator

My guess is that wo is derived from other activations (e.g. wo + attention_out = layer_output), so given attention_out and layer_output then you can derive wo and it doesn't need to be recomputed or saved directly.

You can confirm this theory from looking at what memory is used in the "memory view" of the trace. You should see only one activation tensor that is saved fully - which is the decoder layer inputs (and outputs) of size [layers, batch, sequence, embed]. The wo and attention_out also have this shape, but attention_out looks like its being recomputed and probably wo is being derived as decoder_layer_output - attention_out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants