Skip to content

Commit

Permalink
[Mosaic GPU] Pass in TMA descriptors through kernel parameters
Browse files Browse the repository at this point in the history
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.

PiperOrigin-RevId: 648744363
  • Loading branch information
apaszke authored and jax authors committed Jul 2, 2024
1 parent 9f5483b commit 265a54d
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 53 deletions.
67 changes: 29 additions & 38 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
@dataclasses.dataclass()
class LaunchContext:
launch_op: gpu.LaunchOp
gmem_scratch_ptr: ir.Value
profiler: OnDeviceProfiler | None = None
next_scratch_offset: int = 0
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
Expand Down Expand Up @@ -251,36 +252,31 @@ def host_init_wrapped(host_ptr):
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
)
self.host_scratch_init.append(host_init_wrapped)

with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
ptr_ty = ir.Type.parse("!llvm.ptr")
const_ptr_ty = ir.Type.parse("!llvm.ptr<4>")
gmem_scratch_ptr = llvm.call_intrinsic(
ptr_ty,
"llvm.nvvm.ptr.constant.to.gen.p0.p4",
[llvm.mlir_addressof(const_ptr_ty, "global_scratch")],
)
return device_init(llvm.getelementptr(
ptr_ty, gmem_scratch_ptr, [], [alloc_base], i8
))
# with ir.InsertionPoint(self.gmem_scratch_ptr.owner):
# There is no way to create an insertion point after an operation...
gep = llvm.GEPOp(
ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
)
gep.move_after(self.gmem_scratch_ptr.owner)
return device_init(gep.result)

def _get_tma_desc(
self,
ref,
gmem_ref,
gmem_transform: tuple[MemRefTransform, ...],
transformed_slice_shape: tuple[int, ...],
swizzle: int | None,
):
tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform)
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
with ir.InsertionPoint(self.launch_op):
for t in gmem_transform:
ref = t.apply(ref)
ref_ty = ir.MemRefType(ref.type)

i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr")
def init_tma_desc(host_ptr):
ref = gmem_ref
for t in gmem_transform:
ref = t.apply(ref)
ref_ty = ir.MemRefType(ref.type)
# TODO(apaszke): Use utils.memref_ptr to compute base_ptr
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
as_i64 = lambda i: arith.index_cast(i64, i)
Expand Down Expand Up @@ -466,6 +462,7 @@ def _launch(
token,
grid,
block,
scratch_arr,
smem_buffers: ShapeTree | Union[ShapeTree],
profiler_spec: profiler.ProfilerSpec | None = None,
maybe_prof_buffer: ir.Value | None = None,
Expand Down Expand Up @@ -506,8 +503,8 @@ def _launch(
(ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
)
)
smem_ref_trees = []

smem_ref_trees = []
for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
smem_ref_tree = _construct_smem_reftree(
dynamic_smem, smem_live_buffers_collection)
Expand All @@ -532,7 +529,9 @@ def _launch(
else:
smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []

yield LaunchContext(launch_op, prof), smem_ref_tree
ptr_ty = ir.Type.parse("!llvm.ptr")
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
yield LaunchContext(launch_op, scratch_ptr, prof), smem_ref_tree
if prof is not None:
prof.finalize(grid=grid, block=block)
gpu.terminator()
Expand Down Expand Up @@ -590,32 +589,24 @@ def main(token_ptr, buffers, gmem_scratch_ptr):
in_refs = arg_refs[:len(in_ref_tys)]
out_refs = arg_refs[len(in_ref_tys):]
prof_buffer = out_refs.pop() if prof_spec is not None else None
empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>")
scratch_alloc = llvm.AllocaOp(ptr_ty, c(1, i64), empty_arr_ty)
scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result)
with _launch(
token, grid, block, smem_scratch_shape,
token, grid, block, scratch_arr, smem_scratch_shape,
prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
body(launch_ctx, *in_refs, *out_refs, smem_refs)
gmem_scratch_bytes = launch_ctx.next_scratch_offset
# Allocate and initialize the host buffer right before the launch.
# Note that we couldn't do that before, because we had to run the body
# to learn what the scratch contains.
with ir.InsertionPoint(launch_ctx.launch_op):
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
with ir.InsertionPoint(scratch_arr.owner):
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty)
scratch_arr.set_type(scratch_arr_ty)
for init_callback in launch_ctx.host_scratch_init:
init_callback(host_scratch_ptr)
global_scratch.global_type = ir.TypeAttr.get(
ir.Type.parse("!llvm.array<" + str(gmem_scratch_bytes) + " x i8>")
)
func.call(
[],
"mosaic_gpu_memcpy_async_h2d",
[
gmem_scratch_ptr,
host_scratch_ptr,
c(gmem_scratch_bytes, i64),
token_ptr,
],
)
init_callback(scratch_alloc.result)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
sym_tab = ir.SymbolTable(module.operation)
sym_tab.insert(main.func_op)
Expand Down
5 changes: 5 additions & 0 deletions jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,18 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::memref::registerMemRefPasses();
mlir::registerConvertToLLVMPass();
mlir::registerGPUPasses();
mlir::registerGpuLaunchSinkIndexComputations();
mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerConvertGpuToLLVMPass();
mosaic::gpu::registerByvalInsertionPass();
return true;
}();
(void)register_once;
return mlir::parsePassPipeline(
R"(
builtin.module(
canonicalize,
gpu-launch-sink-index-computations,
convert-nvgpu-to-nvvm,
gpu-kernel-outlining{data-layout-str=},
convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1},
Expand All @@ -135,6 +139,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}),
gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}),
gpu.module(cse),
gpu.module(mosaic-byval-insertion),
gpu.module(reconcile-unrealized-casts),
mosaic-convert-gpu-to-llvm,
gpu-module-to-binary{format=)" +
Expand Down
9 changes: 5 additions & 4 deletions jaxlib/mosaic/gpu/launch_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,15 @@ mlir::Value packKernelArgs(mlir::OpBuilder &builder,
c1);

for (auto [i, operand] : llvm::enumerate(launch.getKernelOperands())) {
mlir::LLVM::GEPArg gep_arg(i);
mlir::Value storage_ptr = builder.create<mlir::LLVM::GEPOp>(
launch.getLoc(), ptr_ty, operand.getType(), kernel_args_struct,
gep_arg);
launch.getLoc(), ptr_ty, kernel_args_struct_ty, kernel_args_struct,
mlir::ArrayRef<mlir::LLVM::GEPArg>{mlir::LLVM::GEPArg(0),
mlir::LLVM::GEPArg(i)});
builder.create<mlir::LLVM::StoreOp>(launch.getLoc(), operand, storage_ptr);
mlir::LLVM::GEPArg arr_gep_arg(i);
mlir::Value array_slot_ptr = builder.create<mlir::LLVM::GEPOp>(
launch.getLoc(), ptr_ty, builder.getI64Type(), kernel_args_array,
gep_arg);
mlir::LLVM::GEPArg(i));
builder.create<mlir::LLVM::StoreOp>(launch.getLoc(), storage_ptr,
array_slot_ptr);
}
Expand Down
58 changes: 58 additions & 0 deletions jaxlib/mosaic/gpu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/SymbolTable.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
Expand Down Expand Up @@ -65,6 +66,57 @@ class ConvertGpuToLLVMPass
}
};

// Convert all array parameters to GPU kernels into byval pointers.
// NVVM backend converts them into arrays in the .param memory space.
// We only use arrays to pass in TMA descriptors, which is why we also
// require 64-byte alignment.
class ByvalInsertionPass
: public mosaic::gpu::Pass<ByvalInsertionPass, mlir::gpu::GPUModuleOp> {
public:
using mosaic::gpu::Pass<ByvalInsertionPass, mlir::gpu::GPUModuleOp>::Pass;
static constexpr llvm::StringLiteral kArgumentName = "mosaic-byval-insertion";
static constexpr llvm::StringLiteral kPassName = "ByvalInsertionPass";

void runOnOperation() override {
auto result = getOperation().walk([](mlir::LLVM::LLVMFuncOp op) {
// TODO(apaszke): op.isDeclaration() always returns false...
if (op.getFunctionBody().empty()) { // Skip over declarations.
return mlir::WalkResult::advance();
}
auto ptr_ty = mlir::LLVM::LLVMPointerType::get(op.getContext());
mlir::LLVM::LLVMFunctionType func_ty = op.getFunctionType();
std::vector<mlir::Type> new_arg_types = func_ty.getParams().vec();
for (unsigned i = 0; i < op.getNumArguments(); ++i) {
mlir::BlockArgument arg = op.getArgument(i);
if (!mlir::isa<mlir::LLVM::LLVMArrayType>(arg.getType())) {
continue;
}
if (op.getArgAttrDict(i)) {
op->emitOpError(
"!llvm.array argument already has some argument attributes");
return mlir::WalkResult::interrupt();
}
// It would be a lot simpler to use op.insertArgument, but the
// impl of FunctionOpInterface for llvm.func is _completely_ broken
new_arg_types[i] = ptr_ty;
op.setArgAttr(i, "llvm.byval", mlir::TypeAttr::get(arg.getType()));
op.setArgAttr(i, "nvvm.grid_constant",
mlir::UnitAttr::get(op.getContext()));
op.setArgAttr(i, "llvm.align",
mlir::IntegerAttr::get(
mlir::IntegerType::get(op.getContext(), 32), 64));
arg.setType(ptr_ty);
}
op.setFunctionType(mlir::LLVM::LLVMFunctionType::get(
func_ty.getReturnType(), new_arg_types, func_ty.isVarArg()));
return mlir::WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
}
}
};

} // namespace

void registerConvertGpuToLLVMPass() {
Expand All @@ -73,5 +125,11 @@ void registerConvertGpuToLLVMPass() {
});
}

void registerByvalInsertionPass() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return std::make_unique<ByvalInsertionPass>();
});
}

} // namespace gpu
} // namespace mosaic
1 change: 1 addition & 0 deletions jaxlib/mosaic/gpu/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
namespace mosaic {
namespace gpu {

void registerByvalInsertionPass();
void registerConvertGpuToLLVMPass();

} // namespace gpu
Expand Down
11 changes: 0 additions & 11 deletions jaxlib/mosaic/gpu/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
}
}

void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes,
CUstream stream) {
CUresult result = cuMemcpyHtoDAsync(dst, src, bytes, stream);
if (result != CUDA_SUCCESS) {
const char *ptr = nullptr;
cuGetErrorString(result, &ptr);
fprintf(stderr, "cuMemcpyAsync failed: %s\n", ptr);
abort();
}
}

void* mosaic_gpu_module_load(void *data) {
CUmodule module = nullptr;
if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) {
Expand Down

0 comments on commit 265a54d

Please sign in to comment.