diff --git a/lib/nnc/cmd/pad/mps/ccv_nnc_pad_mps.m b/lib/nnc/cmd/pad/mps/ccv_nnc_pad_mps.m index 11abcd198..22d03f8df 100644 --- a/lib/nnc/cmd/pad/mps/ccv_nnc_pad_mps.m +++ b/lib/nnc/cmd/pad/mps/ccv_nnc_pad_mps.m @@ -73,7 +73,6 @@ static int _ccv_nnc_pad_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int* const end = cmd.info.pad.end; const int g_nd = ccv_nnc_tensor_nd(g->info.dim); assert(g_nd == ccv_nnc_tensor_nd(a->info.dim)); - const int type = cmd.info.pad.type; MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(a_key, indices, ^void (MPSGraph* graph, NSMutableArray* inputTensors, NSMutableArray* inputShapedTypes, NSMutableArray* resultTensors) { MPSGraphTensor* mps_input_g; MPSGraphTensor* mps_g = ccv_nnc_mps_graph_tensor_input(graph, g, g->info.dim, g->stride, &mps_input_g); diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index d14828c2a..13ae4b798 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -108,12 +108,6 @@ void mfa::cache::prepare(mfa::context* con _mfa_cache_prepare(&map, context, hash); } -template <> -void mfa::cache::prepare(mfa::context* context, mfa::cmul::hash hash) -{ - _mfa_cache_prepare(&map, context, hash); -} - template <> void mfa::cache::prepare(mfa::context* context, mfa::gemv::hash hash) { diff --git a/lib/nnc/mfa/ccv_nnc_mfa.hpp b/lib/nnc/mfa/ccv_nnc_mfa.hpp index 1ba6f733a..61dfafbf8 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.hpp @@ -52,7 +52,6 @@ class context { cache normalization_cache; cache depalettize_cache; cache adam_cache; - cache cmul_cache; cache gemv_cache; cache cast_cache; cache add_cache; diff --git a/lib/nnc/mfa/ccv_nnc_mfa_cmul.cpp b/lib/nnc/mfa/ccv_nnc_mfa_cmul.cpp index c3891aeaf..171a5e20a 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_cmul.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_cmul.cpp @@ -1,5 +1,7 @@ #include "ccv_nnc_mfa.hpp" #include "ccv_nnc_mfa_hash.hpp" +#include "v2/CMulDescriptor.hpp" +#include "v2/CMulKernel.hpp" #include using namespace ccv::nnc; @@ -9,30 +11,57 @@ using namespace ccv::nnc; void ccv_nnc_mfa_prepare_cmul(mfa::context* context, ccv_nnc_mfa_cmul_params_t params) { - context->cmul_cache.prepare(context, mfa::cmul::hash(params)); + // Do nothing now. } void ccv_nnc_mfa_encode_cmul(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_cmul_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets) { - mfa::cmul::hash hash(params); - auto iterator = context->cmul_cache.map.find(hash); - if (iterator == context->cmul_cache.map.end()) { - mfa::precondition_failure("cmul hash not cached.", __LINE__, __FILE__, __FUNCTION__); - } - - auto* pipeline = iterator->second; auto encoder = command_batch->startCommand(); - + int num_tensors = 0; while (tensors[num_tensors] != nullptr) { encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors)); num_tensors += 1; } CCV_NNC_MFA_PRECONDITION(num_tensors == 3); - - encoder->setComputePipelineState(pipeline->cmul_pso.get()); - if (tensors[0] == tensors[2]) - { + + CMulDescriptor descriptor; + descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16; + descriptor.stridesA[0] = params.astride[0]; + descriptor.stridesA[1] = params.astride[1]; + descriptor.stridesA[2] = params.astride[2]; + descriptor.stridesB[0] = params.bstride[0]; + descriptor.stridesB[1] = params.bstride[1]; + descriptor.stridesB[2] = params.bstride[2]; + descriptor.stridesC[0] = params.cstride[0]; + descriptor.stridesC[1] = params.cstride[1]; + descriptor.stridesC[2] = params.cstride[2]; + + descriptor.dimensions[0] = params.dim[0]; + descriptor.dimensions[1] = params.dim[1]; + descriptor.dimensions[2] = params.dim[2]; + + if (params.dim[3] == 0 && params.dim[2] == 0 && params.dim[1] == 0) { + descriptor.value = 0; + } else if (params.dim[3] == 0 && params.dim[2] == 0) { + descriptor.value = 1; + } else if (params.dim[3] == 0) { + descriptor.value = 2; + } else { + descriptor.value = 3; + } + + auto pool = NS::AutoreleasePool::alloc()->init(); + auto &shaderCache = context->v2_cache; + DeviceProperties dprops = DeviceProperties(); + auto pipelineValue = shaderCache.findKernel(descriptor, context->device.get(), dprops); + pool->drain(); + auto kernel = pipelineValue->kernel; + auto pipeline = pipelineValue->pipeline; + + encoder->setComputePipelineState(pipeline.get()); + + if (tensors[0] == tensors[2]) { encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); encoder->useResource(tensors[1], MTL::ResourceUsageRead); } else if (tensors[1] == tensors[2]) { @@ -44,288 +73,20 @@ void ccv_nnc_mfa_encode_cmul(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_cmul_pa encoder->useResource(tensors[2], MTL::ResourceUsageWrite); } - auto grid_size = pipeline->grid_size; - CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0); - encoder->dispatchThreadgroups(grid_size, pipeline->group_size); - command_batch->finishCommand(encoder); -} - -// MARK: - C++ - -mfa::cmul::hash::hash(ccv_nnc_mfa_cmul_params_t params) { - data_type = params.data_type; - memcpy(astride, params.astride, sizeof(params.astride)); - memcpy(bstride, params.bstride, sizeof(params.bstride)); - memcpy(cstride, params.cstride, sizeof(params.cstride)); - memcpy(dim, params.dim, sizeof(params.dim)); -} - -bool mfa::cmul::hash::operator==(const mfa::cmul::hash& hash) const { - return - (data_type == hash.data_type) && - (astride[0] == hash.astride[0]) && - (astride[1] == hash.astride[1]) && - (astride[2] == hash.astride[2]) && - (bstride[0] == hash.bstride[0]) && - (bstride[1] == hash.bstride[1]) && - (bstride[2] == hash.bstride[2]) && - (cstride[0] == hash.cstride[0]) && - (cstride[1] == hash.cstride[1]) && - (cstride[2] == hash.cstride[2]) && - (dim[0] == hash.dim[0]) && - (dim[1] == hash.dim[1]) && - (dim[2] == hash.dim[2]) && - (dim[3] == hash.dim[3]); -} - -std::ostream& operator<<(std::ostream& os, const mfa::cmul::hash& hash) { - os << "mfa::cmul::hash {"; - os << " .data_type = " << hash.data_type << ','; - os << " .astride[0] = " << hash.astride[0] << ','; - os << " .astride[1] = " << hash.astride[1] << ','; - os << " .astride[2] = " << hash.astride[2] << ','; - os << " .bstride[0] = " << hash.bstride[0] << ','; - os << " .bstride[1] = " << hash.bstride[1] << ','; - os << " .bstride[2] = " << hash.bstride[2] << ','; - os << " .cstride[0] = " << hash.cstride[0] << ','; - os << " .cstride[1] = " << hash.cstride[1] << ','; - os << " .cstride[2] = " << hash.cstride[2] << ','; - os << " .dim[0] = " << hash.dim[0] << ','; - os << " .dim[1] = " << hash.dim[1] << ','; - os << " .dim[2] = " << hash.dim[2] << ','; - os << " .dim[3] = " << hash.dim[3] << " "; - os << "}"; - return os; -} - -std::size_t std::hash::operator()(const mfa::cmul::hash& hash) const noexcept { - std::size_t seed = 0; - using namespace mfa::hash; - combine_64(seed, hash.data_type); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.astride[0], (unsigned int)hash.astride[1] })); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.astride[2], (unsigned int)hash.bstride[0] })); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.bstride[1], (unsigned int)hash.bstride[2] })); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.cstride[0], (unsigned int)hash.cstride[1] })); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.cstride[2], (unsigned int)hash.dim[0] })); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.dim[1], (unsigned int)hash.dim[2] })); - combine_32(seed, (unsigned int)hash.dim[3]); - return seed; -} - -mfa::cmul::pipeline::pipeline(mfa::context* context, mfa::cmul::hash hash) { - // FlashNorm not supported for group cmul yet. - CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) - - auto* pool = NS::AutoreleasePool::alloc()->init(); - - std::string shader; - if (hash.dim[3] == 0 && hash.dim[2] == 0 && hash.dim[1] == 0) { - shader = R"( -#include -using namespace metal; - -kernel void cmul( - device real *src0 [[buffer(0)]], - device real *src1 [[buffer(1)]], - device real *destination [[buffer(2)]], - - uint3 tpig [[thread_position_in_grid]] -) { - const uint idx = tpig.x; - if (idx >= dim0) - return; - const float a0 = (float)src0[idx * 2]; - const float a1 = (float)src0[idx * 2 + 1]; - const float b0 = (float)src1[idx * 2]; - const float b1 = (float)src1[idx * 2 + 1]; - destination[idx * 2] = (real)(a0 * b0 - a1 * b1); - destination[idx * 2 + 1] = (real)(a0 * b1 + a1 * b0); -} - )"; - } else if (hash.dim[3] == 0 && hash.dim[2] == 0) { - shader = R"( -#include -using namespace metal; - -kernel void cmul( - device real *src0 [[buffer(0)]], - device real *src1 [[buffer(1)]], - device real *destination [[buffer(2)]], - - uint3 tpig [[thread_position_in_grid]] -) { - const uint x = tpig.x; - const uint y = tpig.y; - if (y >= dim1 || x >= dim0) - return; - const uint ida = y * astride0 + x * 2; - const uint idb = y * bstride0 + x * 2; - const uint idc = y * cstride0 + x * 2; - const float a0 = (float)src0[ida]; - const float a1 = (float)src0[ida + 1]; - const float b0 = (float)src1[idb]; - const float b1 = (float)src1[idb + 1]; - destination[idc] = (real)(a0 * b0 - a1 * b1); - destination[idc + 1] = (real)(a0 * b1 + a1 * b0); -} - )"; - } else if (hash.dim[3] == 0) { - shader = R"( -#include -using namespace metal; - -kernel void cmul( - device real *src0 [[buffer(0)]], - device real *src1 [[buffer(1)]], - device real *destination [[buffer(2)]], - - uint3 tpig [[thread_position_in_grid]] -) { - const uint x = tpig.x; - const uint y = tpig.y; - const uint z = tpig.z; - if (y >= dim1 || x >= dim0) - return; - const uint ida = z * astride1 + y * astride0 + x * 2; - const uint idb = z * bstride1 + y * bstride0 + x * 2; - const uint idc = z * cstride1 + y * cstride0 + x * 2; - const float a0 = (float)src0[ida]; - const float a1 = (float)src0[ida + 1]; - const float b0 = (float)src1[idb]; - const float b1 = (float)src1[idb + 1]; - destination[idc] = (real)(a0 * b0 - a1 * b1); - destination[idc + 1] = (real)(a0 * b1 + a1 * b0); -} - )"; + MTL::Size gridSize; + if (params.dim[3] == 0 && params.dim[2] == 0 && params.dim[1] == 0) { + const int num_blocks = (params.dim[0] / 2 + 255) / 256; + gridSize = MTL::Size(num_blocks, 1, 1); + } else if (params.dim[3] == 0 && params.dim[2] == 0) { + gridSize = MTL::Size((params.dim[0] / 2 + 31) / 32, (params.dim[1] + 7) / 8, 1); + } else if (params.dim[3] == 0) { + gridSize = MTL::Size((params.dim[0] / 2 + 31) / 32, (params.dim[1] + 7) / 8, params.dim[2]); } else { - shader = R"( -#include -using namespace metal; - -kernel void cmul( - device real *src0 [[buffer(0)]], - device real *src1 [[buffer(1)]], - device real *destination [[buffer(2)]], - - uint3 tpig [[thread_position_in_grid]] -) { - const uint x = tpig.x; - const uint y = tpig.y; - const uint z = tpig.z; - if (y >= dim1 || x >= dim0) - return; - const int u = z % dim2; - const int v = z / dim2; - const uint ida = v * astride2 + u * astride1 + y * astride0 + x * 2; - const uint idb = v * bstride2 + u * bstride1 + y * bstride0 + x * 2; - const uint idc = v * cstride2 + u * cstride1 + y * cstride0 + x * 2; - const float a0 = (float)src0[ida]; - const float a1 = (float)src0[ida + 1]; - const float b0 = (float)src1[idb]; - const float b1 = (float)src1[idb + 1]; - destination[idc] = (real)(a0 * b0 - a1 * b1); - destination[idc + 1] = (real)(a0 * b1 + a1 * b0); -} - )"; + gridSize = MTL::Size((params.dim[0] / 2 + 31) / 32, (params.dim[1] + 7) / 8, params.dim[2] * params.dim[3]); } + CCV_NNC_MFA_PRECONDITION(gridSize.depth > 0); + encoder->dispatchThreadgroups(gridSize, kernel->threadgroupSize); - std::string defines = ""; - if (hash.data_type == MTL::DataTypeFloat) { - defines += std::string("typedef float real;"); - defines += "\n"; - } else { - defines += std::string("typedef half real;"); - defines += "\n"; - } - - defines += "constant uint dim0 = "; - defines += std::to_string(hash.dim[0] / 2) + ";"; - defines += "\n"; - if (hash.dim[1] > 0) - { - defines += "constant uint dim1 = "; - defines += std::to_string(hash.dim[1]) + ";"; - defines += "\n"; - defines += "constant uint astride0 = "; - defines += std::to_string(hash.astride[0]) + ";"; - defines += "\n"; - defines += "constant uint bstride0 = "; - defines += std::to_string(hash.bstride[0]) + ";"; - defines += "\n"; - defines += "constant uint cstride0 = "; - defines += std::to_string(hash.cstride[0]) + ";"; - defines += "\n"; - } - if (hash.dim[2] > 0) - { - defines += "constant uint astride1 = "; - defines += std::to_string(hash.astride[1]) + ";"; - defines += "\n"; - defines += "constant uint bstride1 = "; - defines += std::to_string(hash.bstride[1]) + ";"; - defines += "\n"; - defines += "constant uint cstride1 = "; - defines += std::to_string(hash.cstride[1]) + ";"; - defines += "\n"; - } - if (hash.dim[3] > 0 && hash.dim[2] > 0) - { - defines += "constant uint dim2 = "; - defines += std::to_string(hash.dim[2]) + ";"; - defines += "\n"; - defines += "constant uint astride2 = "; - defines += std::to_string(hash.astride[2]) + ";"; - defines += "\n"; - defines += "constant uint bstride2 = "; - defines += std::to_string(hash.bstride[2]) + ";"; - defines += "\n"; - defines += "constant uint cstride2 = "; - defines += std::to_string(hash.cstride[2]) + ";"; - defines += "\n"; - } - if (hash.dim[3] == 0 && hash.dim[2] == 0 && hash.dim[1] == 0) - { - this->group_size = MTL::Size(256, 1, 1); - const int num_blocks = (hash.dim[0] / 2 + 255) / 256; - this->grid_size = MTL::Size(num_blocks, 1, 1); - } else if (hash.dim[3] == 0 && hash.dim[2] == 0) { - this->group_size = MTL::Size(32, 8, 1); - this->grid_size = MTL::Size((hash.dim[0] / 2 + 31) / 32, (hash.dim[1] + 7) / 8, 1); - } else if (hash.dim[3] == 0) { - this->group_size = MTL::Size(32, 8, 1); - this->grid_size = MTL::Size((hash.dim[0] / 2 + 31) / 32, (hash.dim[1] + 7) / 8, hash.dim[2]); - } else { - this->group_size = MTL::Size(32, 8, 1); - this->grid_size = MTL::Size((hash.dim[0] / 2 + 31) / 32, (hash.dim[1] + 7) / 8, hash.dim[2] * hash.dim[3]); - } - - auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init()); - NS::SharedPtr* pso = &cmul_pso; - - std::string source = defines; - if (METAL_LOG_LEVEL(context) >= 4) { - std::cerr << source << std::endl; - } - source += shader; - - NS::Error *error = nullptr; - auto swift_source = NS::String::string(source.c_str(), - NS::UTF8StringEncoding); - auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error)); - if (!library) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - auto swift_name = NS::String::string("cmul", NS::UTF8StringEncoding); - auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error)); - if (!function) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - *pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error)); - if (!*pso) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - pool->drain(); + command_batch->finishCommand(encoder); } + diff --git a/lib/nnc/mfa/ccv_nnc_mfa_cmul.hpp b/lib/nnc/mfa/ccv_nnc_mfa_cmul.hpp index 31bd93b09..f65810f29 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_cmul.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_cmul.hpp @@ -28,8 +28,6 @@ class hash { uint32_t dim[4]; hash(ccv_nnc_mfa_cmul_params_t); - - bool operator==(const hash& rhs) const; }; class pipeline { @@ -47,14 +45,6 @@ class pipeline { } // namespace nnc } // namespace ccv -std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::cmul::hash& hash); - -template<> -struct std::hash -{ - std::size_t operator()(const ccv::nnc::mfa::cmul::hash& hash) const noexcept; -}; - extern "C" { #endif // __cplusplus diff --git a/lib/nnc/mfa/makefile b/lib/nnc/mfa/makefile index 51d123741..f28902119 100644 --- a/lib/nnc/mfa/makefile +++ b/lib/nnc/mfa/makefile @@ -2,7 +2,7 @@ include ../../config.mk CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS) -SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp +SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp v2/CMulDescriptor.cpp v2/CMulKernel.cpp SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS))) diff --git a/lib/nnc/mfa/v2/CMulDescriptor.cpp b/lib/nnc/mfa/v2/CMulDescriptor.cpp new file mode 100644 index 000000000..f54fcd61d --- /dev/null +++ b/lib/nnc/mfa/v2/CMulDescriptor.cpp @@ -0,0 +1,108 @@ +#include "CMulDescriptor.hpp" +#include "CMulKernel.hpp" +#include "../ccv_nnc_mfa_hash.hpp" +#include "../ccv_nnc_mfa_error.hpp" + +bool CMulDescriptor::operator==(const CMulDescriptor& rhs) const { + return + memoryPrecision == rhs.memoryPrecision && + value == rhs.value && + simd_all(stridesA == rhs.stridesA) && + simd_all(stridesB == rhs.stridesB) && + simd_all(stridesC == rhs.stridesC) && + simd_all(dimensions == rhs.dimensions); +} + +std::size_t std::hash::operator()(const CMulDescriptor& hash) const noexcept { + using namespace ccv::nnc::mfa::hash; + std::size_t seed = 0; + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.stridesA[0], (unsigned int)hash.stridesA[1] })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.stridesA[2], (unsigned int)hash.stridesB[0] })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.stridesB[1], (unsigned int)hash.stridesB[2] })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.stridesC[0], (unsigned int)hash.stridesC[1] })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.stridesC[2], (unsigned int)hash.dimensions[0] })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.dimensions[1], (unsigned int)hash.dimensions[2] })); + return seed; +} + +std::pair *> CMulDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map> *const libraryCache) const noexcept { + // The caller is not responsible for calling 'delete' on this pointer. The + // reference is saved in the 'libraryCache'. It will be deallocated whenever + // the shader cache itself is cleaned up. + auto createKernel = + [=](CMulKernelDescriptor descriptor) -> CMulKernel* { + auto iterator = libraryCache->find(descriptor); + if (iterator != libraryCache->end()) { + return iterator->second.get(); + } else { + CMulKernel* kernel = new CMulKernel(descriptor, device); + (*libraryCache)[descriptor] = std::unique_ptr(kernel); + return kernel; + } + }; + + CMulKernelDescriptor kernelDesc; + kernelDesc.memoryPrecision = memoryPrecision; + kernelDesc.value = value; + + // WARNING: The owner must explicitly retain the compute pipeline. + auto createPipeline = + [=](MTL::Library* library) -> MTL::ComputePipelineState* { + // Set the function constants. + auto constants = NS::TransferPtr + (MTL::FunctionConstantValues::alloc()->init()); + uint32_t dim0 = dimensions[0] / 2; + constants->setConstantValue(&dim0, MTL::DataTypeUInt, NS::UInteger(0)); + + if (value != 0) { + uint32_t dim1 = dimensions[1]; + uint32_t astride0 = stridesA[0]; + uint32_t bstride0 = stridesB[0]; + uint32_t cstride0 = stridesC[0]; + constants->setConstantValue(&dim1, MTL::DataTypeUInt, 1); + constants->setConstantValue(&astride0, MTL::DataTypeUInt, 2); + constants->setConstantValue(&bstride0, MTL::DataTypeUInt, 3); + constants->setConstantValue(&cstride0, MTL::DataTypeUInt, 4); + } + + if (value != 0 && value != 1) { + uint32_t astride1 = stridesA[1]; + uint32_t bstride1 = stridesB[1]; + uint32_t cstride1 = stridesC[1]; + constants->setConstantValue(&astride1, MTL::DataTypeUInt, 5); + constants->setConstantValue(&bstride1, MTL::DataTypeUInt, 6); + constants->setConstantValue(&cstride1, MTL::DataTypeUInt, 7); + } + + if (value != 0 && value != 1 && value != 2) { + uint32_t dim2 = dimensions[2]; + uint32_t astride2 = stridesA[2]; + uint32_t bstride2 = stridesB[2]; + uint32_t cstride2 = stridesC[2]; + constants->setConstantValue(&dim2, MTL::DataTypeUInt, 8); + constants->setConstantValue(&astride2, MTL::DataTypeUInt, 9); + constants->setConstantValue(&bstride2, MTL::DataTypeUInt, 10); + constants->setConstantValue(&cstride2, MTL::DataTypeUInt, 11); + } + + NS::String* swiftName = NS::String::string("cmul", NS::UTF8StringEncoding); + NS::Error* error = nil; + + auto function = NS::TransferPtr + (library->newFunction(swiftName, constants.get(), &error)); + CCV_NNC_MFA_CHECK_ERROR(error); + + auto pipeline = device->newComputePipelineState(function.get(), &error); + CCV_NNC_MFA_CHECK_ERROR(error); + return pipeline; + }; + CMulKernel* kernel = createKernel(kernelDesc); + auto pipeline = NS::TransferPtr(createPipeline(kernel->library.get())); + + // Force the user to retrieve the return value from the cache. We ensure + // the cache takes ownership, and the pointer doesn't become a zombie + // object. + PipelineValue* output = new PipelineValue { kernel, pipeline }; + return std::make_pair(kernelDesc, output); +} diff --git a/lib/nnc/mfa/v2/CMulDescriptor.hpp b/lib/nnc/mfa/v2/CMulDescriptor.hpp new file mode 100644 index 000000000..e82937c59 --- /dev/null +++ b/lib/nnc/mfa/v2/CMulDescriptor.hpp @@ -0,0 +1,49 @@ +#ifndef MFA_CMULDESCRIPTOR_HPP_ +#define MFA_CMULDESCRIPTOR_HPP_ + +#include +#include +#include "PipelineValue.hpp" +#include "DeviceProperties.hpp" +#include "GEMMOperandPrecision.hpp" + +struct CMulKernelDescriptor { + GEMMOperandPrecision memoryPrecision; + unsigned int value; + constexpr bool operator==(const CMulKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision; } +}; + +template<> +struct std::hash +{ + std::size_t operator()(const CMulKernelDescriptor& hash) const noexcept { return (size_t)hash.value; } +}; + +struct CMulKernel; + +struct CMulDescriptor { + unsigned int value; + + GEMMOperandPrecision memoryPrecision; + + simd::uint3 stridesA; + + simd::uint3 stridesB; + + simd::uint3 stridesC; + + simd::uint3 dimensions; + + bool operator==(const CMulDescriptor& rhs) const; + + std::pair *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map> *const libraryCache) const noexcept; +}; + +template<> +struct std::hash +{ + std::size_t operator()(const CMulDescriptor& hash) const noexcept; +}; + +#endif + diff --git a/lib/nnc/mfa/v2/CMulKernel.cpp b/lib/nnc/mfa/v2/CMulKernel.cpp new file mode 100644 index 000000000..b84b6e2fd --- /dev/null +++ b/lib/nnc/mfa/v2/CMulKernel.cpp @@ -0,0 +1,201 @@ +#include "CMulKernel.hpp" +#include "../ccv_nnc_mfa.hpp" + +#include + +CMulKernel::CMulKernel(CMulKernelDescriptor descriptor, MTL::Device *const device) { + + memoryPrecision = descriptor.memoryPrecision; + + value = descriptor.value; + + source = createSource(); + + threadgroupMemoryAllocation = createThreadgroupMemoryAllocation(); + + if (value == 0) { + threadgroupSize = MTL::Size(256, 1, 1); + } else if (value == 1) { + threadgroupSize = MTL::Size(32, 8, 1); + } else if (value == 2) { + threadgroupSize = MTL::Size(32, 8, 1); + } else { + threadgroupSize = MTL::Size(32, 8, 1); + } + + // Compile the shader source. + { + auto string = NS::String::string(source.c_str(), NS::UTF8StringEncoding); + NS::Error* error = nil; + library = NS::TransferPtr(device->newLibrary(string, nil, &error)); + CCV_NNC_MFA_CHECK_ERROR(error); + } +} + +unsigned short CMulKernel::createThreadgroupMemoryAllocation() const noexcept { + return 0; +} + +std::string CMulKernel::createSource() const noexcept { + std::string shader = createConstants() + "\n"; + if (value == 0) { + shader += R"( +#include +using namespace metal; + +kernel void cmul( + device real *src0 [[buffer(0)]], + device real *src1 [[buffer(1)]], + device real *destination [[buffer(2)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint idx = tpig.x; + if (idx >= dim0) + return; + const float a0 = (float)src0[idx * 2]; + const float a1 = (float)src0[idx * 2 + 1]; + const float b0 = (float)src1[idx * 2]; + const float b1 = (float)src1[idx * 2 + 1]; + destination[idx * 2] = (real)(a0 * b0 - a1 * b1); + destination[idx * 2 + 1] = (real)(a0 * b1 + a1 * b0); +} + )"; + } else if (value == 1) { + shader += R"( +#include +using namespace metal; + +kernel void cmul( + device real *src0 [[buffer(0)]], + device real *src1 [[buffer(1)]], + device real *destination [[buffer(2)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint x = tpig.x; + const uint y = tpig.y; + if (y >= dim1 || x >= dim0) + return; + const uint ida = y * astride0 + x * 2; + const uint idb = y * bstride0 + x * 2; + const uint idc = y * cstride0 + x * 2; + const float a0 = (float)src0[ida]; + const float a1 = (float)src0[ida + 1]; + const float b0 = (float)src1[idb]; + const float b1 = (float)src1[idb + 1]; + destination[idc] = (real)(a0 * b0 - a1 * b1); + destination[idc + 1] = (real)(a0 * b1 + a1 * b0); +} + )"; + } else if (value == 2) { + shader += R"( +#include +using namespace metal; + +kernel void cmul( + device real *src0 [[buffer(0)]], + device real *src1 [[buffer(1)]], + device real *destination [[buffer(2)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint x = tpig.x; + const uint y = tpig.y; + const uint z = tpig.z; + if (y >= dim1 || x >= dim0) + return; + const uint ida = z * astride1 + y * astride0 + x * 2; + const uint idb = z * bstride1 + y * bstride0 + x * 2; + const uint idc = z * cstride1 + y * cstride0 + x * 2; + const float a0 = (float)src0[ida]; + const float a1 = (float)src0[ida + 1]; + const float b0 = (float)src1[idb]; + const float b1 = (float)src1[idb + 1]; + destination[idc] = (real)(a0 * b0 - a1 * b1); + destination[idc + 1] = (real)(a0 * b1 + a1 * b0); +} + )"; + } else { + shader += R"( +#include +using namespace metal; + +kernel void cmul( + device real *src0 [[buffer(0)]], + device real *src1 [[buffer(1)]], + device real *destination [[buffer(2)]], + + uint3 tpig [[thread_position_in_grid]] +) { + const uint x = tpig.x; + const uint y = tpig.y; + const uint z = tpig.z; + if (y >= dim1 || x >= dim0) + return; + const int u = z % dim2; + const int v = z / dim2; + const uint ida = v * astride2 + u * astride1 + y * astride0 + x * 2; + const uint idb = v * bstride2 + u * bstride1 + y * bstride0 + x * 2; + const uint idc = v * cstride2 + u * cstride1 + y * cstride0 + x * 2; + const float a0 = (float)src0[ida]; + const float a1 = (float)src0[ida + 1]; + const float b0 = (float)src1[idb]; + const float b1 = (float)src1[idb + 1]; + destination[idc] = (real)(a0 * b0 - a1 * b1); + destination[idc + 1] = (real)(a0 * b1 + a1 * b0); +} + )"; + } + return shader; +} + +std::string CMulKernel::createConstants() const noexcept { + + std::string defines = ""; + if (memoryPrecision == GEMMOperandPrecision::FP32) { + defines += std::string("typedef float real;"); + defines += "\n"; + } else if (memoryPrecision == GEMMOperandPrecision::BF16) { + defines += std::string("typedef bfloat real;"); + defines += "\n"; + } else { + defines += std::string("typedef half real;"); + defines += "\n"; + } + + defines += "constant uint dim0 [[function_constant(0)]];"; + defines += "\n"; + if (value != 0) + { + defines += "constant uint dim1 [[function_constant(1)]];"; + defines += "\n"; + defines += "constant uint astride0 [[function_constant(2)]];"; + defines += "\n"; + defines += "constant uint bstride0 [[function_constant(3)]];"; + defines += "\n"; + defines += "constant uint cstride0 [[function_constant(4)]];"; + defines += "\n"; + } + if (value != 0 && value != 1) + { + defines += "constant uint astride1 [[function_constant(5)]];"; + defines += "\n"; + defines += "constant uint bstride1 [[function_constant(6)]];"; + defines += "\n"; + defines += "constant uint cstride1 [[function_constant(7)]];"; + defines += "\n"; + } + if (value != 0 && value != 1 && value != 2) + { + defines += "constant uint dim2 [[function_constant(8)]];"; + defines += "\n"; + defines += "constant uint astride2 [[function_constant(9)]];"; + defines += "\n"; + defines += "constant uint bstride2 [[function_constant(10)]];"; + defines += "\n"; + defines += "constant uint cstride2 [[function_constant(11)]];"; + defines += "\n"; + } + return defines; +} diff --git a/lib/nnc/mfa/v2/CMulKernel.hpp b/lib/nnc/mfa/v2/CMulKernel.hpp new file mode 100644 index 000000000..9ce015117 --- /dev/null +++ b/lib/nnc/mfa/v2/CMulKernel.hpp @@ -0,0 +1,31 @@ +#ifndef CMulKernel_hpp +#define CMulKernel_hpp + +#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" +#include +#include "CMulDescriptor.hpp" + +struct CMulKernel { + NS::SharedPtr library; + + std::string source; + + unsigned short threadgroupMemoryAllocation; + + /// The number of threads per group. + MTL::Size threadgroupSize; + + GEMMOperandPrecision memoryPrecision; + + unsigned int value; + + CMulKernel(CMulKernelDescriptor descriptor, MTL::Device *const device); + +private: + unsigned short createThreadgroupMemoryAllocation() const noexcept; + std::string createSource() const noexcept; + std::string createConstants() const noexcept; +}; + +#endif /* CMulKernel_hpp */ +