Skip to content

Commit

Permalink
Use reshape instead of view and update PyTorch autocast API
Browse files Browse the repository at this point in the history
Signed-off-by: eljandoubi <[email protected]>
  • Loading branch information
eljandoubi committed Oct 18, 2024
1 parent d0b2c57 commit 5cb5e2e
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 246 deletions.
10 changes: 0 additions & 10 deletions qa/L0_pytorch_distributed_unittest/test.sh

This file was deleted.

3 changes: 2 additions & 1 deletion tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def run_dpa_with_cp(
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [
x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,9 @@ def _run_ref_mha_f16(dtype, config, backend):
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = (
torch.load("out_grad.pt").to(device="cuda").reshape(config.batch_size, config.max_seqlen_q, -1)
torch.load("out_grad.pt")
.to(device="cuda")
.reshape(config.batch_size, config.max_seqlen_q, -1)
)

_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
24 changes: 12 additions & 12 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def setup_method(self, *, seed: int = 0) -> None:
def test_grad_scaler(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
scaler = torch.amp.GradScaler("cuda", enabled=True)
scaler_ = torch.amp.GradScaler("cuda", enabled=True)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
Expand All @@ -333,7 +333,7 @@ def test_grad_scaler(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with torch.amp.autocast(device_type="cuda", enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -342,7 +342,7 @@ def test_grad_scaler(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with torch.amp.autocast(device_type="cuda", enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -374,8 +374,8 @@ def test_grad_scaler(self):
def test_grad_scaler_capturable(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
scaler = torch.amp.GradScaler("cuda", enabled=True)
scaler_ = torch.amp.GradScaler("cuda", enabled=True)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
Expand All @@ -384,7 +384,7 @@ def test_grad_scaler_capturable(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with torch.amp.autocast(device_type="cuda", enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -393,7 +393,7 @@ def test_grad_scaler_capturable(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with torch.amp.autocast(device_type="cuda", enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -432,8 +432,8 @@ def test_grad_scaler_capturable_master(self):
optimizer_ = te.optimizers.FusedAdam(
params_, lr=self.lr, capturable=True, master_weights=master_weights
)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
scaler = torch.amp.GradScaler("cuda", enabled=True)
scaler_ = torch.amp.GradScaler("cuda", enabled=True)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
Expand All @@ -442,7 +442,7 @@ def test_grad_scaler_capturable_master(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with torch.amp.autocast(device_type="cuda", enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -451,7 +451,7 @@ def test_grad_scaler_capturable_master(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with torch.amp.autocast(device_type="cuda", enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(
)

# change view to [b, np, sq, sk]
attention_scores = matmul_result.reshape(*output_size)
attention_scores = matmul_result.view(*output_size)

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
Expand All @@ -233,7 +233,7 @@ def forward(
value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)

# change view [b * np, sq, sk]
attention_probs = attention_probs.reshape(output_size[0] * output_size[1], output_size[2], -1)
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)

# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
Expand Down
Loading

0 comments on commit 5cb5e2e

Please sign in to comment.