Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First steps to enable SYCL backend in Python Interface #155

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion python/cutlass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,35 @@ def initialize_cuda_context():
this._device_id = int(device_id)


import dpctl

this._sycl_device: dpctl.SyclDevice = None
this._use_sycl = False

def initialize_sycl_context():
if this._device_id is not None and this._sycl_device is not None:
return

device_id = int(os.getenv("CUTLASS_SYCL_DEVICE_ID", default=0))
sycl_gpus = dpctl.get_devices(
dpctl.backend_type.level_zero, dpctl.device_type.gpu)

if len(sycl_gpus) <= device_id:
raise Exception("No LevelZero device found")

this._device_id = device_id
this._sycl_device = sycl_gpus[device_id]


def device_id() -> int:
initialize_cuda_context()
if os.getenv("CUTLASS_USE_SYCL"):
initialize_sycl_context()
this._use_sycl = True
else:
initialize_cuda_context()
return this._device_id


def sycl_device() -> dpctl.SyclDevice:
initialize_sycl_context()
return this._sycl_device
194 changes: 142 additions & 52 deletions python/cutlass/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@

import ctypes
import json
import os
import pathlib
import sqlite3
import subprocess
import tempfile

from cuda import cuda, nvrtc
from cutlass_library import SubstituteTemplate

import dpctl

