Skip to content

Commit d1ef0e8

Browse files
authored
DistGEMM bug fixes (#2713)
* Blackwell DistGEMM bug fixes 1. If using preferred cluster, there needs to be a branch so that the universal GEMM wrapper finds the correct base params. 2. Workspace sizes can change depending on problem shape in Blackwell, and DistGEMM was previously using the per-device shape to evaluate workspace size instead of the per-gemm shape. 3. Flattened size used to initialize host tensors can overflow (in Hopper example as well) 4. Preferred and fallback cluster args need to be set explicitly, otherwise if someone modifies the example to use preferred cluster, it will just fail. * Fix example runtimes * Set default fallback cluster shapes to the static ones
1 parent 020c700 commit d1ef0e8

File tree

4 files changed

+84
-27
lines changed

4 files changed

+84
-27
lines changed

examples/65_distributed_gemm/65_distributed_gemm.cu

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ using namespace cute;
132132
using TP = _8;
133133
static constexpr int TP_ = TP{};
134134

135-
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
135+
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
136136
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
137137

138138
// Distributed GEMM tiling/sharding schedule
@@ -252,7 +252,7 @@ HostTensorB tensor_B_arr[TP_];
252252
HostTensorD tensor_C_arr[TP_];
253253
HostTensorD tensor_D_arr[TP_];
254254

255-
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
255+
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) &&
256256
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
257257

258258
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -345,8 +345,7 @@ struct Result {
345345

346346
};
347347

348-
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
349-
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
348+
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
350349

351350
/////////////////////////////////////////////////////////////////////////////////////////////////
352351
/// GEMM setup and evaluation
@@ -403,9 +402,9 @@ void initialize(const Options &options) {
403402
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
404403
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
405404

406-
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
407-
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
408-
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
405+
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
406+
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
407+
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
409408

410409
tensor_A.resize(a_coord);
411410
tensor_B.resize(b_coord);
@@ -650,7 +649,7 @@ int run(Options &options) {
650649
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
651650

652651
// Using the arguments, query for extra workspace required for matrix multiplication computation
653-
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
652+
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
654653
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
655654

656655
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
@@ -804,8 +803,7 @@ int run(Options &options) {
804803
return 0;
805804
}
806805

807-
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
808-
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
806+
#endif //(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
809807

810808
///////////////////////////////////////////////////////////////////////////////////////////////////
811809

@@ -859,7 +857,7 @@ int main(int argc, char const **args) {
859857
// Evaluate CUTLASS kernels
860858
//
861859

862-
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
860+
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
863861
run(options);
864862
#else
865863
std::cerr

examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ using namespace cute;
132132
using TP = _8;
133133
static constexpr int TP_ = TP{};
134134

135-
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
135+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
136136
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
137137

138138
// Distributed GEMM tiling/sharding schedule
@@ -254,7 +254,7 @@ HostTensorB tensor_B_arr[TP_];
254254
HostTensorD tensor_C_arr[TP_];
255255
HostTensorD tensor_D_arr[TP_];
256256

257-
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
257+
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) &&
258258
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
259259

260260
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -347,8 +347,7 @@ struct Result {
347347

348348
};
349349

350-
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
351-
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
350+
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
352351

353352
/////////////////////////////////////////////////////////////////////////////////////////////////
354353
/// GEMM setup and evaluation
@@ -405,9 +404,9 @@ void initialize(const Options &options) {
405404
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
406405
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
407406

408-
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
409-
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
410-
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
407+
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
408+
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
409+
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
411410

412411
tensor_A.resize(a_coord);
413412
tensor_B.resize(b_coord);
@@ -475,6 +474,9 @@ GemmArguments gemm_args_from_options(const Options &options) {
475474
tensor_ref_D.device_data(), stride_D
476475
}
477476
};
477+
// Preferred cluster can fail if these aren't set explicitly
478+
arguments.hw_info.cluster_shape = dim3(2,1,1);
479+
arguments.hw_info.cluster_shape_fallback = dim3(2,1,1);
478480

479481
return arguments;
480482
}
@@ -548,6 +550,9 @@ DistGemmArguments dist_gemm_args_from_options(
548550
{}, // hw_info
549551
{} // scheduler
550552
};
553+
// Preferred cluster can fail if these aren't set explicitly
554+
arguments.hw_info.cluster_shape = dim3(2,1,1);
555+
arguments.hw_info.cluster_shape_fallback = dim3(2,1,1);
551556

552557
return arguments;
553558
}
@@ -652,7 +657,7 @@ int run(Options &options) {
652657
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
653658

654659
// Using the arguments, query for extra workspace required for matrix multiplication computation
655-
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
660+
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
656661
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
657662

658663
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
@@ -806,8 +811,7 @@ int run(Options &options) {
806811
return 0;
807812
}
808813

809-
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
810-
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
814+
#endif // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
811815

812816
///////////////////////////////////////////////////////////////////////////////////////////////////
813817

@@ -861,7 +865,7 @@ int main(int argc, char const **args) {
861865
// Evaluate CUTLASS kernels
862866
//
863867

864-
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
868+
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
865869
run(options);
866870
#else
867871
std::cerr

include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,59 @@ class DistributedGemmUniversalAdapter {
253253
return DistSchedule::get_tensor_D(tensor_D, tensor_buffer, device_idx, iteration);
254254
}
255255

256+
static
257+
auto make_dummy_base_args(Arguments const* args, int device_idx, int iteration, void ** buffer_space) {
258+
259+
// Set up GEMM arguments for the current stage/iteration
260+
auto tensor_a_iter = get_tensor_A_for_iter(args, buffer_space, device_idx, iteration);
261+
auto tensor_b_iter = get_tensor_B_for_iter(args, buffer_space, device_idx, iteration);
262+
auto tensor_c_iter = get_tensor_C_for_iter(args, buffer_space, device_idx, iteration);
263+
auto tensor_d_iter = get_tensor_D_for_iter(args, buffer_space, device_idx, iteration);
264+
265+
Arguments base_args = args[device_idx];
266+
base_args.problem_shape = DistSchedule::get_local_gemm_shape(args[device_idx].problem_shape);
267+
base_args.mainloop = {
268+
reinterpret_cast<const ElementA*>(tensor_a_iter.data()),
269+
tensor_a_iter.stride(),
270+
reinterpret_cast<const ElementB*>(tensor_b_iter.data()),
271+
tensor_b_iter.stride()
272+
};
273+
base_args.epilogue = {
274+
base_args.epilogue.thread,
275+
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
276+
tensor_c_iter.stride(),
277+
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
278+
tensor_d_iter.stride()
279+
};
280+
281+
if constexpr (DistSchedule::RemoteC) {
282+
if (iteration > 0) {
283+
base_args.epilogue.thread.beta = 1.0;
284+
}
285+
else if (iteration == 0){
286+
base_args.epilogue.thread.beta = 0.0;
287+
}
288+
}
289+
290+
return base_args;
291+
}
292+
256293
static size_t
257-
get_workspace_size(Arguments const& args) {
294+
get_workspace_size(Arguments const* args, int device_idx) {
258295
size_t workspace_bytes = 0;
259296

260-
workspace_bytes = get_buffer_space_size(args);
297+
workspace_bytes = get_buffer_space_size(args[device_idx]);
298+
299+
void* dummy_buffer_space[TP_];
261300

262301
for (int iteration = 0; iteration < TP_; ++iteration) {
302+
// Workspace sizes can vary if arguments change, therefore we must
303+
// construct args for each iteration exactly as it will be run.
304+
auto args_base = make_dummy_base_args(args, device_idx, iteration, dummy_buffer_space);
305+
263306
// NOTE: assumes underlying kernels align up to alignment requirements on their own,
264307
// and that the alignment requirements of the individual kernels match.
265-
workspace_bytes += GemmKernel::get_workspace_size(args);
308+
workspace_bytes += GemmKernel::get_workspace_size(args_base);
266309
}
267310

268311
return workspace_bytes;

include/cutlass/gemm/device/gemm_universal_adapter.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ constexpr int stages_member(DispatchPolicy) {
110110
}
111111
}
112112

113+
template <class GemmKernel, class = void>
114+
struct IsDistGemmKernel : cute::false_type { };
115+
116+
template <typename GemmKernel>
117+
struct IsDistGemmKernel<GemmKernel, cute::void_t<typename GemmKernel::TP>>
118+
: cute::true_type { };
119+
113120
} // namespace detail
114121

115122
template <class GemmKernel_>
@@ -396,8 +403,13 @@ class GemmUniversalAdapter<
396403
|| GemmKernel::ArchTag::kMinComputeCapability == 103
397404
) {
398405
if constexpr (!cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape>) {
399-
fallback_cluster = params.hw_info.cluster_shape_fallback;
400-
cluster = params.hw_info.cluster_shape;
406+
if constexpr (detail::IsDistGemmKernel<GemmKernel>::value) {
407+
fallback_cluster = params.base.hw_info.cluster_shape_fallback;
408+
cluster = params.base.hw_info.cluster_shape;
409+
} else {
410+
fallback_cluster = params.hw_info.cluster_shape_fallback;
411+
cluster = params.hw_info.cluster_shape;
412+
}
401413
}
402414
}
403415

0 commit comments

Comments
 (0)