From e2db65623cf07ab289702b336fba38a8618c422a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Apr 2025 21:22:12 -0400 Subject: [PATCH] feat: add debug option for always sharding --- src/Types.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/Types.jl b/src/Types.jl index a55bedcbd8..b46c9dd75c 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -32,6 +32,8 @@ function get_padding(x) return ntuple(Returns(0), ndims(x)) end +const DEBUG_ENSURE_ALWAYS_SHARDED = Ref(false) + # Traced Types ## MissingTracedValue -- defined in ReactantCore @@ -95,6 +97,9 @@ mutable struct ConcretePJRTNumber{T,D,S<:Sharding.ShardInfo} <: AbstractConcrete function ConcretePJRTNumber{T,D,S}( data::NTuple{D,XLA.PJRT.AsyncBuffer}, sharding::S ) where {T,D,S} + if DEBUG_ENSURE_ALWAYS_SHARDED[] + @assert Sharding.is_sharded(sharding) + end return new{T,D,S}(data, sharding, false) end end @@ -148,6 +153,9 @@ mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcret function ConcretePJRTArray{T,N,D,S}( data::NTuple{D,XLA.PJRT.AsyncBuffer}, shape::NTuple{N,Int}, sharding::S ) where {T,N,D,S} + if DEBUG_ENSURE_ALWAYS_SHARDED[] + @assert Sharding.is_sharded(sharding) + end return new{T,N,D,S}(data, shape, sharding, false) end end @@ -251,6 +259,9 @@ mutable struct ConcreteIFRTNumber{T,S<:Sharding.ShardInfo} <: AbstractConcreteNu donated::Bool function ConcreteIFRTNumber{T,S}(data::XLA.IFRT.AsyncArray, sharding::S) where {T,S} + if DEBUG_ENSURE_ALWAYS_SHARDED[] + @assert Sharding.is_sharded(sharding) + end return new{T,S}(data, sharding, false) end end @@ -301,6 +312,9 @@ mutable struct ConcreteIFRTArray{ sharding::S, padding::Union{Nothing,NTuple{N,Int}}=nothing, ) where {T,N,S} + if DEBUG_ENSURE_ALWAYS_SHARDED[] + @assert Sharding.is_sharded(sharding) + end return new{T,N,S,typeof(padding)}(data, shape, sharding, false, padding) end end