diff --git a/csrc/flashfftconv/setup.py b/csrc/flashfftconv/setup.py index a2370af..e5dd625 100644 --- a/csrc/flashfftconv/setup.py +++ b/csrc/flashfftconv/setup.py @@ -1,75 +1,210 @@ +from __future__ import annotations + +import os +import subprocess + import torch + +from functools import cache +from pathlib import Path +from typing import Tuple, List + +from packaging.version import parse, Version from setuptools import setup + from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME -import subprocess -def get_last_arch_torch(): - arch = torch.cuda.get_arch_list()[-1] - print(f"Found arch: {arch} from existing torch installation") - return arch - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - return nvcc_extra_args + ["--threads", "4"] + +CUDA_PATH: Path = Path(CUDA_HOME) +TORCH_VERSION: Version = Version(torch.__version__) + +TORCH_MAJOR: int = TORCH_VERSION.major +TORCH_MINOR: int = TORCH_VERSION.minor + +EXTENSION_NAME: str = 'monarch_cuda' +EXTENDED_CAPABILITIES: Tuple[int, ...] = (89, 90) + + +@cache +def get_cuda_bare_metal_version(cuda_dir: Path) -> Tuple[str, Version]: + + raw = ( + subprocess.run( + [str(cuda_dir / 'bin' / 'nvcc'), '-V'], + capture_output=True, + check=True, + encoding='utf-8', + ) + .stdout + ) + + output = raw.split() + version, _, _ = output[output.index('release') + 1].partition(',') + + return raw, parse(version) + + +def raise_if_cuda_home_none(global_option: str) -> None: + + if CUDA_HOME is None: + + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your " + "environment has nvcc available? If you're installing within a container from " + "https://hub.docker.com/r/pytorch/pytorch, only images whose names contain " + "'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: + + _, version = get_cuda_bare_metal_version(CUDA_PATH) + + if version >= Version("11.2"): + + nvcc_extra_args.extend(("--threads", "4")) + return nvcc_extra_args -arch = get_last_arch_torch() -# [MP] make install more flexible here -sm_num = arch[-2:] -cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] + +def arch_flags(compute: int, ptx: bool = False) -> str: + + build = 'compute' if ptx else 'sm' + + return f'arch=compute_{compute},code={build}_{compute}' + + +class CompilerFlags(List[str]): + + def add_arch(self, compute: int, ptx: bool = True, sass: bool = False): + + if ptx: + + self.append("-gencode") + self.append(arch_flags(compute, True)) + + if sass: + + self.append("-gencode") + self.append(arch_flags(compute, False)) + + return self + + +def build_compiler_flags( + ptx: bool = True, sass: bool = False, multi_arch: bool = False +) -> List[str]: + + flags = ( + CompilerFlags() + .add_arch(compute=80, ptx=ptx, sass=sass) + ) + + if multi_arch: + + _, version = get_cuda_bare_metal_version(CUDA_PATH) + + if version < Version("11.0"): + + raise RuntimeError(f"{EXTENSION_NAME} is only supported on CUDA 11 and above") + + elif version >= Version("11.8"): + + for compute in EXTENDED_CAPABILITIES: + + flags.add_arch(compute=compute, ptx=ptx, sass=sass) + + return flags + + +if not torch.cuda.is_available(): + + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, FlashFFTConv will cross-compile for Ampere (compute capabilities 8.0 and 8.6) " + "and if CUDA version >= 11.8, Ada (compute capability 8.9) and Hopper (compute capability 9.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_PATH) + + if bare_metal_version >= Version("11.8"): + + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0" + + elif bare_metal_version >= Version("11.1"): + + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" + + elif bare_metal_version == Version("11.0"): + + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0" + + else: + + raise RuntimeError(f"{EXTENSION_NAME} is only supported on CUDA 11 and above") + + +# Log PyTorch Version. +print(f"\n\ntorch.__version__ = {TORCH_VERSION}\n\n") + + +# Verify that CUDA_HOME exists. +raise_if_cuda_home_none(EXTENSION_NAME) setup( name='monarch_cuda', ext_modules=[ - CUDAExtension('monarch_cuda', [ - 'monarch.cpp', - 'monarch_cuda/monarch_cuda_interface_fwd.cu', - 'monarch_cuda/monarch_cuda_interface_fwd_complex.cu', - 'monarch_cuda/monarch_cuda_interface_fwd_bf16.cu', - 'monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu', - 'monarch_cuda/monarch_cuda_interface_fwd_r2r.cu', - 'monarch_cuda/monarch_cuda_interface_fwd_r2r_bf16.cu', - 'monarch_cuda/monarch_cuda_interface_bwd.cu', - 'monarch_cuda/monarch_cuda_interface_bwd_complex.cu', - 'monarch_cuda/monarch_cuda_interface_bwd_bf16.cu', - 'monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu', - 'monarch_cuda/monarch_cuda_interface_bwd_r2r.cu', - 'monarch_cuda/monarch_cuda_interface_bwd_r2r_bf16.cu', - 'butterfly/butterfly_cuda.cu', - 'butterfly/butterfly_padded_cuda.cu', - 'butterfly/butterfly_padded_cuda_bf16.cu', - 'butterfly/butterfly_ifft_cuda.cu', - 'butterfly/butterfly_cuda_bf16.cu', - 'butterfly/butterfly_ifft_cuda_bf16.cu', - 'butterfly/butterfly_padded_ifft_cuda.cu', - 'butterfly/butterfly_padded_ifft_cuda_bf16.cu', - 'conv1d/conv1d_bhl.cu', - 'conv1d/conv1d_blh.cu', - 'conv1d/conv1d_bwd_cuda_bhl.cu', - 'conv1d/conv1d_bwd_cuda_blh.cu', - ], - extra_compile_args={'cxx': ['-O3'], - 'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag) - }) + CUDAExtension( + name=EXTENSION_NAME, + sources=[ + 'monarch.cpp', + 'monarch_cuda/monarch_cuda_interface_fwd.cu', + 'monarch_cuda/monarch_cuda_interface_fwd_complex.cu', + 'monarch_cuda/monarch_cuda_interface_fwd_bf16.cu', + 'monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu', + 'monarch_cuda/monarch_cuda_interface_fwd_r2r.cu', + 'monarch_cuda/monarch_cuda_interface_fwd_r2r_bf16.cu', + 'monarch_cuda/monarch_cuda_interface_bwd.cu', + 'monarch_cuda/monarch_cuda_interface_bwd_complex.cu', + 'monarch_cuda/monarch_cuda_interface_bwd_bf16.cu', + 'monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu', + 'monarch_cuda/monarch_cuda_interface_bwd_r2r.cu', + 'monarch_cuda/monarch_cuda_interface_bwd_r2r_bf16.cu', + 'butterfly/butterfly_cuda.cu', + 'butterfly/butterfly_padded_cuda.cu', + 'butterfly/butterfly_padded_cuda_bf16.cu', + 'butterfly/butterfly_ifft_cuda.cu', + 'butterfly/butterfly_cuda_bf16.cu', + 'butterfly/butterfly_ifft_cuda_bf16.cu', + 'butterfly/butterfly_padded_ifft_cuda.cu', + 'butterfly/butterfly_padded_ifft_cuda_bf16.cu', + 'conv1d/conv1d_bhl.cu', + 'conv1d/conv1d_blh.cu', + 'conv1d/conv1d_bwd_cuda_bhl.cu', + 'conv1d/conv1d_bwd_cuda_blh.cu', + ], + extra_compile_args=( + { + 'cxx': ['-O3'], + 'nvcc': append_nvcc_threads( + ['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + + build_compiler_flags(ptx=True, sass=False, multi_arch=False) + ), + } + ), + ), ], - cmdclass={ - 'build_ext': BuildExtension - }, + cmdclass={'build_ext': BuildExtension}, version='0.0.0', description='Fast FFT algorithms for convolutions', url='https://github.com/HazyResearch/flash-fft-conv', author='Dan Fu, Hermann Kumbong', author_email='danfu@cs.stanford.edu', - license='Apache 2.0') \ No newline at end of file + license='Apache 2.0' +)