Skip to content

Llama2-70b - remat - MLP Layer doesn't recompute mlpwo #1229

Closed
@sopiko99

Description

@sopiko99

Hi,
I've been trying to find out why there is no recomputation in backward pass for MLP block mlpwo. I have tested different remat policies, but mlpwo doesn't appear in the profiler traces. I've tried explicitly defining custom policy with key mlpwo: 'remat' but still no luck. Checked also xla dumps, but mlpwo remat wasn't present, but others - wi_0 or wi_1 - were there.

base_config: "base.yml"
model_name: "llama2-7b"
enable_checkpointing: False
attention: "cudnn_flash_te"

dcn_data_parallelism: -1
dcn_fsdp_parallelism: 1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
remat_policy: "full"
param_scan_axis: 1
use_iota_embed: True
scan_layers: True
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False
dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
per_device_batch_size: 6 
max_target_length: 8192
learning_rate: 3.e-5

XLA Flags used:
--xla_gpu_enable_cublaslt=False --xla_gpu_graph_level=0 --xla_gpu_autotune_level=5 --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=137438953472 --xla_gpu_enable_all_gather_combine_by_dim=FALSE

Platform: AMD GPU - MI300

  • rocm6.2.0
  • jax0.4.35
  • py3.10.14

Example screenshot of rematted_computation/mlp recomputing only wi_0 and wi_1:

Image

What would you suggest as the next step of investigating this issue?

Please, let me know if more information is needed.

Looking forward to your answer 👾

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions