Skip to content

Commit acc7b79

Browse files
committed
Blackwell DistGEMM bug fixes
1. If using preferred cluster, there needs to be a branch so that the universal GEMM wrapper finds the correct base params. 2. Workspace sizes can change depending on problem shape in Blackwell, and DistGEMM was previously using the per-device shape to evaluate workspace size instead of the per-gemm shape. 3. Flattened size used to initialize host tensors can overflow (in Hopper example as well) 4. Preferred and fallback cluster args need to be set explicitly, otherwise if someone modifies the example to use preferred cluster, it will just fail.
1 parent b2ca083 commit acc7b79

File tree

4 files changed

+74
-13
lines changed

4 files changed

+74
-13
lines changed

examples/65_distributed_gemm/65_distributed_gemm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,9 @@ void initialize(const Options &options) {
403403
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
404404
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
405405

406-
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
407-
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
408-
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
406+
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
407+
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
408+
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
409409

410410
tensor_A.resize(a_coord);
411411
tensor_B.resize(b_coord);
@@ -650,7 +650,7 @@ int run(Options &options) {
650650
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
651651

652652
// Using the arguments, query for extra workspace required for matrix multiplication computation
653-
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
653+
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
654654
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
655655

656656
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);

examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,9 @@ void initialize(const Options &options) {
405405
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
406406
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
407407

408-
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
409-
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
410-
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
408+
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
409+
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
410+
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
411411

412412
tensor_A.resize(a_coord);
413413
tensor_B.resize(b_coord);
@@ -475,6 +475,9 @@ GemmArguments gemm_args_from_options(const Options &options) {
475475
tensor_ref_D.device_data(), stride_D
476476
}
477477
};
478+
// Preferred cluster can fail if these aren't set explicitly
479+
arguments.hw_info.cluster_shape = dim3(1,1,1);
480+
arguments.hw_info.cluster_shape_fallback = dim3(1,1,1);
478481

479482
return arguments;
480483
}
@@ -548,6 +551,9 @@ DistGemmArguments dist_gemm_args_from_options(
548551
{}, // hw_info
549552
{} // scheduler
550553
};
554+
// Preferred cluster can fail if these aren't set explicitly
555+
arguments.hw_info.cluster_shape = dim3(1,1,1);
556+
arguments.hw_info.cluster_shape_fallback = dim3(1,1,1);
551557

552558
return arguments;
553559
}
@@ -652,7 +658,7 @@ int run(Options &options) {
652658
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
653659

654660
// Using the arguments, query for extra workspace required for matrix multiplication computation
655-
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
661+
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
656662
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
657663

658664
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);

include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,59 @@ class DistributedGemmUniversalAdapter {
253253
return DistSchedule::get_tensor_D(tensor_D, tensor_buffer, device_idx, iteration);
254254
}
255255

256+
static
257+
auto make_dummy_base_args(Arguments const* args, int device_idx, int iteration, void ** buffer_space) {
258+
259+
// Set up GEMM arguments for the current stage/iteration
260+
auto tensor_a_iter = get_tensor_A_for_iter(args, buffer_space, device_idx, iteration);
261+
auto tensor_b_iter = get_tensor_B_for_iter(args, buffer_space, device_idx, iteration);
262+
auto tensor_c_iter = get_tensor_C_for_iter(args, buffer_space, device_idx, iteration);
263+
auto tensor_d_iter = get_tensor_D_for_iter(args, buffer_space, device_idx, iteration);
264+
265+
Arguments base_args = args[device_idx];
266+
base_args.problem_shape = DistSchedule::get_local_gemm_shape(args[device_idx].problem_shape);
267+
base_args.mainloop = {
268+
reinterpret_cast<const ElementA*>(tensor_a_iter.data()),
269+
tensor_a_iter.stride(),
270+
reinterpret_cast<const ElementB*>(tensor_b_iter.data()),
271+
tensor_b_iter.stride()
272+
};
273+
base_args.epilogue = {
274+
base_args.epilogue.thread,
275+
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
276+
tensor_c_iter.stride(),
277+
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
278+
tensor_d_iter.stride()
279+
};
280+
281+
if constexpr (DistSchedule::RemoteC) {
282+
if (iteration > 0) {
283+
base_args.epilogue.thread.beta = 1.0;
284+
}
285+
else if (iteration == 0){
286+
base_args.epilogue.thread.beta = 0.0;
287+
}
288+
}
289+
290+
return base_args;
291+
}
292+
256293
static size_t
257-
get_workspace_size(Arguments const& args) {
294+
get_workspace_size(Arguments const* args, int device_idx) {
258295
size_t workspace_bytes = 0;
259296

260-
workspace_bytes = get_buffer_space_size(args);
297+
workspace_bytes = get_buffer_space_size(args[device_idx]);
298+
299+
void* dummy_buffer_space[TP_];
261300

262301
for (int iteration = 0; iteration < TP_; ++iteration) {
302+
// Workspace sizes can vary if arguments change, therefore we must
303+
// construct args for each iteration exactly as it will be run.
304+
auto args_base = make_dummy_base_args(args, device_idx, iteration, dummy_buffer_space);
305+
263306
// NOTE: assumes underlying kernels align up to alignment requirements on their own,
264307
// and that the alignment requirements of the individual kernels match.
265-
workspace_bytes += GemmKernel::get_workspace_size(args);
308+
workspace_bytes += GemmKernel::get_workspace_size(args_base);
266309
}
267310

268311
return workspace_bytes;

include/cutlass/gemm/device/gemm_universal_adapter.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ constexpr int stages_member(DispatchPolicy) {
110110
}
111111
}
112112

113+
template <class GemmKernel, class = void>
114+
struct IsDistGemmKernel : cute::false_type { };
115+
116+
template <typename GemmKernel>
117+
struct IsDistGemmKernel<GemmKernel, cute::void_t<typename GemmKernel::TP>>
118+
: cute::true_type { };
119+
113120
} // namespace detail
114121

115122
template <class GemmKernel_>
@@ -396,8 +403,13 @@ class GemmUniversalAdapter<
396403
|| GemmKernel::ArchTag::kMinComputeCapability == 103
397404
) {
398405
if constexpr (!cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape>) {
399-
fallback_cluster = params.hw_info.cluster_shape_fallback;
400-
cluster = params.hw_info.cluster_shape;
406+
if constexpr (detail::IsDistGemmKernel<GemmKernel>::value) {
407+
fallback_cluster = params.base.hw_info.cluster_shape_fallback;
408+
cluster = params.base.hw_info.cluster_shape;
409+
} else {
410+
fallback_cluster = params.hw_info.cluster_shape_fallback;
411+
cluster = params.hw_info.cluster_shape;
412+
}
401413
}
402414
}
403415

0 commit comments

Comments
 (0)