import cutlass
from cutlass import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
from cutlass.backend.gemm_operation import GemmOperationUniversal
Expand Down Expand Up @@ -75,43 +77,44 @@ class CompilationOptions:
Compilation options.
"""

def __init__(self, flags, arch, include_paths=[]):
def __init__(self, flags, arch, include_paths=[], for_sycl=False):
self.includes = []
self.include_paths = include_paths
self.flags = flags
self.arch = arch
self.for_sycl = for_sycl

def get_str(self):
def _encode(self):
opts = []
for flag in self.flags:
opts.append(flag)

for incl in self.include_paths:
opts.append(f"--include-path={incl}")
if self.for_sycl:
opts.append(f"-I{incl}")
else:
opts.append(f"--include-path={incl}")

arch_flag = f"-arch=sm_{self.arch}"
if self.arch == 90:
arch_flag += "a"
opts.append(arch_flag)
if self.for_sycl:
arch_flag = f"-fsycl-targets={self.arch}"
else:
arch_flag = f"-arch=sm_{self.arch}"
if self.arch == 90:
arch_flag += "a"

return " ".join(opts)
opts.append(arch_flag)

def get(self):
options = []
return opts

for flag in self.flags:
options.append(bytes(str.encode(flag)))

for incl in self.include_paths:
options.append(bytes(str.encode(f" --include-path={incl}")))

arch_flag = f" -arch=sm_{self.arch}"
if self.arch == 90:
arch_flag += "a"
def get_str(self):
opts = self._encode()

options.append(bytes(str.encode(arch_flag)))
return " ".join(opts)

return options
def get(self):
options = self._encode()
return [bytes(str.encode(s)) for s in options]


def convertToBinaryData(filename):
Expand All @@ -122,7 +125,8 @@ def convertToBinaryData(filename):

def CDLLBin(host_binary):
tempfile.tempdir = "./"
temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True)
temp_so = tempfile.NamedTemporaryFile(
prefix="host_func", suffix=".so", delete=True)
with open(temp_so.name, "wb") as file:
file.write(host_binary)
host_lib = ctypes.CDLL(temp_so.name)
Expand Down Expand Up @@ -155,6 +159,11 @@ def __init__(self) -> None:
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
self._dpcpp_compile_options = ["-fsycl", "-std=c++17",
"-DCUTLASS_ENABLE_SYCL",
"-fsycl-rtc-mode",
"-DSYCL_INTEL_TARGET",
"-shared", "-fPIC"]
self.nvcc()
self.compiled_cache_device = {}
self.compiled_cache_host = {}
Expand All @@ -167,6 +176,13 @@ def nvcc(self):
self.backend = "nvcc"
self.default_compile_options = self._nvcc_compile_options

def dpcpp(self):
self.backend = "dpcpp"
self.default_compile_options = self._dpcpp_compile_options

def _is_sycl(self):
return self.backend == "dpcpp"

def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
Expand All @@ -191,11 +207,20 @@ def load_operation(self, op_key, extra_funcs):
for row in record:
key, cubin_image, host_binary, operation_name, op_attr = row
op_attr = json.loads(op_attr)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
module = dpctl.program.create_program_from_spirv(q, cubin_image)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))

if self._is_sycl():
kernel = module.get_sycl_kernel(operation_name)
else:
err, kernel = cuda.cuModuleGetFunction(
module, bytes(str.encode(operation_name)))

err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
self.compiled_cache_device[key] = kernel

compiled_host_fns = {}
Expand Down Expand Up @@ -237,7 +262,10 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
if incl not in includes:
includes.append(incl)

includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
includes_host = ["stddef.h"] + includes
if not self._is_sycl():
includes_host.extend(["device_launch_parameters.h", "builtin_types.h"])

for incl in includes:
source_buffer_device += SubstituteTemplate(
IncludeTemplate,
Expand All @@ -262,10 +290,11 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
operation.KernelTemplate,
values,
)
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
source_buffer_host += SubstituteTemplate(
operation.HostTemplate, values)

if self.backend == "nvrtc":
# 3. compile
# 3. compile
if self.backend == "nvrtc": # with nvrtc backend
err, program = nvrtc.nvrtcCreateProgram(
str.encode(source_buffer_device),
bytes(str.encode("module.cu")),
Expand All @@ -291,7 +320,8 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))

raise RuntimeError(error_string + log.decode() + source_buffer_device)
raise RuntimeError(
error_string + log.decode() + source_buffer_device)

# Get data from compilation
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
Expand All @@ -303,6 +333,39 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))

elif self.backend == "dpcpp": # with DPC++ backend
# Emit code to file
tempfile.tempdir = "./"
temp_cpp = tempfile.NamedTemporaryFile(
prefix="kernel_", suffix=".cpp", delete=True)
temp_dump_dir = tempfile.TemporaryDirectory(
prefix="kernel_", suffix="_dpcpp")
ignore_out = tempfile.NamedTemporaryFile(
prefix="kernel_", suffix=".o", delete=True)
with open(temp_cpp.name, "w") as file:
file.write(source_buffer_device)

# Compile with DPC++
cmd_template = "clang++ ${options} ${srcfile} -o ${outfile} -fsycl-dump-device-code=${tmpdir}"
values = {
"options": compilation_options.get_str(),
"srcfile": temp_cpp.name,
"outfile": ignore_out.name,
"tmpdir": temp_dump_dir.name
}
cmd = SubstituteTemplate(cmd_template, values)
compile_with_nvcc(cmd.split(" "), source_buffer_device,
"./cutlass_python_compilation_device_error.txt")

# Find SPIR-V device code in temporary directory
spv_files = list(pathlib.Path(temp_dump_dir.name).glob("*.spv"))
if len(spv_files) != 1:
raise RuntimeError("More than one SPIR-V files generated")

# Load the SPIR-V image
with open(spv_files[0], "rb") as file:
cubin_image = file.read()

else: # with nvcc backend
# emit code
tempfile.tempdir = "./"
Expand All @@ -322,15 +385,17 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
"tarfile": temp_cubin.name,
}
cmd = SubstituteTemplate(cmd_template, values)
compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt")
compile_with_nvcc(cmd.split(" "), source_buffer_device,
"./cutlass_python_compilation_device_error.txt")

# load the cubin image
with open(temp_cubin.name, "rb") as file:
cubin_image = file.read()

tempfile.tempdir = "./"
host_suffix = ".cpp" if self._is_sycl() else ".cu"
temp_src = tempfile.NamedTemporaryFile(
prefix="host_src", suffix=".cu", delete=True)
prefix="host_src", suffix=host_suffix, delete=True)

# Write the host source
with open(temp_src.name, "w") as outfile:
Expand All @@ -341,13 +406,23 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op

# Set up host compilation arguments
cmd = []
cmd.append(f"{cuda_install_path()}/bin/nvcc")
cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"])
cmd.extend(host_compilation_options.get_str().split(" "))
cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"])
if not self._is_sycl():
cmd.append(f"{cuda_install_path()}/bin/nvcc")
cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive",
"-Xcompiler=-w", "-Xcompiler=-fPIC"])
cmd.extend(host_compilation_options.get_str().split(" "))
cmd.extend(["-shared", "-o", temp_dst.name,
temp_src.name, "-lcudart", "-lcuda"])
else:
cmd.append("clang++")
# Clang does not support "-fpermissive"
cmd.extend(["-fsycl", "-w", "-fPIC"])
cmd.extend(host_compilation_options.get_str().split(" "))
cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name])

# Comile and load the library
compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt")
compile_with_nvcc(cmd, source_buffer_host,
error_file="./cutlass_python_compilation_host_error.txt")
host_lib = ctypes.CDLL(temp_dst.name)

return cubin_image, host_lib, temp_dst
Expand All @@ -357,20 +432,28 @@ def add_module(self, operations, compile_options=None, bypass_cache=False):
Insert a new compiled device module
"""
include_paths = [
cuda_install_path() + "/include",
CUTLASS_PATH + "/include",
CUTLASS_PATH + "/tools/util/include",
CUTLASS_PATH + "/python/cutlass/cpp/include",
]

