Skip to content

Commit

Permalink
Merge branch 'main' into xren/cp_mask_type
Browse files Browse the repository at this point in the history
  • Loading branch information
xrennvidia authored Dec 5, 2024
2 parents 26c97b9 + 8c00424 commit 33e5be0
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 29 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/deploy_nightly_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Download artifact
uses: actions/download-artifact@v4.1.7
uses: actions/download-artifact@v4
with:
name: "te_docs"
path: "html"
- name: Prepare for pages
uses: actions/[email protected]
with:
name: github-pages
path: "html"
deploy:
needs: prepare
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
cd docs
make html
- name: 'Upload docs'
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: te_docs
path: docs/_build/html
Expand Down
11 changes: 9 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ def ref_func(query, kv, mask):
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
Expand Down Expand Up @@ -423,6 +424,12 @@ def impl_test_contex_parallel_attn(
qkv_format = get_qkv_format(qkv_layout)

batch, seqlen, num_head, hidden = data_shape

# Scale the sequence length by 2*CP so its never too small as we scale up test.
# 2*CP is used since we split into two CP groups for load balancing.
seqlen = seqlen * cp_size * 2
data_shape = batch, seqlen, num_head, hidden

num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

Expand Down
91 changes: 67 additions & 24 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,20 +588,50 @@ def reset(key):

def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None

# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363

def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst

# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration

if fp8_checkpoint:

# Copy tensors to CPU and store
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history

# Store other pickelable values.
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv)

# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
Expand All @@ -610,22 +640,23 @@ def get_extra_state(self) -> torch.Tensor:
extra[k] = v
state["extra_fp8_variables"] = extra

if is_in_onnx_export_mode():
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)

# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized

def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
if state is None:
return

# Load state
if isinstance(state, torch.Tensor):
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
# Deprecated format with io.BytesIO
state.seek(0)
state = torch.load(state, map_location="cuda")
else:
Expand All @@ -634,20 +665,32 @@ def set_extra_state(self, state: torch.Tensor) -> None:
if state is None:
return

# Load extra items.
# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

# Initialize before loading.
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])

def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst.copy_(src, non_blocking=True)

# Load tensors
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv)
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv)
torch.cuda.synchronize()

def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def get_extra_state(self) -> torch.Tensor:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with
# We have experienced problems (e.g. in ONNX export) with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
Expand Down

0 comments on commit 33e5be0

Please sign in to comment.