@@ -120,26 +120,3 @@ def measure_memory(func):
120
120
# Our implementation should use less memory
121
121
assert triton_mem < pytorch_mem , f"Triton implementation with { n_chunks } chunks uses more memory than PyTorch"
122
122
123
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
124
- def test_reduction_modes ():
125
- batch_size , seq_len , vocab_size = 4 , 16 , 100
126
- logits = torch .randn (batch_size , seq_len , vocab_size , device = 'cuda' )
127
- targets = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = 'cuda' )
128
-
129
- # Test mean reduction
130
- triton_ce_mean = TritonCrossEntropyLoss (pad_token_id = - 100 , reduction = "mean" )
131
- torch_ce_mean = torch .nn .CrossEntropyLoss (ignore_index = - 100 , reduction = "mean" )
132
-
133
- triton_loss_mean = triton_ce_mean (logits , targets )
134
- torch_loss_mean = torch_ce_mean (logits .view (- 1 , vocab_size ), targets .view (- 1 ))
135
-
136
- torch .testing .assert_close (triton_loss_mean , torch_loss_mean , rtol = 1e-3 , atol = 1e-3 )
137
-
138
- # Test sum reduction
139
- triton_ce_sum = TritonCrossEntropyLoss (pad_token_id = - 100 , reduction = "sum" )
140
- torch_ce_sum = torch .nn .CrossEntropyLoss (ignore_index = - 100 , reduction = "sum" )
141
-
142
- triton_loss_sum = triton_ce_sum (logits , targets )
143
- torch_loss_sum = torch_ce_sum (logits .view (- 1 , vocab_size ), targets .view (- 1 ))
144
-
145
- torch .testing .assert_close (triton_loss_sum , torch_loss_sum , rtol = 1e-3 , atol = 1e-3 )
0 commit comments