Skip to content

Commit 21ccb3d

Browse files
committed
update test
1 parent 73d77dc commit 21ccb3d

File tree

1 file changed

+0
-23
lines changed

1 file changed

+0
-23
lines changed

tests/test_cross_entropy.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -120,26 +120,3 @@ def measure_memory(func):
120120
# Our implementation should use less memory
121121
assert triton_mem < pytorch_mem, f"Triton implementation with {n_chunks} chunks uses more memory than PyTorch"
122122

123-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
124-
def test_reduction_modes():
125-
batch_size, seq_len, vocab_size = 4, 16, 100
126-
logits = torch.randn(batch_size, seq_len, vocab_size, device='cuda')
127-
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
128-
129-
# Test mean reduction
130-
triton_ce_mean = TritonCrossEntropyLoss(pad_token_id=-100, reduction="mean")
131-
torch_ce_mean = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
132-
133-
triton_loss_mean = triton_ce_mean(logits, targets)
134-
torch_loss_mean = torch_ce_mean(logits.view(-1, vocab_size), targets.view(-1))
135-
136-
torch.testing.assert_close(triton_loss_mean, torch_loss_mean, rtol=1e-3, atol=1e-3)
137-
138-
# Test sum reduction
139-
triton_ce_sum = TritonCrossEntropyLoss(pad_token_id=-100, reduction="sum")
140-
torch_ce_sum = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
141-
142-
triton_loss_sum = triton_ce_sum(logits, targets)
143-
torch_loss_sum = torch_ce_sum(logits.view(-1, vocab_size), targets.view(-1))
144-
145-
torch.testing.assert_close(triton_loss_sum, torch_loss_sum, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)