From 7eed38aa73f140ab5f214753105e0f37c40bdf00 Mon Sep 17 00:00:00 2001 From: Ashish Farmer Date: Wed, 4 Nov 2020 09:44:15 -0800 Subject: [PATCH] Fix LayerNorm op on ROCm (#36) * fix warp size in WARP_SHFL* in layernorm * enable fused_layer_norm tests on ROCm --- csrc/layer_norm_cuda_kernel.cu | 24 ++++++++++++------------ tests/L0/run_test.py | 1 - 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu index a6301b2ee..c935fa67f 100644 --- a/csrc/layer_norm_cuda_kernel.cu +++ b/csrc/layer_norm_cuda_kernel.cu @@ -88,9 +88,9 @@ void cuWelfordMuSigma2( // intra-warp reductions for (int l = 0; l <= 4; ++l) { int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); } // threadIdx.x == 0 has correct values for each warp @@ -126,8 +126,8 @@ void cuWelfordMuSigma2( sigma2 = ubuf[1]/U(n2); // don't care about final value of count, we know count == n2 } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2/U(n2), 0); + mu = WARP_SHFL(mu, 0, 32); + sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32); } } } @@ -183,9 +183,9 @@ void cuWelfordMuSigma2( // intra-warp reductions for (int l = 0; l <= 4; ++l) { int srcLaneB = (threadIdx.x+(1< 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32); } // inter-warp reductions if (blockDim.y > 1) { diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index 2678c1902..7299cf6ef 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -6,7 +6,6 @@ test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"] ROCM_BLACKLIST = [ - 'run_fused_layer_norm', 'run_pyprof_nvtx', 'run_pyprof_data', ]