Skip to content

Commit 76e60f3

Browse files
dnikolaev-amdjeffdaily
authored andcommitted
[ROCm] fix and unskip tests on rocm (pytorch#169827)
This PR fixes: - `torch.nonzero` for large tensors on ROCm. It was malfunctioning due to a known hip compiler problem with `::min` for int64_t arguments. Fixed by expliced typing to `std::min<int64_t>` - using `torch.ops.aten.miopen_batch_norm` instead of `torch.ops.aten.cudnn_batch_norm` on ROCm Fixed tests: - Fixes pytorch#168878. - Fixes pytorch#168879. - Fixes pytorch#168553. - Fixes pytorch#168554. Pull Request resolved: pytorch#169827 Approved by: https://github.com/jeffdaily, https://github.com/mlazos, https://github.com/cyyever Co-authored-by: Jeff Daily <[email protected]>
1 parent e09550e commit 76e60f3

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

aten/src/ATen/native/cuda/Nonzero.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
183183
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
184184
auto num_nonzeros = allocator.allocate(sizeof(int) * num_chunks);
185185
for (int64_t idx = 0; idx < num_chunks; idx++) {
186-
int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
186+
int64_t remaining = std::min<int64_t>(chunk_size, self.numel() - idx * chunk_size);
187187
ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp<scalar_t>, const scalar_t*) itr(
188188
self_.const_data_ptr<scalar_t>() + idx * chunk_size,
189189
NonZeroOp<scalar_t>());
@@ -241,7 +241,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
241241
int64_t curr_nonzeros = 0;
242242
if (self.dim() > 0) {
243243
for (int64_t idx = 0; idx < num_chunks; idx++) {
244-
int remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
244+
int remaining = std::min<int64_t>(chunk_size, self.numel() - idx * chunk_size);
245245

246246
ATEN_CUB_COUNTING_ITERATOR(int64_t) counting_itr(idx * chunk_size);
247247
ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp<scalar_t>, const scalar_t*)
@@ -353,7 +353,7 @@ void nonzero_static_cuda_out_impl(
353353
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
354354
in_data_ptr, out_data_ptr, (int64_t*)agg_cum.get(), self.numel(), size, iters_per_cta);
355355
C10_CUDA_KERNEL_LAUNCH_CHECK();
356-
int64_t out_grid = std::min(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
356+
int64_t out_grid = std::min<int64_t>(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
357357
write_fill_value<<<out_grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(out_data_ptr, (int64_t *)agg_cum.get() + grid_size - 1, fill_value, size);
358358
if (self.dim() > 1) {
359359
TensorDims<int64_t> dims;

test/functorch/test_aotdispatch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
outs_and_grads,
8888
parametrize,
8989
run_tests,
90-
skipIfRocm,
9190
TEST_MKL,
9291
TestCase,
9392
xfail_inherited_tests,
@@ -3900,7 +3899,6 @@ def f(self_s_emb, add_3):
39003899

39013900
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
39023901
@unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable")
3903-
@skipIfRocm # https://github.com/pytorch/pytorch/issues/96560
39043902
def test_batch_norm_amp(self):
39053903
device = "cuda"
39063904
input_dtype = torch.float16
@@ -3914,7 +3912,12 @@ def test_batch_norm_amp(self):
39143912
)
39153913

39163914
def bn(x):
3917-
return torch.ops.aten.cudnn_batch_norm(
3915+
fn = (
3916+
torch.ops.aten.cudnn_batch_norm
3917+
if torch.version.hip is None
3918+
else torch.ops.aten.miopen_batch_norm
3919+
)
3920+
return fn(
39183921
x,
39193922
weight,
39203923
bias,

test/test_unary_ufuncs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
numpy_to_torch_dtype_dict,
4545
run_tests,
4646
skipIfNoSciPy,
47-
skipIfRocm,
4847
slowTest,
4948
suppress_warnings,
5049
TEST_SCIPY,
@@ -1613,7 +1612,6 @@ def assert_tuple_empty(tup, dim):
16131612
@onlyCUDA
16141613
@dtypes(torch.int8)
16151614
@largeTensorTest("8GB")
1616-
@skipIfRocm(msg="ROCM tries to allocate 60GB")
16171615
def test_nonzero_large(self, device, dtype):
16181616
indices = (
16191617
torch.tensor((0, 2, 3, 4, 6, 100, 103, 2**30, 2**31 - 3, 2**31 - 2)),

0 commit comments

Comments
 (0)