From 55138e88891215de560e498b3b2cdd40076900a7 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Tue, 3 Mar 2026 10:56:54 -0400 Subject: [PATCH] Use grid-stride loop for `fill!` --- src/host/construction.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/host/construction.jl b/src/host/construction.jl index 454ff3d9..a69be3fe 100644 --- a/src/host/construction.jl +++ b/src/host/construction.jl @@ -14,12 +14,22 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T @kernel function fill_kernel!(a, val) idx = @index(Global, Linear) - @inbounds a[idx] = val + stride = prod(@ndrange()) + while idx <= length(a) + @inbounds a[idx] = val + idx += stride + end end # ndims check for 0D support kernel = fill_kernel!(get_backend(A)) - kernel(A, x; ndrange = length(A)) + + # Calculate ndrange to ensure that a total grid size >typemax(UInt32) is never + # chosen. Grid stride to accomodate grid size limitations on AMD and Metal backends + len = length(A) + ndrange = cld(len, cld(len, typemax(UInt32) - 1024)) + + kernel(A, x; ndrange) A end