cutlass.initialize_cuda_context()
arch = device_cc()
if not self._is_sycl():
include_paths.append(cuda_install_path() + "/include")

cutlass.initialize_cuda_context()
arch = device_cc()
host_compile_options = CompilationOptions(
self._nvcc_compile_options, arch, include_paths, False)
else:
cutlass.initialize_sycl_context()
arch = "spir64"
host_compile_options = CompilationOptions(
["-std=c++17", "-DCUTLASS_ENABLE_SYCL", "-DSYCL_INTEL_TARGET"],
arch, include_paths, True)

host_compile_options = CompilationOptions(
self._nvcc_compile_options, arch, include_paths)
if compile_options is None:
compile_options = CompilationOptions(
self.default_compile_options, arch, include_paths)
self.default_compile_options, arch, include_paths, self._is_sycl())
# save the cubin
operation_key = []
operation_list = []
Expand Down Expand Up @@ -400,18 +483,25 @@ def add_module(self, operations, compile_options=None, bypass_cache=False):
cubin_image, host_lib, host_file = self.emit_compile_(
operation_list, compile_options, host_compile_options)

err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
program = dpctl.program.create_program_from_spirv(q, cubin_image)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))

operation_name = []
operation_attr = []
for operation, key in zip(operation_list, operation_key):
# get device kernels
err, operation.kernel = cuda.cuModuleGetFunction(
module,
bytes(str.encode(operation.name()))
)
if self._is_sycl():
operation.kernel = program.get_sycl_kernel(operation.name())
else:
err, operation.kernel = cuda.cuModuleGetFunction(
module,
bytes(str.encode(operation.name()))
)
operation_name.append(operation.name())
self.compiled_cache_device[key] = operation.kernel
# get host functions
Expand Down Expand Up @@ -442,7 +532,7 @@ def add_module(self, operations, compile_options=None, bypass_cache=False):
op_attr.append(param_size)

if hasattr(operation, "extra_funcs"):
for suffix, ret_type in operation.extra_funcs.items():
for suffix, ret_type in operation.extra_funcs.items():
func_name = operation.name() + "_" + suffix
func = getattr(host_lib, func_name)
if ret_type is not None:
Expand Down
Loading
Loading