diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index f9723c4651..2508bc7a26 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -95,6 +95,8 @@ def cuda_install_path(): else: this.use_rmm = False +this._use_sycl = False + def set_log_level(level: int): """ @@ -172,6 +174,35 @@ def initialize_cuda_context(): this._device_id = int(device_id) +import dpctl + +this._sycl_device: dpctl.SyclDevice = None + +def initialize_sycl_context(): + if this._device_id is not None and this._sycl_device is not None: + return + + device_id = int(os.getenv("CUTLASS_SYCL_DEVICE_ID", default=0)) + sycl_gpus = dpctl.get_devices( + dpctl.backend_type.level_zero, dpctl.device_type.gpu) + + if len(sycl_gpus) <= device_id: + raise Exception("No LevelZero device found") + + this._device_id = device_id + this._sycl_device = sycl_gpus[device_id] + + def device_id() -> int: - initialize_cuda_context() + if os.getenv("CUTLASS_USE_SYCL"): + initialize_sycl_context() + this._use_sycl = True + else: + this._use_sycl = False + initialize_cuda_context() return this._device_id + + +def sycl_device() -> dpctl.SyclDevice: + initialize_sycl_context() + return this._sycl_device diff --git a/python/cutlass/backend/arguments.py b/python/cutlass/backend/arguments.py index b91cdf1f0c..9e6049170a 100644 --- a/python/cutlass/backend/arguments.py +++ b/python/cutlass/backend/arguments.py @@ -39,8 +39,11 @@ import cutlass from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend from cutlass.backend.memory_manager import DevicePtrWrapper +from cutlass.backend.utils.device import default_stream from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor +import dpctl + class ArgumentBase: """ @@ -58,7 +61,7 @@ def __init__( # tensor_C can be interpreted as the bias with bias=True in keyword args self.bias = kwargs.get("bias", False) - self.stream = kwargs.get("stream", cuda.CUstream(0)) + self.stream = kwargs.get("stream", default_stream()) # RMM buffers used to track tensor lifetime self.buffers = {} @@ -83,34 +86,43 @@ def tensor_to_ptr(self, tensor, name, is_output=False): if is_numpy_tensor(tensor): if is_output: assert name - self.buffers[name] = NumpyFrontend.argument(tensor, is_output) + self.buffers[name] = NumpyFrontend.argument(tensor, is_output, self.stream) if is_output: self.host_tensors[name] = tensor return self.buffers[name].ptr elif is_torch_tensor(tensor): - return TorchFrontend.argument(tensor) + return TorchFrontend.argument(tensor, self.stream) elif isinstance(tensor, cuda.CUdeviceptr): return tensor elif is_cupy_tensor(tensor): return CupyFrontend.argument(tensor) else: - raise TypeError("Unsupported Frontend. Only support numpy and torch") + raise TypeError( + "Unsupported Frontend. Only support numpy and torch") def sync(self, stream_sync=True): + is_sycl = isinstance(self.stream, dpctl.SyclQueue) if stream_sync: - (err,) = cudart.cudaDeviceSynchronize() - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) + if is_sycl: + self.stream.wait() + else: + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) for key in self.host_tensors.keys(): host_tensor = self.host_tensors[key] - (err,) = cuda.cuMemcpyDtoH( - host_tensor, - self.buffers[key].ptr, - host_tensor.size * host_tensor.itemsize, - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) + if is_sycl: + self.stream.memcpy(host_tensor, self.buffers[key].usm_mem, + host_tensor.size * host_tensor.itemsize) + else: + (err,) = cuda.cuMemcpyDtoH( + host_tensor, + self.buffers[key].ptr, + host_tensor.size * host_tensor.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) self.free() diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index d72af78eae..a265212704 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -161,7 +161,8 @@ def get_mainloop_arguments_3x( element_A, element_B, alignment_A: int, - alignment_B: int) -> ctypes.Structure: + alignment_B: int, + use_sycl: bool = False) -> ctypes.Structure: """ Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters. @@ -207,10 +208,15 @@ def from_generic_mainloop_args(args: GenericMainloopArguments3x_): args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, ) - # Currently all 3.x kernels (CpAsync and Tma) have the same argument structure. - # Should that become not the case, this is the place to return custom ctypes - # structures based on selected kernel schedule. - return _MainloopArgumentsTma + if use_sycl: + # For SYCL, we don't have the additional 'mma_promotion_interval' arg. + return _MainloopArgumentsMultistage + else: + # Currently all 3.x kernels (CpAsync and Tma) for Nvidia devices have + # the same argument structure. Should that become not the case, this is + # the place to return custom ctypes structures based on selected kernel + # schedule. + return _MainloopArgumentsTma def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue): diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index f52b181831..c9b47c4cff 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -32,7 +32,7 @@ import ctypes import json -import os +import pathlib import sqlite3 import subprocess import tempfile @@ -40,6 +40,8 @@ from cuda import cuda, nvrtc from cutlass_library import SubstituteTemplate +import dpctl + import cutlass from cutlass import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger from cutlass.backend.gemm_operation import GemmOperationUniversal @@ -75,43 +77,44 @@ class CompilationOptions: Compilation options. """ - def __init__(self, flags, arch, include_paths=[]): + def __init__(self, flags, arch, include_paths=[], for_sycl=False): self.includes = [] self.include_paths = include_paths self.flags = flags self.arch = arch + self.for_sycl = for_sycl - def get_str(self): + def _encode(self): opts = [] for flag in self.flags: opts.append(flag) for incl in self.include_paths: - opts.append(f"--include-path={incl}") - - arch_flag = f"-arch=sm_{self.arch}" - if self.arch == 90: - arch_flag += "a" - opts.append(arch_flag) + if self.for_sycl: + opts.append(f"-I{incl}") + else: + opts.append(f"--include-path={incl}") - return " ".join(opts) + if self.for_sycl: + arch_flag = f"-fsycl-targets={self.arch}" + else: + arch_flag = f"-arch=sm_{self.arch}" + if self.arch == 90: + arch_flag += "a" - def get(self): - options = [] + opts.append(arch_flag) - for flag in self.flags: - options.append(bytes(str.encode(flag))) + return opts - for incl in self.include_paths: - options.append(bytes(str.encode(f" --include-path={incl}"))) - arch_flag = f" -arch=sm_{self.arch}" - if self.arch == 90: - arch_flag += "a" + def get_str(self): + opts = self._encode() - options.append(bytes(str.encode(arch_flag))) + return " ".join(opts) - return options + def get(self): + options = self._encode() + return [bytes(str.encode(s)) for s in options] def convertToBinaryData(filename): @@ -122,7 +125,8 @@ def convertToBinaryData(filename): def CDLLBin(host_binary): tempfile.tempdir = "./" - temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True) + temp_so = tempfile.NamedTemporaryFile( + prefix="host_func", suffix=".so", delete=True) with open(temp_so.name, "wb") as file: file.write(host_binary) host_lib = ctypes.CDLL(temp_so.name) @@ -155,6 +159,13 @@ def __init__(self) -> None: "--expt-relaxed-constexpr", "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", ] + self._dpcpp_compile_options = ["-fsycl", "-std=c++17", + "-DCUTLASS_ENABLE_SYCL", + "-fsycl-rtc-mode", + "-DSYCL_INTEL_TARGET", + "-shared", "-fPIC", + "-fno-sycl-dead-args-optimization", + "-fsycl-range-rounding=disable"] self.nvcc() self.compiled_cache_device = {} self.compiled_cache_host = {} @@ -167,6 +178,13 @@ def nvcc(self): self.backend = "nvcc" self.default_compile_options = self._nvcc_compile_options + def dpcpp(self): + self.backend = "dpcpp" + self.default_compile_options = self._dpcpp_compile_options + + def _is_sycl(self): + return self.backend == "dpcpp" + def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): connection = sqlite3.connect(CACHE_FILE) cursor = connection.cursor() @@ -191,11 +209,17 @@ def load_operation(self, op_key, extra_funcs): for row in record: key, cubin_image, host_binary, operation_name, op_attr = row op_attr = json.loads(op_attr) - err, module = cuda.cuModuleLoadData(cubin_image) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("Cuda Error: {}".format(err)) + if self._is_sycl(): + q = dpctl.SyclQueue(cutlass.sycl_device()) + module = dpctl.program.create_program_from_spirv(q, cubin_image) + kernel = module.get_sycl_kernel(operation_name) + else: + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + err, kernel = cuda.cuModuleGetFunction( + module, bytes(str.encode(operation_name))) - err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) self.compiled_cache_device[key] = kernel compiled_host_fns = {} @@ -237,7 +261,10 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op if incl not in includes: includes.append(incl) - includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes + includes_host = ["stddef.h"] + includes + if not self._is_sycl(): + includes_host.extend(["device_launch_parameters.h", "builtin_types.h"]) + for incl in includes: source_buffer_device += SubstituteTemplate( IncludeTemplate, @@ -262,10 +289,11 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op operation.KernelTemplate, values, ) - source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) + source_buffer_host += SubstituteTemplate( + operation.HostTemplate, values) + # 3. compile if self.backend == "nvrtc": - # 3. compile err, program = nvrtc.nvrtcCreateProgram( str.encode(source_buffer_device), bytes(str.encode("module.cu")), @@ -291,7 +319,8 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise RuntimeError("NVRTC Error: {}".format(err)) - raise RuntimeError(error_string + log.decode() + source_buffer_device) + raise RuntimeError( + error_string + log.decode() + source_buffer_device) # Get data from compilation err, dataSize = nvrtc.nvrtcGetCUBINSize(program) @@ -303,6 +332,50 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise RuntimeError("NVRTC Error: {}".format(err)) + elif self.backend == "dpcpp": + # Emit code to file + tempfile.tempdir = "./" + temp_cpp = tempfile.NamedTemporaryFile( + prefix="kernel_", suffix=".cpp", delete=True) + temp_dump_dir = tempfile.TemporaryDirectory( + prefix="kernel_", suffix="_dpcpp") + ignore_out = tempfile.NamedTemporaryFile( + prefix="kernel_", suffix=".o", delete=True) + with open(temp_cpp.name, "w") as file: + file.write(source_buffer_device) + + # Compile with DPC++ + cmd_template = "clang++ ${options} ${srcfile} -o ${outfile} -fsycl-dump-device-code=${tmpdir}" + values = { + "options": compilation_options.get_str(), + "srcfile": temp_cpp.name, + "outfile": ignore_out.name, + "tmpdir": temp_dump_dir.name + } + cmd = SubstituteTemplate(cmd_template, values) + compile_with_nvcc(cmd.split(" "), source_buffer_device, + "./cutlass_python_compilation_device_error.txt") + + # Find SPIR-V device code in temporary directory + spv_files = list(pathlib.Path(temp_dump_dir.name).glob("*.spv")) + + # When specifying a specific subgroup size, DPC++ currently + # generates multiple SPIR-V files. We create a program from each of + # them to find the one containing the kernel with the correct + # subgroup size. + q = dpctl.SyclQueue(cutlass.sycl_device()) + op_name = f"__sycl_kernel_{operation_list[0].name()}" + for f in spv_files: + with open(f, "rb") as spirv_file: + spirv_image = spirv_file.read() + program = dpctl.program.create_program_from_spirv(q, spirv_image) + if not program.has_sycl_kernel(op_name): + continue + spirv_kernel = program.get_sycl_kernel(op_name) + if spirv_kernel.max_sub_group_size == 16: + cubin_image = spirv_image + break + else: # with nvcc backend # emit code tempfile.tempdir = "./" @@ -322,15 +395,17 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op "tarfile": temp_cubin.name, } cmd = SubstituteTemplate(cmd_template, values) - compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt") + compile_with_nvcc(cmd.split(" "), source_buffer_device, + "./cutlass_python_compilation_device_error.txt") # load the cubin image with open(temp_cubin.name, "rb") as file: cubin_image = file.read() tempfile.tempdir = "./" + host_suffix = ".cpp" if self._is_sycl() else ".cu" temp_src = tempfile.NamedTemporaryFile( - prefix="host_src", suffix=".cu", delete=True) + prefix="host_src", suffix=host_suffix, delete=True) # Write the host source with open(temp_src.name, "w") as outfile: @@ -341,13 +416,23 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op # Set up host compilation arguments cmd = [] - cmd.append(f"{cuda_install_path()}/bin/nvcc") - cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"]) - cmd.extend(host_compilation_options.get_str().split(" ")) - cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"]) + if not self._is_sycl(): + cmd.append(f"{cuda_install_path()}/bin/nvcc") + cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", + "-Xcompiler=-w", "-Xcompiler=-fPIC"]) + cmd.extend(host_compilation_options.get_str().split(" ")) + cmd.extend(["-shared", "-o", temp_dst.name, + temp_src.name, "-lcudart", "-lcuda"]) + else: + cmd.append("clang++") + # Clang does not support "-fpermissive" + cmd.extend(["-fsycl", "-w", "-fPIC"]) + cmd.extend(host_compilation_options.get_str().split(" ")) + cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name]) # Comile and load the library - compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt") + compile_with_nvcc(cmd, source_buffer_host, + error_file="./cutlass_python_compilation_host_error.txt") host_lib = ctypes.CDLL(temp_dst.name) return cubin_image, host_lib, temp_dst @@ -357,20 +442,28 @@ def add_module(self, operations, compile_options=None, bypass_cache=False): Insert a new compiled device module """ include_paths = [ - cuda_install_path() + "/include", CUTLASS_PATH + "/include", CUTLASS_PATH + "/tools/util/include", CUTLASS_PATH + "/python/cutlass/cpp/include", ] - cutlass.initialize_cuda_context() - arch = device_cc() + if not self._is_sycl(): + include_paths.append(cuda_install_path() + "/include") + + cutlass.initialize_cuda_context() + arch = device_cc() + host_compile_options = CompilationOptions( + self._nvcc_compile_options, arch, include_paths, False) + else: + cutlass.initialize_sycl_context() + arch = "spir64" + host_compile_options = CompilationOptions( + ["-std=c++17", "-DCUTLASS_ENABLE_SYCL", "-DSYCL_INTEL_TARGET"], + arch, include_paths, True) - host_compile_options = CompilationOptions( - self._nvcc_compile_options, arch, include_paths) if compile_options is None: compile_options = CompilationOptions( - self.default_compile_options, arch, include_paths) + self.default_compile_options, arch, include_paths, self._is_sycl()) # save the cubin operation_key = [] operation_list = [] @@ -380,7 +473,9 @@ def add_module(self, operations, compile_options=None, bypass_cache=False): # step 1: check if the operation is in cache compiled_kernel = self.compiled_cache_device.get(key) - if compiled_kernel is None and not bypass_cache: + # TODO(Lukas): Caching is currently deactivated for SYCL and needs + # to be enabled. + if compiled_kernel is None and not bypass_cache and not self._is_sycl(): hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {})) if hit: compiled_kernel = self.compiled_cache_device.get(key) @@ -400,18 +495,27 @@ def add_module(self, operations, compile_options=None, bypass_cache=False): cubin_image, host_lib, host_file = self.emit_compile_( operation_list, compile_options, host_compile_options) - err, module = cuda.cuModuleLoadData(cubin_image) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("Cuda Error: {}".format(err)) + if self._is_sycl(): + q = dpctl.SyclQueue(cutlass.sycl_device()) + program = dpctl.program.create_program_from_spirv(q, cubin_image) + else: + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) operation_name = [] operation_attr = [] for operation, key in zip(operation_list, operation_key): # get device kernels - err, operation.kernel = cuda.cuModuleGetFunction( - module, - bytes(str.encode(operation.name())) - ) + if self._is_sycl(): + # Free function kernels always have a name prefix. + fnName = f"__sycl_kernel_{operation.name()}" + operation.kernel = program.get_sycl_kernel(fnName) + else: + err, operation.kernel = cuda.cuModuleGetFunction( + module, + bytes(str.encode(operation.name())) + ) operation_name.append(operation.name()) self.compiled_cache_device[key] = operation.kernel # get host functions @@ -442,7 +546,7 @@ def add_module(self, operations, compile_options=None, bypass_cache=False): op_attr.append(param_size) if hasattr(operation, "extra_funcs"): - for suffix, ret_type in operation.extra_funcs.items(): + for suffix, ret_type in operation.extra_funcs.items(): func_name = operation.name() + "_" + suffix func = getattr(host_lib, func_name) if ret_type is not None: diff --git a/python/cutlass/backend/frontend.py b/python/cutlass/backend/frontend.py index 2b907cc765..6c3c1b6248 100644 --- a/python/cutlass/backend/frontend.py +++ b/python/cutlass/backend/frontend.py @@ -33,8 +33,16 @@ from cuda import cuda import numpy as np +import dpctl + from cutlass.backend.memory_manager import device_mem_alloc, todevice -from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor +from cutlass.utils.datatypes import ( + is_cupy_tensor, + is_numpy_tensor, + is_torch_tensor, + is_xpu_tensor, + is_xpu_available +) class NumpyFrontend: @@ -43,19 +51,19 @@ class NumpyFrontend: """ @staticmethod - def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr: + def argument(np_tensor: "np.ndarray", is_output: "bool", stream=None): """Convert the input numpy tensor to CUDA device pointer :param np_tensor: input numpy nd array :param is_output: whether the tensor is output - :return: CUDA device pointer + :return: Wrapped device pointer """ # copy the data to device if is_output: - return device_mem_alloc(np_tensor.size * np_tensor.itemsize) + return device_mem_alloc(np_tensor.size * np_tensor.itemsize, stream=stream) else: - return todevice(np_tensor) + return todevice(np_tensor, stream=stream) class TorchFrontend: @@ -64,15 +72,23 @@ class TorchFrontend: """ @staticmethod - def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr: + def argument(torch_tensor: "torch.Tensor", stream=None): """Convert the input torch tensor to CUDA device pointer :param torch_tensor: input torch tensor :param is_output: whether the tensor is output - :return: CUDA device pointer + :return: Device pointer """ + if isinstance(stream, dpctl.SyclQueue): + if not is_xpu_available(): + raise Exception("No XPU support in Torch available") + if not is_xpu_tensor(torch_tensor): + torch_tensor = torch_tensor.to("xpu") + + return torch_tensor.data_ptr() + # check the device of torch_tensor if not torch_tensor.is_cuda: torch_tensor = torch_tensor.to("cuda") diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 62ac6c272d..4ad85cae6e 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -38,6 +38,8 @@ from cutlass_library import SubstituteTemplate import numpy as np +import dpctl + from cutlass_library import ( ComplexTransformTag, DataType, @@ -512,12 +514,14 @@ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalM super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) def get_arguments(self): + use_sycl = isinstance(self.stream, dpctl.SyclQueue) mainloop_args = get_mainloop_arguments_3x( self.operation.tile_description.kernel_schedule, self.operation.A.element, self.operation.B.element, self.operation.A.alignment, - self.operation.B.alignment + self.operation.B.alignment, + use_sycl ) scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler) uses_default_epilogue = self.operation.rt_module.uses_default_epilogue() @@ -908,6 +912,9 @@ def get_device_workspace_size(self, arguments): return 0 def initialize(self): + if self.operation.arch == 11: + return + err, = cuda.cuFuncSetAttribute( self.kernel, attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, @@ -1112,11 +1119,19 @@ class GemmRTUniversal3x(GemmRTUniversal): using Operator = ${operation_name}${operation_suffix}; extern "C" +#if defined(CUTLASS_ENABLE_SYCL) +SYCL_EXTERNAL SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( + (sycl::ext::oneapi::experimental::nd_range_kernel< + 3>)) [[sycl::reqd_sub_group_size(16)]] +void ${operation_name}(typename Operator::Params const params, + sycl::ext::oneapi::experimental::work_group_memory mem) { + auto* smem = &mem[0]; +#else __global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) void ${operation_name}(__grid_constant__ typename Operator::Params const params) { // Dynamic shared memory base pointer extern __shared__ char smem[]; - +#endif // Declare pointer to dynamic shared memory. Operator op; op(params, smem); @@ -1300,8 +1315,9 @@ def __init__(self, operation_suffix=""): def emit(self, operation): # Support built-in epilogue functors or user-defined functions - - if operation.tile_description.stages is None or operation.tile_description.stages == 0: + if operation.arch == 11: + stage_count_type = "cutlass::gemm::collective::StageCountAuto" + elif operation.tile_description.stages is None or operation.tile_description.stages == 0: stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>" else: stage_count_type = "_" + str(operation.tile_description.stages) @@ -1321,6 +1337,7 @@ def emit(self, operation): if operation.tile_description.tile_scheduler is not None: tschedule = operation.tile_description.tile_scheduler + arch = "cutlass::arch::IntelPVC" if operation.arch == 11 else f"cutlass::arch::Sm{operation.arch}" values = { "operation_name": operation.procedural_name(), "operation_suffix": self.operation_suffix, @@ -1335,7 +1352,7 @@ def emit(self, operation): "element_accumulator": DataTypeTag[operation.accumulator_type()], "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue], "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - "arch": "cutlass::arch::Sm%d" % operation.arch, + "arch": arch, "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py index 49cb537a2d..4b374ba930 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass/backend/library.py @@ -471,6 +471,9 @@ def api_version(arch, opclass, dtype): :return: API version to be used in code emission :rtype: ApiVersion """ + if opclass == OpcodeClass.TensorOp and arch == 11: + return ApiVersion.v3x + if (arch >= 90 and opclass == OpcodeClass.TensorOp and (dtype != DataType.f64)): diff --git a/python/cutlass/backend/memory_manager.py b/python/cutlass/backend/memory_manager.py index 89e6908395..dee8fd3d0f 100644 --- a/python/cutlass/backend/memory_manager.py +++ b/python/cutlass/backend/memory_manager.py @@ -40,6 +40,8 @@ else: from cuda import cudart +from dpctl.memory import MemoryUSMDevice + class PoolMemoryManager: def __init__(self, init_pool_size: int, max_pool_size: int) -> None: @@ -67,13 +69,33 @@ def __init__(self, dev_ptr): def ptr(self): return self.dev_ptr +class SYCLPtrWrapper: + """ + Wrapper around a pointer to USM device memory to provide a uniform interface. + """ + def __init__(self, usm): + self.usm = usm + + @property + def ptr(self): + return self.usm.__sycl_usm_array_interface__["data"][0] + + @property + def usm_mem(self): + return self.usm -def _todevice(host_data): + +def _todevice(host_data, stream): """ Helper for transferring host data to device memory """ if cutlass.use_rmm: return rmm.DeviceBuffer.to_device(host_data.tobytes()) + if cutlass._use_sycl: + nbytes = len(host_data.tobytes()) + usm_device_ptr = device_mem_alloc(nbytes, stream) + stream.memcpy(usm_device_ptr.usm_mem, host_data.tobytes(), nbytes) + return usm_device_ptr else: nbytes = len(host_data.tobytes()) dev_ptr_wrapper = device_mem_alloc(nbytes) @@ -88,19 +110,22 @@ def _todevice(host_data): return dev_ptr_wrapper -def todevice(host_data, dtype=np.float32): +def todevice(host_data, dtype=np.float32, stream = None): """ Pass the host_data to device memory """ if isinstance(host_data, list): - return _todevice(np.array(host_data, dtype=dtype)) + return _todevice(np.array(host_data, dtype=dtype), stream) elif is_numpy_tensor(host_data): - return _todevice(host_data) + return _todevice(host_data, stream) -def device_mem_alloc(size): +def device_mem_alloc(size, stream = None): if cutlass.use_rmm: return rmm.DeviceBuffer(size=size) + elif cutlass._use_sycl: + device_usm = MemoryUSMDevice(size, queue=stream) + return SYCLPtrWrapper(device_usm) else: err, ptr = cudart.cudaMalloc(size) if err != cudart.cudaError_t.cudaSuccess: diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index a73cef6857..ac5857bc81 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -34,6 +34,8 @@ from cuda import __version__, cuda +import dpctl + from cutlass.backend.utils.device import device_cc _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] @@ -122,10 +124,22 @@ def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstrea return err + def run_with_sycl(self, launch_config, kernel_params, param_size, stream): + local_mem = dpctl.experimental.WorkGroupMemory(launch_config.shared_memory_capacity) + raw_arg = dpctl.experimental.RawKernelArg(param_size, kernel_params) + globalSize = [g * l for g, l in zip(launch_config.grid, launch_config.block)] + globalSize.reverse() + localSize = launch_config.block + localSize.reverse() + stream.submit(self.kernel, [raw_arg, local_mem], globalSize, localSize) + def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)): cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace) packed = (ctypes.c_void_p * 1)() packed[0] = ctypes.addressof(cArg) + if isinstance(stream, dpctl.SyclQueue): + self.run_with_sycl(launch_config, packed[0], len(host_workspace), stream) + return cuda.CUresult.CUDA_SUCCESS if supports_cluster_launch(): return self.run_with_clusters(launch_config, packed, stream) diff --git a/python/cutlass/backend/utils/device.py b/python/cutlass/backend/utils/device.py index 7ccf6ee981..7952069e46 100644 --- a/python/cutlass/backend/utils/device.py +++ b/python/cutlass/backend/utils/device.py @@ -35,6 +35,7 @@ """ from cuda import cuda, cudart +import dpctl import cutlass from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor @@ -76,6 +77,10 @@ def device_cc(device: int = -1) -> int: if device == -1: device = cutlass.device_id() + if cutlass._use_sycl: + # Using '11' to encode Intel PVC as an integer in the expected format. + return 11 + deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) major = str(deviceProp.major) minor = str(deviceProp.minor) @@ -85,6 +90,10 @@ def device_cc(device: int = -1) -> int: def device_sm_count(device: int = -1): if device == -1: device = cutlass.device_id() + + if cutlass._use_sycl: + return cutlass._sycl_device.max_compute_units + err, device_sm_count = cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device ) @@ -121,3 +130,9 @@ def to_device_ptr(tensor) -> cuda.CUdeviceptr: raise NotImplementedError(tensor) return ptr + + +def default_stream(): + if cutlass._use_sycl: + return dpctl.SyclQueue(cutlass._sycl_device) + return cuda.CUstream(0) diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 7c16cc6855..87c959a8e2 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -46,7 +46,8 @@ from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op -_generator_ccs = [50, 60, 61, 70, 75, 80, 90] +# The value '11' is used to encode Intel PVC GPU in the expected format. +_generator_ccs = [11, 50, 60, 61, 70, 75, 80, 90] # Strip any additional information from the CUDA version _cuda_version = __version__.split("rc")[0] @@ -264,7 +265,7 @@ def __init__( # Identify the method within CUTLASS generator script that generates kernel # descriptions for the target CC - generate_function_name = "GenerateSM" + str(kernel_cc) + generate_function_name = "GeneratePVC" if kernel_cc == 11 else "GenerateSM" + str(kernel_cc) if not hasattr(cutlass_library.generator, generate_function_name): cutlass.logger.warning(f"No generator found for architecture {kernel_cc}") return @@ -276,6 +277,9 @@ def __init__( "--kernels=all", f"--log-level={logging.getLevelName(cutlass.logger.level)}" ] + if self.cc == 11: + args.append("--architectures=11") + manifest_args = cutlass_library.generator.define_parser().parse_args(args) manifest = cutlass_library.manifest.Manifest(manifest_args) generate_function(manifest, _cuda_version) diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index e74c40786f..97a591f9ec 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -129,6 +129,7 @@ from cutlass.backend.evt import EpilogueFunctorVisitor from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal from cutlass.backend.library import TensorDescription, TileDescription +from cutlass.backend.utils.device import default_stream from cutlass.op.op import OperationBase from cutlass.shape import GemmCoord from cutlass.utils import check, datatypes @@ -623,7 +624,7 @@ def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): def run(self, A=None, B=None, C=None, D=None, alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None, - stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments: + stream = default_stream()) -> GemmArguments: """ Runs the kernel currently specified. If it has not already been, the kernel is emitted and compiled. Tensors holding operands and outputs of the kernel are sourced either from the diff --git a/python/cutlass/utils/check.py b/python/cutlass/utils/check.py index 2a37b72c33..4d498e1933 100644 --- a/python/cutlass/utils/check.py +++ b/python/cutlass/utils/check.py @@ -118,10 +118,21 @@ def valid_stage_count( "stage count, and shared memory requirement of the epilogue exceeds " "the available shared memory per SM.") + if kernel_cc == 11: + if (td.stages is None or td.stages == 0): + # Support for Intel PVC GPU currently does not allow explicit + # specification of the stage count. With None or 0, the + # CollectiveBuilder automatically determines the stage count to use. + return (True, "") + elif verbose: + cutlass.logger.warning( + "Setting an explicit stage count for Intel PVC GPU is currently " + "not supported.") + if td.stages <= 0: return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") - if cc < 80 and td.stages != 2: + if cc >= 50 and cc < 80 and td.stages != 2: return (False, f"Tile description has stage count of {td.stages}, " f"but only 2 stages are supported on SM{cc}.") diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index 8ef50ad8ca..25915e486f 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -200,6 +200,17 @@ def is_torch_tensor(inp) -> bool: return isinstance(inp, torch.Tensor) return False +def is_xpu_available(): + if is_torch_available(): + import torch + return torch.xpu.is_available() + return False + +def is_xpu_tensor(inp) -> bool: + if is_torch_tensor(inp): + return inp.device.type == "xpu" + return False + def torch_library_type(inp) -> cutlass.DataType: return _torch_to_library_dict.get(inp, None) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index e6a9f9e8e5..82deb2ca51 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -6902,6 +6902,47 @@ def GenerateSM90(manifest, cuda_version): ################################################################################################### +def GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version): + # TODO: Add remaining supported configurations + layouts = [ + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 4]] + ] + + math_instructions = [ + MathInstruction( + [8, 16, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ] + + min_cc = 11 + max_cc = 11 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([math_inst.instruction_shape[0] * 32, math_inst.instruction_shape[1] * 16, math_inst.instruction_shape[2] * 2], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1, 1, 1]) + ] + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules, tile_schedulers=[TileSchedulerType.Persistent]) + +def GeneratePVC(manifest, cuda_version): + GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version) + +################################################################################################### + def numeric_log_level(log_level: str) -> int: """ Converts the string identifier of the log level diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index be9eef20ed..0955e06c89 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -626,6 +626,7 @@ class Target(enum.Enum): library = enum_auto() # ArchitectureNames = { + 11: 'pvc', 50: 'maxwell', 60: 'pascal', 61: 'pascal', @@ -638,6 +639,7 @@ class Target(enum.Enum): # SharedMemPerCC = { + 11: 128, # 128 KiB of SMEM on Intel PVC 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM