From 5d513616803019b3b6712ccb8d298f1e68a1ddba Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Tue, 22 Oct 2024 15:49:18 +0000 Subject: [PATCH 01/92] Add cutlass 2:4 infrastructure --- .../semi_structured_benchmarks.py | 373 ++++++ csrc/ops.h | 3 + csrc/semi_structured/cusparselt/binding.py | 47 + .../cusparselt/cusparselt_mm.cu | 1077 +++++++++++++++++ .../cusparselt/cusparselt_mm_entry.cu | 135 +++ csrc/semi_structured/cutlass/common.hpp | 27 + .../cutlass/semi_structured_mm_c3x.cu | 223 ++++ .../cutlass/semi_structured_mm_entry.cu | 54 + csrc/torch_bindings.cpp | 7 + vllm/_custom_ops.py | 15 + 10 files changed, 1961 insertions(+) create mode 100644 benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py create mode 100644 csrc/semi_structured/cusparselt/binding.py create mode 100644 csrc/semi_structured/cusparselt/cusparselt_mm.cu create mode 100644 csrc/semi_structured/cusparselt/cusparselt_mm_entry.cu create mode 100644 csrc/semi_structured/cutlass/common.hpp create mode 100644 csrc/semi_structured/cutlass/semi_structured_mm_c3x.cu create mode 100644 csrc/semi_structured/cutlass/semi_structured_mm_entry.cu diff --git a/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py b/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py new file mode 100644 index 0000000000000..61eed3da41458 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py @@ -0,0 +1,373 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +# helpers + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_semi_structured_mm", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # cutlass impl - float16 + timers.append( + bench_fn(label, sub_label, + "cutlass_fp16_fp16_fp16_semi_structured_mm", + torch.mm, a.to(dtype=torch.float16), + b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_semi_structured_mm", + ops.cutlass_semi_structured_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # # pytorch impl: bf16 output, without fp8 fast accum + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp8_fp8_bf16_semi_structured_mm", + # torch._semi_structured_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.bfloat16)) + + # # pytorch impl: bf16 output, with fp8 fast accum + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp8_fp8_bf16_semi_structured_mm_fast_accum", + # torch._semi_structured_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.bfloat16, + # use_fast_accum=True)) + + # # pytorch impl: fp16 output, without fp8 fast accum + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp8_fp8_fp16_semi_structured_mm", + # torch._semi_structured_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.float16)) + + # # pytorch impl: fp16 output, with fp8 fast accum + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp8_fp8_fp16_semi_structured_mm_fast_accum", + # torch._semi_structured_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.float16, + # use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_semi_structured_mm", + ops.cutlass_semi_structured_mm, a, b, + torch.bfloat16)) + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_semi_structured_mm", + ops.cutlass_semi_structured_mm, a, b, + torch.float16)) + + # # cutlass impl: bf16 output, with bias + # timers.append( + # bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_semi_structured_mm_bias", + # ops.cutlass_semi_structured_mm, a, b, scale_a, scale_b, + # torch.bfloat16, bias)) + + # # cutlass impl: fp16 output, with bias + # timers.append( + # bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_semi_structured_mm_bias", + # ops.cutlass_semi_structured_mm, a, b, scale_a, scale_b, + # torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"semi_structured-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/csrc/ops.h b/csrc/ops.h index c10c34e085750..c0b4fa7f5d15e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -115,6 +115,9 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +void cutlass_semi_structured_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/semi_structured/cusparselt/binding.py b/csrc/semi_structured/cusparselt/binding.py new file mode 100644 index 0000000000000..035c18abd312a --- /dev/null +++ b/csrc/semi_structured/cusparselt/binding.py @@ -0,0 +1,47 @@ +from torch.utils.cpp_extension import load +import os +import torch + +base_path = __file__.replace("spmm.py", "") + +if not os.path.exists(f"{base_path}/build"): + os.makedirs(f"{base_path}/build") + +if not os.path.exists(base_path + "/libcusparse_lt"): + os.system( + "wget https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-x86_64/libcusparse_lt-linux-x86_64-0.5.1.1-archive.tar.xz") + os.system("tar -xf libcusparse_lt-linux-x86_64-0.5.1.1-archive.tar.xz") + os.system(f"mv libcusparse_lt-linux-x86_64-0.5.1.1-archive {base_path}/libcusparse_lt") + os.system("rm libcusparse_lt-linux-x86_64-0.5.1.1-archive.tar.xz") + +pruner = load(name='pruner', + sources=[f'{base_path}/spmm_backend.cpp', + f'{base_path}/spmm_backend.cu', + ], + extra_cflags=[ + f'-L{base_path}/libcusparse_lt/lib', + '-lcusparse', + '-lcusparseLt', + '-ldl' + ], + extra_cuda_cflags=[ + f'-L{base_path}/libcusparse_lt/lib', + '-lcusparse', + '-lcusparseLt', + '-ldl' + ], + extra_ldflags=[ + f'-L{base_path}/libcusparse_lt/lib', + '-lcusparse', + '-lcusparseLt', + '-ldl' + ], + extra_include_paths=[ + base_path + '/libcusparse_lt/include' + ], + build_directory=f'{base_path}/build', + with_cuda=True, + verbose=False) + +init_flag = pruner.init_cusparse_lt() +assert init_flag == 0, "Failed to initialize CuSparseLT" \ No newline at end of file diff --git a/csrc/semi_structured/cusparselt/cusparselt_mm.cu b/csrc/semi_structured/cusparselt/cusparselt_mm.cu new file mode 100644 index 0000000000000..0e088b35c7b87 --- /dev/null +++ b/csrc/semi_structured/cusparselt/cusparselt_mm.cu @@ -0,0 +1,1077 @@ +/* + * Copyright 1993-2023 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ +#include // cudaMalloc, cudaMemcpy, etc. +#include // cusparseLt header +#include // printf +#include // std::rand +#include // std::vector +#include +#include + + +#define INT8_OUTPUT_TYPE int32_t //at::Half //int8_t +#define INT8_OUTPUT_TYPE_CUDA CUDA_R_8I //CUDA_R_32I +#define INT8_OUTPUT_TYPE_TORCH torch::kInt32 //torch::kInt32 + + +#define MAX(a, b) ((abs(a) > abs(b) ? (a) : (b))) +#define MIN(a, b) ((abs(a) < abs(b) ? (a) : (b))) + + +#define CHECK_CUDA(func) \ +{ \ + cudaError_t status = (func); \ + if (status != cudaSuccess) { \ + printf("CUDA API failed at line %d with error: %s (%d)\n", \ + __LINE__, cudaGetErrorString(status), status); \ + return EXIT_FAILURE; \ + } \ +} + + +#define CHECK_CUDA_TORCH(func) \ +{ \ + cudaError_t status = (func); \ + if (status != cudaSuccess) { \ + printf("CUDA API failed at line %d with error: %s (%d)\n", \ + __LINE__, cudaGetErrorString(status), status); \ + return torch::ones(1); \ + } \ +} + + +#define CHECK_CUSPARSE(func) \ +{ \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + printf("CUSPARSE API failed at line %d with error: %s (%d)\n", \ + __LINE__, cusparseGetErrorString(status), status); \ + return EXIT_FAILURE; \ + } \ +} + + +#define CHECK_CUSPARSE_TORCH(func) \ +{ \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + printf("CUSPARSE API failed at line %d with error: %s (%d)\n", \ + __LINE__, cusparseGetErrorString(status), status); \ + return torch::ones(1); \ + } \ +} + +constexpr int EXIT_UNSUPPORTED = 2; + +cusparseLtHandle_t handle; + +float alpha = 1.0; +float beta = 0.0; + + +typedef struct { + at::Half data; + int index; +} indexed_half; + + +int init_cusparse_lt_cuda() +{ + int major_cc, minor_cc; + CHECK_CUDA( cudaDeviceGetAttribute(&major_cc, + cudaDevAttrComputeCapabilityMajor, 0) ) + CHECK_CUDA( cudaDeviceGetAttribute(&minor_cc, + cudaDevAttrComputeCapabilityMinor, 0) ) + if (!(major_cc == 8 && minor_cc == 0) && + !(major_cc == 8 && minor_cc == 6) && + !(major_cc == 8 && minor_cc == 9)) { + std::printf("\ncusparseLt is supported only on GPU devices with" + " compute capability == 8.0, 8.6, 8.9 current: %d.%d\n\n", + major_cc, minor_cc); + return EXIT_UNSUPPORTED; + } + CHECK_CUSPARSE( cusparseLtInit(&handle) ) + + return EXIT_SUCCESS; +} + + +typedef struct cusparseLtMatmulArgs_t { + cusparseLtMatmulPlan_t* plan; + cusparseLtMatmulDescriptor_t* matmul; + cusparseLtMatmulAlgSelection_t* alg_sel; + cudaStream_t* streams; + int num_streams; + cudaStream_t stream; + size_t workspace_size; +// void* d_workspace; + void *dCompressed; + int m; + int n; +// torch::Tensor grad; + + cusparseLtMatmulArgs_t() + { + plan = new cusparseLtMatmulPlan_t; + matmul = new cusparseLtMatmulDescriptor_t; + alg_sel = new cusparseLtMatmulAlgSelection_t; + streams = nullptr; + num_streams = 0; + stream = nullptr; + m = 0; + n = 0; + dCompressed = nullptr; + } + + ~cusparseLtMatmulArgs_t() + { + cusparseLtMatmulPlanDestroy(plan); +// cudaFree(d_workspace); + } +} cusparseLtMatmulArgs ; + + +std::vector matmul_args; + + +template +int setup_prune_matmul( const int m, + const int n, + const int k, + T *dSparse, + T *dDense, + int *index, + const bool transpose_A=false, + const bool transpose_B=false, + const bool sparseA=true, + const bool transposable_mask=false, + const bool is_sparse_pruned=false, + const bool check_sparsity=false, + cudaDataType_t input_type=CUDA_R_16F, + cudaDataType_t output_type=CUDA_R_16F, + cusparseComputeType compute_type=CUSPARSE_COMPUTE_16F) +{ + matmul_args.push_back(new cusparseLtMatmulArgs_t); + *index = matmul_args.size() - 1; + + auto args = matmul_args.back(); + args->m = m; + args->n = n; + + // Host problem definition, row-major order + // bigger sizes may require dynamic allocations + auto order = CUSPARSE_ORDER_ROW; + auto opA = transpose_A ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + + bool is_rowmajor = (order == CUSPARSE_ORDER_ROW); + bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE); + bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE); + auto num_A_rows = (isA_transposed) ? k : m; + auto num_A_cols = (isA_transposed) ? m : k; + auto num_B_rows = (isB_transposed) ? n : k; + auto num_B_cols = (isB_transposed) ? k : n; + auto num_C_rows = m; + auto num_C_cols = n; + unsigned alignment = 16; + auto lda = (is_rowmajor) ? num_A_cols : num_A_rows; + auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows; + auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows; + auto C_height = (is_rowmajor) ? num_C_rows : num_C_cols; + auto C_size = C_height * ldc * sizeof(V); + + + cusparseLtMatDescriptor_t* matA; + cusparseLtMatDescriptor_t* matB; + cusparseLtMatDescriptor_t* matC; + matA = new cusparseLtMatDescriptor_t; + matB = new cusparseLtMatDescriptor_t; + matC = new cusparseLtMatDescriptor_t; + + V *dC, *dD; + CHECK_CUDA( cudaMalloc((void**) &dC, C_size) ) + dD = dC; + + int *d_valid; + CHECK_CUDA( cudaMalloc((void**) &d_valid, sizeof(int)) ) + + // matrix descriptor initialization + if(sparseA) + { + CHECK_CUSPARSE( cusparseLtStructuredDescriptorInit( + &handle, matA, num_A_rows, + num_A_cols, lda, alignment, + input_type, order, + CUSPARSELT_SPARSITY_50_PERCENT) ) + + CHECK_CUSPARSE( cusparseLtDenseDescriptorInit( + &handle, matB, num_B_rows, + num_B_cols, ldb, alignment, + input_type, order) ) + } + else + { + CHECK_CUSPARSE( cusparseLtStructuredDescriptorInit( + &handle, matB, num_B_rows, + num_B_cols, ldb, alignment, + input_type, order, + CUSPARSELT_SPARSITY_50_PERCENT) ) + + CHECK_CUSPARSE( cusparseLtDenseDescriptorInit( + &handle, matA, num_A_rows, + num_A_cols, lda, alignment, + input_type, order) ) + } + CHECK_CUSPARSE( cusparseLtDenseDescriptorInit( + &handle, matC, num_C_rows, + num_C_cols, ldc, alignment, + output_type, order) ) + + // matmul, algorithm selection, and plan initialization + CHECK_CUSPARSE( cusparseLtMatmulDescriptorInit( + &handle, args->matmul, opA, opB, + matA, matB, matC, matC, + compute_type) ) + + CHECK_CUSPARSE( cusparseLtMatmulAlgSelectionInit( + &handle, args->alg_sel, args->matmul, + CUSPARSELT_MATMUL_ALG_DEFAULT) ) + + CHECK_CUSPARSE( cusparseLtMatmulPlanInit(&handle, args->plan, args->matmul, args->alg_sel)) + + //-------------------------------------------------------------------------- + // Prune the A matrix (in-place) and check the correctness + if (!is_sparse_pruned){ + cusparseLtPruneAlg_t prune_alg = transposable_mask ? CUSPARSELT_PRUNE_SPMMA_TILE : CUSPARSELT_PRUNE_SPMMA_STRIP; + CHECK_CUSPARSE( cusparseLtSpMMAPrune(&handle, args->matmul, dSparse, dSparse, + prune_alg, args->stream) ) + } + if (check_sparsity) + { + CHECK_CUSPARSE( cusparseLtSpMMAPruneCheck(&handle, args->matmul, dSparse, d_valid, args->stream) ) + int is_valid; + CHECK_CUDA( cudaMemcpyAsync(&is_valid, d_valid, sizeof(int), cudaMemcpyDeviceToHost, args->stream) ) + CHECK_CUDA( cudaStreamSynchronize(args->stream) ) + if (is_valid != 0) { + std::printf("!!!! The matrix does not conform to the SpMMA sparsity pattern. " + "cusparseLtMatmul does not provide correct results\n"); + return EXIT_FAILURE; + } + } + + +// int *d_valid; +// CHECK_CUDA( cudaMalloc((void**) &d_valid, sizeof(int)) ) +// CHECK_CUSPARSE( cusparseLtSpMMAPruneCheck2( &handle, +// sparseA ? matA : matB, +// sparseA, +// sparseA ? opA : opB, +// dSparse, +// d_valid, +// args->stream) ) + +// int is_valid; +// CHECK_CUDA( cudaMemcpyAsync(&is_valid, d_valid, sizeof(int), +// cudaMemcpyDeviceToHost, args->stream) ) +// CHECK_CUDA( cudaStreamSynchronize(args->stream) ) +// if (is_valid != 0) { +// std::printf("!!!! The matrix has been pruned in a wrong way. " +// "cusparseLtMatmul will not provide correct results\n"); +// return EXIT_FAILURE; +// } + CHECK_CUDA( cudaFree(d_valid) ) + + //-------------------------------------------------------------------------- + // Compress the A matrix + size_t compressed_size, compressed_buffer_size; + void* dCompressedBuffer; + CHECK_CUSPARSE( cusparseLtSpMMACompressedSize(&handle, + args->plan, + &compressed_size, + &compressed_buffer_size) ) + + CHECK_CUDA( cudaMalloc((void**) &args->dCompressed, compressed_size) ) + CHECK_CUDA( cudaMalloc((void**) &dCompressedBuffer, + compressed_buffer_size) ) + + CHECK_CUSPARSE( cusparseLtSpMMACompress(&handle, + args->plan, + dSparse, + (T *) args->dCompressed, + dCompressedBuffer, + args->stream) ) + CHECK_CUDA( cudaFree(dCompressedBuffer) ) + + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Search the best kernel + if(sparseA) + { +// printf("%f, %f, %f, %f, %f, %f\n", alpha, beta, *dDense,0.,0.,0.);// , dDense[0], beta, dC[0], dD[0]); + CHECK_CUSPARSE( cusparseLtMatmulSearch(&handle, args->plan, &alpha, + (T*) args->dCompressed, dDense, &beta, + dC, dD, nullptr, + args->streams, args->num_streams) ) + } else { + CHECK_CUSPARSE( cusparseLtMatmulSearch(&handle, args->plan, &alpha, + dDense, (T*) args->dCompressed, &beta, + dC, dD, nullptr, + args->streams, args->num_streams) ) + } +// // otherwise, it is possible to set it directly: +// int alg = 0; +// CHECK_CUSPARSE( cusparseLtMatmulAlgSetAttribute( +// &handle, args->alg_sel, +// CUSPARSELT_MATMUL_ALG_CONFIG_ID, +// &alg, sizeof(alg))) + + + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + CHECK_CUSPARSE( cusparseLtMatmulPlanInit(&handle, args->plan, args->matmul, args->alg_sel)) + + CHECK_CUSPARSE( cusparseLtMatmulGetWorkspace(&handle, args->plan, + &args->workspace_size)) + +// printf("workspace_size: %lu (MB)\n", args->workspace_size / 1024 / 1024); + CHECK_CUDA( cudaFree(dC) ) + cusparseLtMatDescriptorDestroy(matA); + cusparseLtMatDescriptorDestroy(matB); + cusparseLtMatDescriptorDestroy(matC); + + return EXIT_SUCCESS; +} + +int destroy_cusparse_matmul_cuda(int index){ + if (index > matmul_args.size() - 1) + throw std::runtime_error("Index out of range of matmul_args"); + + auto args = matmul_args[index]; + cusparseLtMatmulPlanDestroy(args->plan); + CHECK_CUDA(cudaFree(args->streams)); + CHECK_CUDA(cudaFree(args->dCompressed)); + matmul_args.erase(matmul_args.begin() + index); + + return EXIT_SUCCESS; +} + +torch::Tensor setup_spmatmul_cuda(torch::Tensor A, + torch::Tensor B, + const bool transpose_A=false, + const bool transpose_B=false, + const bool sparseA=true, + const bool transposable_mask=false, + const bool is_sparse_pruned=false, + const bool check_sparsity=false) { + auto index = torch::zeros({1}, torch::kInt32); + int result; + int m, k, n; + if(transpose_A && transpose_B) + { + m = A.size(1); + k = A.size(0); + n = B.size(0); + } else if(transpose_A) + { + m = A.size(1); + k = A.size(0); + n = B.size(1); + } else if(transpose_B) + { + m = A.size(0); + k = A.size(1); + n = B.size(0); + } else { + m = A.size(0); + k = A.size(1); + n = B.size(1); + } + switch (A.type().scalarType()) { + case torch::ScalarType::Half: + { + auto sparse_mat = sparseA ? A.data_ptr() : B.data_ptr(); + auto dense_mat = sparseA ? B.data_ptr() : A.data_ptr(); + at::Half *dCompressed; + result = setup_prune_matmul( m, + n, + k, + sparse_mat, + dense_mat, + index.data_ptr(), + transpose_A, + transpose_B, + sparseA, + transposable_mask, + is_sparse_pruned, + check_sparsity, + CUDA_R_16F, + CUDA_R_16F, + CUSPARSE_COMPUTE_16F); + break; + } + case torch::ScalarType::Char: + { + auto sparse_mat = sparseA ? A.data_ptr() : B.data_ptr(); + auto dense_mat = sparseA ? B.data_ptr() : A.data_ptr(); + int8_t *dCompressed; + result = setup_prune_matmul( m, + n, + k, + sparse_mat, + dense_mat, + index.data_ptr(), + transpose_A, + transpose_B, + sparseA, + transposable_mask, + is_sparse_pruned, + check_sparsity, + CUDA_R_8I, + INT8_OUTPUT_TYPE_CUDA, + CUSPARSE_COMPUTE_32I); + break;} + default: + { + std::cout << A.type().scalarType() << std::endl; + throw std::runtime_error("Unsupported data type"); + } + } + if(result == EXIT_SUCCESS) { + return index; + } else { + return -torch::ones({1}, torch::kInt32); + } +} + + +template +torch::Tensor matmul( T* dDense, + int index, + bool sparseA, + int m, + torch::TensorOptions options=torch::TensorOptions() + ) +{ + auto args = matmul_args[index]; + + torch::Tensor C = torch::zeros({m, args->n}, options); + auto dC = C.data_ptr(); + auto dD = dC; + auto dA = sparseA ? (T*) args->dCompressed : dDense; + auto dB = sparseA ? dDense : (T*) args->dCompressed; + void *d_workspace; + CHECK_CUDA_TORCH( cudaMalloc((void**) &d_workspace, args->workspace_size) ) + // Perform the matrix multiplication + CHECK_CUSPARSE_TORCH( cusparseLtMatmul(&handle, args->plan, &alpha, dA, dB, + &beta, dC, dD, d_workspace, args->streams, + args->num_streams) ) + CHECK_CUDA_TORCH( cudaFree(d_workspace) ) + return C; +} + + +torch::Tensor spmatmul_cuda(torch::Tensor Dense, + int index, + bool sparseA) +{ + switch (Dense.type().scalarType()) { + case torch::ScalarType::Half: { + auto options = torch::TensorOptions().dtype(torch::kHalf).device(torch::kCUDA); + return matmul(Dense.data_ptr(), index, sparseA, Dense.size(0), options); + } + case torch::ScalarType::Char: { + auto options = torch::TensorOptions().dtype(INT8_OUTPUT_TYPE_TORCH).device(torch::kCUDA); + return matmul(Dense.data_ptr(), index, sparseA, Dense.size(0), options); + } + default: + { + throw std::runtime_error("Unsupported data type"); + } + } +} + + +void save_grad_cuda(torch::Tensor grad, int index) +{ + auto args = matmul_args[index]; +// args->grad = grad.clone().detach(); +} + + +__global__ void prune_kernel( + const float* __restrict__ input, + float* __restrict__ output, + bool* __restrict__ mask, + size_t row_size) { + const int column = 4 * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + reinterpret_cast(&output[index])[0] = reinterpret_cast(&input[index])[0]; + if(abs(output[index]) > abs(output[index + 1])){ + output[index + 1] = 0.; + mask[index + 1] = true; + } else { + output[index] = 0.; + mask[index] = true; + } + if(abs(output[index + 2]) > abs(output[index + 3])){ + output[index + 3] = 0.; + mask[index + 3] = true; + } else { + output[index + 2] = 0.; + mask[index + 2] = true; + } + } +} + + +__global__ void prune_kernel( + const at::Half* __restrict__ input, + at::Half* __restrict__ output, + bool* __restrict__ mask, + size_t row_size) { + const int column = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + reinterpret_cast(&output[index])[0] = reinterpret_cast(&input[index])[0]; + at::Half min1, min2; + int min_idx1, min_idx2; + min1 = output[index]; + min_idx1 = index; + if(MIN(min1, output[index + 1]) == output[index + 1]){ + min1 = output[index + 1]; + min_idx1 = index + 1; + } + if(MIN(min1, output[index + 2]) == output[index + 2]){ + min1 = output[index + 2]; + min_idx1 = index + 2; + } + if(MIN(min1, output[index + 3]) == output[index + 3]){ + min1 = output[index + 3]; + min_idx1 = index + 3; + } + min2 = min_idx1 == index ? output[index + 1] : output[index]; + min_idx2 = min_idx1 == index ? index + 1 : index; + if((MIN(min2, output[index + 1]) == output[index + 1]) && min_idx1 != index + 1){ + min2 = output[index + 1]; + min_idx2 = index + 1; + } + if((MIN(min2, output[index + 2]) == output[index + 2]) && min_idx1 != index + 2){ + min2 = output[index + 2]; + min_idx2 = index + 2; + } + if((MIN(min2, output[index + 3]) == output[index + 3]) && min_idx1 != index + 3){ + min2 = output[index + 3]; + min_idx2 = index + 3; + } + output[min_idx1] = 0.; mask[min_idx1] = true; + output[min_idx2] = 0.; mask[min_idx2] = true; + + min1 = output[index + 4]; + min_idx1 = index + 4; + if(MIN(min1, output[index + 5]) == output[index + 5]){ + min1 = output[index + 5]; + min_idx1 = index + 5; + } + if(MIN(min1, output[index + 6]) == output[index + 6]){ + min1 = output[index + 6]; + min_idx1 = index + 6; + } + if(MIN(min1, output[index + 7]) == output[index + 7]){ + min1 = output[index + 7]; + min_idx1 = index + 7; + } + min2 = min_idx1 == index + 4 ? output[index + 5] : output[index + 4]; + min_idx2 = min_idx1 == index + 4 ? index + 5 : index + 4; + if((MIN(min2, output[index + 5]) == output[index + 5]) && min_idx1 != index + 5){ + min2 = output[index + 5]; + min_idx2 = index + 5; + } + if((MIN(min2, output[index + 6]) == output[index + 6]) && min_idx1 != index + 6){ + min2 = output[index + 6]; + min_idx2 = index + 6; + } + if((MIN(min2, output[index + 7]) == output[index + 7]) && min_idx1 != index + 7){ + min2 = output[index + 7]; + min_idx2 = index + 7; + } + + output[min_idx1] = 0.; mask[min_idx1] = true; + output[min_idx2] = 0.; mask[min_idx2] = true; + } +} + + +template +__device__ void find_kth_smallest( + int *smallest_idx, + const T* __restrict__ input, + const int k, + const int M, int index) { + int min_idx = 0; + T min = 6.0e4; + + for(int i = 0; i < M; i++) + { + bool ignore = false; + for(int j = 0; j < k; j++) + { + if(smallest_idx[j] == i) + { + ignore = true; + } + } + if(ignore) + { + continue; + } + if(MIN(min, input[i]) == input[i]){ + min = input[i]; + min_idx = i; + } + } + smallest_idx[k] = min_idx; +} + + +__global__ void prune_kernel( + const at::Half* __restrict__ input, + at::Half* __restrict__ output, + bool* __restrict__ mask, + size_t row_size, + const int N, + const int M) { + + const int column = M * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + for(int i = 0; i < M / 8; i++) + { + reinterpret_cast(&output[index + 8 * i])[0] = reinterpret_cast(&input[index + 8 * i])[0]; + } + + int min_idx_list[16]; + for(int k = 0; k < (M - N); k++) + { + find_kth_smallest(min_idx_list, &input[index], k, M, index); + } + + for(int i = 0; i < (M - N); i++) + { + output[min_idx_list[i] + index] = 0.; mask[min_idx_list[i] + index] = true; + } + } +} + + +__global__ void prune_kernel( + const float* __restrict__ input, + float* __restrict__ output, + bool* __restrict__ mask, + size_t row_size, + const int N, + const int M) { + + const int column = M * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + for(int i = 0; i < M / 4; i++) + { + reinterpret_cast(&output[index + 4 * i])[0] = reinterpret_cast(&input[index + 4 * i])[0]; + } + + int *min_idx_list; + min_idx_list = (int*)malloc((M - N) * sizeof(int)); + for(int k = 0; k < (M - N); k++) + { + find_kth_smallest(min_idx_list, &input[index], k, M, index); + } + + for(int i = 0; i < (M - N); i++) + { + output[min_idx_list[i] + index] = 0.; mask[min_idx_list[i] + index] = true; + } + } +} + + +template +__global__ void prune_kernel( + const float* __restrict__ input, + float* __restrict__ output, + bool* __restrict__ mask, + size_t row_size) { + + const int column = M * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + for(int i = 0; i < M / 4; i++) + { + reinterpret_cast(&output[index + 4 * i])[0] = reinterpret_cast(&input[index + 4 * i])[0]; + } + + int min_idx_list[M - N]; + for(int k = 0; k < (M - N); k++) + { + find_kth_smallest(min_idx_list, &input[index], k, M, index); + } + + for(int i = 0; i < (M - N); i++) + { + output[min_idx_list[i] + index] = 0.; mask[min_idx_list[i] + index] = true; + } + } +} + + +template +__global__ void prune_kernel( + const at::Half* __restrict__ input, + at::Half* __restrict__ output, + bool* __restrict__ mask, + size_t row_size) { + + const int column = M * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + for(int i = 0; i < M / 8; i++) + { + reinterpret_cast(&output[index + 8 * i])[0] = reinterpret_cast(&input[index + 8 * i])[0]; + } + + int min_idx_list[M - N]; + for(int k = 0; k < (M - N); k++) + { + find_kth_smallest(min_idx_list, &input[index], k, M, index); + } + + for(int i = 0; i < (M - N); i++) + { + output[min_idx_list[i] + index] = 0.; mask[min_idx_list[i] + index] = true; + } + } +} + + +std::vector prune_cuda( + torch::Tensor input, const int N, const int M) { + + auto output = torch::zeros_like(input); + auto options = torch::TensorOptions().dtype(torch::kBool); + auto mask = torch::zeros_like(input, options); + + const auto batch_size = input.size(0); + const auto row_size = input.size(1); + + const int threads = 1024; + + if(N == 1 && M == 2) { + switch (input.type().scalarType()) { + case torch::ScalarType::Float: { + const dim3 blocks(((row_size / 4) + threads - 1) / threads, batch_size); + prune_kernel<<>>( + input.data(), + output.data(), + mask.data(), + row_size); + break; + } + case torch::ScalarType::Half: { + throw std::runtime_error("Half precision not supported for N=1, M=2"); + } + } + } + else if(N == 2 && M == 4) + { + switch (input.type().scalarType()) { + case torch::ScalarType::Float: { + throw std::runtime_error("Full precision not supported for N=2, M=4"); + break; + } + case torch::ScalarType::Half: { + const dim3 blocks(((row_size / 8) + threads - 1) / threads, batch_size); + prune_kernel<<>>( + input.data(), + output.data(), + mask.data(), + row_size); + } + } + } + else if((N == 2 && M == 8)) + { + switch (input.type().scalarType()){ + case torch::ScalarType::Float: { + const dim3 blocks(((row_size / M) + threads - 1) / threads, batch_size); + prune_kernel<2, 8><<>>( + input.data(), + output.data(), + mask.data(), + row_size); + break; + } + case torch::ScalarType::Half: { + const dim3 blocks(((row_size / M) + threads - 1) / threads, batch_size); + prune_kernel<2, 8><<>>( + input.data(), + output.data(), + mask.data(), + row_size); + } + } + } + else if((N == 2 && M == 16)) + { + switch (input.type().scalarType()){ + case torch::ScalarType::Float: { + const dim3 blocks(((row_size / M) + threads - 1) / threads, batch_size); + prune_kernel<2, 16><<>>( + input.data(), + output.data(), + mask.data(), + row_size); + break; + } + case torch::ScalarType::Half: { + const dim3 blocks(((row_size / M) + threads - 1) / threads, batch_size); + prune_kernel<2, 16><<>>( + input.data(), + output.data(), + mask.data(), + row_size); + } + } + } + else + { + if(M < 8 || M % 8 != 0) + { + throw std::runtime_error("M must be a multiple of 8"); + } + switch (input.type().scalarType()) { + case torch::ScalarType::Float: + { + const dim3 blocks(((row_size / M) + threads - 1) / threads, batch_size); + prune_kernel<<>>( + input.data(), + output.data(), + mask.data(), + row_size, + N, + M); + break; + } + case torch::ScalarType::Half: + { + const dim3 blocks(((row_size / M) + threads - 1) / threads, batch_size); + prune_kernel<<>>( + input.data(), + output.data(), + mask.data(), + row_size, + N, + M); + } + } + } + return {output, mask}; +} + + +__global__ void prune_and_compress_kernel( + const at::Half* __restrict__ input, + at::Half* __restrict__ output, + bool* __restrict__ mask, + size_t row_size) { + const int input_column = 16 * (blockIdx.x * blockDim.x + threadIdx.x); + const int output_column = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + const int input_row = blockIdx.y * row_size; + const int output_row = blockIdx.y * (row_size / 2); + const int input_index = input_row + input_column; + const int output_index = output_row + output_column; + if (input_column < row_size) { + bool local_mask[16]; + reinterpret_cast(local_mask)[0] = reinterpret_cast(&mask[input_index])[0]; + + int local_index = 0; + #pragma unroll (2) + for(int i = 0; i < 2; i++) + { + at::Half local_data[8]; + reinterpret_cast(local_data)[0] = reinterpret_cast(&input[input_index + 8 * i])[0]; + #pragma unroll (8) + for(int j = 0; j < 8; j++) + { + if(local_mask[8 * i + j]) + { + output[local_index + output_index] = local_data[j]; + local_index++; + } + } + } + } +} + + +torch::Tensor prune_and_compress_cuda(torch::Tensor dense, torch::Tensor mask) +{ + auto row_size = dense.size(1); + auto batch_size = dense.size(0); + if(row_size % 16 != 0) + { + throw std::runtime_error("Pruning dimension should be a multiple of 128."); + } + auto options = torch::TensorOptions().dtype(torch::kHalf).device(torch::kCUDA); + torch::Tensor result = torch::zeros({dense.size(0), dense.size(1) / 2}, options); + const int threads = 1024; + switch (dense.type().scalarType()) { + case torch::ScalarType::Float: + { + throw std::runtime_error("Full precision not supported for prune_and_compress"); + } + case torch::ScalarType::Half: + { + const dim3 blocks(((row_size / 16) + threads - 1) / threads, batch_size); + prune_and_compress_kernel<<>>( + dense.data(), + result.data(), + mask.data(), + row_size); + } + } + return result; +} + + +__global__ void sparse_add_kernel( + const at::Half* __restrict__ mat1, + const at::Half* __restrict__ mat2, + const at::Half alpha, + const at::Half beta, + at::Half* __restrict__ output, + size_t row_size) { + const int column = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + at::Half mat1_local[8], mat2_local[8]; + reinterpret_cast(&mat1_local)[0] = reinterpret_cast(&mat1[index])[0]; + reinterpret_cast(&mat2_local)[0] = reinterpret_cast(&mat2[index])[0]; + #pragma unroll (8) + for(int i = 0; i < 8; i++) + { + output[index + i] = alpha * mat1_local[i] + beta * mat2_local[i]; + } + } + +} + + +torch::Tensor sparse_add_cuda(torch::Tensor dense, torch::Tensor sparse_index, torch::Tensor alpha, torch::Tensor beta) +{ + int row_size = dense.size(1); + int batch_size = dense.size(0); + if(row_size % 8 != 0) + { + throw std::runtime_error("Pruning dimension should be a multiple of 8."); + } + int index = sparse_index.item(); + auto args = matmul_args[index]; + torch::Tensor result = torch::zeros_like(dense); + const int threads = 1024; + switch (dense.type().scalarType()) { + case torch::ScalarType::Float: + { + throw std::runtime_error("Full precision not supported for prune_and_compress"); + } + case torch::ScalarType::Half: + { + const dim3 blocks(((row_size / 8) + threads - 1) / threads, batch_size); + sparse_add_kernel<<>>( + dense.data(), + (at::Half*) args->dCompressed, + alpha.item(), + beta.item(), + result.data(), + row_size); + } + } + return result; +} + + +__global__ void update_sparse_matrix_kernel( + const at::Half* __restrict__ new_data, + at::Half* __restrict__ output, + size_t row_size) { + const int column = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + const int index = blockIdx.y * row_size + column; + if (column < row_size) { + reinterpret_cast(&output[index])[0] = reinterpret_cast(&new_data[index])[0]; + } +} + + +void update_sparse_matrix_cuda(torch::Tensor new_data, torch::Tensor sparse_idx) +{ + auto args = matmul_args[sparse_idx.item()]; + const int threads = 1024; + switch (new_data.type().scalarType()) { + case torch::ScalarType::Float: + { + throw std::runtime_error("Full precision not supported for prune_and_compress"); + } + case torch::ScalarType::Half: + { + cudaMemcpy(args->dCompressed, new_data.data(), new_data.size(0) * new_data.size(1) * sizeof(at::Half), cudaMemcpyDeviceToDevice); + } + } +} + + +// sparse = prune_and_compress(dense, mask) +// result = add_sparse_dense(sparse_idx, dense, alpha, beta) +// update_sparse(data, sparse_idx, sparse_transpose_idx) diff --git a/csrc/semi_structured/cusparselt/cusparselt_mm_entry.cu b/csrc/semi_structured/cusparselt/cusparselt_mm_entry.cu new file mode 100644 index 0000000000000..ddc5ef090ec2b --- /dev/null +++ b/csrc/semi_structured/cusparselt/cusparselt_mm_entry.cu @@ -0,0 +1,135 @@ +#include +#include // cusparseLt header +#include + +#define CHECK_CUDA_DEVICE(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA_DEVICE(x); CHECK_CONTIGUOUS(x) + +int init_cusparse_lt_cuda(); +torch::Tensor setup_spmatmul_cuda(torch::Tensor A, + torch::Tensor B, + const bool transpose_A=false, + const bool transpose_B=false, + const bool sparseA=true, + const bool transposable_mask=false, + const bool is_sparse_pruned=false, + const bool check_sparsity=false); + + +torch::Tensor spmatmul_cuda(torch::Tensor Dense, + int index, + bool sparseA); + +int destroy_cusparse_matmul_cuda(int index); + +void save_grad_cuda(torch::Tensor grad, int index); + + +torch::Tensor init_cusparse_lt() { + int result = init_cusparse_lt_cuda(); + if(result == EXIT_SUCCESS) { + return torch::zeros({1}, torch::kInt32); + } else { + return torch::ones({1}, torch::kInt32); + } +} + + +torch::Tensor setup_spmatmul(torch::Tensor A, + torch::Tensor B, + const bool transpose_A=false, + const bool transpose_B=false, + const bool sparseA=true, + const bool transposable_mask=false, + const bool is_sparse_pruned=false, + const bool check_sparsity=false) { + + CHECK_INPUT(A); + CHECK_INPUT(B); + return setup_spmatmul_cuda(A, + B, + transpose_A, + transpose_B, + sparseA, + transposable_mask, + is_sparse_pruned, + check_sparsity); +} + + +torch::Tensor spmatmul( torch::Tensor Dense, + torch::Tensor index, + const bool sparseA=true) { + CHECK_INPUT(Dense); +// std::cout << Dense.data_ptr()[0] << std::endl; + auto result = spmatmul_cuda( Dense, + *index.data_ptr(), + sparseA); + return result; +} + +int destroy_cusparse_matmul(int index){ + return destroy_cusparse_matmul_cuda(index); +} + +torch::Tensor save_grad(torch::Tensor input, torch::Tensor index) { + CHECK_INPUT(input); + save_grad_cuda(input, *index.data_ptr()); +} + + +std::vector prune_cuda(torch::Tensor input, const int N, const int M); + + +std::vector prune( + torch::Tensor input, const int N, const int M) { + CHECK_INPUT(input); + return prune_cuda(input, N, M); +} + + +torch::Tensor prune_and_compress_cuda(torch::Tensor input, torch::Tensor mask); + + +torch::Tensor prune_and_compress( + torch::Tensor input, torch::Tensor mask) { + CHECK_INPUT(input); + return prune_and_compress_cuda(input, mask); +} + + +torch::Tensor sparse_add_cuda(torch::Tensor dense, torch::Tensor sparse_index, torch::Tensor alpha, torch::Tensor beta); + + +torch::Tensor sparse_add( + torch::Tensor dense, torch::Tensor sparse_index, torch::Tensor alpha, torch::Tensor beta) { + CHECK_INPUT(dense); + return sparse_add_cuda(dense, sparse_index, alpha, beta); +} + + +void update_sparse_matrix_cuda(torch::Tensor new_data, torch::Tensor sparse_idx); + + +void update_sparse_matrix( + torch::Tensor new_data, torch::Tensor sparse_idx) { + CHECK_INPUT(new_data); + update_sparse_matrix_cuda(new_data, sparse_idx); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("init_cusparse_lt", &init_cusparse_lt, "Initialize CUSPARSE LT"); + m.def("setup_spmatmul", &setup_spmatmul, "Setup Sparse Matrix Multiplication"); + m.def("destroy_cusparse_matmul", &destroy_cusparse_matmul, "Destroy matmul arguments"); + m.def("spmatmul", &spmatmul, "Sparse Matrix Multiplication"); + m.def("save_grad", &save_grad, "Save Gradient"); + m.def("prune", &prune, "N:M Prune (CUDA)"); + m.def("prune_and_compress", &prune_and_compress, "Prune the dense matrix using the mask and store it in a " + "compressed tensor (CUDA)"); + m.def("sparse_add", &sparse_add, "Add the sparse matrix to the dense matrix and return a " + "compressed dense matrix(CUDA)"); + m.def("update_sparse_matrix", &update_sparse_matrix, "Update the sparse matrix with the new dense matrix " + "data (CUDA)"); +} diff --git a/csrc/semi_structured/cutlass/common.hpp b/csrc/semi_structured/cutlass/common.hpp new file mode 100644 index 0000000000000..bf04bb400790f --- /dev/null +++ b/csrc/semi_structured/cutlass/common.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + diff --git a/csrc/semi_structured/cutlass/semi_structured_mm_c3x.cu b/csrc/semi_structured/cutlass/semi_structured_mm_c3x.cu new file mode 100644 index 0000000000000..794d325b36eba --- /dev/null +++ b/csrc/semi_structured/cutlass/semi_structured_mm_c3x.cu @@ -0,0 +1,223 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "broadcast_load_epilogue_c3x.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This file defines quantized GEMM operations using the CUTLASS 3.x API, for + NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace { + +template +struct cutlass_3x_sparse_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + // using ElementAcc = + // typename std::conditional, int32_t, + // float>::type; + using ElementAcc = ElementD; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + using LayoutTagC = cutlass::layout::ColumnMajor; + using StrideC = StrideD; + + constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; + + // using CollectiveEpilogue = + // typename cutlass::epilogue::collective::CollectiveBuilder< + // cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + // ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + // ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + // EpilogueSchedule, EVTCompute>::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementAcc, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + + // static constexpr size_t CEStorageSize = + // sizeof(typename CollectiveEpilogue::SharedStorage); + // using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + // static_cast(CEStorageSize)>; + + // using CollectiveMainloop = + // typename cutlass::gemm::collective::CollectiveBuilder< + // cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // ElementAB, cutlass::layout::RowMajor, 16, + // ElementAB, cutlass::layout::ColumnMajor, 16, + // ElementAcc, TileShape, ClusterShape, + // Stages, + // KernelSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + // cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, + ElementAcc, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + + // using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + // cutlass::gemm::PersistentScheduler>>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue +>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + // typename GemmKernel::EpilogueArguments epilogue_args{ + // Gemm::Epilogue::prepare_args( + // std::forward(epilogue_params)...), + // c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_sparse_gemm; +}; + +} // namespace + +template +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config_default::Cutlass3xGemm; + + return cutlass_gemm_caller(out, a, b); +} + +void cutlass_semi_structured_mm_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b) { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_gemm_sm90_fp8_dispatch( + out, a, b); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_gemm_sm90_fp8_dispatch( + out, a, b); + } + // TODO: Add other data types +} + +#endif diff --git a/csrc/semi_structured/cutlass/semi_structured_mm_entry.cu b/csrc/semi_structured/cutlass/semi_structured_mm_entry.cu new file mode 100644 index 0000000000000..0d570a48b39ac --- /dev/null +++ b/csrc/semi_structured/cutlass/semi_structured_mm_entry.cu @@ -0,0 +1,54 @@ +#include + +#include +#include + +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +void cutlass_semi_structured_mm_sm90(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); +#endif + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} + +void cutlass_semi_structured_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b) { + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); + // Hopper + + // TODO: Guard against compilation issues for sm90 kernels +// #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { + cutlass_semi_structured_mm_sm90(c, a, b, a_scales, b_scales, bias); + return; + } +// #endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_semi_structured_mm for a compute capability less than " + "CUDA device capability: ", + version_num); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b999028fe06a9..ff14f7fb97b52 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -264,6 +264,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // CUTLASS sparse GEMM, supporting semi-structured sparsity + ops.def( + "cutlass_semi_structured_mm(Tensor! out, Tensor a," + " Tensor b) -> ()"); + ops.impl("cutlass_semi_structured_mm", torch::kCUDA, + &cutlass_semi_structured_mm); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a25f7abca5498..e8efd5f339ced 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -509,6 +509,21 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, return out +def cutlass_semi_structured_mm(a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_semi_structured_mm(out, a, b) + + return out + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, From 17f5b963d30eff3d74b1627b5489f3a35e9cfcb2 Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Mon, 28 Oct 2024 02:39:35 +0000 Subject: [PATCH 02/92] Update with test code --- CMakeLists.txt | 93 +- .../semi_structured_benchmarks.py | 171 +- .../cutlass_benchmarks/test_benchmarks.py | 367 +++ csrc/ops.h | 11 + .../broadcast_load_epilogue_c3x.hpp | 447 ++++ csrc/quantization/cutlass_test/common.hpp | 27 + .../quantization/cutlass_test/common_gemm.cuh | 568 +++++ .../quantization/cutlass_test/device_memory.h | 377 +++ .../example/62_hopper_sparse_gemm.cu | 596 +++++ .../cutlass_test/example/Makefile | 68 + .../cutlass_test/example/util/command_line.h | 313 +++ .../cutlass_test/example/util/distribution.h | 154 ++ .../example/util/gather_tensor.hpp | 215 ++ .../cutlass_test/example/util/helper.h | 108 + .../cutlass_test/example/util/host_tensor.h | 541 +++++ .../example/util/packed_stride.hpp | 570 +++++ .../util/reference/detail/inner_product.h | 135 ++ .../reference/detail/linear_to_coordinate.h | 94 + .../util/reference/device/convolution.h | 1549 ++++++++++++ .../example/util/reference/device/gemm.h | 385 +++ .../util/reference/device/gemm_complex.h | 350 +++ .../reference/device/gemm_planar_complex.h | 311 +++ .../example/util/reference/device/gett.hpp | 146 ++ .../util/reference/device/kernel/gemm.h | 162 ++ .../device/kernel/tensor_elementwise.h | 168 ++ .../reference/device/kernel/tensor_foreach.h | 159 ++ .../util/reference/device/rank_2k_complex.h | 355 +++ .../util/reference/device/tensor_compare.h | 246 ++ .../util/reference/device/tensor_fill.h | 2077 +++++++++++++++++ .../util/reference/device/tensor_foreach.h | 144 ++ .../util/reference/device/tensor_reduce.h | 510 ++++ .../util/reference/device/tensor_relu.h | 141 ++ .../util/reference/device/thread/gemm.h | 186 ++ .../example/util/reference/host/conv.hpp | 698 ++++++ .../example/util/reference/host/convolution.h | 802 +++++++ .../util/reference/host/error_metrics.h | 66 + .../example/util/reference/host/gemm.h | 531 +++++ .../util/reference/host/gemm_complex.h | 210 ++ .../util/reference/host/gemm_planar_complex.h | 228 ++ .../example/util/reference/host/gett.hpp | 538 +++++ .../example/util/reference/host/rank_2k.h | 261 +++ .../util/reference/host/rank_2k_complex.h | 318 +++ .../util/reference/host/rank_k_complex.h | 234 ++ .../example/util/reference/host/symm.h | 285 +++ .../util/reference/host/symm_complex.h | 319 +++ .../util/reference/host/tensor_compare.h | 423 ++++ .../util/reference/host/tensor_compare.hpp | 101 + .../example/util/reference/host/tensor_copy.h | 256 ++ .../util/reference/host/tensor_elementwise.h | 341 +++ .../example/util/reference/host/tensor_fill.h | 1718 ++++++++++++++ .../util/reference/host/tensor_fill.hpp | 432 ++++ .../util/reference/host/tensor_foreach.h | 134 ++ .../example/util/reference/host/tensor_norm.h | 42 + .../util/reference/host/tensor_reduce.h | 203 ++ .../util/reference/host/tensor_reduce.hpp | 203 ++ .../example/util/reference/host/trmm.h | 215 ++ .../util/reference/host/trmm_complex.h | 262 +++ .../example/util/tensor_view_io.h | 270 +++ csrc/quantization/cutlass_test/exceptions.h | 69 + csrc/quantization/cutlass_test/helper.h | 94 + csrc/quantization/cutlass_test/host_tensor.h | 541 +++++ .../cutlass_test/packed_stride.hpp | 570 +++++ csrc/quantization/cutlass_test/test_mm_c3x.cu | 205 ++ .../cutlass_test/test_mm_entry.cu | 82 + csrc/quantization/cutlass_test/test_util.cu | 199 ++ .../cutlass_w8a8/scaled_mm_c3x.cu | 4 +- .../cutlass/semi_structured_mm_c3x.cu | 231 +- .../cutlass/semi_structured_mm_entry.cu | 16 +- csrc/torch_bindings.cpp | 17 + sane_cute_errors.py | 119 + vllm/_custom_ops.py | 42 +- 71 files changed, 22778 insertions(+), 245 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/test_benchmarks.py create mode 100644 csrc/quantization/cutlass_test/broadcast_load_epilogue_c3x.hpp create mode 100644 csrc/quantization/cutlass_test/common.hpp create mode 100644 csrc/quantization/cutlass_test/common_gemm.cuh create mode 100644 csrc/quantization/cutlass_test/device_memory.h create mode 100644 csrc/quantization/cutlass_test/example/62_hopper_sparse_gemm.cu create mode 100644 csrc/quantization/cutlass_test/example/Makefile create mode 100644 csrc/quantization/cutlass_test/example/util/command_line.h create mode 100644 csrc/quantization/cutlass_test/example/util/distribution.h create mode 100644 csrc/quantization/cutlass_test/example/util/gather_tensor.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/helper.h create mode 100644 csrc/quantization/cutlass_test/example/util/host_tensor.h create mode 100644 csrc/quantization/cutlass_test/example/util/packed_stride.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/detail/inner_product.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/detail/linear_to_coordinate.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/convolution.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/gemm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/gemm_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/gemm_planar_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/gett.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/kernel/gemm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_elementwise.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_foreach.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/rank_2k_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/tensor_compare.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/tensor_fill.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/tensor_foreach.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/tensor_reduce.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/tensor_relu.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/device/thread/gemm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/conv.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/convolution.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/error_metrics.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/gemm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/gemm_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/gemm_planar_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/gett.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/rank_2k.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/rank_2k_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/rank_k_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/symm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/symm_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_copy.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_elementwise.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_foreach.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_norm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.hpp create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/trmm.h create mode 100644 csrc/quantization/cutlass_test/example/util/reference/host/trmm_complex.h create mode 100644 csrc/quantization/cutlass_test/example/util/tensor_view_io.h create mode 100644 csrc/quantization/cutlass_test/exceptions.h create mode 100644 csrc/quantization/cutlass_test/helper.h create mode 100644 csrc/quantization/cutlass_test/host_tensor.h create mode 100644 csrc/quantization/cutlass_test/packed_stride.hpp create mode 100644 csrc/quantization/cutlass_test/test_mm_c3x.cu create mode 100644 csrc/quantization/cutlass_test/test_mm_entry.cu create mode 100644 csrc/quantization/cutlass_test/test_util.cu create mode 100644 sane_cute_errors.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f6d1c66b2cf7..a13a1e8065e21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,12 +203,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG be692b48b01620eedabeef8325df5d4eeed6c2ae GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -226,7 +226,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gguf/gguf_kernel.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/cutlass_test/test_mm_entry.cu" + "csrc/quantization/cutlass_test/test_util.cu" + "csrc/semi_structured/cutlass/semi_structured_mm_entry.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -283,6 +286,90 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SCALED_MM_3X_ARCHS) endif() + # + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/quantization/cutlass_test/test_util.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") + message(STATUS "Building test_util for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building test_util as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building test_util as no compatible archs found " + "in CUDA target architectures") + endif() + + # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't + # build any 3x kernels + set(SCALED_MM_3X_ARCHS) + endif() + + # + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/quantization/cutlass_test/test_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") + message(STATUS "Building test_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building test_mm_c3x as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building test_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + + # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't + # build any 3x kernels + set(SCALED_MM_3X_ARCHS) + endif() + + # + # The cutlass_semi_structured_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/semi_structured/cutlass/semi_structured_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") + message(STATUS "Building semi_structured_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building semi_structured_mm_c3x as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + + # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't + # build any 3x kernels + set(SCALED_MM_3X_ARCHS) + endif() + # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. diff --git a/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py b/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py index 61eed3da41458..ebe6668a89e43 100644 --- a/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/semi_structured_benchmarks.py @@ -30,6 +30,18 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float16) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float16) + + +def to_fp32(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float) + + def make_rand_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: a = torch.randn((m, k), device='cuda') * 5 @@ -39,6 +51,10 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, return to_int8(a), to_int8(b) if dtype == torch.float8_e4m3fn: return to_fp8(a), to_fp8(b) + if dtype == torch.float16: + return to_fp16(a), to_fp16(b) + if dtype == torch.float: + return to_fp32(a), to_fp32(b) raise ValueError("unsupported dtype") @@ -61,150 +77,35 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, description=description, ).blocked_autorange(min_run_time=min_run_time) - -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - assert dtype == torch.int8 - a, b = make_rand_tensors(torch.int8, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) - azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) - - timers = [] - # pytorch impl - bfloat16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) - - # pytorch impl - float16 - timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) - - # cutlass impl - bfloat16 - timers.append( - bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_semi_structured_mm", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) - - # cutlass impl - float16 - timers.append( - bench_fn(label, sub_label, - "cutlass_fp16_fp16_fp16_semi_structured_mm", - torch.mm, a.to(dtype=torch.float16), - b.to(dtype=torch.float16))) - - # cutlass impl - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_semi_structured_mm", - ops.cutlass_semi_structured_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - - return timers - - -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, +def bench_fp32(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: - assert dtype == torch.float8_e4m3fn - a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + assert dtype == torch.float + a, b = make_rand_tensors(torch.float, m, n, k) timers = [] - # pytorch impl w. bf16 + # pytorch impl w. fp32 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) - - # # pytorch impl: bf16 output, without fp8 fast accum - # timers.append( - # bench_fn(label, - # sub_label, - # "pytorch_fp8_fp8_bf16_semi_structured_mm", - # torch._semi_structured_mm, - # a, - # b, - # scale_a=scale_a, - # scale_b=scale_b, - # out_dtype=torch.bfloat16)) - - # # pytorch impl: bf16 output, with fp8 fast accum - # timers.append( - # bench_fn(label, - # sub_label, - # "pytorch_fp8_fp8_bf16_semi_structured_mm_fast_accum", - # torch._semi_structured_mm, - # a, - # b, - # scale_a=scale_a, - # scale_b=scale_b, - # out_dtype=torch.bfloat16, - # use_fast_accum=True)) - - # # pytorch impl: fp16 output, without fp8 fast accum - # timers.append( - # bench_fn(label, - # sub_label, - # "pytorch_fp8_fp8_fp16_semi_structured_mm", - # torch._semi_structured_mm, - # a, - # b, - # scale_a=scale_a, - # scale_b=scale_b, - # out_dtype=torch.float16)) - - # # pytorch impl: fp16 output, with fp8 fast accum - # timers.append( - # bench_fn(label, - # sub_label, - # "pytorch_fp8_fp8_fp16_semi_structured_mm_fast_accum", - # torch._semi_structured_mm, - # a, - # b, - # scale_a=scale_a, - # scale_b=scale_b, - # out_dtype=torch.float16, - # use_fast_accum=True)) - - # cutlass impl: bf16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_semi_structured_mm", - ops.cutlass_semi_structured_mm, a, b, - torch.bfloat16)) - # cutlass impl: fp16 output + bench_fn(label, sub_label, "pytorch_f32_f32_f32_matmul-no-scales", + torch.mm, a.to(dtype=torch.float, device="cuda"), + b.to(dtype=torch.float, device="cuda"))) + + # cutlass impl: fp32 timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_semi_structured_mm", + bench_fn(label, sub_label, "cutlass_fp32_fp32_fp32_semi_structured_mm", ops.cutlass_semi_structured_mm, a, b, - torch.float16)) - - # # cutlass impl: bf16 output, with bias - # timers.append( - # bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_semi_structured_mm_bias", - # ops.cutlass_semi_structured_mm, a, b, scale_a, scale_b, - # torch.bfloat16, bias)) - - # # cutlass impl: fp16 output, with bias - # timers.append( - # bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_semi_structured_mm_bias", - # ops.cutlass_semi_structured_mm, a, b, scale_a, scale_b, - # torch.float16, bias.to(dtype=torch.float16))) - + torch.float)) + return timers def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: - if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) - if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) + return bench_fp32(torch.float, m, k, n, label, sub_label) + # if dtype == torch.int8: + # return bench_int8(dtype, m, k, n, label, sub_label) + # if dtype == torch.float8_e4m3fn: + # return bench_fp8(dtype, m, k, n, label, sub_label) raise ValueError("unsupported type") @@ -312,6 +213,10 @@ def to_torch_dtype(dt): return torch.int8 if dt == "fp8": return torch.float8_e4m3fn + if dt == "fp16": + return torch.float16 + if dt == "fp32": + return torch.float raise ValueError("unsupported dtype") parser = FlexibleArgumentParser( @@ -335,7 +240,7 @@ def to_torch_dtype(dt): parser.add_argument("--dtype", type=to_torch_dtype, required=True, - help="Available options are ['int8', 'fp8']") + help="Available options are ['int8', 'fp8', 'fp16', 'fp32']") subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") diff --git a/benchmarks/cutlass_benchmarks/test_benchmarks.py b/benchmarks/cutlass_benchmarks/test_benchmarks.py new file mode 100644 index 0000000000000..4d1884dcd2135 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/test_benchmarks.py @@ -0,0 +1,367 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +# helpers + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + a, b = make_rand_tensors(torch.int8, m, n, k) + a_compressed, e = cutlass_sparsify_and_compress_entry(a) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_test_mm, a_compressed, e, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_test_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + a_compressed, e = cutlass_sparsify_and_compress_entry(a) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_test_mm, a_compressed, e, b, scale_a, scale_b, + torch.bfloat16)) + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", + ops.cutlass_scaled_test_mm, a_compressed, e, b, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", + ops.cutlass_scaled_test_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", + ops.cutlass_scaled_test_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/test_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/test_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/test_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/csrc/ops.h b/csrc/ops.h index c0b4fa7f5d15e..e5d798cc832dd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -116,6 +116,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, c10::optional const& azp, c10::optional const& bias); +bool cutlass_scaled_test_mm_supports_fp8(int64_t cuda_device_capability); + +void cutlass_scaled_test_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_sparsify_and_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, + torch::Tensor const& a); + void cutlass_semi_structured_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b); #endif diff --git a/csrc/quantization/cutlass_test/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_test/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 0000000000000..58b1e8ff159fb --- /dev/null +++ b/csrc/quantization/cutlass_test/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_test/common.hpp b/csrc/quantization/cutlass_test/common.hpp new file mode 100644 index 0000000000000..bf04bb400790f --- /dev/null +++ b/csrc/quantization/cutlass_test/common.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + diff --git a/csrc/quantization/cutlass_test/common_gemm.cuh b/csrc/quantization/cutlass_test/common_gemm.cuh new file mode 100644 index 0000000000000..b0298a6bf5971 --- /dev/null +++ b/csrc/quantization/cutlass_test/common_gemm.cuh @@ -0,0 +1,568 @@ +using namespace cute; + +/* + This file defines quantized GEMM operations using the CUTLASS 3.x API, for + NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace { + +// A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementAB, cutlass::layout::RowMajor, 32, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_test_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; + using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; + + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; + + LayoutA a_layout = SparseConfig::fill_layoutA(prob_shape); + LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto e_ptr = static_cast(e.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_layout, + b_ptr, b_stride, + e_ptr, e_layout}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_default { + // For M > 128 and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M128 { + // For M in (64, 128] and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M64 { + // For M in (32, 64] and any N + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NBig { + // For M in [1, 32] and N >= 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NSmall { + // For M in [1, 32] and N < 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_test/device_memory.h b/csrc/quantization/cutlass_test/device_memory.h new file mode 100644 index 0000000000000..7d3fa73f62df8 --- /dev/null +++ b/csrc/quantization/cutlass_test/device_memory.h @@ -0,0 +1,377 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ interface to CUDA device memory management functions. + */ + +#include +#include + +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_types.h" +#include "cutlass/trace.h" +#include "exceptions.h" + +namespace cutlass { +namespace device_memory { + +/****************************************************************************** + * Allocation lifetime + ******************************************************************************/ + +/// Allocate a buffer of \p count elements of type \p T on the current CUDA device +template +T* allocate(size_t count = 1) { + + T* ptr = 0; + size_t bytes = 0; + + bytes = count * sizeof(T); + + cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + + if (cuda_error != cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 0) + std::ostringstream os; + os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); +#endif + throw cuda_exception("Failed to allocate memory", cuda_error); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + std::ostringstream os; + os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); + } +#endif + + return ptr; +} + +/// Free the buffer pointed to by \p ptr +template +void free(T* ptr) { + if (ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + throw cuda_exception("Failed to free device memory", cuda_error); + } + } +} + +/****************************************************************************** + * Data movement + ******************************************************************************/ + +template +void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { + size_t bytes = count * sizeof_bits::value / 8; + if (bytes == 0 && count > 0) { + bytes = 1; + } + cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); + if (cuda_error != cudaSuccess) { + std::ostringstream os; + os << "cutlass::device_memory::copy: cudaMemcpy() failed: " + << "dst=" << dst << ", src=" << src + << ", bytes=" << bytes << ", count=" << count; + if (kind == cudaMemcpyHostToDevice) { + os << ", kind=cudaMemcpyHostToDevice"; + } + else if (kind == cudaMemcpyDeviceToHost) { + os << ", kind=cudaMemcpyDeviceToHost"; + } + else if (kind == cudaMemcpyDeviceToDevice) { + os << ", kind=cudaMemcpyDeviceToDevice"; + } + else if (kind == cudaMemcpyHostToHost) { + os << ", kind=cudaMemcpyHostToHost"; + } + else if (kind == cudaMemcpyDefault) { + os << ", kind=cudaMemcpyDefault"; + } + else { + os << ", kind=Unknown"; + } + os << ", error: " << cudaGetErrorString(cuda_error); + + throw cuda_exception(os.str().c_str(), cuda_error); + } +} + +template +void copy_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToDevice); +} + +template +void copy_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToHost); +} + +template +void copy_device_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToDevice); +} + +template +void copy_host_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToHost); +} + +/// Copies elements from device memory to host-side range +template +void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { + size_t elements = end - begin; + copy_to_host(&*begin, device_begin, elements); +} + +/// Copies elements to device memory from host-side range +template +void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { + size_t elements = end - begin; + copy_to_device(device_begin, &*begin, elements); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DeviceAllocation { +public: + + /// Delete functor for CUDA device memory + struct deleter { + void operator()(T* ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + // noexcept + // throw cuda_exception("cudaFree() failed", cuda_error); + return; + } + } + }; + +public: + // + // Data members + // + + /// Number of elements of T allocated on the current CUDA device + size_t capacity; + + /// Smart pointer + platform::unique_ptr smart_ptr; + +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + + // + // Methods + // + + /// Constructor: allocates no memory + DeviceAllocation() : capacity(0) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} + + /// Copy constructor + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + } + + /// Destructor + ~DeviceAllocation() { reset(); } + + /// Returns a pointer to the managed object + T* get() const { return smart_ptr.get(); } + + /// Releases the ownership of the managed object (without deleting) and resets capacity to zero + T* release() { + capacity = 0; + return smart_ptr.release(); + } + + /// Deletes the managed object and resets capacity to zero + void reset() { + capacity = 0; + smart_ptr.reset(); + } + + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity + void reset(T* _ptr, size_t _capacity) { + smart_ptr.reset(_ptr); + capacity = _capacity; + } + + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + + /// Returns a pointer to the object owned by *this + T* operator->() const { return smart_ptr.get(); } + + /// Returns the deleter object which would be used for destruction of the managed object. + deleter& get_deleter() { return smart_ptr.get_deleter(); } + + /// Returns the deleter object which would be used for destruction of the managed object (const) + const deleter& get_deleter() const { return smart_ptr.get_deleter(); } + + /// Copies a device-side memory allocation + DeviceAllocation & operator=(DeviceAllocation const &p) { + if (capacity != p.capacity) { + smart_ptr.reset(device_memory::allocate(p.capacity)); + capacity = p.capacity; + } + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + return *this; + } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/62_hopper_sparse_gemm.cu b/csrc/quantization/cutlass_test/example/62_hopper_sparse_gemm.cu new file mode 100644 index 0000000000000..5b7361f805098 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/62_hopper_sparse_gemm.cu @@ -0,0 +1,596 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Sparse GEMM example. + + This example demonstrates how to construct and run a structured sparse GEMM kernel + on NVIDIA Hopper architecture. + +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" + +#include "util/command_line.h" +#include "util/distribution.h" +#include "util/host_tensor.h" +#include "util/packed_stride.hpp" +#include "util/tensor_view_io.h" +#include "util/reference/device/gemm.h" +#include "util/reference/device/tensor_compare.h" +#include "util/reference/device/tensor_fill.h" + +#include "util/helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel +using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy + +using ProblemShape = Shape; + +// Sparse kernel setup + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference (dense) kernel setup + +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapeRef, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShapeRef, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopRef, + CollectiveEpilogue +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +// Layouts +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// Layouts for reference (non-sparse) tensors +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; + +using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +// Offline compressor kernel +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm90>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +ProblemShape problem_shape; + +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_A_compressed; +cutlass::DeviceAllocation block_E; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(5120), n(4096), k(16384), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "62_hopper_sparse_gemm\n\n" + << " Hopper Sparse GEMM example.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM (batch size)\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Make A structured sparse by replacing elements with 0 and compress it +bool sparsify_and_compress() +{ + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + block_A_compressed.reset(M * KC * L); + block_E.reset(ME * KE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + + // Random sparsification is performed on host + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast(seed + 2024)); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + problem_shape, + { block_A.get(), + stride_A, + block_A_compressed.get(), + block_E.get() }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(Options const& options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + auto [M, N, K, L] = problem_shape; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // Allocate memory for tensors + block_A.reset(M * K * L); + block_B.reset(N * K * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_D_ref.reset(M * N * L); + + // Fill input tensors with data + initialize_block(block_A, seed + 2021); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2023); + + // Replace 0 in A with 1 to avoid metadata changes + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + if (!sparsify_and_compress()) { + return false; + }; + + // Build the compressed/metadata layouts + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + + return true; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments make_args(Options const& options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D.get(), stride_D } + }; + + return arguments; +} + +typename GemmRef::Arguments make_args_ref(Options const& options) +{ + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A.get(), stride_A, block_B.get(), stride_B }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D_ref.get(), stride_D } + }; + + return arguments; +} + +template +void print_device_tensor(cute::Tensor const& t) +{ + // Assumes size = cosize, i.e. compact tensor + std::vector data_host(t.size()); + cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size()); + auto t_host = cute::make_tensor(data_host.data(), t.layout()); + cute::print_tensor(t_host); +} + +bool verify(Options const& options) { + CUDA_CHECK(cudaDeviceSynchronize()); + + bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size()); + +#if 0 + if (!passed) { + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A)); + cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed)); + cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E)); + cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast(layout_E))); + cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B)); + cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C)); + cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D)); + cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D)); + } +#endif + + return passed; +} + +template +struct Runner +{ + using Arguments = typename Gemm::Arguments; + + Runner(Arguments args): arguments(args) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + workspace.reset(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + } + + void run() { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + + void benchmark(Options const& options) { + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + run(); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double gflops = options.gflops(avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << gflops << std::endl; + } + } + + Gemm gemm; + Arguments arguments; + cutlass::device_memory::allocation workspace; +}; + +/// Execute the example (verification and timing) +void run(Options &options) { + bool init = initialize(options); + if (!init) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + Runner gemm(make_args(options)); + Runner gemm_ref(make_args_ref(options)); + + gemm.run(); + gemm_ref.run(); + + bool passed = verify(options); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(EXIT_FAILURE); + } + + std::cout << "Sparse GEMM:" << std::endl; + gemm.benchmark(options); + + std::cout << "Dense GEMM:" << std::endl; + gemm_ref.benchmark(options); +} + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) { + std::cerr << "This example requires CUDA 12.2 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + run(options); +#endif + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/csrc/quantization/cutlass_test/example/Makefile b/csrc/quantization/cutlass_test/example/Makefile new file mode 100644 index 0000000000000..7e5eac250d2e3 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/Makefile @@ -0,0 +1,68 @@ +# Copyright 2023 The FLash-LLM Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# host compiler +HOST_COMPILER ?= g++ +CUDA_PATH ?= /usr/local/cuda/ +#below is the path for Narval +#CUDA_PATH ?= /cvmfs/soft.computecanada.ca/easybuild/software/2020/Core/cudacore/11.7.0/ +# CUDA_PATH ?= /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.2.2/ +NVCC := /usr/local/cuda/bin/nvcc -ccbin $(HOST_COMPILER) + +# internal flags +NVCCFLAGS := -m$(shell getconf LONG_BIT) +CCFLAGS := -fPIC +LDFLAGS := + +ALL_CCFLAGS := +ALL_CCFLAGS += $(NVCCFLAGS) +ALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS)) + +ALL_LDFLAGS := +ALL_LDFLAGS += $(ALL_CCFLAGS) +ALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS)) + +# Common includes and paths for CUDA +INCLUDES := -I/usr/local/cuda/include/ -I /home/ferrar/vllm/.deps/cutlass-src/include +LIBRARIES := -lcublas -lcusparse + +################################################################################ + +# Gencode arguments +SMS ?= 90 +# Generate SASS code for each SM architecture listed in $(SMS) +$(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm))) + +ALL_CCFLAGS += --threads 0 --std=c++11 -lineinfo -O3 + +FLASHLLM_CCFLAGS := -maxrregcount=255 +ALL_CCFLAGS += --use_fast_math +ALL_CCFLAGS += --ptxas-options=-v,-warn-lmem-usage,--warn-on-spills +################################################################################ + +HEAD_FILES = ./util/command_line.h \ + ./util/distribution.h \ + ./util/host_tensor.h \ + ./util/packed_stride.hpp \ + ./util/tensor_view_io.h \ + ./util/reference/device/gemm.h \ + ./util/reference/device/tensor_compare.h \ + ./util/reference/device/tensor_fill.h + + +# Target rules +all: example + +example: 62_hopper_sparse_gemm.cu $(HEAD_FILES) + $(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(OUR_CCFLAGS) $(GENCODE_FLAGS) $< -o $@ + +clean: + rm -f example \ No newline at end of file diff --git a/csrc/quantization/cutlass_test/example/util/command_line.h b/csrc/quantization/cutlass_test/example/util/command_line.h new file mode 100644 index 0000000000000..9dc3a1174067a --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/command_line.h @@ -0,0 +1,313 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * Utility for parsing command line arguments + */ + +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/****************************************************************************** + * command_line + ******************************************************************************/ + +/** + * Utility for parsing command line arguments + */ +struct CommandLine { + std::vector keys; + std::vector values; + std::vector args; + + /** + * Constructor + */ + CommandLine(int argc, const char** argv) { + using namespace std; + + for (int i = 1; i < argc; i++) { + string arg = argv[i]; + + if ((arg[0] != '-') || (arg[1] != '-')) { + args.push_back(arg); + continue; + } + + string::size_type pos; + string key, val; + if ((pos = arg.find('=')) == string::npos) { + key = string(arg, 2, arg.length() - 2); + val = ""; + } else { + key = string(arg, 2, pos - 2); + val = string(arg, pos + 1, arg.length() - 1); + } + + keys.push_back(key); + values.push_back(val); + } + } + + /** + * Checks whether a flag "--" is present in the commandline + */ + bool check_cmd_line_flag(const char* arg_name) const { + using namespace std; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) return true; + } + return false; + } + + /** + * Returns number of naked (non-flag and non-key-value) commandline parameters + */ + size_t num_naked_args() const { + return args.size(); + } + + /** + * Print naked (non-flag and non-key-value) commandline parameters + */ + void print_naked_args(std::ostream &out) const { + for (auto arg : args) { + out << " " << arg <<"\n"; + } + } + + /** + * Returns the commandline parameter for a given index (not including flags) + */ + template + void get_cmd_line_argument(size_t index, value_t& val) const { + using namespace std; + if (index < args.size()) { + istringstream str_stream(args[index]); + str_stream >> val; + } + } + + /** + * Obtains the boolean value specified for a given commandline parameter --= + */ + void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { + val = _default; + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + val = !(value == "0" || value == "false"); + } + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val, + value_t const& _default) const { + using namespace std; + + val = _default; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) { + istringstream str_stream(values[i]); + str_stream >> val; + } + } + } + + /** + * Returns the values specified for a given commandline parameter --=,* + */ + template + void get_cmd_line_arguments(const char* arg_name, + std::vector& vals, + char sep = ',') const { + using namespace std; + + if (check_cmd_line_flag(arg_name)) { + // Clear any default values + vals.clear(); + + // Recover from multi-value string + for (size_t i = 0; i < keys.size(); ++i) { + if (keys[i] == string(arg_name)) { + string val_string(values[i]); + separate_string(val_string, vals, sep); + } + } + } + } + + /** + * Returns the values specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_pairs(const char* arg_name, + std::vector >& tokens, + char delim = ',', + char sep = ':') const { + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + tokenize(tokens, value, delim, sep); + } + } + + /** + * Returns a list of ranges specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_ranges(const char* arg_name, + std::vector >& vals, + char delim = ',', + char sep = ':') const { + std::vector ranges; + get_cmd_line_arguments(arg_name, ranges, delim); + + for (std::vector::const_iterator range = ranges.begin(); + range != ranges.end(); ++range) { + + std::vector range_vals; + separate_string(*range, range_vals, sep); + vals.push_back(range_vals); + } + } + + /** + * The number of pairs parsed + */ + int parsed_argc() const { return (int)keys.size(); } + + //------------------------------------------------------------------------- + // Utility functions + //------------------------------------------------------------------------- + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector >& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + // Home-built to avoid Boost dependency + size_t s_idx = 0; + size_t d_idx = std::string::npos; + while (s_idx < str.size()) { + d_idx = str.find_first_of(delim, s_idx); + + size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); + size_t sep_idx = str.find_first_of(sep, s_idx); + size_t offset = 1; + if (sep_idx == std::string::npos || sep_idx >= end_idx) { + sep_idx = end_idx; + offset = 0; + } + + std::pair item( + str.substr(s_idx, sep_idx - s_idx), + str.substr(sep_idx + offset, end_idx - sep_idx - offset)); + + tokens.push_back(item); + s_idx = end_idx + 1; + } + } + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + typedef std::vector > TokenVector; + typedef TokenVector::const_iterator token_iterator; + + std::vector > token_pairs; + tokenize(token_pairs, str, delim, sep); + for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { + tokens.push_back(tok->first); + } + } + + template + static void separate_string(std::string const& str, + std::vector& vals, + char sep = ',') { + std::istringstream str_stream(str); + std::string::size_type old_pos = 0; + std::string::size_type new_pos = 0; + + // Iterate -delimited values + value_t val; + while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { + if (new_pos != old_pos) { + str_stream.width(new_pos - old_pos); + str_stream >> val; + vals.push_back(val); + } + + // skip over delimiter + str_stream.ignore(1); + old_pos = new_pos + 1; + } + + // Read last value + str_stream >> val; + vals.push_back(val); + } +}; + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/distribution.h b/csrc/quantization/cutlass_test/example/util/distribution.h new file mode 100644 index 0000000000000..649a573603ff5 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/distribution.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief This header contains a class to parametrize a statistical distribution function. +*/ + +#include + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Distribution type +struct Distribution { + /// Variant types + enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; + + /// Distribution state + union { + /// Uniform distribution + struct { + double min; + double max; + // Percent elements set to NaN + double pnan; + } uniform; + + /// Gaussian distribution + struct { + double mean; + double stddev; + double pnz; + double pnzA; + double pnzB; + double pnzC; + } gaussian; + + /// Elements are linear combination of row and column index + struct { + double start; + double delta; + } sequential; + }; + + /// Active variant kind + Kind kind; + + /// Random values are cast to integer after scaling by this power of two + int int_scale; + + // + // Methods + // + + Distribution() : kind(Invalid), int_scale(0) {} + +/// Configures distribution as uniform random + Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) { + kind = Uniform; + uniform.min = _min; + uniform.max = _max; + int_scale = _int_scale; + uniform.pnan = _pnan; + return *this; + } + + /// Configures distribution as Gaussian distribution + Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) { + kind = Gaussian; + gaussian.mean = _mean; + gaussian.stddev = _stddev; + gaussian.pnz = _pnz; + int_scale = _int_scale; + return *this; + } + + /// Sets identity + Distribution &set_identity() { + kind = Identity; + return *this; + } + + /// Sets sequential + Distribution &set_sequential(double start, double delta, int _int_scale = 0) { + kind = Sequential; + sequential.start = start; + sequential.delta = delta; + int_scale = _int_scale; + return *this; + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints a Distribution to ostream +inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { + switch (dist.kind) { + case cutlass::Distribution::Uniform: + out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max + << ", pnan: " << dist.uniform.pnan; + break; + case cutlass::Distribution::Gaussian: + out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev + << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " + << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; + break; + case cutlass::Distribution::Identity: + out << "identity"; + break; + case cutlass::Distribution::Sequential: + out << "sequential"; + break; + default: + out << "unknown"; + } + + out << ", int_scale: " << dist.int_scale; + + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/gather_tensor.hpp b/csrc/quantization/cutlass_test/example/util/gather_tensor.hpp new file mode 100644 index 0000000000000..62616e00c7357 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/csrc/quantization/cutlass_test/example/util/helper.h b/csrc/quantization/cutlass_test/example/util/helper.h new file mode 100644 index 0000000000000..a7a81e7479022 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/helper.h @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cuda_runtime.h" +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer +{ + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) + { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() + { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() + { + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; diff --git a/csrc/quantization/cutlass_test/example/util/host_tensor.h b/csrc/quantization/cutlass_test/example/util/host_tensor.h new file mode 100644 index 0000000000000..3f061875b48dc --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/host_tensor.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensor { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRef; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorView; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensor() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensor( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensor( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensor() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); +#endif + + device_.reset(); + host_.clear(); + + size_t count_container = count_to_container_storage_unit_count(count); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); +#endif + host_.resize(count_container); + + // Allocate memory + StorageUnit* device_memory = nullptr; + if (device_backed_) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); +#endif + device_memory = device_memory::allocate(count_container); + } + device_.reset(device_memory, device_backed_ ? count_container : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); + + if (static_cast(new_size_container) > host_.size()) { + reserve(new_size, device_backed_); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the logical number of elements stored in the host tensor + size_t size() const { + return layout_.capacity(extent_); + } + + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). + LongIndex capacity() const { + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; + } + + /// Gets pointer to host data + Element * host_data() { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to device data + Element * device_data() { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data + Element const * device_data() const { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + + /// Returns the layout object + Layout & layout() { + return layout_; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_.data(), device_.get(), device_.size()); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_.get(), host_.data(), host_.size()); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + host_.data(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + device_.get(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + device_.get(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + host_.data(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + reinterpret_cast(ptr_host), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + reinterpret_cast(ptr_device), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + reinterpret_cast(ptr_device), host_.data(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + reinterpret_cast(ptr_host), host_.data(), container_count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/packed_stride.hpp b/csrc/quantization/cutlass_test/example/util/packed_stride.hpp new file mode 100644 index 0000000000000..e9a243a1322cc --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/packed_stride.hpp @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/container/array.hpp" // cute::array +#include "cutlass/conv/convolution.h" // cutlass::conv::Operator + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with group mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides for convolutions + +// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) +// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order +// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout +// right in KTRSC order and can be coalesced to just k. +// We enforce this condition here with asserts. +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, cute::Int<0>> s, + cute::array shape_output, + cute::array stride_output, + cutlass::conv::Operator conv_op) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + static_assert(RankT_ >= 3u); + constexpr static int RankT = static_cast(RankT_); + + assert(stride_output[RankT-1] == 1); + cute::for_each(cute::make_seq{}, [&](auto i) { + assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); + }); + + auto s_copy = s; + cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? + stride_output[0] : + stride_output[RankT-2]; + return s_copy; +} + +// +// Activation tensor ((w, h, d, n), _1) for fprop kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nwc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_nwc[1]; + cute::get<0,1>(s_copy) = stride_nwc[0]; + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nhwc[3] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_nhwc[2-i]; + }); + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; + }); + return s_copy; +} + +// +// Filter tensor (k, (_1, s, r, t)) for fprop kernel +// + +// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>> +make_cute_packed_stride( + cute::Stride, IntT>> s, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>> s, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>> s, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + +// +// Activation tensor (_1, (w, h, d, n)) for wgrad kernel +// +// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad +// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nwc[2] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::get<1,0>(s_copy) = stride_nwc[1]; + cute::get<1,1>(s_copy) = stride_nwc[0]; + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nwc in dgrad is ksc. + cute::get<1,0>(s_copy) = stride_nwc[0]; + cute::get<1,1>(s_copy) = stride_nwc[1]; + } + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad +// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nhwc[3] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_nhwc[2-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nhwc in dgrad is krsc. + cute::get<1,0>(s_copy) = stride_nhwc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; + }); + } + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_ndhwc in dgrad is ktrsc. + cute::get<1,0>(s_copy) = stride_ndhwc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; + }); + } + return s_copy; +} + +// +// NZPQ tensor (_1, nzpq) for wgrad kernel +// + +// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nqk[2] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nqk[1]; + return s_copy; +} + +// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_npqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_npqk[3] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_npqk[2]; + return s_copy; +} + +// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nzpqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nzpqk[4] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nzpqk[3]; + return s_copy; +} + + + +// +// Wgrad output tensor (k, (_1, s, r, t), _0) +// + +// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/detail/inner_product.h b/csrc/quantization/cutlass_test/example/util/reference/detail/inner_product.h new file mode 100644 index 0000000000000..2bce60b1390c0 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/detail/inner_product.h @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +namespace cutlass { +namespace reference { +namespace detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template function to compute an inner product. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a + // host-only type +template +CUTLASS_HOST_DEVICE +Ctype inner_product(Atype a, Btype b, Ctype c) { + return Ctype(a) * Ctype(b) + c; +} + +/// Specialization for matrix multiplication with binary operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int bit = 0; bit < 32; bit++) { + accum += a[bit] ^ b[bit]; + } + return accum + c; +} + +/* +/// Specialization for matrix multiplication with signed 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} + +/// Specialization for matrix multiplication with unsigned 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Cast { + // Default behavior: convert to the destination type +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + static DstType apply(SrcType src) { return static_cast(src); }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static int8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(-128.f, fminf(127.f, src))); + }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static uint8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(0.f, fminf(255.f, src))); + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + diff --git a/csrc/quantization/cutlass_test/example/util/reference/detail/linear_to_coordinate.h b/csrc/quantization/cutlass_test/example/util/reference/detail/linear_to_coordinate.h new file mode 100644 index 0000000000000..1f784c46f6eb9 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/detail/linear_to_coordinate.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + + int64_t prod = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - Index; i < Rank; ++i) { + prod *= int64_t(extent[i]); + } + + coord[Rank - Index - 1] = int(idx / prod); + + int64_t residual = idx % prod; + LinearToCoordinateHelper()(coord, residual, extent); + } +}; + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &) const { + coord[Rank - 1] = int(idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinate { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + LinearToCoordinateHelper()(coord, idx, extent); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/convolution.h b/csrc/quantization/cutlass_test/example/util/reference/device/convolution.h new file mode 100644 index 0000000000000..c91cd0e229bdd --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/convolution.h @@ -0,0 +1,1549 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv2d device reference kernel +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + int c_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Get group id of currnet channel + int c_group_idx = C / c_per_group; + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + int k_group_idx = thread_k / k_per_group; + + if (thread_k < problem_size.K && k_group_idx == c_group_idx) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_z[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, Z, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + int64_t ZPQ = PQ * problem_size.Z; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nzpq = nzpq_start + m; + + thread_n[m] = int(nzpq / ZPQ); + + int64_t residual = nzpq % ZPQ; + thread_z[m] = int(residual / PQ); + + residual = residual % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && + d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + + if (thread_k < problem_size.K) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_z[m] < problem_size.Z && + thread_p[m] < problem_size.P && + thread_q[m] < problem_size.Q) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } // for (n) + + } + } // for (m) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nhw = nhw_start + m; + + thread_n[m] = int(nhw / HW); + + int64_t residual = nhw % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_d[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + int64_t DHW = HW * problem_size.D; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t ndhw = ndhw_start + m; + + thread_n[m] = int(ndhw / DHW); + + int64_t residual = ndhw % DHW; + thread_d[m] = int(residual / HW); + + residual = residual % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (z >= 0 && !(z % problem_size.stride_d) && + p >= 0 && !(p % problem_size.stride_h) && + q >= 0 && !(q % problem_size.stride_w)) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_d[m] < problem_size.D && + thread_h[m] < problem_size.H && + thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t rsc = rsc_start + n; + int64_t residual = rsc % SC; + + thread_r[n] = int(rsc / SC); + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_t[kThreadN]; + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + int64_t RSC = SC * problem_size.R; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t trsc = trsc_start + n; + + thread_t[n] = int(trsc / RSC); + + int64_t residual = trsc % RSC; + thread_r[n] = int(residual / SC); + + residual = residual % SC; + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int Z = 0; Z < problem_size.Z; ++Z) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_t = thread_t[n]; + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - filter_t; + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d; + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W && + thread_c[n] < problem_size.C) { + + element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (Q) + } // for (P) + } // for (Z) + } // for (N) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_t[n] < problem_size.T && + thread_r[n] < problem_size.R && + thread_s[n] < problem_size.S && + thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Conv2d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; + int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; + int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; + int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; + int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv2dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; + int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv3dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kDgrad: + return Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kWgrad: + return Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + default: break; + } + + return Status::kErrorNotSupported; +} + +/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kDgrad: + return Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kWgrad: + return Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + default: break; + } + + return Status::kErrorNotSupported; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/gemm.h b/csrc/quantization/cutlass_test/example/util/reference/device/gemm.h new file mode 100644 index 0000000000000..1a1bd3751801a --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/gemm.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/kernel/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + kernel::Gemm< + TensorRef, + TensorRef, + TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum) { + + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp, + typename ConvertOp +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + static_assert( + TensorRefCollectionA::kRank == 2 && + TensorRefCollectionB::kRank == 2 && + TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), + batch_count + ); + + // Launch a GEMM kernel + kernel::BatchedGemm< + TensorRefCollectionA, + TensorRefCollectionB, + TensorRefCollectionC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + initial_accum + ); +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/gemm_complex.h b/csrc/quantization/cutlass_test/example/util/reference/device/gemm_complex.h new file mode 100644 index 0000000000000..b4d41bd28efb5 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/gemm_complex.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + if (grid.y <= std::numeric_limits::max()) { + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } else { + // Using bigger thread tile size + int const kBigMblock = 4; + int const kBigNblock = 16; + + dim3 Bigblock(16, 8); + dim3 Biggrid( + (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), + (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kBigMblock, + kBigNblock + ><<< Biggrid, Bigblock >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/gemm_planar_complex.h b/csrc/quantization/cutlass_test/example/util/reference/device/gemm_planar_complex.h new file mode 100644 index 0000000000000..37c103c3fcb45 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/gett.hpp b/csrc/quantization/cutlass_test/example/util/reference/device/gett.hpp new file mode 100644 index 0000000000000..78586ad62dc18 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/gett.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT device reference code +*/ +#pragma once + +#include + +namespace cutlass::reference::device { + +template < + class ATensor, + class BTensor, + class CTensor, + class DTensor, + class ElementAccumulator, + class ElementEpilogue> +__global__ static +void +gett_kernel( + DTensor D, + ATensor const A, + BTensor const B, + CTensor const C, + ElementEpilogue alpha, ElementEpilogue beta, + ElementAccumulator acc_init) +{ + using namespace cute; + + static_assert(DTensor::rank == 3, "(M,N,L)"); + static_assert(ATensor::rank == 3, "(M,K,L)"); + static_assert(BTensor::rank == 3, "(N,K,L)"); + static_assert(CTensor::rank == 3, "(M,N,L)"); + + assert(size<0>(A) == size<0>(D)); // M + assert(size<0>(C) == size<0>(D)); // M + assert(size<0>(B) == size<1>(D)); // N + assert(size<1>(C) == size<1>(D)); // N + assert(size<1>(A) == size<1>(B)); // K + assert(size<2>(A) == size<2>(D)); // L + assert(size<2>(B) == size<2>(D)); // L + assert(size<2>(C) == size<2>(D)); // L + + NumericConverter a_converter; + NumericConverter b_converter; + NumericConverter acc_converter; + NumericConverter source_converter; + NumericConverter output_converter; + + // Thread id to each element of D + for (int tid = threadIdx.x + blockDim.x * blockIdx.x; + tid < size(D); + tid += blockDim.x * gridDim.x) { + // (m,n,l) coordinate + auto mnl_coord = idx2crd(tid, product_each(shape(D))); + auto m = get<0>(mnl_coord); + auto n = get<1>(mnl_coord); + auto l = get<2>(mnl_coord); + + auto A_ml = A(m,_,l); + auto B_nl = B(n,_,l); + + ElementAccumulator accum = ElementAccumulator(0); + for (int k = 0; k < size<1>(A); ++k) { + ElementAccumulator a = a_converter(A_ml(k)); + ElementAccumulator b = b_converter(B_nl(k)); + accum += a * b; + } + + ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); + D(m,n,l) = output_converter(scaled_output); + } +} + +// Most general version +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +void +gett( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4); + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto K = get<2>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full tensors + auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) + auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) + auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) + auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) + + dim3 dimBlock(256); + dim3 dimGrid(240); + gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); +} + +} // namespace cutlass::reference::device diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/kernel/gemm.h b/csrc/quantization/cutlass_test/example/util/reference/device/kernel/gemm.h new file mode 100644 index 0000000000000..f7731213013d5 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/kernel/gemm.h @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/thread/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void Gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefA tensor_a, + TensorRefB tensor_b, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + AccumulatorType initial_accum) { + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Compute the general matrix product + thread::Gemm< + TensorRefA, + TensorRefB, + TensorRefC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void BatchedGemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefCollectionA tensor_collection_a, + TensorRefCollectionB tensor_collection_b, + ScalarType beta, + TensorRefCollectionC tensor_collection_c, + AccumulatorType initial_accum) { + + // Obtain batch ID + int batch_id = blockIdx.z; + + // Dereference based on batch_id + typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); + typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); + typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, + (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow + ); + + // Compute the general matrix product + thread::Gemm< + typename TensorRefCollectionA::TensorRef, + typename TensorRefCollectionB::TensorRef, + typename TensorRefCollectionC::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_elementwise.h b/csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_elementwise.h new file mode 100644 index 0000000000000..c703f07f78a24 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_elementwise.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform random distribution +template +__global__ void TensorInitializeUniform( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + double range = dist.uniform.max - dist.uniform.min; + + double rnd = curand_uniform(&rng_state[threadIdx.x]); + + rnd = dist.uniform.min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + + tensor += ldm; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform distribution +template +__global__ void TensorInitializeGaussian( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + double rnd = curand_normal(&rng_state[threadIdx.x]); + + rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; + + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeLinear( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = + dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeIdentity( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = (c_idx == s_idx ? T(1) : T(0)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_foreach.h b/csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_foreach.h new file mode 100644 index 0000000000000..a64a419d8a193 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/kernel/tensor_foreach.h @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/fast_math.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for general rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + int64_t product = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - RankRemaining; i < Rank; ++i) { + product *= size[i]; + } + + coord[Rank - 1 - RankRemaining] = index / product; + int64_t remaining = index % product; + + TensorForEachHelper(func, size, coord, remaining); + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for fastest changing rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + coord[Rank - 1] = index; + + if (coord < size) { + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element in a tensor's index space +template +__global__ void TensorForEach(Coord size, Params params = Params()) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t max_index = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + max_index *= size[i]; + } + + CUTLASS_PRAGMA_NO_UNROLL + while (index < max_index) { + Coord coord; + + detail::TensorForEachHelper(func, size, coord, index); + index += blockDim.x * gridDim.x; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element along a tensor's diagonal +template +__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; + + if (index < end) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = index; + } + + func(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params) { + + Func func(params); + + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + + for (; index < capacity; index += blockDim.x * gridDim.x) { + ReferenceFactory::get(ptr, index) = func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass + diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/rank_2k_complex.h b/csrc/quantization/cutlass_test/example/util/reference/device/rank_2k_complex.h new file mode 100644 index 0000000000000..d5892457ca942 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/rank_2k_complex.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + assert(M=N); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::Rank2KComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + fill_mode_c, + blas_mode, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/tensor_compare.h b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_compare.h new file mode 100644 index 0000000000000..e6b36990f0f1a --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_compare.h @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" + +#include "cutlass/util/distribution.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template +__global__ void BlockCompareEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (a != b) { + *equal = 0; + + return; + } + } +} + +template +__global__ void BlockCompareRelativelyEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (!relatively_equal(a, b, epsilon, nonzero_floor)) { + *equal = 0; + return; + } + } +} + +} // namespace kernel + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + int grid_size = 0, + int block_size = 0) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareEqual<<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareRelativelyEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor, + int grid_size = 0, + int block_size = 0) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareRelativelyEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareRelativelyEqual<<< grid, block >>>( + device_equal_flag, + ptr_A, + ptr_B, + capacity, + epsilon, + nonzero_floor + ); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // reference +} // cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/tensor_fill.h b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_fill.h new file mode 100644 index 0000000000000..13aedf14d113f --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_fill.h @@ -0,0 +1,2077 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) + +// Standard Library includes +#include +#include +#include +#include +#include + +#endif + +// CUDA includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_view.h" +#include "cutlass/blas3.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/layout/vector.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" +#include "cutlass/util/distribution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +CUTLASS_DEVICE +FloatType random_normal_float(curandState_t *state) { + return curand_normal(state); +} + +template <> +CUTLASS_DEVICE +double random_normal_float(curandState_t *state) { + return curand_normal_double(state); +} + +template +CUTLASS_DEVICE +FloatType random_uniform_float(curandState_t *state) { + return curand_uniform(state); +} + +template <> +CUTLASS_DEVICE +double random_uniform_float(curandState_t *state) { + return curand_uniform_double(state); +} + +template +struct RandomGaussianFunc { + + using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element mean_ = 0, + Element stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd = random_normal_float(&rng_state); + rnd = params.mean + params.stddev * rnd; + + Element result; + if (params.int_scale >= 0) { + rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); + result = Element(IntType(rnd * params.float_scale_down)); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd += FloatType(1); + } else { + rnd -= FloatType(1); + } + result = Element(rnd); + } + + return result; + } +}; + + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_up += FloatType(0.5) * float_scale_up; + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); + rnd_i = FloatType(IntType(rnd_i * params.float_scale_down)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r += FloatType(1); + } else { + rnd_r -= FloatType(1); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomGaussianFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomGaussianFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = typename RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean + typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + using Func = detail::TensorFillRandomGaussianFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template ///< Element type +void BlockFillRandomGaussian( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + + typename RandomFunc::Params params(seed, mean, stddev, bits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random uniform distribution +template ///< Element type +struct RandomUniformFunc { + + using FloatType = typename std::conditional< + (sizeof(Element) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Element) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType max; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element max_ = 1, + Element min = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max_) - static_cast(min)), + max(static_cast(max_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + range = (min == Element(0)) ? range - FloatType(1): range; + max = (max_ == Element(0)) ? max - FloatType(1): max; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(NAN); + } + } + + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.max - params.range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); + result = Element(IntType(rnd * params.float_scale_down)); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd = std::min(params.max, rnd + FloatType(1)); + } else { + rnd = std::max((params.max - params.range), rnd - FloatType(1)); + } + result = Element(rnd); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_up += FloatType(0.5) * float_scale_up; + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + min = (min == FloatType(0)) ? min + FloatType(1): min; + range = (max == FloatType(0)) ? range - FloatType(1): range; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(Real(NAN), Real(NAN)); + } + } + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); + rnd_i = FloatType(IntType(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1)); + } else { + rnd_r = std::max((params.min), rnd_r - FloatType(1)); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomUniformFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type max = Element(1), ///< upper bound of distribution + typename RealType::Type min = Element(0), ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + + typename RandomFunc::Params params(seed, max, min, bits, pnan); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random sparse meta +template ///< Element type +struct RandomSparseMetaFunc { + + using FloatType = float; + + using IntType = int32_t; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + int MetaSizeInBits; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), + MetaSizeInBits(MetaSizeInBits_) { + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomSparseMetaFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element *MetaArray = + (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomSparseMetaFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, MetaSizeInBits); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + + typename RandomFunc::Params params(seed, MetaSizeInBits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Updates the tensor + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + params.view.at(coord) = (is_diag ? params.diag : params.other); + } +}; + +// Overwrites the elements of a tensor with a uniform value depending on fill mode +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element element; + FillMode fill_mode; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params(): fill_mode(FillMode::kNone) { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, + Element element_, + FillMode fill_mode_ + ): + view(view_), element(element_), fill_mode(fill_mode_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorFillPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + case FillMode::kFull: + predicate = true; + break; + + case FillMode::kLower: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] < coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kUpper: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] > coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kDiagonal: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] != coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorClearPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// + static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); + + /// Parameters structure + struct Params { + TensorView view{}; + Element element{}; + FillMode fill_mode{FillMode::kNone}; + int alignment{0}; + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorClearPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + + case FillMode::kLower: + if ((coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kUpper: + if ((coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0), ///< value to write off the diagonal + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are +/// not written. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, element, fill_mode), + stream + ); +} + +/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side +/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorClearPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + int alignment, + cudaStream_t stream = nullptr) { + + typedef detail::TensorClearPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params{view, element, fill_mode, alignment}, + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView view, ///< destination tensor + Element val = Element(0), ///< value to uniformly fill it with + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, val, val, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView view, ///< destination tensor + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, Element(1), Element(0), stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1) + ): + view(view_), diag(diag_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (is_diag) { + params.view.at(coord) = params.diag; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + params.view.at(coord) = params.other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView view, ///< destination tensor + Element other = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateOffDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillLinearFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element sum = params.s; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + if constexpr (is_complex::value) { + if constexpr (sizeof_bits::value <= 32) { + sum = Element(static_cast>(sum) + + static_cast>(params.v[i]) * static_cast>(coord[i])); + } + } + else if constexpr (sizeof_bits::value <= 32) { + if constexpr (std::numeric_limits::is_integer) { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + else { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + } + else { + sum += params.v[i] * coord[i]; + } + } + + params.view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView view, ///< destination tensor + Array const & v, + Element s = Element(0), + cudaStream_t stream = nullptr) { + + using Func = detail::TensorFillLinearFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, v, s), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr, + int exclude_zero = -1 ///< If non-negative, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. + ) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + exclude_zero, + stream); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + exclude_zero, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + + using Layout = layout::PackedVectorLayout; + Layout::TensorCoord size(static_cast(capacity)); // -Wconversion + Layout layout = Layout::packed(size); + TensorView view(ptr, layout, size); + + Array c{}; + c[0] = v; + + TensorFillLinear(view, c, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalInFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element const *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element const *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { + + } + + /// Only update the diagonal element + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.view.at(coord) = params.ptr[coord[0]]; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView view, ///< destination tensor + Element const *ptr, ///< dense buffer of elements + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalInFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalOutFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.ptr[coord[0]] = params.view.at(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView view, ///< source tensor + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalOutFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/tensor_foreach.h b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_foreach.h new file mode 100644 index 0000000000000..3911b0240c6d2 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_foreach.h @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/reference/device/kernel/tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element in a tensor's index space. +template +struct TensorForEach { + + /// Constructor performs the operation. + TensorForEach( + Coord size, Params params = Params(), + int grid_size = 0, int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::TensorForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element along a tensor's diagonal +template +struct TensorDiagonalForEach { + + /// Constructor performs the operation + TensorDiagonalForEach( + Coord size, Params params = Params(), + int start = 0, int end = -1, + int block_size = 128, cudaStream_t stream = nullptr) { + + if (end < 0) { + end = size.min(); + } + + dim3 block(block_size, 1, 1); + dim3 grid((end - start + block_size - 1) / block_size, 1, 1); + + kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( + size, params, start, end); + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params(), + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/tensor_reduce.h b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_reduce.h new file mode 100644 index 0000000000000..47b898b4fd161 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_reduce.h @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/detail/linear_to_coordinate.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + + // Fetch element + Element x = view.at(coord); + + // Transform + identity = reduce(identity, transform(x)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + auto size = static_cast(view_A.size()); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + + // Fetch element + Element a = view_A.at(coord); + Element b = view_B.at(coord); + + // Transform + identity = reduce(identity, transform(a, b)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + + +template < + typename ComputeType, + typename ReduceOp, + int kBlockSize = 32 +> +__global__ void TensorTransformReduceFinalize( + ComputeType *workspace, + ComputeType identity, + int workspace_size, + ReduceOp reduce) { + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { + identity = reduce(identity, workspace[idx]); + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[0] = identity; + } +} + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of two tensors, zipped together +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Extents must be equal."); + } + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view_A, view_B, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view_A, + view_B, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform, stream, workspace_size); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/tensor_relu.h b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_relu.h new file mode 100644 index 0000000000000..4e5a50403cf8d --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/tensor_relu.h @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorReLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element threshold; + + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element threshold_ = Element(0) + ): + view(view_), threshold(threshold_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorReLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + params.view.at(coord) = (value < params.threshold) ? params.threshold : value; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply ReLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorReLu( + TensorView view, ///< destination tensor + Element threshold = Element(0)) { ///< ReLu threshold + + using Func = detail::TensorReLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, threshold) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/device/thread/gemm.h b/csrc/quantization/cutlass_test/example/util/reference/device/thread/gemm.h new file mode 100644 index 0000000000000..04775a746ad16 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/device/thread/gemm.h @@ -0,0 +1,186 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace thread { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level blocked general matrix product. +// +// Note, this is a reference implementation. Performance is not expected to approach peak. +// +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +struct Gemm { + + using ElementA = typename TensorRefA::Element; + using ElementB = typename TensorRefB::Element; + using ElementC = typename TensorRefC::Element; + + // + // Data members + // + + /// Tile for A operand + ElementA A_tile[OutputTile::kColumn]; + + /// Tile for B operand + ElementB B_tile[OutputTile::kRow]; + + /// Tile for Accumulator + AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { + + // Clear fetch registers + for (int i = 0; i < OutputTile::kColumn; ++i) { + A_tile[i] = ElementA(0); + } + + for (int j = 0; j < OutputTile::kRow; ++j) { + B_tile[j] = ElementB(0); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kColumn; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kRow; ++i) { + accum[j][i] = initial_accum; + } + } + } + + /// Computes a matrix product + CUTLASS_HOST_DEVICE + Gemm & multiply_add( + gemm::GemmCoord problem_size, + TensorRefA tensor_a, + TensorRefB tensor_b, + MatrixCoord output_coord = MatrixCoord()) { + + InnerProductOp inner_product_op; + + // Loop over the GEMM K dimension + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < problem_size.k(); ++k) { + + // Fetch a slice of the A matrix + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + if (output_coord.row() + i < problem_size.m()) { + A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); + } + } + + // Fetch a slice of the B matrix + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + if (output_coord.column() + j < problem_size.n()) { + B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); + } + } + + // Compute an accumulated matrix product + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); + } + } + } + + return *this; + } + + /// Performs linear scaling of matrix product and updates output tensor + CUTLASS_HOST_DEVICE + Gemm & epilogue( + gemm::GemmCoord problem_size, + ScalarType alpha, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + MatrixCoord output_coord = MatrixCoord()) { + + ConvertOp convert_op; + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[j][i]) + + beta * ScalarType(tensor_c.at(coord)) + ); + } + } + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/conv.hpp b/csrc/quantization/cutlass_test/example/util/reference/host/conv.hpp new file mode 100644 index 0000000000000..545dbba9a4e89 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/conv.hpp @@ -0,0 +1,698 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for CONV in host-side code. +*/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_) { + return ((n_ >= 0 && n_ < size<4>(activation)) && + (d_ >= 0 && d_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t h_, int32_t w_, int32_t c_) { + return ((n_ >= 0 && n_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t w_, int32_t c_) { + return ((n_ >= 0 && n_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +} // namespace detail + +template< + class ElementAcc_, + class ElementScalar_, + class ElementCompute_, + class ElementC_, + class ElementOut_, + class TensorAlpha_, + class TensorBeta_, + class TensorBias_, + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> +struct ConvEpilogueFusionParams { + using ElementAcc = ElementAcc_; + using ElementScalar = ElementScalar_; + using ElementCompute = ElementCompute_; + using ElementC = ElementC_; + using ElementOut = ElementOut_; + using TensorAlpha = TensorAlpha_; + using TensorBeta = TensorBeta_; + using TensorBias = TensorBias_; + using ActivationFunctor = ActivationFunctor_; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorAlpha tensor_alpha{}; + TensorBeta tensor_beta{}; + TensorBias tensor_bias{}; +}; + +template< + cutlass::conv::Operator ConvOp, + int NumSpatialDims, + class TensorA, + class TensorB, + class TensorC, + class TensorD, + class ShapePadding, + class StrideTraversal, + class ShapeDilation, + class EpilogueFusionParams +> +struct ConvReferenceImpl { + // Hard code accumlulator type to float to avoid data lost in accumulating add. + using ElementAcc = cutlass::platform::conditional_t, double, float>; + using ElementC = typename EpilogueFusionParams::ElementC; + using ElementOut = typename EpilogueFusionParams::ElementOut; + using ElementScalar = typename EpilogueFusionParams::ElementScalar; + using ElementCompute = typename EpilogueFusionParams::ElementCompute; + using ElementBias = typename EpilogueFusionParams::TensorBias::value_type; + using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor; + + // Input related converter + NumericConverter acc_converter; + NumericConverter residual_converter; + NumericConverter bias_converter; + // Scale related converter + NumericConverter scale_converter; + // Output related converter + NumericConverter output_converter; + + EpilogueFusionParams& epi_fusion_params_; + TensorA const& tensor_a_; + TensorB const& tensor_b_; + TensorC const& tensor_c_; + TensorD& tensor_d_; + + ShapePadding const& padding_; + StrideTraversal const& tstride_; + ShapeDilation const& dilation_; + + // Epilogue activation operation + ActivationFunctor epi_activation; + + ConvReferenceImpl( + TensorA const& tensor_a, + TensorB const& tensor_b, + TensorC const& tensor_c, + TensorD& tensor_d, + ShapePadding const& padding, + StrideTraversal const& tstride, + ShapeDilation const& dilation, + EpilogueFusionParams& epi_fusion_params) + : tensor_a_(tensor_a), + tensor_b_(tensor_b), + tensor_c_(tensor_c), + tensor_d_(tensor_d), + padding_(padding), + tstride_(tstride), + dilation_(dilation), + epi_fusion_params_(epi_fusion_params) + { + static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); + static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); + } + + void compute_reference() { + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + fprop_reference(cute::Int{}); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + dgrad_reference(cute::Int{}); + } + else { + wgrad_reference(cute::Int{}); + } + } + +private: + // Specialization for 1D fprop kernel + void fprop_reference(cute::Int<1> spatial_dims) { + int32_t N = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, w, c)) { + auto a = tensor_a_(c, w, n); + auto b = tensor_b_(c, s, k); + accumulator += ElementAcc(a * b); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(k, q, n)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + tensor_d_(k, q, n) = output_converter(output); + } + } + } + + } + + // Specialization for 2D fprop kernel + void fprop_reference(cute::Int<2> spatial_dims) { + int32_t N = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c)) { + auto a = tensor_a_(c, w, h, n); + auto b = tensor_b_(c, s, r, k); + accumulator += ElementAcc(a * b); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + tensor_d_(k, q, p, n) = output_converter(output); + } + } + } + } + + } + + // Specialization for 3D fprop kernel + void fprop_reference(cute::Int<3> spatial_dims) { + int32_t N = size<4>(tensor_d_); + int32_t Z = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c)) { + auto a = tensor_a_(c, w, h, d, n); + auto b = tensor_b_(c, s, r, t, k); + accumulator += ElementAcc(a * b); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + tensor_d_(k, q, p, z, n) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 1D dgrad kernel + void dgrad_reference(cute::Int<1> spatial_dims) { + int32_t N = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t n = 0; n < N; ++n) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, q, k)) { + accumulator += ElementAcc(tensor_a_(k, q, n) * tensor_b_(c, s, k)); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(c, w, n)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + tensor_d_(c, w, n) = output_converter(output); + } + } + } + + } + + // Specialization for 2D dgrad kernel + void dgrad_reference(cute::Int<2> spatial_dims) { + int32_t N = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t n = 0; n < N; ++n) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k)) { + accumulator += ElementAcc(tensor_a_(k, q, p, n) * tensor_b_(c, s, r, k)); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + + tensor_d_(c, w, h, n) = output_converter(output); + } + } + } + } + + } + + // Specialization for 3D dgrad kernel + void dgrad_reference(cute::Int<3> spatial_dims) { + int32_t N = size<4>(tensor_d_); + int32_t D = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<4>(tensor_b_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t n = 0; n < N; ++n) { + for (int32_t d = 0; d < D; ++d) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (z % cute::get<2>(tstride_) == 0) { + z /= cute::get<2>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k)) { + accumulator += ElementAcc(tensor_a_(k, q, p, z, n) * tensor_b_(c, s, r, t, k)); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + tensor_d_(c, w, h, d, n) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 1D wgrad kernel + void wgrad_reference(cute::Int<1> spatial_dims) { + int32_t N = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t k = 0; k < K; ++k) { + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, w, c); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, n); + auto xformed_act = + tensor_a_(k, q, n); + accumulator += ElementAcc(act * xformed_act); + } + } + } + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(c, s, k)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + tensor_d_(c, s, k) = output_converter(output); + } + } + } + } + + // Specialization for 2D wgrad kernel + void wgrad_reference(cute::Int<2> spatial_dims) { + int32_t N = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t k = 0; k < K; ++k) { + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, h, w, c); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, n); + auto xformed_act = + tensor_a_(k, q, p, n); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + tensor_d_(c, s, r, k) = output_converter(output); + } + } + } + } + } + + // Specialization for 3D wgrad kernel + void wgrad_reference(cute::Int<3> spatial_dims) { + int32_t N = + size<4>(tensor_a_); + int32_t Z = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t T = size<3>(tensor_d_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t k = 0; k < K; ++k) { + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, d, n); + auto xformed_act = + tensor_a_(k, q, p, z, n); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + } + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + + scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k)); + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + tensor_d_(c, s, r, t, k) = output_converter(output); + } + } + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/convolution.h b/csrc/quantization/cutlass_test/example/util/reference/host/convolution.h new file mode 100644 index 0000000000000..f28b4a658a388 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/convolution.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Forward propagation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv2d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + int group_idx = k / (problem_size.K / problem_size.groups); + int channels_per_group = problem_size.C / problem_size.groups; + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < channels_per_group; ++c) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); + ElementB b = tensor_w.at({k, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +/// Depthwise-separable convolution +template , + typename InnerProductOp = multiply_add> +void Depsep_Fprop(cutlass::TensorView tensor_A, + cutlass::TensorView tensor_B, + cutlass::TensorView tensor_C, + cutlass::TensorView tensor_D, + ElementCompute alpha, + ElementCompute beta, + cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), + cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), + cutlass::Coord<2> dilation = cutlass::Coord<2>(), + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < tensor_C.extent().n(); ++n) { + for (int p = 0; p < tensor_C.extent().h(); ++p) { + for (int q = 0; q < tensor_C.extent().w(); ++q) { + for (int g = 0; g < tensor_C.extent().c(); ++g) { + ElementAccumulator acc = ElementAccumulator(); + for (int r = 0; r < tensor_B.extent().h(); ++r) { + for (int s = 0; s < tensor_B.extent().w(); ++s) { + + // input activation H and W + int h = p * conv_stride[0] - padding[0] + r * dilation[0]; + int w = q * conv_stride[1] - padding[2] + s * dilation[1]; + + if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { + ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); + + ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) + ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) + : tensor_B.at(cutlass::make_Coord( + g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); + tensor_D.at(cutlass::make_Coord(n, p, q, g)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dDgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; +#if 0 + std::cout << "row:" + << n * problem_size.H * problem_size.W + + h * problem_size.W + + w << " " + << "n, p, q: (" + << n << ", " + << p << ", " + << q << ") * " + << "r, s: (" + << r << ", " + << s << ") [" + << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" + << std::endl; +#endif + if (p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dWgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + cutlass::Tensor4DCoord b_coord; + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + b_coord = make_Coord( + n, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + acc = inner_product_op(a, b, acc); + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (K) +} + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// 3D convolution +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv3d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (d >= 0 && d < problem_size.D && + h >=0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, d, h, w, c}); + ElementB b = tensor_w.at({k, t, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dDgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int d = 0; d < problem_size.D; ++d) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (z >= 0 && (z % problem_size.stride_d) == 0 && + p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + } // for (T) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (D) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dWgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + Tensor5DCoord b_coord = make_Coord( + n, + z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && + b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + + acc = inner_product_op(a, b, acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + } // for (K) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/error_metrics.h b/csrc/quantization/cutlass_test/example/util/reference/host/error_metrics.h new file mode 100644 index 0000000000000..86db65ccc441e --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/error_metrics.h @@ -0,0 +1,66 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorRelativeErrorMetric( + TensorView view_A_computed, + TensorView view_B_reference, + ComputeType identity = ComputeType() +) { + + return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / + cutlass::reference::host::TensorNorm(view_B_reference, identity); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/gemm.h b/csrc/quantization/cutlass_test/example/util/reference/host/gemm.h new file mode 100644 index 0000000000000..03888131095fc --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/gemm.h @@ -0,0 +1,531 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +namespace cutlass { +namespace reference { +namespace host { + +template +struct CastIfScalar { + static Out cast(In in) { + return Out(in); + } +}; + +template +struct CastIfScalar, In> { + typedef cutlass::complex Out; + static Out cast(In in) { + return Out(static_cast(in)); + } +}; + +template +struct CastIfScalar, cutlass::complex> { + typedef cutlass::complex Out; + typedef cutlass::complex In; + static Out cast(In in) { + return Out(in); + } +}; + +template +Out cast_if_scalar(In in) { + return CastIfScalar::cast(in); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +/// Partial specialization for AND-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); + typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); + typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); + + for (int batch = 0; + batch < batch_count; + ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { + + Gemm + gemm; + + gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, + initial_accum); + } +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/gemm_complex.h b/csrc/quantization/cutlass_test/example/util/reference/host/gemm_complex.h new file mode 100644 index 0000000000000..92da343a9c222 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/gemm_complex.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/gemm_planar_complex.h b/csrc/quantization/cutlass_test/example/util/reference/host/gemm_planar_complex.h new file mode 100644 index 0000000000000..094af8b37b695 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/gett.hpp b/csrc/quantization/cutlass_test/example/util/reference/host/gett.hpp new file mode 100644 index 0000000000000..f6984fb2ba9c5 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/gett.hpp @@ -0,0 +1,538 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_ // (N, K, L) +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + TensorA A{}; + TensorB B{}; + + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = TensorD_, // (M, 1) + class TensorAux_ = TensorD_, // (M, N, L) + class VectorAlpha_ = TensorD_, // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + + using RingOp = multiply_add; + RingOp fma_op; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { + a_frag[m_b] = conj(a_frag[m_b]); + } + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { + b_frag[n_b] = conj(b_frag[n_b]); + } + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool IsClamp = + cute::is_same_v>; + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // per-row alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // per-row beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); + converted_beta = mul(converted_beta, converted_scale_c); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (IsClamp) { // Treat Clamp as ReLU + output = activation(output, {0, std::numeric_limits::max()}); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto make_layout_rank3(const TensorType& tensor) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + return make_layout( + make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}), + make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); +} + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + + if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { + cute::Layout layout_A = make_layout_rank3(mainloop_params.A); + cute::Layout layout_B = make_layout_rank3(mainloop_params.B); + cute::Layout layout_C = make_layout_rank3(epilogue_params.C); + cute::Layout layout_D = make_layout_rank3(epilogue_params.D); + cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); + cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); + auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); + auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); + auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); + auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); + auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); + + // Reconstruct mainloop params + GettMainloopParams + mainloop_params_converted{TensorA, + TensorB, + mainloop_params.transform_A, + mainloop_params.transform_B}; + + // Reconstruct epilogue params + GettEpilogueParams + epilogue_params_converted{epilogue_params.alpha, + epilogue_params.beta, + TensorC, + TensorD, + VectorBias, + TensorAux, + VectorAlpha, + VectorBeta, + epilogue_params.abs_amax_D, + epilogue_params.abs_amax_Aux, + epilogue_params.scale_a, + epilogue_params.scale_b, + epilogue_params.scale_c, + epilogue_params.scale_d, + epilogue_params.scale_aux + }; + + Gett(mainloop_params_converted, epilogue_params_converted); + } + else { + // if we already have a batch mode, just pass it through + Gett(mainloop_params, epilogue_params); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/rank_2k.h b/csrc/quantization/cutlass_test/example/util/reference/host/rank_2k.h new file mode 100644 index 0000000000000..2a99bc03a35ba --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/rank_2k.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for Rank 2k update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert( + FillModeC == FillMode::kLower || + FillModeC == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), + std::greater_equal, + std::less_equal>::type; + + // Note: batch is ignored. + // Note: M is same as N for Rank 2k update + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < N; row_block += Nblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Nblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < N && col < N && compare_op(row, col)) + { + + // A x B^T + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b_t(cast_if_scalar(b_t)); + + accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); + + // B x A^T + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType compute_b(cast_if_scalar(b)); + ComputeType compute_a_t(cast_if_scalar(a_t)); + + accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < N && col < N && + ( (FillModeC == FillMode::kLower && row >= col) || + (FillModeC == FillMode::kUpper && row <= col) ) + ) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_rank2k( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Rank2K; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Rank2K { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/rank_2k_complex.h b/csrc/quantization/cutlass_test/example/util/reference/host/rank_2k_complex.h new file mode 100644 index 0000000000000..090019c100396 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/rank_2k_complex.h @@ -0,0 +1,318 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + } + } + } + } + + /* HER2K need two epilogues to handle complex alpha value */ + if ( blas_mode == BlasMode::kHermitian ) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op(alpha * + ScalarType(accum[i][j]) + + beta * c); + } + } + } + + /* Zeoring out accum for second HERK */ + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); + } + } + } + } + + ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? + conj(alpha) : alpha; + ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? + 1 : beta; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType d = (blas_mode == BlasMode::kHermitian) ? + tensor_d.at(coord) : tensor_c.at(coord); + + ScalarType tmp_d = convert_op( + alpha_hermitian * ScalarType(accum[i][j]) + + beta_hermitian * d); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/rank_k_complex.h b/csrc/quantization/cutlass_test/example/util/reference/host/rank_k_complex.h new file mode 100644 index 0000000000000..ef44270a314a4 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/rank_k_complex.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x A^T (Symmetric) or A x A^H (Hermitian) + // complex conjugation on operandB (a_t) (function of blas3 computation) + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))) : + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(a_t); + + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + ScalarType tmp_d = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void RankKComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/symm.h b/csrc/quantization/cutlass_test/example/util/reference/host/symm.h new file mode 100644 index 0000000000000..a585caf73f64f --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/symm.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for SYMM update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + FillModeA == FillMode::kLower || + FillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (SideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + + ComputeType compute_a_1(cast_if_scalar(a_1)); + ComputeType compute_b_1(cast_if_scalar(b_1)); + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (SideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + } + + ComputeType compute_a_2(cast_if_scalar(a_2)); + ComputeType compute_b_2(cast_if_scalar(b_2)); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_symm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Symm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Symm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/symm_complex.h b/csrc/quantization/cutlass_test/example/util/reference/host/symm_complex.h new file mode 100644 index 0000000000000..2618feaa70cee --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/symm_complex.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued SYMM update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = BlasMode::kSymmetric, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static BlasMode const kBlasMode = BlasMode_; + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(kSideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + kFillModeA == FillMode::kLower || + kFillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) + { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (kSideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (kSideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + ComputeType compute_a_1 = ComputeType(a_1); + ComputeType compute_b_1 = ComputeType(b_1); + + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { + compute_a_1 = real(compute_a_1); + } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { + compute_b_1 = real(compute_b_1); + } + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (kSideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + if (kBlasMode == BlasMode::kHermitian) + a_2 = conj(a_2); + } else if (kSideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + if (kBlasMode == BlasMode::kHermitian) + b_2 = conj(b_2); + } + + ComputeType compute_a_2 = ComputeType(a_2); + ComputeType compute_b_2 = ComputeType(b_2); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + ScalarType c = tensor_c.at(coord); + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct SymmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.h new file mode 100644 index 0000000000000..df164a37e9297 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.h @@ -0,0 +1,423 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + bool result; + + /// Ctor + TensorEqualsFunc(): result(true) { } + + /// Ctor + TensorEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), rhs(rhs_), result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (lhs_ != rhs_) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorRelativelyEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + Element epsilon; + Element nonzero_floor; + bool result; + + /// Ctor + TensorRelativelyEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_, + Element epsilon_, + Element nonzero_floor_ + ) : + lhs(lhs_), + rhs(rhs_), + epsilon(epsilon_), + nonzero_floor(nonzero_floor_), + result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorView const &lhs, + TensorView const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are NOT equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return true; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return !bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorContainsFunc { + + // + // Data members + // + + TensorView view; + Element value; + bool contains; + Coord location; + + // + // Methods + // + + /// Ctor + TensorContainsFunc(): contains(false) { } + + /// Ctor + TensorContainsFunc( + TensorView const &view_, + Element value_ + ) : + view(view_), value(value_), contains(false) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + if (view.at(coord) == value) { + if (!contains) { + location = coord; + } + contains = true; + } + } + + /// Returns true if equal + operator bool() const { + return contains; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if a value is present in a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorContains( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return bool(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of +/// of the first occurrence. If the value is not contained in the tensor, the second element of the +/// pair is undefined. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +std::pair > TensorFind( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return std::make_pair(bool(func), func.location); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.hpp b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.hpp new file mode 100644 index 0000000000000..a1f3f5b14e6f0 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_compare.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename TensorL, + typename TensorR +> +bool TensorEquals( + TensorL lhs, + TensorR rhs) { + + // Extents must be identical + if (cute::size(lhs) != cute::size(rhs)) { + return false; + } + + for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { + if (lhs(idx) != rhs(idx)) { + return false; + } + } + + return true; +} + +/// Returns true if two tensor views are NOT equal. +template < + typename TensorL, + typename TensorR +> +bool TensorNotEquals( + TensorL lhs, + TensorR rhs) { + + return TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_copy.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_copy.h new file mode 100644 index 0000000000000..0b963b72e9152 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_copy.h @@ -0,0 +1,256 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to convert between types +template < + typename DstElement, + typename SrcElement +> +struct TrivialConvert { + + TrivialConvert() { } + + DstElement operator()(SrcElement src) const { + return DstElement(src); + } +}; + +/// Helper to conditionally copy between tensor views. +template < + typename DstElement, + typename DstLayout, + typename SrcElement, + typename SrcLayout, + typename F +> +struct TensorCopyIf { + + using DstTensorView = TensorView; + using SrcTensorView = TensorView; + + // + // Data members + // + + DstTensorView dst; + SrcTensorView src; + F convert; + + // + // Methods + // + + TensorCopyIf() { } + + TensorCopyIf( + DstTensorView const &dst_, + SrcTensorView const &src_, + F const &convert_): dst(dst_), src(src_), convert(convert_) {} + + /// Copies based on destination and source bounds + void operator()(Coord const &coord) { + if (dst.contains(coord) && src.contains(coord)) { + dst.at(coord) = convert(src.at(coord)); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + CopyIf copy_if(dst, src, transform); + + TensorForEach(dst.extent(), copy_if); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView src_view(src, dst.extent()); + + CopyIf copy_if(dst, src_view, transform); + + TensorForEach(dst.extent(), copy_if); +} + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorRef dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView dst_view(dst, src.extent()); + + CopyIf copy_if(dst_view, src, transform); + + TensorForEach(src.extent(), copy_if); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorView dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorRef dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_elementwise.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_elementwise.h new file mode 100644 index 0000000000000..42ce2183b6a24 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_elementwise.h @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to apply a binary operator in place +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementD, + typename LayoutD, + typename BinaryFunc> +struct TensorFuncBinaryOp { + + // + // Data members + // + + /// View of left-hand-side tensor + TensorView view_d; + TensorRef view_a; + TensorRef view_b; + BinaryFunc func; + + // + // Methods + // + + /// Constructor + TensorFuncBinaryOp() { } + + /// Constructor + TensorFuncBinaryOp( + TensorView const & view_d_, + TensorRef const & view_a_, + TensorRef const & view_b_, + BinaryFunc func = BinaryFunc() + ): + view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } + + /// Equality check + void operator()(Coord const &coord) const { + view_d.at(coord) = func( + ElementD(view_a.at(coord)), + ElementD(view_b.at(coord)) + ); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Adds two tensors and stores in the destination tensor: d = a + b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::plus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Adds a tensor in place: d = d .+ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorAdd(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Subtracts two tensors and stores in the destination tensor: d = a - b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference + ) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::minus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Subtracts two tensors in place: d = d .- a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference + ) { + + TensorSub(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiplies two tensors and stores in the destination tensor: d = a .* b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::multiplies + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Multiplies tensors in place: d = d .* a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorMul(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.h new file mode 100644 index 0000000000000..b9f0c84d9a2a9 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.h @@ -0,0 +1,1718 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element value; + + // + // Methods + // + + TensorFillFunc( + TensorView const &view_ = TensorView(), + Element value_ = Element(0) + ): view(view_), value(value_) { } + + void operator()(Coord const & coord) const { + view.at(coord) = value; + } +}; + +/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method +struct BoxMullerFunc { + + BoxMullerFunc() {} + + void operator()( + double* rnd, ///< Size-2 vector to be filled with random values + double mean = 0, ///< Mean of the Gaussian distribution + double stddev = 1, ///< Standard deviation of the Gaussian distribution + double pi = std::acos(-1)) const { + + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); + rnd[0] = mean + stddev * rnd[0]; + rnd[1] = mean + stddev * rnd[1]; + } +}; +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView dst, ///< destination tensor + Element val = Element(0)) { ///< value to uniformly fill it with + + detail::TensorFillFunc func(dst, val); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + } + else { + result = static_cast(0); + } + + // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros + if (exclude_zero && result == Element(0)) { + if (rnd > 0) { + rnd += 1; + } else { + rnd -= 1; + } + result = Element(rnd); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + double rnd[2]; + detail::BoxMullerFunc func; + func(rnd, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd[0] = double(int(rnd[0] * double(1 << int_scale))); + rnd[1] = double(int(rnd[1] * double(1 << int_scale))); + reals[0] = from_real(rnd[0] / double(1 << int_scale)); + reals[1] = from_real(rnd[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd[0]); + reals[1] = from_real(rnd[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0.0) && + reals[1] == from_real(0.0)) { + + if (rnd[0] > 0.0) { + rnd[0] += 1.0; + } else { + rnd[0] -= 1.0; + } + reals[0] = from_real(rnd[0]); + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + double rnd1[2]; + double rnd2[2]; + detail::BoxMullerFunc func; + func(rnd1, mean, stddev, pi); + func(rnd2, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); + + reals[0] = from_real(rnd1[0] / double(1 << int_scale)); + reals[1] = from_real(rnd1[1] / double(1 << int_scale)); + reals[2] = from_real(rnd2[0] / double(1 << int_scale)); + reals[3] = from_real(rnd2[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd1[0]); + reals[1] = from_real(rnd1[1]); + reals[2] = from_real(rnd2[0]); + reals[3] = from_real(rnd2[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + reals[2] = from_real(0); + reals[3] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0) && + reals[1] == from_real(0) && + reals[2] == from_real(0) && + reals[3] == from_real(0)) { + + if (rnd1[0] > 0.0) { + rnd1[0] += 1.0; + } else { + rnd1[0] -= 1.0; + } + reals[0] = from_real(rnd1[0]); + } + + return Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + view.at(coord) = func(); + } +}; + +/// Computes a random Gaussian distribution for a rank-2 tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillSymmetricGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz, exclude_zero); + + detail::TensorFillGaussianFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + detail::TensorFillSymmetricGaussianFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) + { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + Element operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + if (exclude_zero && result == Element(0)) { + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + complex operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + + if (exclude_zero && + i == 0 && + reals[0] == from_real(0.0)) { + + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + reals[0] = from_real(Real(rnd)); + } + + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_), + bernoulli_rnd{static_cast(seed_)}, + bernoulli_dist(pnan_) + { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + + view.at(coord) = func(); + } +}; + +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillSymmetricRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +/// Computes a random Uniform distribution and pads diagonal with zeros +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPadDiagonalRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + int alignment; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillPadDiagonalRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, + int alignment_ = 1 + ): + view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + (fill_mode == cutlass::FillMode::kLower) && + (coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= alignment)) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + (coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= alignment)) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan, exclude_zero); + + detail::TensorFillRandomUniformFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero); +} + + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView, Layout> dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc> random_func(seed, max, min, bits); + + detail::TensorFillRandomUniformFunc, Layout> func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillSymmetricRandomUniformFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPadDiagonalRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int alignment = 1 +) { + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillPadDiagonalRandomUniformFunc func( + dst, + random_func, + fill_mode, + alignment + ); + + TensorForEach( + dst.extent(), + func + ); +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element ///< Element type +> +void BlockFill( + Element *ptr, + size_t capacity, + Element val + ) { + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = val; + } +} + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0) { ///< Percentage of NaN elements. + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + // + // Methods + // + + TensorFillDiagonalFunc( + TensorView const &view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + view.at(coord) = (is_diag ? diag : other); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView dst, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0)) { ///< value to write off the diagonal + + detail::TensorFillDiagonalFunc func( + dst, + diag, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView dst) { ///< destination tensor + + TensorFillDiagonal(dst, Element(1), Element(0)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView dst, ///< destination tensor + Element val = Element(1)) { + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = val; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element other; + + // + // Methods + // + + TensorUpdateOffDiagonalFunc( + TensorView const &view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + view.at(coord) = other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView dst, ///< destination tensor + Element other = Element(1)) { + + detail::TensorUpdateOffDiagonalFunc func( + dst, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + // + // Methods + // + + TensorFillLinearFunc() { } + + /// Constructs functor + TensorFillLinearFunc( + TensorView const &view_, + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { } + + /// Updates the tensor + void operator()(Coord const & coord) const { + + Element sum(s); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + sum += Element(coord[i]) * v[i]; + } + + view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView dst, ///< destination tensor + Array const & v, + Element s = Element(0)) { + + detail::TensorFillLinearFunc func( + dst, + v, + s + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSequential( + TensorView dst, ///< destination tensor + Element s = Element(0)) { + + Array stride; + + stride[0] = Element(1); + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); + } + + TensorFillLinear(dst, stride, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + bool exclude_zero = false ///< If true, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. +) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz, + exclude_zero); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan, + exclude_zero); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = s; + + s = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist) { + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomSparseMetaFunc { + + uint64_t seed; + int range; + int MetaSizeInBits; + + // + // Methods + // + + RandomSparseMetaFunc( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), MetaSizeInBits(MetaSizeInBits_) { + std::srand((unsigned)seed); + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + + /// Compute random value and update RNG state + Element operator()() const { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + int rnd = std::rand() % range; + Element meta = MetaArray[rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random sparse meta +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomSparseMetaFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillRandomSparseMetaFunc( + TensorView view_ = TensorView(), + RandomSparseMetaFunc func_ = RandomSparseMetaFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + + view.at(coord) = func(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4 bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + detail::TensorFillRandomSparseMetaFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a ell block index matrix with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomEllIdx( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int rows, int ell_cols, int cols) { ///< dimension of the matrix + + std::srand((unsigned)seed); + + for (int i = 0; i < rows; ++i) { + int col_idx = std::rand() % cols; + + for (int j = 0; j < ell_cols; ++j) { + dst.at({i, j}) = col_idx; + + if (col_idx != -1) { + if (col_idx == (cols - 1)) { + col_idx = -1; + } else { + col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView dst, ///< destination tensor + Element const *ptr) { ///< dense buffer of elements + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = ReferenceFactory::get(ptr, i); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView src) { ///< source tensor + + typename Layout::Index extent = src.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + ReferenceFactory::get(ptr, i) = src.at(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.hpp b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.hpp new file mode 100644 index 0000000000000..86a54e2ee06b7 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_fill.hpp @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Uniform and procedural tensor fills +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a scalar element +template +void TensorFill(Tensor dst, typename Tensor::value_type element) { + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = element; + } +} + +/// Fills a tensor with the contents of its layout +template +void TensorFillSequential(Tensor dst) { + + auto layout = dst.layout(); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = layout(idx); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random uniform values +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Element operator()() const { + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template ///< Tensor object +void TensorFillRandomUniform( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random Gaussian +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Tensor +> +void TensorFillRandomGaussian( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_foreach.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_foreach.h new file mode 100644 index 0000000000000..43ff17362c21b --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_foreach.h @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - RankRemaining - 1; + + /// Constructor for general rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + TensorForEachHelper(func, extent, coord); + } + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - 1; + + /// Constructor for fastest changing rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEach(Coord extent, Func & func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor and calls a C++ lambda +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEachLambda(Coord extent, Func func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params()) { + + Func func(params); + + for (size_t index = 0; index < capacity; ++index) { + ptr[index] = func(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_norm.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_norm.h new file mode 100644 index 0000000000000..8a7240665550d --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_norm.h @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" + +// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. + +#include "cutlass/util/reference/host/tensor_reduce.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.h b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.h new file mode 100644 index 0000000000000..048352ae29514 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/reference/detail/linear_to_coordinate.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) { + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + Element x = view.at(coord); + identity = reduce(identity, transform(x)); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Tensor extents must match."); + } + + for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) { + + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + Element a = view_A.at(coord); + Element b = view_B.at(coord); + identity = reduce(identity, transform(a, b)); + } + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.hpp b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.hpp new file mode 100644 index 0000000000000..5ea5154107fcb --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/tensor_reduce.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tensor reductions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Tensor, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + Tensor view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < cute::size(view); ++idx) { + identity = reduce(identity, transform(view(idx))); + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename TensorA, + typename TensorB, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorA view_A, + TensorB view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (cute::size(view_A) != cute::size(view_B)) { + throw std::runtime_error("Tensor sizes must match."); + } + + for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { + identity = reduce(identity, transform(view_A(idx), view_B(idx))); + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSum( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSumSq( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Tensor, + typename ComputeType = double +> +ComputeType TensorNorm( + Tensor view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/trmm.h b/csrc/quantization/cutlass_test/example/util/reference/host/trmm.h new file mode 100644 index 0000000000000..08b979254278c --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/trmm.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Trmm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Trmm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/reference/host/trmm_complex.h b/csrc/quantization/cutlass_test/example/util/reference/host/trmm_complex.h new file mode 100644 index 0000000000000..86e58a035b481 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/reference/host/trmm_complex.h @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + // Conjugate, and hence hermitian, is only allowed for the triangular matrix + if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct TrmmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/example/util/tensor_view_io.h b/csrc/quantization/cutlass_test/example/util/tensor_view_io.h new file mode 100644 index 0000000000000..4f6bdd686b8f0 --- /dev/null +++ b/csrc/quantization/cutlass_test/example/util/tensor_view_io.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include "cutlass/core_io.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteLeastSignificantRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + out << ScalarIO(view.at(coord)); + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by "\n" + if (idx) { + out << ",\n"; + } + TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << ",\n\n"; + } + TensorView_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by ";\n" + if (idx) { + out << ";\n"; + } + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << "\n"; + } + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/exceptions.h b/csrc/quantization/cutlass_test/exceptions.h new file mode 100644 index 0000000000000..54c62fdbb6f5d --- /dev/null +++ b/csrc/quantization/cutlass_test/exceptions.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ exception semantics for CUDA error codes + */ + +#include +#include +#include + +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/// C++ exception wrapper for CUDA \p cudaError_t +class cuda_exception : public std::exception { + public: + /// Constructor + cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} + + /// Returns the underlying CUDA \p cudaError_t + cudaError_t cudaError() const { return err; } + + protected: + /// Explanatory string + const char* msg; + + /// Underlying CUDA \p cudaError_t + cudaError_t err; +}; + +/// Writes a cuda_exception instance to an output stream +inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { + return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); +} + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/helper.h b/csrc/quantization/cutlass_test/helper.h new file mode 100644 index 0000000000000..f333fab9cac53 --- /dev/null +++ b/csrc/quantization/cutlass_test/helper.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cuda_runtime.h" +#include + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer +{ + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) + { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() + { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() + { + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; diff --git a/csrc/quantization/cutlass_test/host_tensor.h b/csrc/quantization/cutlass_test/host_tensor.h new file mode 100644 index 0000000000000..3f061875b48dc --- /dev/null +++ b/csrc/quantization/cutlass_test/host_tensor.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensor { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRef; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorView; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensor() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensor( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensor( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensor() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); +#endif + + device_.reset(); + host_.clear(); + + size_t count_container = count_to_container_storage_unit_count(count); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); +#endif + host_.resize(count_container); + + // Allocate memory + StorageUnit* device_memory = nullptr; + if (device_backed_) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); +#endif + device_memory = device_memory::allocate(count_container); + } + device_.reset(device_memory, device_backed_ ? count_container : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); + + if (static_cast(new_size_container) > host_.size()) { + reserve(new_size, device_backed_); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the logical number of elements stored in the host tensor + size_t size() const { + return layout_.capacity(extent_); + } + + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). + LongIndex capacity() const { + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; + } + + /// Gets pointer to host data + Element * host_data() { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to device data + Element * device_data() { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data + Element const * device_data() const { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + + /// Returns the layout object + Layout & layout() { + return layout_; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_.data(), device_.get(), device_.size()); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_.get(), host_.data(), host_.size()); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + host_.data(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + device_.get(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + device_.get(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + host_.data(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + reinterpret_cast(ptr_host), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + reinterpret_cast(ptr_device), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + reinterpret_cast(ptr_device), host_.data(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + reinterpret_cast(ptr_host), host_.data(), container_count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/packed_stride.hpp b/csrc/quantization/cutlass_test/packed_stride.hpp new file mode 100644 index 0000000000000..e9a243a1322cc --- /dev/null +++ b/csrc/quantization/cutlass_test/packed_stride.hpp @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/container/array.hpp" // cute::array +#include "cutlass/conv/convolution.h" // cutlass::conv::Operator + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with group mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides for convolutions + +// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) +// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order +// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout +// right in KTRSC order and can be coalesced to just k. +// We enforce this condition here with asserts. +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, cute::Int<0>> s, + cute::array shape_output, + cute::array stride_output, + cutlass::conv::Operator conv_op) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + static_assert(RankT_ >= 3u); + constexpr static int RankT = static_cast(RankT_); + + assert(stride_output[RankT-1] == 1); + cute::for_each(cute::make_seq{}, [&](auto i) { + assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); + }); + + auto s_copy = s; + cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? + stride_output[0] : + stride_output[RankT-2]; + return s_copy; +} + +// +// Activation tensor ((w, h, d, n), _1) for fprop kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nwc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_nwc[1]; + cute::get<0,1>(s_copy) = stride_nwc[0]; + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nhwc[3] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_nhwc[2-i]; + }); + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; + }); + return s_copy; +} + +// +// Filter tensor (k, (_1, s, r, t)) for fprop kernel +// + +// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>> +make_cute_packed_stride( + cute::Stride, IntT>> s, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>> s, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>> s, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + +// +// Activation tensor (_1, (w, h, d, n)) for wgrad kernel +// +// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad +// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nwc[2] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::get<1,0>(s_copy) = stride_nwc[1]; + cute::get<1,1>(s_copy) = stride_nwc[0]; + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nwc in dgrad is ksc. + cute::get<1,0>(s_copy) = stride_nwc[0]; + cute::get<1,1>(s_copy) = stride_nwc[1]; + } + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad +// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nhwc[3] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_nhwc[2-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nhwc in dgrad is krsc. + cute::get<1,0>(s_copy) = stride_nhwc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; + }); + } + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_ndhwc in dgrad is ktrsc. + cute::get<1,0>(s_copy) = stride_ndhwc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; + }); + } + return s_copy; +} + +// +// NZPQ tensor (_1, nzpq) for wgrad kernel +// + +// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nqk[2] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nqk[1]; + return s_copy; +} + +// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_npqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_npqk[3] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_npqk[2]; + return s_copy; +} + +// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nzpqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nzpqk[4] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nzpqk[3]; + return s_copy; +} + + + +// +// Wgrad output tensor (k, (_1, s, r, t), _0) +// + +// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/quantization/cutlass_test/test_mm_c3x.cu b/csrc/quantization/cutlass_test/test_mm_c3x.cu new file mode 100644 index 0000000000000..b544e01a2913a --- /dev/null +++ b/csrc/quantization/cutlass_test/test_mm_c3x.cu @@ -0,0 +1,205 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "broadcast_load_epilogue_c3x.hpp" +#include "common.hpp" +// clang-format on + +#include "common_gemm.cuh" + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(e.dtype() == torch::kUInt8); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(e.dtype() == torch::kUInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_int8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } else { + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_test_gemm_caller( + out, a, e, b, std::forward(args)...); + } +} + +template