You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I've been trying to find out why there is no recomputation in backward pass for MLP block mlpwo. I have tested different remat policies, but mlpwo doesn't appear in the profiler traces. I've tried explicitly defining custom policy with key mlpwo: 'remat' but still no luck. Checked also xla dumps, but mlpwo remat wasn't present, but others - wi_0 or wi_1 - were there.
My guess is that wo is derived from other activations (e.g. wo + attention_out = layer_output), so given attention_out and layer_output then you can derive wo and it doesn't need to be recomputed or saved directly.
You can confirm this theory from looking at what memory is used in the "memory view" of the trace. You should see only one activation tensor that is saved fully - which is the decoder layer inputs (and outputs) of size [layers, batch, sequence, embed]. The wo and attention_out also have this shape, but attention_out looks like its being recomputed and probably wo is being derived as decoder_layer_output - attention_out
Hi,
I've been trying to find out why there is no recomputation in backward pass for MLP block
mlpwo
. I have tested different remat policies, butmlpwo
doesn't appear in the profiler traces. I've tried explicitly defining custom policy with keymlpwo: 'remat'
but still no luck. Checked also xla dumps, but mlpwo remat wasn't present, but others - wi_0 or wi_1 - were there.XLA Flags used:
--xla_gpu_enable_cublaslt=False --xla_gpu_graph_level=0 --xla_gpu_autotune_level=5 --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=137438953472 --xla_gpu_enable_all_gather_combine_by_dim=FALSE
Platform: AMD GPU - MI300
Example screenshot of rematted_computation/mlp recomputing only wi_0 and wi_1:
What would you suggest as the next step of investigating this issue?
Please, let me know if more information is needed.
Looking forward to your answer 👾
The text was updated successfully, but these errors were encountered: