Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU] Pass in TMA descriptors through kernel parameters #22175

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 29 additions & 38 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
@dataclasses.dataclass()
class LaunchContext:
launch_op: gpu.LaunchOp
gmem_scratch_ptr: ir.Value
profiler: OnDeviceProfiler | None = None
next_scratch_offset: int = 0
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
Expand Down Expand Up @@ -251,36 +252,31 @@ def host_init_wrapped(host_ptr):
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
)
self.host_scratch_init.append(host_init_wrapped)

with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
ptr_ty = ir.Type.parse("!llvm.ptr")
const_ptr_ty = ir.Type.parse("!llvm.ptr<4>")
gmem_scratch_ptr = llvm.call_intrinsic(
ptr_ty,
"llvm.nvvm.ptr.constant.to.gen.p0.p4",
[llvm.mlir_addressof(const_ptr_ty, "global_scratch")],
)
return device_init(llvm.getelementptr(
ptr_ty, gmem_scratch_ptr, [], [alloc_base], i8
))
# with ir.InsertionPoint(self.gmem_scratch_ptr.owner):
# There is no way to create an insertion point after an operation...
gep = llvm.GEPOp(
ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
)
gep.move_after(self.gmem_scratch_ptr.owner)
return device_init(gep.result)

def _get_tma_desc(
self,
ref,
gmem_ref,
gmem_transform: tuple[MemRefTransform, ...],
transformed_slice_shape: tuple[int, ...],
swizzle: int | None,
):
tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform)
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
with ir.InsertionPoint(self.launch_op):
for t in gmem_transform:
ref = t.apply(ref)
ref_ty = ir.MemRefType(ref.type)

i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr")
def init_tma_desc(host_ptr):
ref = gmem_ref
for t in gmem_transform:
ref = t.apply(ref)
ref_ty = ir.MemRefType(ref.type)
# TODO(apaszke): Use utils.memref_ptr to compute base_ptr
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
as_i64 = lambda i: arith.index_cast(i64, i)
Expand Down Expand Up @@ -466,6 +462,7 @@ def _launch(
token,
grid,
block,
scratch_arr,
smem_buffers: ShapeTree | Union[ShapeTree],
profiler_spec: profiler.ProfilerSpec | None = None,
maybe_prof_buffer: ir.Value | None = None,
Expand Down Expand Up @@ -506,8 +503,8 @@ def _launch(
(ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
)
)
smem_ref_trees = []

smem_ref_trees = []
for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
smem_ref_tree = _construct_smem_reftree(
dynamic_smem, smem_live_buffers_collection)
Expand All @@ -532,7 +529,9 @@ def _launch(
else:
smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []

yield LaunchContext(launch_op, prof), smem_ref_tree
ptr_ty = ir.Type.parse("!llvm.ptr")
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
yield LaunchContext(launch_op, scratch_ptr, prof), smem_ref_tree
if prof is not None:
prof.finalize(grid=grid, block=block)
gpu.terminator()
Expand Down Expand Up @@ -590,32 +589,24 @@ def main(token_ptr, buffers, gmem_scratch_ptr):
in_refs = arg_refs[:len(in_ref_tys)]
out_refs = arg_refs[len(in_ref_tys):]
prof_buffer = out_refs.pop() if prof_spec is not None else None
empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>")
scratch_alloc = llvm.AllocaOp(ptr_ty, c(1, i64), empty_arr_ty)
scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result)
with _launch(
token, grid, block, smem_scratch_shape,
token, grid, block, scratch_arr, smem_scratch_shape,
prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
body(launch_ctx, *in_refs, *out_refs, smem_refs)
gmem_scratch_bytes = launch_ctx.next_scratch_offset
# Allocate and initialize the host buffer right before the launch.
# Note that we couldn't do that before, because we had to run the body
# to learn what the scratch contains.
with ir.InsertionPoint(launch_ctx.launch_op):
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
with ir.InsertionPoint(scratch_arr.owner):
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty)
scratch_arr.set_type(scratch_arr_ty)
for init_callback in launch_ctx.host_scratch_init:
init_callback(host_scratch_ptr)
global_scratch.global_type = ir.TypeAttr.get(
ir.Type.parse("!llvm.array<" + str(gmem_scratch_bytes) + " x i8>")
)
func.call(
[],
"mosaic_gpu_memcpy_async_h2d",
[
gmem_scratch_ptr,
host_scratch_ptr,
c(gmem_scratch_bytes, i64),
token_ptr,
],
)
init_callback(scratch_alloc.result)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
sym_tab = ir.SymbolTable(module.operation)
sym_tab.insert(main.func_op)
Expand Down
5 changes: 5 additions & 0 deletions jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,18 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::memref::registerMemRefPasses();
mlir::registerConvertToLLVMPass();
mlir::registerGPUPasses();
mlir::registerGpuLaunchSinkIndexComputations();
mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerConvertGpuToLLVMPass();
mosaic::gpu::registerByvalInsertionPass();
return true;
}();
(void)register_once;
return mlir::parsePassPipeline(
R"(
builtin.module(
canonicalize,
gpu-launch-sink-index-computations,
convert-nvgpu-to-nvvm,
gpu-kernel-outlining{data-layout-str=},
convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1},
Expand All @@ -135,6 +139,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}),
gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}),
gpu.module(cse),
gpu.module(mosaic-byval-insertion),
gpu.module(reconcile-unrealized-casts),
mosaic-convert-gpu-to-llvm,
gpu-module-to-binary{format=)" +
Expand Down
9 changes: 5 additions & 4 deletions jaxlib/mosaic/gpu/launch_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,15 @@ mlir::Value packKernelArgs(mlir::OpBuilder &builder,
c1);

for (auto [i, operand] : llvm::enumerate(launch.getKernelOperands())) {
mlir::LLVM::GEPArg gep_arg(i);
mlir::Value storage_ptr = builder.create<mlir::LLVM::GEPOp>(
launch.getLoc(), ptr_ty, operand.getType(), kernel_args_struct,
gep_arg);
launch.getLoc(), ptr_ty, kernel_args_struct_ty, kernel_args_struct,
mlir::ArrayRef<mlir::LLVM::GEPArg>{mlir::LLVM::GEPArg(0),
mlir::LLVM::GEPArg(i)});
builder.create<mlir::LLVM::StoreOp>(launch.getLoc(), operand, storage_ptr);
mlir::LLVM::GEPArg arr_gep_arg(i);
mlir::Value array_slot_ptr = builder.create<mlir::LLVM::GEPOp>(
launch.getLoc(), ptr_ty, builder.getI64Type(), kernel_args_array,
gep_arg);
mlir::LLVM::GEPArg(i));
builder.create<mlir::LLVM::StoreOp>(launch.getLoc(), storage_ptr,
array_slot_ptr);
}
Expand Down
58 changes: 58 additions & 0 deletions jaxlib/mosaic/gpu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/SymbolTable.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
Expand Down Expand Up @@ -65,6 +66,57 @@ class ConvertGpuToLLVMPass
}
};

// Convert all array parameters to GPU kernels into byval pointers.
// NVVM backend converts them into arrays in the .param memory space.
// We only use arrays to pass in TMA descriptors, which is why we also
// require 64-byte alignment.
class ByvalInsertionPass
: public mosaic::gpu::Pass<ByvalInsertionPass, mlir::gpu::GPUModuleOp> {
public:
using mosaic::gpu::Pass<ByvalInsertionPass, mlir::gpu::GPUModuleOp>::Pass;
static constexpr llvm::StringLiteral kArgumentName = "mosaic-byval-insertion";
static constexpr llvm::StringLiteral kPassName = "ByvalInsertionPass";

void runOnOperation() override {
auto result = getOperation().walk([](mlir::LLVM::LLVMFuncOp op) {
// TODO(apaszke): op.isDeclaration() always returns false...
if (op.getFunctionBody().empty()) { // Skip over declarations.
return mlir::WalkResult::advance();
}
auto ptr_ty = mlir::LLVM::LLVMPointerType::get(op.getContext());
mlir::LLVM::LLVMFunctionType func_ty = op.getFunctionType();
std::vector<mlir::Type> new_arg_types = func_ty.getParams().vec();
for (unsigned i = 0; i < op.getNumArguments(); ++i) {
mlir::BlockArgument arg = op.getArgument(i);
if (!mlir::isa<mlir::LLVM::LLVMArrayType>(arg.getType())) {
continue;
}
if (op.getArgAttrDict(i)) {
op->emitOpError(
"!llvm.array argument already has some argument attributes");
return mlir::WalkResult::interrupt();
}
// It would be a lot simpler to use op.insertArgument, but the
// impl of FunctionOpInterface for llvm.func is _completely_ broken
new_arg_types[i] = ptr_ty;
op.setArgAttr(i, "llvm.byval", mlir::TypeAttr::get(arg.getType()));
op.setArgAttr(i, "nvvm.grid_constant",
mlir::UnitAttr::get(op.getContext()));
op.setArgAttr(i, "llvm.align",
mlir::IntegerAttr::get(
mlir::IntegerType::get(op.getContext(), 32), 64));
arg.setType(ptr_ty);
}
op.setFunctionType(mlir::LLVM::LLVMFunctionType::get(
func_ty.getReturnType(), new_arg_types, func_ty.isVarArg()));
return mlir::WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
}
}
};

} // namespace

void registerConvertGpuToLLVMPass() {
Expand All @@ -73,5 +125,11 @@ void registerConvertGpuToLLVMPass() {
});
}

void registerByvalInsertionPass() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return std::make_unique<ByvalInsertionPass>();
});
}

} // namespace gpu
} // namespace mosaic
1 change: 1 addition & 0 deletions jaxlib/mosaic/gpu/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
namespace mosaic {
namespace gpu {

void registerByvalInsertionPass();
void registerConvertGpuToLLVMPass();

} // namespace gpu
Expand Down
11 changes: 0 additions & 11 deletions jaxlib/mosaic/gpu/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
}
}

void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes,
CUstream stream) {
CUresult result = cuMemcpyHtoDAsync(dst, src, bytes, stream);
if (result != CUDA_SUCCESS) {
const char *ptr = nullptr;
cuGetErrorString(result, &ptr);
fprintf(stderr, "cuMemcpyAsync failed: %s\n", ptr);
abort();
}
}

void* mosaic_gpu_module_load(void *data) {
CUmodule module = nullptr;
if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) {
Expand Down