From 0ef5803e960568561d7dd2ff662e7eb18dea1b2d Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Thu, 22 Aug 2024 21:34:30 -0700 Subject: [PATCH 1/5] Added offloading support FP8 attention Signed-off-by: Selvaraj Anandaraj --- transformer_engine/pytorch/attention.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6a46d6c3c1..619c0b658a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5456,16 +5456,27 @@ def forward( out_save = out_ret fp8_tensors = (None, None, None, None, None, None) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] + if ctx.fp8: + tensor_list = [] + for t in fp8_tensors: + tensor_list.append(t) + for t in aux_ctx_tensors: + tensor_list.append(t) + else: + tensor_list = [q, k, v, out_save] + for t in aux_ctx_tensors: + tensor_list.append(t) + qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward( *qkvo_tensors, From e59e4baaa365d37362ca7f9c3d2a879b510cef53 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Wed, 4 Sep 2024 11:08:16 -0700 Subject: [PATCH 2/5] Update transformer_engine/pytorch/attention.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Selvaraj Anandaraj --- transformer_engine/pytorch/attention.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 131723a4d9..9da9612c3f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5643,15 +5643,9 @@ def forward( if CPUOffloadEnabled: if ctx.fp8: - tensor_list = [] - for t in fp8_tensors: - tensor_list.append(t) - for t in aux_ctx_tensors: - tensor_list.append(t) + tensor_list = fp8_tensors.extend(aux_ctx_tensors) else: - tensor_list = [q, k, v, out_save] - for t in aux_ctx_tensors: - tensor_list.append(t) + tensor_list = [q, k, v, out_save].extend(aux_ctx_tensors) qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: From e54c03b3515ba0830d405b321f2b181b4d434b12 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Sep 2024 15:16:58 +0000 Subject: [PATCH 3/5] Fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9da9612c3f..4aa48a82c4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5643,9 +5643,11 @@ def forward( if CPUOffloadEnabled: if ctx.fp8: - tensor_list = fp8_tensors.extend(aux_ctx_tensors) + tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out_save].extend(aux_ctx_tensors) + tensor_list = [q, k, v, out_save] + + tensor_list.extend(aux_ctx_tensors) qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: From 69866e25560f08d939fdc37d581751a4dfe021dd Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Tue, 29 Oct 2024 14:21:29 -0700 Subject: [PATCH 4/5] Added example for CPU offloading Signed-off-by: Selvaraj Anandaraj --- .../pytorch/cpu_offloading/cpu_offload.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 examples/pytorch/cpu_offloading/cpu_offload.py diff --git a/examples/pytorch/cpu_offloading/cpu_offload.py b/examples/pytorch/cpu_offloading/cpu_offload.py new file mode 100644 index 0000000000..3e8e29accb --- /dev/null +++ b/examples/pytorch/cpu_offloading/cpu_offload.py @@ -0,0 +1,32 @@ +import torch +import transformer_engine as te + +from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context + +#Initialize a CPU offload context to enable activation offloading and set number of layers +# to be offloaded to 1 +context, sync_func = get_cpu_offload_context(True, 1, True, False) + + +#Define a 2 Linear layer model +layer = [] +for i in range(2): + layer.append(te.pytorch.Linear(1024,1024,bias=False,device="cuda")) + +#Create dummy inputs on GPU +input_state = torch.rand(1024,1024).cuda() + +#Wrap the forward prop under the context +with context: + hidden = layer[0](input_state) + +#Use synchronize function to sync across layers +hidden = sync_func(hidden) + +with context: + output = layer[1](hidden) + +output = sync_func(output) + +#Trigger backward +output.sum().backward() From dc89e7d7405266321b589d7651b90d87ff7290f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:22:33 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/cpu_offloading/cpu_offload.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/cpu_offloading/cpu_offload.py b/examples/pytorch/cpu_offloading/cpu_offload.py index 3e8e29accb..8a8815dd9c 100644 --- a/examples/pytorch/cpu_offloading/cpu_offload.py +++ b/examples/pytorch/cpu_offloading/cpu_offload.py @@ -3,24 +3,24 @@ from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context -#Initialize a CPU offload context to enable activation offloading and set number of layers +# Initialize a CPU offload context to enable activation offloading and set number of layers # to be offloaded to 1 context, sync_func = get_cpu_offload_context(True, 1, True, False) -#Define a 2 Linear layer model +# Define a 2 Linear layer model layer = [] for i in range(2): - layer.append(te.pytorch.Linear(1024,1024,bias=False,device="cuda")) + layer.append(te.pytorch.Linear(1024, 1024, bias=False, device="cuda")) -#Create dummy inputs on GPU -input_state = torch.rand(1024,1024).cuda() +# Create dummy inputs on GPU +input_state = torch.rand(1024, 1024).cuda() -#Wrap the forward prop under the context +# Wrap the forward prop under the context with context: hidden = layer[0](input_state) -#Use synchronize function to sync across layers +# Use synchronize function to sync across layers hidden = sync_func(hidden) with context: @@ -28,5 +28,5 @@ output = sync_func(output) -#Trigger backward +# Trigger backward output.sum().backward()