diff --git a/operatorspy/tests/add.py b/operatorspy/tests/add.py index 455014cc..da9c58c9 100644 --- a/operatorspy/tests/add.py +++ b/operatorspy/tests/add.py @@ -115,6 +115,16 @@ def test_bang(lib, test_cases): test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for c_shape, a_shape, b_shape, inplace in test_cases: + test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -163,6 +173,8 @@ def test_bang(lib, test_cases): test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) - if not (args.cpu or args.cuda or args.bang): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/causal_softmax.py b/operatorspy/tests/causal_softmax.py index 623c0fac..b7cabc4a 100644 --- a/operatorspy/tests/causal_softmax.py +++ b/operatorspy/tests/causal_softmax.py @@ -119,6 +119,16 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + + handle = create_handle(lib, device) + for x_shape, x_stride in test_cases: + test(lib, handle, "musa", x_shape, x_stride) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # x_shape, x_stride @@ -161,6 +171,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py index e060ad73..87365c05 100644 --- a/operatorspy/tests/expand.py +++ b/operatorspy/tests/expand.py @@ -133,6 +133,16 @@ def test_bang(lib, test_cases): test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -174,6 +184,8 @@ def test_bang(lib, test_cases): test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) - if not (args.cpu or args.cuda or args.bang): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index ba590447..31076fb5 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -325,6 +325,37 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for ( + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride, + b_stride, + c_stride, + dtype, + ) in test_cases: + test( + lib, + handle, + "musa", + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride, + b_stride, + c_stride, + dtype, + ) + if __name__ == "__main__": test_cases = [ # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype @@ -387,6 +418,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index 4b0c2a10..85a3c681 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -94,7 +94,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ if(torch_device == 'maca'): indices = torch.zeros([1], dtype = torch.int64).to('cuda') else: - indices = torch.zeros([1], dtype = torch.uint64).to(torch_device) + indices = torch.zeros([1], dtype = torch.int64).to(torch_device) x_tensor = to_tensor(data, lib) indices_tensor = to_tensor(indices, lib) indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 @@ -170,7 +170,7 @@ def test_ascend(lib, test_cases): for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "npu", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) - + def test_maca(lib, test_cases): device = DeviceEnum.DEVICE_MACA handle = create_handle(lib, device) @@ -179,6 +179,13 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for (voc, random_val, topp, topk, temperature) in test_cases: + test(lib, handle, "musa", voc, random_val, topp, topk, temperature) + destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ @@ -236,6 +243,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rearrange.py b/operatorspy/tests/rearrange.py index 124fe552..9709e6b3 100644 --- a/operatorspy/tests/rearrange.py +++ b/operatorspy/tests/rearrange.py @@ -117,6 +117,26 @@ def test_maca(lib, test_cases): test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for test_case in test_cases: + x_shape, x_stride = test_case[0] + y_shape, y_stride = test_case[1] + test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride) + destroy_handle(lib, handle) + +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for test_case in test_cases: + x_shape, x_stride = test_case[0] + y_shape, y_stride = test_case[1] + test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride) + destroy_handle(lib, handle) + if __name__ == "__main__": args = get_args() test_cases = [ @@ -156,4 +176,6 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/relu.py b/operatorspy/tests/relu.py index b7f76627..b99706ff 100644 --- a/operatorspy/tests/relu.py +++ b/operatorspy/tests/relu.py @@ -132,6 +132,16 @@ def test_bang(lib, test_cases): test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for tensor_shape, inplace in test_cases: + test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -172,6 +182,8 @@ def test_bang(lib, test_cases): test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) - if not (args.cpu or args.cuda or args.bang): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rms_norm.py b/operatorspy/tests/rms_norm.py index 8176af64..46b1d0f3 100644 --- a/operatorspy/tests/rms_norm.py +++ b/operatorspy/tests/rms_norm.py @@ -125,6 +125,14 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: + test(lib, handle, "musa", y_shape, x_shape, w_shape, dtype, w_dtype) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # y_shape, x_shape, w_shape, dtype, w_dtype @@ -174,6 +182,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index b7123052..1c1122a6 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -77,7 +77,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): pos[2 * i] = posTmp[i] pos[2 * i + 1] = 0 theta = 1e4 - if torch_device == 'mlu' or torch_device == 'npu': + if torch_device == 'mlu' or torch_device == 'npu' or torch_device == 'musa': ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) pos = pos.to(torch_device) t = t.to(torch_device) @@ -181,6 +181,14 @@ def test_maca(lib, test_cases) : test(lib, handle, "maca", shape, strides, dtype) destroy_handle(lib, handle) +def test_musa(lib, test_cases) : + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for shape, strides, dtype in test_cases: + test(lib, handle, "musa", shape, strides, dtype) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ ((1, 32, 128), None, torch.float16), @@ -233,6 +241,8 @@ def test_maca(lib, test_cases) : test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/swiglu.py b/operatorspy/tests/swiglu.py index fcd044f1..9ca07c14 100644 --- a/operatorspy/tests/swiglu.py +++ b/operatorspy/tests/swiglu.py @@ -262,6 +262,20 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + + for shape, a_stride, b_stride, c_stride, dtype in test_cases: + test_out_of_place( + lib, handle, "musa", shape, a_stride, b_stride, c_stride, dtype + ) + test_in_place1(lib, handle, "musa", shape, a_stride, b_stride, dtype) + test_in_place2(lib, handle, "musa", shape, a_stride, b_stride, dtype) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -307,4 +321,6 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/test_utils.py b/operatorspy/tests/test_utils.py index 68b71bc4..6e4960d5 100644 --- a/operatorspy/tests/test_utils.py +++ b/operatorspy/tests/test_utils.py @@ -32,6 +32,11 @@ def get_args(): action="store_true", help="Run ASCEND NPU test", ) + parser.add_argument( + "--musa", + action="store_true", + help="Run MUSA test", + ) return parser.parse_args() diff --git a/src/devices/handle.cc b/src/devices/handle.cc index 45779776..6b7f54a8 100644 --- a/src/devices/handle.cc +++ b/src/devices/handle.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_METAX_GPU #include "./maca/maca_handle.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "./musa/musa_handle.h" +#endif __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device device, int device_id) { @@ -48,6 +51,11 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d case DevMetaxGpu: { return createMacaHandle((MacaHandle_t *) handle_ptr, device_id); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return createMusaHandle((MusaHandle_t *) handle_ptr, device_id); + } #endif } return STATUS_BAD_DEVICE; @@ -81,6 +89,12 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { case DevMetaxGpu: { return deleteMacaHandle((MacaHandle_t) handle); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + deleteMusaHandle((MusaHandle_t) handle); + return STATUS_SUCCESS; + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/devices/musa/common_musa.h b/src/devices/musa/common_musa.h new file mode 100644 index 00000000..c42b5197 --- /dev/null +++ b/src/devices/musa/common_musa.h @@ -0,0 +1,77 @@ +#ifndef __COMMON_MUSA_H__ +#define __COMMON_MUSA_H__ + +#define MAX_THREADS_PER_BLOCK 1024 +#define MAX_WARP_PER_BLOCK 32 +#define WARP_SIZE 32 + +#include +#include "data_type.h" +#include +#include +#include + +enum class Type { + QINT4, + QINT8, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + HALF, + BFLOAT16, + FLOAT, + DOUBLE, + BOOL, +}; + +enum class Format { + UNKNOWN, + NCW, + NWC, + NCHW, + NHWC, + HWCN, + NCDHW, + NDHWC, + DHWCN, +}; + +#define checkMusaErrorWithCode(call, errorCode) \ + do { \ + if (auto status = call; status != musaSuccess) { \ + std::cerr << "MUSA error: " << musaGetErrorString(status) \ + << " in file " << __FILE__ \ + << ", function " << __func__ \ + << ", line " << __LINE__ << std::endl; \ + return errorCode; \ + } \ + } while (0) + +#define checkMusaError(call) checkMusaErrorWithCode(call, STATUS_BAD_DEVICE) + +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +inline __device__ uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} + +// get the memory offset of the given element in a tensor given its flat index +inline __device__ uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +#endif // __COMMON_MUSA_H__ diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc new file mode 100644 index 00000000..3a7f8174 --- /dev/null +++ b/src/devices/musa/musa_handle.cc @@ -0,0 +1,57 @@ +#include "musa_handle.h" +#include + +infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { + int device_count; + musaGetDeviceCount(&device_count); + if (device_id >= device_count) { + return STATUS_BAD_DEVICE; + } + + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != device_id && musaSetDevice(device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + + // set MUSA device property + musaDeviceProp prop; + musaGetDeviceProperties(&prop, device_id); + + // create a mublas handle pool + auto mublas_pool = std::make_shared>(); + mublasHandle_t *mublas_handle = new mublasHandle_t; + mublasCreate(mublas_handle); + mublas_pool->push(mublas_handle); + + // create a mudnn handle pool + auto mudnn_pool = std::make_shared>(); + musa::dnn::Handle *mudnn_handle = new musa::dnn::Handle; + mudnn_pool->push(mudnn_handle); + + int capability_major; + int capability_minor; + musaDeviceGetAttribute(&capability_major, musaDevAttrComputeCapabilityMajor, device_id); + musaDeviceGetAttribute(&capability_minor, musaDevAttrComputeCapabilityMinor, device_id); + + *handle_ptr = new MusaContext{ + DevMthreadsGpu, + device_id, + std::move(mublas_pool), + std::move(mudnn_pool), + std::move(prop), + capability_major, + capability_minor,}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr) { + handle_ptr->mublas_handles_t = nullptr; + handle_ptr->mudnn_handles_t = nullptr; + delete handle_ptr; + + return STATUS_SUCCESS; +} diff --git a/src/devices/musa/musa_handle.h b/src/devices/musa/musa_handle.h new file mode 100644 index 00000000..6de2c2d3 --- /dev/null +++ b/src/devices/musa/musa_handle.h @@ -0,0 +1,64 @@ +#ifndef __MUSA_HANDLE_H__ +#define __MUSA_HANDLE_H__ + +#include "pool.h" +#include "device.h" +#include "status.h" +#include "ops/matmul/matmul.h" +#include +#include +#include +#include +#include + +struct MusaContext { + Device device; + int device_id; + std::shared_ptr> mublas_handles_t; + std::shared_ptr> mudnn_handles_t; + musaDeviceProp prop; + int compute_capability_major; + int compute_capability_minor; +}; +typedef struct MusaContext *MusaHandle_t; + +infiniopStatus_t createMusaHandle(MusaHandle_t *handle_ptr, int device_id); + +infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr); + +template +void use_mublas(std::shared_ptr> mublas_handles_t, int device_id, MUstream stream, T const &f) { + mublasHandle_t *handle = mublas_handles_t->pop(); + if (!handle) { + int current_device; + musaGetDevice(¤t_device); + if (current_device != device_id) { + musaSetDevice(device_id); + } + mublasHandle_t *handle = new mublasHandle_t; + mublasCreate(handle); + } + mublasSetStream(*handle, (MUstream) stream); + f(*handle); + mublas_handles_t->push(handle); +} + +template +void use_mudnn(std::shared_ptr> mudnn_handles_t, int device_id, musaStream_t stream, T const &f) { + musa::dnn::Handle* handle = mudnn_handles_t->pop(); + if (!handle) { + int current_device; + musaGetDevice(¤t_device); + if (current_device != device_id) { + musaSetDevice(device_id); + } + handle = new musa::dnn::Handle(device_id); + // mudnnCreate(handle); + } + // mudnnSetStream(*handle, (MUstream) stream); + handle->SetStream(stream); + f(handle); + mudnn_handles_t->push(handle); +} + +#endif // __MUSA_HANDLE_H__ diff --git a/src/devices/musa/pool.h b/src/devices/musa/pool.h new file mode 100644 index 00000000..2cfb5e32 --- /dev/null +++ b/src/devices/musa/pool.h @@ -0,0 +1,50 @@ +#ifndef __POOL_MUSA_H__ +#define __POOL_MUSA_H__ + +#include +#include +#include + +template +class Pool { +public: + Pool() : _head(nullptr) {} + + Pool(const Pool &) = delete; + + Pool(Pool &&pool) noexcept : _head(pool._head.exchange(nullptr)) {} + + ~Pool() { + while (this->pop()) {} + } + + void push(T *val) const { + Node *new_node = new Node(val); + new_node->next = _head.load(); + while (!_head.compare_exchange_weak(new_node->next, new_node)); + } + + T* pop() const { + Node *top = _head.load(); + Node *new_head = nullptr; + do { + if (!top) { + return nullptr; + } + new_head = top->next; + } while (!_head.compare_exchange_weak(top, new_head)); + return top->data; + } + +private: + template + struct Node { + U *data; + Node *next; + Node(U *data) : data(data), next(nullptr) {} + }; + + mutable std::atomic *> _head; +}; + +#endif // __POOL_MUSA_H__ diff --git a/src/ops/add/musa/add_musa.cc b/src/ops/add/musa/add_musa.cc new file mode 100644 index 00000000..8c4475fe --- /dev/null +++ b/src/ops/add/musa/add_musa.cc @@ -0,0 +1,81 @@ +#include "add_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateAddDescriptor(MusaHandle_t handle, + AddMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b) { + uint64_t ndim = c->ndim; + if (!isValidBroadcastShape(a, b, c)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!is_contiguous(a) || !is_contiguous(b) || !is_contiguous(c)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (c->dt != F16 && c->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (c->dt != a->dt || c->dt != b->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + bool broadcasted = false; + if (ndim != a->ndim || ndim != b->ndim) { + broadcasted = true; + } else { + for (uint64_t i = 0; i < ndim; ++i) { + if (c->shape[i] != a->shape[i] || c->shape[i] != b->shape[i]) { + broadcasted = true; + break; + } + } + } + + uint64_t c_data_size = std::accumulate(c->shape, c->shape + c->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for a and b + int64_t *a_strides = new int64_t[ndim]; + int64_t *b_strides = new int64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + a_strides[i] = (i < ndim - a->ndim || c->shape[i] != a->shape[i + a->ndim - ndim]) ? 0 : a->strides[i + a->ndim - ndim]; + b_strides[i] = (i < ndim - b->ndim || c->shape[i] != b->shape[i + b->ndim - ndim]) ? 0 : b->strides[i + b->ndim - ndim]; + } + + musaDeviceProp prop; + musaGetDeviceProperties(&prop, handle->device_id); + + int64_t *a_strides_d, *b_strides_d, *c_strides_d; + checkMusaErrorWithCode(musaMalloc(&a_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMalloc(&b_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMalloc(&c_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMemcpy(a_strides_d, a_strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(b_strides_d, b_strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(c_strides_d, c->strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new AddMusaDescriptor{ + DevMthreadsGpu, + c->dt, + handle->device_id, + ndim, + c_data_size, + static_cast(prop.maxGridSize[0]), + a_strides_d, + b_strides_d, + c_strides_d, + broadcasted, + }; + + delete[] a_strides; + delete[] b_strides; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyAddDescriptor(AddMusaDescriptor_t desc) { + checkMusaErrorWithCode(musaFree((void *) desc->a_strides), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaFree((void *) desc->b_strides), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaFree((void *) desc->c_strides), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/add/musa/add_musa.h b/src/ops/add/musa/add_musa.h new file mode 100644 index 00000000..c492c45c --- /dev/null +++ b/src/ops/add/musa/add_musa.h @@ -0,0 +1,37 @@ +#ifndef __MUSA_ADD_H__ +#define __MUSA_ADD_H__ + +#include "../../../devices/musa/common_musa.h" +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" +#include +#include + +struct AddMusaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t c_data_size; + uint64_t max_grid_size; + int64_t const *a_strides; + int64_t const *b_strides; + int64_t const *c_strides; + bool broadcasted; +}; + +typedef struct AddMusaDescriptor *AddMusaDescriptor_t; + +infiniopStatus_t musaCreateAddDescriptor(MusaHandle_t, + AddMusaDescriptor_t *, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b); + +infiniopStatus_t musaAdd(AddMusaDescriptor_t desc, + void *c, void const *a, void const *b, + void *stream); + +infiniopStatus_t musaDestroyAddDescriptor(AddMusaDescriptor_t desc); + +#endif diff --git a/src/ops/add/musa/add_musa.mu b/src/ops/add/musa/add_musa.mu new file mode 100644 index 00000000..0766aa7c --- /dev/null +++ b/src/ops/add/musa/add_musa.mu @@ -0,0 +1,116 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "add_musa.h" + +/** + * @brief A templated vector struct that supports element-wise addition on arrays. + * + * @tparam T - The access data type for elements in the vector. + * @tparam TComp - The computation data type used for arithmetic operations. + * @tparam N - The number of elements of type T in the vector for a single access. + */ +template +struct vecN { + T data[N]; + + __device__ __forceinline__ vecN operator+(const vecN &other) const { + vecN result; + + for (int i = 0; i < N; ++i) { + if constexpr (std::is_same::value) { + result.data[i] = data[i] + other.data[i]; + } else { + constexpr static size_t pack_size = sizeof(T) / sizeof(TComp); + auto data_ = reinterpret_cast *>(result.data); + data_[i] = std::move(reinterpret_cast const *>(data)[i] + + reinterpret_cast const *>(other.data)[i]); + } + } + + return result; + } + + __device__ __forceinline__ const T &operator[](size_t i) const { + return data[i]; + } +}; + +template +__global__ void add( + Tdata *c, + const Tdata *a, + const Tdata *b, + const int64_t *a_strides, + const int64_t *b_strides, + const int64_t *c_strides, + uint64_t data_size, + uint64_t ndim, + uint64_t offset, + bool broadcasted, + unsigned pack_size) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < data_size) { + if (broadcasted) { + idx *= pack_size; + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); +#pragma unroll + for (size_t i = 0; i < pack_size; ++i) { + auto a_idx = getDstOffset(idx + i, ndim, c_strides, a_strides); + auto b_idx = getDstOffset(idx + i, ndim, c_strides, b_strides); + c_[idx + i] = a_[a_idx] + b_[b_idx]; + } + return; + } + c[idx] = a[idx] + b[idx]; + } +} + +template +void _add_nv_gpu(AddMusaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) { + if (data_size == 0) { + return; + } + dim3 blockDims = dim3(std::min(static_cast(256), data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + musaStream_t musa_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < data_size; i += step) { + add<<>>( + c, a, b, desc->a_strides, desc->b_strides, desc->c_strides, offset + data_size, desc->ndim, offset + i, desc->broadcasted, pack_size); + } +} + +template +infiniopStatus_t add_mt_gpu(AddMusaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) { + const auto data_size = desc->c_data_size / pack_size; + const auto a_vec = reinterpret_cast(a); + const auto b_vec = reinterpret_cast(b); + const auto c_vec = reinterpret_cast(c); + _add_nv_gpu(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream); + + const auto remainder = desc->c_data_size % pack_size; + const auto a_ = reinterpret_cast(a); + const auto b_ = reinterpret_cast(b); + const auto c_ = reinterpret_cast(c); + _add_nv_gpu(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaAdd(AddMusaDescriptor_t desc, + void *c, void const *a, void const *b, + void *stream) { + checkMusaError(musaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return add_mt_gpu, half>(desc, c, a, b, stream, 8); + } + if (desc->dtype == F32) { + return add_mt_gpu, float>(desc, c, a, b, stream, 4); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/add/operator.cc b/src/ops/add/operator.cc index c2a30ea8..de97dc94 100644 --- a/src/ops/add/operator.cc +++ b/src/ops/add/operator.cc @@ -9,6 +9,9 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/add.cuh" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/add_musa.h" +#endif __C infiniopStatus_t infiniopCreateAddDescriptor( infiniopHandle_t handle, @@ -29,6 +32,11 @@ __C infiniopStatus_t infiniopCreateAddDescriptor( #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateAddDescriptor((MusaHandle_t) handle, (AddMusaDescriptor_t *) desc_ptr, c, a, b); + } #endif } return STATUS_BAD_DEVICE; @@ -48,6 +56,11 @@ __C infiniopStatus_t infiniopAdd(infiniopAddDescriptor_t desc, void *c, void con #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaAdd((AddMusaDescriptor_t) desc, c, a, b, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -67,6 +80,11 @@ __C infiniopStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyAddDescriptor((AddMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.cc b/src/ops/causal_softmax/musa/causal_softmax_musa.cc new file mode 100644 index 00000000..6ff55d65 --- /dev/null +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.cc @@ -0,0 +1,55 @@ +#include "causal_softmax_musa.h" +#include "../../utils.h" +#include "../../../devices/musa/common_musa.h" + +infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, + CausalSoftmaxMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y) { + uint64_t ndim = y->ndim; + // TODO: only support 2d or 3d tensor + if (ndim != 2 && ndim != 3) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(y->dt, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + uint64_t total_seq_len = y->shape[ndim - 1]; + uint64_t seq_len = y->shape[ndim - 2]; + uint64_t batch_size = 1; + uint64_t stride_b = 0; + uint64_t stride_i = y->strides[ndim - 2]; + uint64_t stride_j = y->strides[ndim - 1]; + if (stride_j != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + for (uint64_t i = 0; i < ndim - 2; i++) { + batch_size *= y->shape[i]; + } + if (ndim == 3) + stride_b = y->strides[ndim - 3]; + unsigned int max_items_per_thread = ROUND_UP_DIV(total_seq_len, MAX_THREADS_PER_BLOCK); + + *desc_ptr = new CausalSoftmaxMusaDescriptor{ + handle->device, + handle->device_id, + y->dt, + batch_size, + stride_b, + seq_len, + stride_i, + total_seq_len, + stride_j, + max_items_per_thread}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.h b/src/ops/causal_softmax/musa/causal_softmax_musa.h new file mode 100644 index 00000000..c6f81afc --- /dev/null +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.h @@ -0,0 +1,35 @@ +#ifndef __MUSA_CAUSAL_SOFTMAX_H__ +#define __MUSA_CAUSAL_SOFTMAX_H__ + +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +struct CausalSoftmaxMusaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t batch_size; + uint64_t stride_b; + uint64_t seq_len; + uint64_t stride_i; + uint64_t total_seq_len; + uint64_t stride_j; + uint64_t max_items_per_thread; +}; + +typedef struct CausalSoftmaxMusaDescriptor *CausalSoftmaxMusaDescriptor_t; + +infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, + CausalSoftmaxMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc); + +infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream); + +infiniopStatus_t musaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMusaDescriptor_t desc); +#endif diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.mu b/src/ops/causal_softmax/musa/causal_softmax_musa.mu new file mode 100644 index 00000000..5eb5c8d9 --- /dev/null +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.mu @@ -0,0 +1,262 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "causal_softmax_musa.h" +#include + +struct AttentionCausualMask { + __forceinline__ __device__ bool + operator()(int tok_id, int seq_len, + int pos_id, int total_seq_len) { + // tok_id ↓ |<-total_seq_len->| + // 0 | * * * ... * | + // 1 | * * * ... * * | + // 2 | * * * ... * * * | + // seq_len: 3 pos_id-> + return total_seq_len + tok_id >= pos_id + seq_len; + } +}; + +template +static __device__ void block_padding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const token_idx, + unsigned int const seq_len) { + auto att_idx = threadIdx.x, total_seq_len = blockDim.x; + auto thread_data = mask(token_idx, seq_len, att_idx, total_seq_len) + ? float(att[att_idx]) + : -__FLT_MAX__; + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + __shared__ float max; + { + auto acc = block_op.Reduce(thread_data, cub::Max(), total_seq_len); + if (threadIdx.x == 0) { max = acc; } + } + __syncthreads(); + + __shared__ float mean; + { + auto acc = block_op.Sum(thread_data = expf(thread_data - max), total_seq_len); + if (threadIdx.x == 0) { mean = fdividef(1, acc); } + } + __syncthreads(); + + att[att_idx] = Tdata(thread_data * mean); +} + +template +static __device__ void block_folding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const token_idx, + unsigned int const seq_len, + unsigned int const total_seq_len) { + + auto local = (total_seq_len + blockDim.x - 1) / blockDim.x; + + auto thread_offset = threadIdx.x * local; + att += thread_offset; + + float thread_data[ITEMS_PER_THREAD], thread_max = -__FLT_MAX__, thread_sum = 0; + for (unsigned int i = 0; i < local; ++i) { + auto att_idx = thread_offset + i; + thread_data[i] = att_idx < total_seq_len && mask(token_idx, seq_len, att_idx, total_seq_len) + ? float(att[i]) + : -__FLT_MAX__; + thread_max = cub::Max()(thread_max, thread_data[i]); + } + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + __shared__ float max; + { + auto acc = block_op.Reduce(thread_max, cub::Max()); + if (threadIdx.x == 0) { max = acc; } + } + __syncthreads(); + + __shared__ float mean; + { + for (unsigned int i = 0; i < local; ++i) { + thread_data[i] = expf(thread_data[i] - max); + thread_sum += thread_data[i]; + } + auto acc = block_op.Sum(thread_sum); + if (threadIdx.x == 0) { mean = fdividef(1, acc); } + } + __syncthreads(); + + for (unsigned int i = 0; i < local; ++i) { + if (auto att_idx = thread_offset + i; att_idx < total_seq_len) { + att[i] = Tdata(thread_data[i] * mean); + } + } +} + +// assert BLOCK_SIZE >= blockDim.x +template +static __forceinline__ __device__ void padding( + Tdata *__restrict__ att, + Tmask mask, + int const stride_x, + int const stride_y, + int const stride_z) { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + block_padding( + att + offset, mask, token_idx, seq_len); +} + +template +static __forceinline__ __device__ void folding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const total_seq_len, + int const stride_x, + int const stride_y, + int const stride_z) { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + block_folding( + att + offset, mask, token_idx, seq_len, total_seq_len); +} + +template +__global__ void fused_softmax_padding( + Tdata *__restrict__ att, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z) { + + padding(att, AttentionCausualMask(), stride_x, stride_y, stride_z); +} + +template +__global__ void fused_softmax_folding( + Tdata *__restrict__ att, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z, + unsigned int const total_seq_len) { + { + folding(att, AttentionCausualMask(), total_seq_len, stride_x, stride_y, stride_z); + } +} + +template +__global__ void fused_softmax_standard( + Tdata *__restrict__ att_, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z, + unsigned int const total_seq_len) { + { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + + auto att = att_ + offset; + auto att_idx = threadIdx.x; + + float partial; + __shared__ float max_; + __shared__ float sum_; + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + // Partial max + partial = -__FLT_MAX__; + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + partial = max(partial, float(att[i])); + } + } + __syncthreads(); + // Block reduce max + { + auto acc = block_op.Reduce(partial, cub::Max()); + if (threadIdx.x == 0) { max_ = acc; } + } + __syncthreads(); + + // Partial sum + partial = 0.; + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + float e = expf(float(att[i]) - max_); + partial += e; + } + } + __syncthreads(); + + // Block reduce sum + { + auto acc = block_op.Reduce(partial, cub::Sum()); + if (threadIdx.x == 0) { sum_ = acc; } + } + __syncthreads(); + + // Softmax + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + float e = expf(float(att[i]) - max_); + att[i] = e / sum_; + } else { + att[i] = half(0); + } + } + } +} + + +void causal_softmax_mt_gpu_f16(CausalSoftmaxMusaDescriptor_t desc, void* y, void *stream) { + uint64_t total_seq_len = desc->total_seq_len; + uint64_t seq_len = desc->seq_len; + uint64_t batch_size = desc->batch_size; + uint64_t stride_x = desc->stride_b; + uint64_t stride_y = desc->stride_i; + uint64_t stride_z = desc->stride_j;// covert byte strides to element strides + unsigned int max_items_per_thread = desc->max_items_per_thread; + + dim3 grid(batch_size, seq_len); + + if (max_items_per_thread == 1) { + fused_softmax_padding + <<>>((half *) (y), stride_x, stride_y, stride_z); + } else if (max_items_per_thread <= 16) { + fused_softmax_folding + <<>>((half *) (y), stride_x, stride_y, stride_z, total_seq_len); + } else { + fused_softmax_standard + <<>>((half *) (y), stride_x, stride_y, stride_z, total_seq_len); + } +} + +infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream) { + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + causal_softmax_mt_gpu_f16(desc, data, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/causal_softmax/operator.cc b/src/ops/causal_softmax/operator.cc index c9d87dda..92498dca 100644 --- a/src/ops/causal_softmax/operator.cc +++ b/src/ops/causal_softmax/operator.cc @@ -21,6 +21,10 @@ #ifdef ENABLE_METAX_GPU #include "maca/causal_softmax_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/causal_softmax_musa.h" +#include "../../devices/musa/common_musa.h" +#endif __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( infiniopHandle_t handle, @@ -52,6 +56,11 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( case DevMetaxGpu: { return macaCreateCausalSoftmaxDescriptor((MacaHandle_t) handle, (CausalSoftmaxMacaDescriptor_t *) desc_ptr, y_desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateCausalSoftmaxDescriptor((MusaHandle_t) handle, (CausalSoftmaxMusaDescriptor_t *) desc_ptr, y_desc); + } #endif } return STATUS_BAD_DEVICE; @@ -85,6 +94,11 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax case DevMetaxGpu: { return macaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -117,6 +131,11 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des case DevMetaxGpu: { return macaCausalSoftmax((CausalSoftmaxMacaDescriptor_t) desc, workspace, workspace_size, data, stream); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCausalSoftmax((CausalSoftmaxMusaDescriptor_t) desc, workspace, workspace_size, data, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -149,6 +168,10 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma case DevMetaxGpu: { return macaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMusaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/expand/musa/expand_musa.cc b/src/ops/expand/musa/expand_musa.cc new file mode 100644 index 00000000..0e2e4581 --- /dev/null +++ b/src/ops/expand/musa/expand_musa.cc @@ -0,0 +1,51 @@ +#include "expand_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateExpandDescriptor(MusaHandle_t handle, + ExpandMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (!isValidBroadcastShape(y, x)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t y_data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for x in terms of y + int64_t *x_strides = new int64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + } + + int64_t *x_strides_d, *y_strides_d; + char *strides_and_shape_d; + checkMusaErrorWithCode(musaMalloc(&strides_and_shape_d, ndim * (2 * sizeof(int64_t) + sizeof(uint64_t))), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d, x_strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d + ndim * sizeof(int64_t), y->strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d + 2 * ndim * sizeof(int64_t), y->shape, ndim * sizeof(uint64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new ExpandMusaDescriptor{ + DevMthreadsGpu, + y->dt, + handle->device_id, + ndim, + y_data_size, + static_cast(handle->prop.maxGridSize[0]), + strides_and_shape_d, + }; + + delete[] x_strides; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyExpandDescriptor(ExpandMusaDescriptor_t desc) { + checkMusaErrorWithCode(musaFree((void *) desc->strides_and_shape_d), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/expand/musa/expand_musa.h b/src/ops/expand/musa/expand_musa.h new file mode 100644 index 00000000..8e4651e1 --- /dev/null +++ b/src/ops/expand/musa/expand_musa.h @@ -0,0 +1,33 @@ +#ifndef __MUSA_EXPAND_H__ +#define __MUSA_EXPAND_H__ + +#include "../../../devices/musa/common_musa.h" +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" +#include +#include + +struct ExpandMusaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t y_data_size; + uint64_t max_grid_size; + char const *strides_and_shape_d; +}; + +typedef struct ExpandMusaDescriptor *ExpandMusaDescriptor_t; + +infiniopStatus_t musaCreateExpandDescriptor(MusaHandle_t, + ExpandMusaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t musaExpand(ExpandMusaDescriptor_t desc, + void *y, void const *x, + void *stream); + +infiniopStatus_t musaDestroyExpandDescriptor(ExpandMusaDescriptor_t desc); + +#endif diff --git a/src/ops/expand/musa/expand_musa.mu b/src/ops/expand/musa/expand_musa.mu new file mode 100644 index 00000000..4b549541 --- /dev/null +++ b/src/ops/expand/musa/expand_musa.mu @@ -0,0 +1,58 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "expand_musa.h" + +template +__global__ void expand( + Tdata *y, + const Tdata *x, + const int64_t *y_strides, + const int64_t *x_strides, + const uint64_t *y_shape, + uint64_t y_data_size, + uint64_t ndim, + uint64_t offset) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < y_data_size) { + uint64_t y_idx = getOffset(idx, ndim, y_shape, y_strides); + y[y_idx] = x[getDstOffset(y_idx, ndim, y_strides, x_strides)]; + } +} + +template +infiniopStatus_t expand_mt_gpu(ExpandMusaDescriptor_t desc, void *y, void const *x, void *stream) { + if (desc->y_data_size == 0) { + return STATUS_SUCCESS; + } + dim3 blockDims = dim3(std::min(static_cast(256), desc->y_data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(desc->y_data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + const auto x_ = reinterpret_cast(x); + const auto y_ = reinterpret_cast(y); + const auto x_strides = reinterpret_cast(desc->strides_and_shape_d); + const auto y_strides = reinterpret_cast(desc->strides_and_shape_d + desc->ndim * sizeof(int64_t)); + const auto y_shape = reinterpret_cast(desc->strides_and_shape_d + 2 * desc->ndim * sizeof(int64_t)); + musaStream_t musa_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < desc->y_data_size; i += step) { + expand<<>>( + y_, x_, y_strides, x_strides, y_shape, i + desc->y_data_size, desc->ndim, i); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t musaExpand(ExpandMusaDescriptor_t desc, + void *y, void const *x, + void *stream) { + checkMusaError(musaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return expand_mt_gpu(desc, y, x, stream); + } + if (desc->dtype == F32) { + return expand_mt_gpu(desc, y, x, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/expand/operator.cc b/src/ops/expand/operator.cc index 0572acd0..b0374645 100644 --- a/src/ops/expand/operator.cc +++ b/src/ops/expand/operator.cc @@ -9,6 +9,10 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/expand.cuh" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/expand_musa.h" +#endif + __C infiniopStatus_t infiniopCreateExpandDescriptor( infiniopHandle_t handle, @@ -28,6 +32,11 @@ __C infiniopStatus_t infiniopCreateExpandDescriptor( #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateExpandDescriptor((MusaHandle_t) handle, (ExpandMusaDescriptor_t *) desc_ptr, y, x); + } #endif } return STATUS_BAD_DEVICE; @@ -47,6 +56,11 @@ __C infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, void *y, vo #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaExpand((ExpandMusaDescriptor_t) desc, y, x, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -66,6 +80,11 @@ __C infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyExpandDescriptor((ExpandMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/matmul/musa/matmul_musa.cc b/src/ops/matmul/musa/matmul_musa.cc new file mode 100644 index 00000000..3256dca6 --- /dev/null +++ b/src/ops/matmul/musa/matmul_musa.cc @@ -0,0 +1,48 @@ +#include "matmul_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include +#include + +#include + +infiniopStatus_t musaCreateMatmulDescriptor(MusaHandle_t handle, + MatmulMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + float alpha, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + float beta) { + DT dtype = c_desc->dt; + + if (dtype != F16 && dtype != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + + infiniopStatus_t *status = new infiniopStatus_t{STATUS_EXECUTION_FAILED}; + auto info = MatmulInfo(c_desc, a_desc, b_desc, status); + if (*status != STATUS_SUCCESS) { + return *status; + } + + *desc_ptr = new MatmulMusaDescriptor{ + DevMthreadsGpu, + dtype, + handle->device_id, + info, + alpha, + beta, + handle->mublas_handles_t}; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetMatmulWorkspaceSize(MatmulMusaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyMatmulDescriptor(MatmulMusaDescriptor_t desc) { + desc->mublas_handles_t = nullptr; + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/matmul/musa/matmul_musa.h b/src/ops/matmul/musa/matmul_musa.h new file mode 100644 index 00000000..b086a494 --- /dev/null +++ b/src/ops/matmul/musa/matmul_musa.h @@ -0,0 +1,45 @@ +#ifndef __MUSA_MATMUL_H__ +#define __MUSA_MATMUL_H__ + +#include +#include +#include +#include +#include +#include "../blas.h" +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +typedef struct MatmulMusaDescriptor { + Device device; + DT dtype; + int device_id; + MatmulInfo info; + float alpha; + float beta; + std::shared_ptr> mublas_handles_t; +} MatmulMusaDescriptor; + +typedef struct MatmulMusaDescriptor *MatmulMusaDescriptor_t; + +infiniopStatus_t musaCreateMatmulDescriptor(MusaHandle_t handle, + MatmulMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + float alpha, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + float beta); + +infiniopStatus_t musaGetMatmulWorkspaceSize(MatmulMusaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t musaMatmul(MatmulMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream); + +infiniopStatus_t musaDestroyMatmulDescriptor(MatmulMusaDescriptor_t desc); + +#endif // __MUSA_MATMUL_H__ diff --git a/src/ops/matmul/musa/matmul_musa.mu b/src/ops/matmul/musa/matmul_musa.mu new file mode 100644 index 00000000..b445a7b3 --- /dev/null +++ b/src/ops/matmul/musa/matmul_musa.mu @@ -0,0 +1,77 @@ +#include "../../../devices/musa/musa_handle.h" +#include "../../utils.h" +#include "../blas.h" +#include "matmul_musa.h" +#include +#include + +template +infiniopStatus_t matmul_musa(MatmulMusaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { + auto info = desc->info; + + if (info.is_transed) { + std::swap(a, b); + } + + Tdata alpha_, beta_; + musaDataType_t a_type, b_type, c_type; + mublasComputeType_t compute_type; + + if constexpr (std::is_same::value) { + alpha_ = __float2half(alpha); + beta_ = __float2half(beta); + a_type = b_type = c_type = MUSA_R_16F; + compute_type = MUBLAS_COMPUTE_16F; + } else { + alpha_ = alpha; + beta_ = beta; + a_type = b_type = c_type = MUSA_R_32F; + compute_type = MUBLAS_COMPUTE_32F_FAST_TF32; + } + + auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; + auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; + + use_mublas(desc->mublas_handles_t, desc->device_id, (MUstream) stream, + [&](mublasHandle_t handle) { mublasGemmStridedBatchedEx( + handle, + op_a, + op_b, + info.m, + info.n, + info.k, + &alpha_, + a, + a_type, + info.a_matrix.ld(), + info.a_matrix.stride, + b, + b_type, + info.b_matrix.ld(), + info.b_matrix.stride, + &beta_, + c, + c_type, + info.c_matrix.ld(), + info.c_matrix.stride, + info.batch, + compute_type, + MUBLAS_GEMM_DEFAULT);}); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaMatmul(MatmulMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream) { + if (desc->dtype == F16) { + return matmul_musa(desc, c, desc->beta, a, b, desc->alpha, stream); + } + if (desc->dtype == F32) { + return matmul_musa(desc, c, desc->beta, a, b, desc->alpha, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/matmul/operator.cc b/src/ops/matmul/operator.cc index 14748b99..5fa766eb 100644 --- a/src/ops/matmul/operator.cc +++ b/src/ops/matmul/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/matmul_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/matmul_musa.h" +#endif __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, infiniopMatmulDescriptor_t *desc_ptr, @@ -56,6 +59,11 @@ __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, case DevMetaxGpu: { return macaCreateMatmulDescriptor((MacaHandle_t) handle, (MatmulMacaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateMatmulDescriptor((MusaHandle_t) handle, (MatmulMusaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); + } #endif } return STATUS_BAD_DEVICE; @@ -88,6 +96,11 @@ __C infiniopStatus_t infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t d case DevMetaxGpu: { return macaGetMatmulWorkspaceSize((MatmulMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetMatmulWorkspaceSize((MatmulMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -122,6 +135,11 @@ __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc, void *works case DevMetaxGpu: { return macaMatmul((MatmulMacaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaMatmul((MatmulMusaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -153,6 +171,11 @@ __C infiniopStatus_t infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t case DevMetaxGpu: { return macaDestroyMatmulDescriptor((MatmulMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyMatmulDescriptor((MatmulMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/random_sample/musa/random_sample_musa.cc b/src/ops/random_sample/musa/random_sample_musa.cc new file mode 100644 index 00000000..70ff941c --- /dev/null +++ b/src/ops/random_sample/musa/random_sample_musa.cc @@ -0,0 +1,37 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "random_sample_musa.h" + +infiniopStatus_t musaCreateRandomSampleDescriptor(MusaHandle_t handle, + RandomSampleMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + if (probs->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(result->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + int voc = probs->shape[0]; + int rLength = result->shape[0]; + if (result->ndim != 1 && rLength != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + *desc_ptr = new RandomSampleMusaDescriptor{ + handle->device, + handle->device_id, + probs->dt, + voc, + result->dt, + rLength}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, uint64_t *size) { + *size = desc->voc * (2 * sizeof(uint64_t) + sizeof(desc->dtype)); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyRandomSampleDescriptor(RandomSampleMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/random_sample/musa/random_sample_musa.h b/src/ops/random_sample/musa/random_sample_musa.h new file mode 100644 index 00000000..d8839ff1 --- /dev/null +++ b/src/ops/random_sample/musa/random_sample_musa.h @@ -0,0 +1,38 @@ +#ifndef __MUSA_RANDOM_SAMPLE_H__ +#define __MUSA_RANDOM_SAMPLE_H__ + +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" + +struct RandomSampleMusaDescriptor { + Device device; + int device_id; + DT dtype; + int voc; + DT rDtype; + int rLength; +}; + +typedef struct RandomSampleMusaDescriptor *RandomSampleMusaDescriptor_t; + +infiniopStatus_t musaCreateRandomSampleDescriptor(MusaHandle_t handle, + RandomSampleMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t musaRandomSample(RandomSampleMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream); + +infiniopStatus_t musaDestroyRandomSampleDescriptor(RandomSampleMusaDescriptor_t desc); + + +#endif diff --git a/src/ops/random_sample/musa/random_sample_musa.mu b/src/ops/random_sample/musa/random_sample_musa.mu new file mode 100644 index 00000000..55dbdd0a --- /dev/null +++ b/src/ops/random_sample/musa/random_sample_musa.mu @@ -0,0 +1,184 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "random_sample_musa.h" +#include +#include + +template +__global__ void softmax( + T *val_out, + int topk, + float temperature, int voc) { + float sum_s = 0.0f; + for (int i = threadIdx.x; i < topk; i += BLOCK_DIM) { + sum_s += __expf(static_cast(val_out[i] - val_out[0]) / temperature); + } + __shared__ float sum_inverse_total; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float block_sum = BlockReduce(temp_storage).Reduce(sum_s, cub::Sum()); + if (threadIdx.x == 0) { + sum_inverse_total = __fdividef(1.0F, block_sum);//高精度除法 + } + + __syncthreads(); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < topk) { + val_out[tid] = static_cast(__expf(static_cast(val_out[tid] - val_out[0]) / temperature) * sum_inverse_total); + } +} + +__global__ void index(uint64_t *key_in, int voc) { + int ind = threadIdx.x + blockIdx.x * blockDim.x; + if (ind < voc) { + key_in[ind] = static_cast(ind); + } +} +template +__global__ void random_sample_kernel(uint64_t *result, + T *val_out, + float random_val, + float topp, + int topk, + uint64_t *key_out) { + int end = 0; + for (end = 0; end < topk; end++) { + if (val_out[end] >= static_cast(topp)) { + break; + } + } + if (end < topk - 1) { + end += 1; + } else { + end = topk; + } + + random_val *= static_cast(val_out[end - 1]); + for (int i = 0; i < end; i++) { + if (random_val < static_cast(val_out[i])) { + result[0] = key_out[i]; + break; + } + } +} +template +void sort_pairs_descending( + void *workspace, size_t &size_radix_sort, + T const *val_in, T *val_out, + I *key_in, I *key_out, + int voc, musaStream_t stream) { + cub::DeviceRadixSort::SortPairsDescending( + workspace, size_radix_sort, + val_in, val_out, + key_in, key_out, + voc, 0, sizeof(T) * 8, stream); +} +template +void inclusive_sum( + void *workspace, size_t &size_scan, + T *data, int voc, + musaStream_t stream) { + cub::DeviceScan::InclusiveSum( + workspace, size_scan, + data, data, voc, + stream); +} +template +void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan, + int voc, musaStream_t stream) { + + + sort_pairs_descending(nullptr, size_radix_sort, + nullptr, nullptr, + nullptr, nullptr, + voc, stream); + + inclusive_sum( + nullptr, size_scan, + nullptr, voc, + stream); +} +__global__ void random_sample_kernel(uint64_t *result, + uint64_t *key_out) { + result[0] = key_out[0]; +} +void random_sample_nv_gpu_f16(RandomSampleMusaDescriptor_t desc, void *workspace, void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { + int voc = desc->voc; + //下面这段代码在排序 + char *origin = reinterpret_cast(workspace); + char *keyTmp = origin + voc * sizeof(half); + half *val_out = (half *) origin; + + uint64_t *key_in = (uint64_t *) keyTmp; + uint64_t *key_out = key_in + voc; + + index<<<(voc + 1023) / 1024, 1024, 0, (musaStream_t) stream>>>(key_in, voc); + //下面开始计算workspace空间 + size_t size_radix_sort; + size_t size_scan; + random_sample_workspace(size_radix_sort, size_scan, + voc, (musaStream_t) stream); + void *workspace_extra; + musaMalloc(&workspace_extra, size_radix_sort + size_scan); + sort_pairs_descending( + workspace_extra, size_radix_sort, + (half *) probs, val_out, + key_in, key_out, + voc, (musaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 + //排序结束,然后开始做softmax变换 + if (topp > 0 && topk > 1) { + int BLOCK_DIM = 1024; + int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; + softmax<<>>(val_out, topk, + temperature, voc); + + + inclusive_sum( + workspace_extra, size_scan, + val_out, voc, + (musaStream_t) stream);//该函数会实现scan功能不断累加结果 + random_sample_kernel<<<1, 1, 0, (musaStream_t) stream>>>((uint64_t *) result, + val_out, + random_val, + topp, + topk, + key_out); + + } else { + random_sample_kernel<<<1, 1, 0, (musaStream_t) stream>>>((uint64_t *) result, + key_out); + } + musaFree(workspace_extra); +} + +infiniopStatus_t musaRandomSample(RandomSampleMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + random_sample_nv_gpu_f16(desc, workspace, result, probs, random_val, topp, topk, temperature, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/random_sample/operator.cc b/src/ops/random_sample/operator.cc index b9cf3ded..40a8ec03 100644 --- a/src/ops/random_sample/operator.cc +++ b/src/ops/random_sample/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/random_sample_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/random_sample_musa.h" +#endif __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handle, infiniopRandomSampleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs) { switch (handle->device) { @@ -47,6 +50,10 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl (RandomSampleMacaDescriptor_t *) desc_ptr, result, probs); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaCreateRandomSampleDescriptor((MusaHandle_t) handle, (RandomSampleMusaDescriptor_t *) desc_ptr, result, probs); #endif } return STATUS_BAD_DEVICE; @@ -79,6 +86,11 @@ __C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDe case DevMetaxGpu: { return macaGetRandomSampleWorkspaceSize((RandomSampleMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetRandomSampleWorkspaceSize((RandomSampleMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -117,6 +129,10 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc, case DevMetaxGpu: { return macaRandomSample((RandomSampleMacaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaRandomSample((RandomSampleMusaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); #endif } return STATUS_BAD_DEVICE; @@ -146,6 +162,10 @@ __C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleD case DevMetaxGpu: { return macaDestroyRandomSampleDescriptor((RandomSampleMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaDestroyRandomSampleDescriptor((RandomSampleMusaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rearrange/musa/rearrange_musa.cc b/src/ops/rearrange/musa/rearrange_musa.cc new file mode 100644 index 00000000..5fa2e768 --- /dev/null +++ b/src/ops/rearrange/musa/rearrange_musa.cc @@ -0,0 +1,70 @@ +#include "rearrange_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include + +infiniopStatus_t musaCreateRearrangeDescriptor(MusaHandle_t handle, + RearrangeMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src) { + auto dt = dst->dt; + if (!dtype_eq(src->dt, dt)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + auto ndim = dst->ndim; + if (src->ndim != ndim || ndim == 0) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (int i = 0; i < ndim; ++i) { + if (dst->shape[i] != src->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + + switch (ndim) { + case 1: + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[0], + 1, 1, + 0, 0, + 0, 0}; + break; + case 2: + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[1], + 1, dst->shape[0], + 0, dst->strides[0], + 0, src->strides[0]}; + break; + case 3: + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[2], + dst->shape[0], dst->shape[1], + dst->strides[0], dst->strides[1], + src->strides[0], src->strides[1]}; + break; + default: + return STATUS_BAD_TENSOR_SHAPE; + } + + (*desc_ptr)->dst_rs *= dt.size; + (*desc_ptr)->dst_cs *= dt.size; + (*desc_ptr)->src_rs *= dt.size; + (*desc_ptr)->src_cs *= dt.size; + + return STATUS_SUCCESS; +} +infiniopStatus_t musaDestroyRearrangeDescriptor(RearrangeMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rearrange/musa/rearrange_musa.h b/src/ops/rearrange/musa/rearrange_musa.h new file mode 100644 index 00000000..df6ade12 --- /dev/null +++ b/src/ops/rearrange/musa/rearrange_musa.h @@ -0,0 +1,30 @@ +#ifndef __MUSA_REARRANGE_H__ +#define __MUSA_REARRANGE_H__ + +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +struct RearrangeMusaDescriptor { + Device device; + int device_id; + uint64_t unit, r, c; + int64_t dst_rs, dst_cs, src_rs, src_cs; +}; + +typedef struct RearrangeMusaDescriptor *RearrangeMusaDescriptor_t; + +infiniopStatus_t musaCreateRearrangeDescriptor(MusaHandle_t handle, + RearrangeMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src); + +infiniopStatus_t musaRearrange(RearrangeMusaDescriptor_t desc, + void *dst, + void const *src, + void *stream); + +infiniopStatus_t musaDestroyRearrangeDescriptor(RearrangeMusaDescriptor_t desc); + +void rearrange_mt_gpu(RearrangeMusaDescriptor *, void *y, void const *x, void *stream); +#endif // __MUSA_REARRANGE_H__ + diff --git a/src/ops/rearrange/musa/rearrange_musa.mu b/src/ops/rearrange/musa/rearrange_musa.mu new file mode 100644 index 00000000..887923b3 --- /dev/null +++ b/src/ops/rearrange/musa/rearrange_musa.mu @@ -0,0 +1,81 @@ +#include "../../../devices/musa/common_musa.h" +#include "rearrange_musa.h" + +template +static __global__ void rearrange( + void *__restrict__ dst, + int const rsa, + int const csa, + void const *__restrict__ src, + int const rsb, + int const csb, + unsigned int const ncols) { + + auto row = blockIdx.y, + col = blockIdx.x * blockDim.y + threadIdx.y; + if (col >= ncols) return; + + auto thread = threadIdx.x, + warp_size = blockDim.x; + auto i = (row * rsa + col * csa) * warp_size + thread; + auto j = (row * rsb + col * csb) * warp_size + thread; + + reinterpret_cast(dst)[i] = reinterpret_cast(src)[j]; +} + + +void rearrange_mt_gpu(RearrangeMusaDescriptor_t desc, void *y, void const *x, void *stream) { + auto musa_stream = reinterpret_cast(stream); + auto unit = desc->unit, + r = desc->r, c = desc->c; + auto dst_rs = desc->dst_rs, dst_cs = desc->dst_cs, + src_rs = desc->src_rs, src_cs = desc->src_cs; + + if (r == 1 && c == 1) { + musaMemcpyAsync(y, x, unit, musaMemcpyDeviceToDevice, musa_stream); + return; + } + + auto warps = 1024 / WARP_SIZE; + auto grid = dim3((c + warps - 1) / warps, r); + auto block = dim3(WARP_SIZE, (c + grid.x - 1) / grid.x); + dst_rs /= unit; + dst_cs /= unit; + src_rs /= unit; + src_cs /= unit; + + switch (unit / WARP_SIZE) { + case 1: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 2: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 4: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 8: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 16: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 32: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + default: + break; + } +} +infiniopStatus_t musaRearrange(RearrangeMusaDescriptor_t desc, + void *dst, void const *src, void *stream) { + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + rearrange_mt_gpu(desc, dst, src, stream); + return STATUS_SUCCESS; +} diff --git a/src/ops/rearrange/operator.cc b/src/ops/rearrange/operator.cc index 752211e5..4a922dc7 100644 --- a/src/ops/rearrange/operator.cc +++ b/src/ops/rearrange/operator.cc @@ -20,6 +20,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/rearrange_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/rearrange_musa.h" +#endif __C infiniopStatus_t infiniopCreateRearrangeDescriptor( infiniopHandle_t handle, @@ -54,6 +57,11 @@ __C infiniopStatus_t infiniopCreateRearrangeDescriptor( case DevMetaxGpu: { return macaCreateRearrangeDescriptor((MacaHandle_t) handle, (RearrangeMacaDescriptor_t *) desc_ptr, dst, src); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateRearrangeDescriptor((MusaHandle_t)handle, (RearrangeMusaDescriptor_t *) desc_ptr, dst, src); + } #endif } return STATUS_BAD_DEVICE; @@ -88,6 +96,11 @@ __C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void case DevMetaxGpu: { return macaRearrange((RearrangeMacaDescriptor_t) desc, dst, src, stream); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaRearrange((RearrangeMusaDescriptor_t) desc, dst, src, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -119,6 +132,11 @@ __C infiniopStatus_t infiniopDestroyRearrangeDescriptor(infiniopRearrangeDescrip case DevMetaxGpu: { return macaDestroyRearrangeDescriptor((RearrangeMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyRearrangeDescriptor((RearrangeMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/relu/musa/relu_musa.cc b/src/ops/relu/musa/relu_musa.cc new file mode 100644 index 00000000..6baaef18 --- /dev/null +++ b/src/ops/relu/musa/relu_musa.cc @@ -0,0 +1,45 @@ +#include "relu_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateReluDescriptor(MusaHandle_t handle, + ReluMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (ndim != x->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (size_t i = 0; i < ndim; ++i) { + if (y->shape[i] != x->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!is_contiguous(y) || !is_contiguous(x)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + *desc_ptr = new ReluMusaDescriptor{ + DevMthreadsGpu, + y->dt, + handle->device_id, + ndim, + data_size, + static_cast(handle->prop.maxGridSize[0]), + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyReluDescriptor(ReluMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/relu/musa/relu_musa.h b/src/ops/relu/musa/relu_musa.h new file mode 100644 index 00000000..84276369 --- /dev/null +++ b/src/ops/relu/musa/relu_musa.h @@ -0,0 +1,32 @@ +#ifndef __MUSA_RELU_H__ +#define __MUSA_RELU_H__ + +#include "../../../devices/musa/common_musa.h" +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" +#include +#include + +struct ReluMusaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t data_size; + uint64_t max_grid_size; +}; + +typedef struct ReluMusaDescriptor *ReluMusaDescriptor_t; + +infiniopStatus_t musaCreateReluDescriptor(MusaHandle_t, + ReluMusaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t musaRelu(ReluMusaDescriptor_t desc, + void *y, void const *x, + void *stream); + +infiniopStatus_t musaDestroyReluDescriptor(ReluMusaDescriptor_t desc); + +#endif diff --git a/src/ops/relu/musa/relu_musa.mu b/src/ops/relu/musa/relu_musa.mu new file mode 100644 index 00000000..3d91b4e2 --- /dev/null +++ b/src/ops/relu/musa/relu_musa.mu @@ -0,0 +1,111 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "relu_musa.h" + +/** + * @brief A templated vector struct that supports applying relu on arrays. + * + * @tparam T - The access data type for elements in the vector. + * @tparam TComp - The computation data type used for arithmetic operations. sizeof(T) should + * be >= sizeof(TComp) + * @tparam N - The number of elements of type T in the vector for a single access. + */ +template +struct vecN { + T data[N]; + constexpr static size_t pack_size = sizeof(T) / sizeof(TComp); + + // Constructor that initializes the data array with type TComp + __device__ __forceinline__ constexpr vecN(const TComp &val) { + const auto data_ = reinterpret_cast(data); + const auto size = N * pack_size; +#pragma unroll + for (size_t i = 0; i < size; ++i) { + data_[i] = 0; + } + } + + // Assignment operator with relu assignment logic + __device__ __forceinline__ vecN &operator=(const vecN &other) { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < N; ++i) { + data[i] = other.data[i] < TComp(0) ? TComp(0) : other.data[i]; + } + } else { + auto *data_this = reinterpret_cast *>(data); + auto *data_other = reinterpret_cast *>(other.data); +#pragma unroll + for (int i = 0; i < N; ++i) { + data_this[i] = data_other[i]; + } + } + return *this; + } + + // Always returns false since the actual relu logic is in the assignment process + __device__ __forceinline__ bool operator<(const vecN &other) const { + return false; + } + + __device__ __forceinline__ const T &operator[](size_t i) const { + return data[i]; + } +}; + +template +__global__ void relu( + Tdata *y, + const Tdata *x, + uint64_t data_size, + uint64_t offset) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < data_size) { + y[idx] = x[idx] < Tdata(0) ? Tdata(0) : x[idx]; + } +} + +template +void relu_mt_gpu(ReluMusaDescriptor_t desc, Tdata *y, Tdata const *x, uint64_t data_size, uint64_t offset, void *stream) { + if (data_size == 0) { + return; + } + dim3 blockDims = dim3(std::min(static_cast(256), data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + musaStream_t musa_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < data_size; i += step) { + relu<<>>(y, x, offset + data_size, offset + i); + } +} + +template +infiniopStatus_t relu_mt_gpu(ReluMusaDescriptor_t desc, void *y, void const *x, void *stream, uint64_t pack_size) { + const auto data_size = desc->data_size / pack_size; + const auto x_vec = reinterpret_cast(x); + const auto y_vec = reinterpret_cast(y); + relu_mt_gpu(desc, y_vec, x_vec, data_size, 0, stream); + + const auto remainder = desc->data_size % pack_size; + const auto x_ = reinterpret_cast(x); + const auto y_ = reinterpret_cast(y); + relu_mt_gpu(desc, y_, x_, remainder, data_size * pack_size, stream); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaRelu(ReluMusaDescriptor_t desc, + void *y, void const *x, + void *stream) { + checkMusaError(musaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return relu_mt_gpu, half>(desc, y, x, stream, 4); + } + if (desc->dtype == F32) { + return relu_mt_gpu, float>(desc, y, x, stream, 4); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/relu/operator.cc b/src/ops/relu/operator.cc index 89122915..7a3a2e2f 100644 --- a/src/ops/relu/operator.cc +++ b/src/ops/relu/operator.cc @@ -9,6 +9,10 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/relu.cuh" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/relu_musa.h" +#endif + __C infiniopStatus_t infiniopCreateReluDescriptor( infiniopHandle_t handle, @@ -28,6 +32,11 @@ __C infiniopStatus_t infiniopCreateReluDescriptor( #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateReluDescriptor((MusaHandle_t) handle, (ReluMusaDescriptor_t *) desc_ptr, y, x); + } #endif } return STATUS_BAD_DEVICE; @@ -47,6 +56,11 @@ __C infiniopStatus_t infiniopRelu(infiniopReluDescriptor_t desc, void *y, void c #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaRelu((ReluMusaDescriptor_t) desc, y, x, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -66,6 +80,11 @@ __C infiniopStatus_t infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyReluDescriptor((ReluMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rms_norm/musa/rms_norm_musa.cc b/src/ops/rms_norm/musa/rms_norm_musa.cc new file mode 100644 index 00000000..99c22c6e --- /dev/null +++ b/src/ops/rms_norm/musa/rms_norm_musa.cc @@ -0,0 +1,46 @@ +#include "rms_norm_musa.h" +#include "../../utils.h" +#include "../../../devices/musa/common_musa.h" + +infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, RMSNormMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon) { + if (y_desc->ndim != 2 || x_desc->ndim != 2 || w_desc->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + + auto n = y_desc->shape[0], + d = y_desc->shape[1]; + + if (x_desc->shape[0] != n || x_desc->shape[1] != d || w_desc->shape[0] != d) { + return STATUS_BAD_TENSOR_SHAPE; + } + + uint64_t stride_y = y_desc->strides[0]; + uint64_t stride_x = x_desc->strides[0]; + auto w_datatype = w_desc->dt; + *desc_ptr = new RMSNormMusaDescriptor{ + handle->device, + handle->device_id, + y_desc->dt, + n, + d, + stride_y, + stride_x, + w_datatype, + epsilon}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyRMSNormDescriptor(RMSNormMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rms_norm/musa/rms_norm_musa.h b/src/ops/rms_norm/musa/rms_norm_musa.h new file mode 100644 index 00000000..ee8dfb72 --- /dev/null +++ b/src/ops/rms_norm/musa/rms_norm_musa.h @@ -0,0 +1,40 @@ +#ifndef __MUSA_RMS_NORM_H__ +#define __MUSA_RMS_NORM_H__ + +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +struct RMSNormMusaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t n; + uint64_t d; + uint64_t stride_y; + uint64_t stride_x; + DT w_datatype; + float epsilon; +}; + +typedef struct RMSNormMusaDescriptor *RMSNormMusaDescriptor_t; + +infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, + RMSNormMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon); + +infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream); + +infiniopStatus_t musaDestroyRMSNormDescriptor(RMSNormMusaDescriptor_t desc); + +void rms_norm_mt_gpu_f16(RMSNormMusaDescriptor_t desc, void *y, void const *x, void const *w, float epsilon, void *stream); + +#endif// __MT_GPU_RMS_NORM_H__ diff --git a/src/ops/rms_norm/musa/rms_norm_musa.mu b/src/ops/rms_norm/musa/rms_norm_musa.mu new file mode 100644 index 00000000..d80bdac9 --- /dev/null +++ b/src/ops/rms_norm/musa/rms_norm_musa.mu @@ -0,0 +1,177 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "rms_norm_musa.h" +#include +#include + +// assert BLOCK_SIZE >= blockDim.x +template +static __global__ void rms_norm_padding( + Tdata *__restrict__ o_, + unsigned int const stride_y, + Tdata const *__restrict__ x_, + unsigned int const stride_x, + Wdata const *__restrict__ w_, + float const epsilon) { + auto y = o_ + blockIdx.x * stride_y + threadIdx.x; + auto x = x_[blockIdx.x * stride_x + threadIdx.x]; + auto w = w_[threadIdx.x]; + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum()); + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + rms = Tdata(rsqrtf(acc / float(blockDim.x) + epsilon)); + } + __syncthreads(); + + *y = rms * x * (Tdata)w; +} + +template +static __global__ void rms_norm_folding( + Tdata *__restrict__ y, + unsigned int const stride_y, + Tdata const *__restrict__ x, + unsigned int const stride_x, + Wdata const *__restrict__ w, + float const epsilon, + unsigned int const items_size) { + y += blockIdx.x * stride_y; + x += blockIdx.x * stride_x; + + float thread_data[ITEMS_PER_THREAD]; + { + using BlockOp = cub::BlockLoad; + __shared__ typename BlockOp::TempStorage temp_storage; + BlockOp(temp_storage).Load(x, thread_data, items_size, 0.f); + } + + float squared[ITEMS_PER_THREAD]; +#pragma unroll + for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) { + squared[i] = thread_data[i] * thread_data[i]; + } + + float acc; + { + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + acc = BlockOp(temp_storage).Reduce(squared, cub::Sum()); + } + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + rms = Tdata(rsqrtf(acc / float(items_size) + epsilon)); + } + __syncthreads(); + +#pragma unroll + for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) { + if (auto j = i + threadIdx.x * ITEMS_PER_THREAD; j < items_size) { + y[j] = Tdata(float(rms) * float(thread_data[i]) * float(w[j])); + } + } +} + +template +static __global__ void rms_norm_standard( + Tdata *__restrict__ y_, + unsigned int const stride_y, + Tdata const *__restrict__ x_, + unsigned int const stride_x, + Wdata const *__restrict__ w, + float const epsilon, + unsigned int const d) { + auto y = y_ + blockIdx.x * stride_y; + auto x = x_ + blockIdx.x * stride_x; + + __shared__ float partial_sum[BLOCK_SIZE]; + + float sum = 0.0f; + for (int i = threadIdx.x; i < d; i += BLOCK_SIZE) { + sum += float(x[i]) * float(x[i]); + } + + partial_sum[threadIdx.x] = sum; + __syncthreads(); + for (int stride = BLOCK_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + stride]; + } + __syncthreads(); + } + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + float row_sum = partial_sum[0]; + rms = Tdata(rsqrtf(row_sum / float(d) + epsilon)); + } + __syncthreads(); + + for (int i = threadIdx.x; i < d; i += BLOCK_SIZE) { + y[i] = rms * x[i] * (Tdata)w[i]; + } +} + +void rms_norm_mt_gpu_f16(RMSNormMusaDescriptor_t desc, void *y, void const *x, void const *w, void *stream) { + auto n = desc->n, d = desc->d; + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto epsilon = desc->epsilon; + + // Get strides in terms of elements + auto stride_y = desc->stride_y; + auto stride_x = desc->stride_x; + + auto musa_stream = reinterpret_cast(stream); + unsigned int items_per_thread = ROUND_UP_DIV(d, MAX_THREADS_PER_BLOCK); + auto w_datatype = desc->w_datatype; + if (dtype_eq(w_datatype, F16)) { + auto w_ = reinterpret_cast(w); + if (items_per_thread == 1) { + rms_norm_padding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon); + } else if (items_per_thread <= 16) { + rms_norm_folding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } else { + rms_norm_standard + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } + } else { + auto w_ = reinterpret_cast(w); + if (items_per_thread == 1) { + rms_norm_padding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon); + } else if (items_per_thread <= 16) { + rms_norm_folding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } else { + rms_norm_standard + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } + } +} + +infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream){ + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)){ + rms_norm_mt_gpu_f16(desc, y, x, w, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/rms_norm/operator.cc b/src/ops/rms_norm/operator.cc index dff9573b..317e7ef2 100644 --- a/src/ops/rms_norm/operator.cc +++ b/src/ops/rms_norm/operator.cc @@ -20,6 +20,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/rms_norm_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/rms_norm_musa.h" +#endif __C infiniopStatus_t infiniopCreateRMSNormDescriptor( infiniopHandle_t handle, @@ -57,6 +60,11 @@ __C infiniopStatus_t infiniopCreateRMSNormDescriptor( case DevMetaxGpu: { return macaCreateRMSNormDescriptor((MacaHandle_t) handle, (RMSNormMacaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateRMSNormDescriptor((MusaHandle_t) handle, (RMSNormMusaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); + } #endif } return STATUS_BAD_DEVICE; @@ -89,6 +97,11 @@ __C infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t case DevMetaxGpu: { return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetRMSNormWorkspaceSize((RMSNormMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -127,6 +140,11 @@ __C infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *wor case DevMetaxGpu: { return macaRMSNorm((RMSNormMacaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaRMSNorm((RMSNormMusaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -153,12 +171,16 @@ __C infiniopStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_ case DevAscendNpu: { return aclnnDestroyRMSNormDescriptor((RMSNormAclnnDescriptor_t) desc); } - #endif #ifdef ENABLE_METAX_GPU case DevMetaxGpu: { return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyRMSNormDescriptor((RMSNormMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc b/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc new file mode 100644 index 00000000..9ba0547d --- /dev/null +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc @@ -0,0 +1,76 @@ +#include "rotary_embedding_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateRoPEDescriptor(MusaHandle_t handle, + RoPEMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table) { + if (desc_ptr == nullptr) + return STATUS_MEMORY_NOT_ALLOCATED; + + if (t->ndim != 3 || + pos_ids->ndim != 1 || + sin_table->ndim != 2 || + cos_table->ndim != 2) + return STATUS_BAD_TENSOR_SHAPE; + + auto seq_len = t->shape[0]; + auto nhead = t->shape[1]; + auto dim = t->shape[2]; + auto total_seq_len = sin_table->shape[0]; + + if (dim % 2 != 0) + return STATUS_BAD_TENSOR_SHAPE; + + if (pos_ids->shape[0] != seq_len || + sin_table->shape[1] != dim || + cos_table->shape[1] != dim || + sin_table->shape[0] != cos_table->shape[0]) + return STATUS_BAD_TENSOR_SHAPE; + + // TODO: support larger dim in the future + if (dim / 2 > MAX_THREADS_PER_BLOCK) { + return STATUS_BAD_TENSOR_SHAPE; + } + + if (t->strides[2] != 1 || + pos_ids->strides[0] != 1 || + sin_table->strides[1] != 1 || + cos_table->strides[1] != 1) + return STATUS_BAD_TENSOR_STRIDES; + + if (!dtype_eq(t->dt, F16)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(pos_ids->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + + *desc_ptr = new RoPEMusaDescriptor{ + handle->device, + handle->device_id, + t->dt, + seq_len, + nhead, + dim, + total_seq_len, + {t->strides[0], t->strides[1]}}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + + +infiniopStatus_t musaDestroyRoPEDescriptor(RoPEMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.h b/src/ops/rotary_embedding/musa/rotary_embedding_musa.h new file mode 100644 index 00000000..7a14daea --- /dev/null +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.h @@ -0,0 +1,40 @@ +#ifndef __MUSA_ROTARY_EMBEDDING_H__ +#define __MUSA_ROTARY_EMBEDDING_H__ + +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" + +struct RoPEMusaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t seq_len; + uint64_t nhead; + uint64_t dim; + uint64_t total_seq_len; + int64_t strides[2]; +}; + +typedef struct RoPEMusaDescriptor *RoPEMusaDescriptor_t; + +infiniopStatus_t musaCreateRoPEDescriptor(MusaHandle_t handle, + RoPEMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table); + +infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream); + +infiniopStatus_t musaDestroyRoPEDescriptor(RoPEMusaDescriptor_t desc); + +#endif// __MT_GPU_ROTARY_EMBEDDING_H__ diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu b/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu new file mode 100644 index 00000000..bac7ad47 --- /dev/null +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu @@ -0,0 +1,68 @@ +#include "../../utils.h" +#include "rotary_embedding_musa.h" +#include + +static __global__ void padding_f16( + half *__restrict__ x_, + uint64_t const *__restrict__ pos_, + float const *__restrict__ sin_, + float const *__restrict__ cos_, + long const stride0, + long const stride1) { + auto dk = blockDim.x; + auto k = threadIdx.x; + auto offset = blockIdx.x * stride0 + blockIdx.y * stride1 + k * 2; + auto &x = reinterpret_cast(x_[offset]); + auto pos = pos_[blockIdx.x]; + auto sincos_offset = pos * dk * 2 + k * 2; + + float sin0 = sin_[sincos_offset], cos0 = cos_[sincos_offset], + sin1 = sin_[sincos_offset + 1], cos1 = cos_[sincos_offset + 1]; + float x0 = __half2float(x.x) * cos0 - __half2float(x.y) * sin0; + float x1 = __half2float(x.y) * cos1 + __half2float(x.x) * sin1; + x = half2(x0, x1); +} + + +void rotary_embedding_mt_gpu_f16( + RoPEMusaDescriptor_t desc, + half *t, + uint64_t const *pos, + float const *sin_, float const *cos_, + void *stream) { + auto nt = desc->seq_len, + nh = desc->nhead, + dh = desc->dim; + + // batching 2 half together + auto stride0 = desc->strides[0], + stride1 = desc->strides[1]; + + auto musa_stream = reinterpret_cast(stream); + padding_f16<<>>(t, pos, sin_, cos_, stride0, stride1); +} + +infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream) { + if (t == nullptr || pos_ids == nullptr || sin_table == nullptr || cos_table == nullptr) + return STATUS_BAD_PARAM; + + if (dtype_eq(desc->dtype, F16)) { + rotary_embedding_mt_gpu_f16(desc, + reinterpret_cast(t), + reinterpret_cast(pos_ids), + reinterpret_cast(sin_table), + reinterpret_cast(cos_table), + stream); + } else { + return STATUS_BAD_TENSOR_DTYPE; + } + + return STATUS_SUCCESS; +} diff --git a/src/ops/rotary_embedding/operator.cc b/src/ops/rotary_embedding/operator.cc index 5c1d4aec..bc2dbc09 100644 --- a/src/ops/rotary_embedding/operator.cc +++ b/src/ops/rotary_embedding/operator.cc @@ -18,6 +18,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/rotary_embedding_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/rotary_embedding_musa.h" +#endif struct RoPEDescriptor { Device device; @@ -65,6 +68,11 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle, sin_table, cos_table); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaCreateRoPEDescriptor((MusaHandle_t) handle, (RoPEMusaDescriptor_t *) desc_ptr, t, pos_ids, sin_table, cos_table); + } #endif } return STATUS_BAD_DEVICE; @@ -98,6 +106,11 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, return macaGetRoPEWorkspaceSize((RoPEMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaGetRoPEWorkspaceSize((RoPEMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -150,6 +163,11 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc, cos_table, stream); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaRoPE((RoPEMusaDescriptor_t) desc, workspace, workspace_size, t, pos_ids, sin_table, cos_table, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -181,6 +199,11 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc case DevMetaxGpu: { return macaDestroyRoPEDescriptor((RoPEMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { + return musaDestroyRoPEDescriptor((RoPEMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/swiglu/musa/swiglu.mu b/src/ops/swiglu/musa/swiglu.mu new file mode 100644 index 00000000..259e5c6f --- /dev/null +++ b/src/ops/swiglu/musa/swiglu.mu @@ -0,0 +1,68 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "swiglu_musa.h" +#include + +static __forceinline__ __device__ float silu(float x) { + return x * fdividef(1, 1 + expf(-x)); +} + +inline int gcd(int a, int b) { + while (b != 0) { + int rem = a % b; + a = b; + b = rem; + } + return a; +} + +template +static __global__ void swiglu( + Tdata *__restrict__ c, + int const stride_c, + Tdata const *__restrict__ a, + int const stride_a, + Tdata const *__restrict__ b, + int const stride_b) { + auto i = blockIdx.y * stride_b + blockIdx.x * blockDim.x + threadIdx.x, + j = blockIdx.y * stride_a + blockIdx.x * blockDim.x + threadIdx.x, + k = blockIdx.y * stride_c + blockIdx.x * blockDim.x + threadIdx.x; + auto x = float(b[i]), + y = float(a[j]); + c[k] = Tdata(silu(x) * y); +} + +void swiglu_mt_gpu_f16(SwiGLUMusaDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { + + auto seq_len = desc->seq_len, + di = desc->di; + + auto stride_a = desc->stride_a, + stride_b = desc->stride_b, + stride_c = desc->stride_c; + + dim3 block_dims = gcd(MAX_THREADS_PER_BLOCK, di); + dim3 grid_dims = dim3(di / block_dims.x, seq_len); + + auto a_ptr = reinterpret_cast(a); + auto b_ptr = reinterpret_cast(b); + auto c_ptr = reinterpret_cast(c); + + auto musa_stream = reinterpret_cast(stream); + + swiglu<<>>( + c_ptr, stride_c, a_ptr, stride_a, b_ptr, stride_b); +} + +infiniopStatus_t musaSwiGLU(SwiGLUMusaDescriptor_t desc, + void *c, + void const *a, + void const *b, + void *stream) { + if (dtype_eq(desc->dtype, F16)) { + swiglu_mt_gpu_f16(desc, c, a, b, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/swiglu/musa/swiglu_musa.cc b/src/ops/swiglu/musa/swiglu_musa.cc new file mode 100644 index 00000000..a1d5719b --- /dev/null +++ b/src/ops/swiglu/musa/swiglu_musa.cc @@ -0,0 +1,50 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "swiglu_musa.h" + +infiniopStatus_t musaCreateSwiGLUDescriptor(infiniopHandle_t handle, + SwiGLUMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + if (c_desc->ndim != 2 || a_desc->ndim != 2 || b_desc->ndim != 2) { + return STATUS_BAD_TENSOR_SHAPE; + } + + DT dtype = c_desc->dt; + + if (!dtype_eq(dtype, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + if (a_desc->strides[1] != 1 || b_desc->strides[1] != 1 || c_desc->strides[1] != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + + uint64_t seq_len = c_desc->shape[0], + di = c_desc->shape[1]; + + uint64_t stride_a = a_desc->strides[0], + stride_b = b_desc->strides[0], + stride_c = c_desc->strides[0]; + + + if (a_desc->shape[0] != seq_len || a_desc->shape[1] != di || !dtype_eq(a_desc->dt, dtype) || + b_desc->shape[0] != seq_len || b_desc->shape[1] != di || !dtype_eq(b_desc->dt, dtype)) { + return STATUS_BAD_PARAM; + } + + *desc_ptr = new SwiGLUMusaDescriptor{DevMthreadsGpu, + dtype, + seq_len, + di, + stride_a, + stride_b, + stride_c}; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroySwiGLUDescriptor(SwiGLUMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/swiglu/musa/swiglu_musa.h b/src/ops/swiglu/musa/swiglu_musa.h new file mode 100644 index 00000000..00ae1155 --- /dev/null +++ b/src/ops/swiglu/musa/swiglu_musa.h @@ -0,0 +1,34 @@ +#ifndef __MUSA_SWIGLU_H__ +#define __MUSA_SWIGLU_H__ + +#include "operators.h" + +struct SwiGLUMusaDescriptor { + Device device; + DT dtype; + uint64_t seq_len; + uint64_t di; + uint64_t stride_a; + uint64_t stride_b; + uint64_t stride_c; +}; + +typedef struct SwiGLUMusaDescriptor *SwiGLUMusaDescriptor_t; + +infiniopStatus_t musaCreateSwiGLUDescriptor(infiniopHandle_t handle, + SwiGLUMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_dec, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +infiniopStatus_t musaSwiGLU(SwiGLUMusaDescriptor_t desc, + void *c, + void const *a, + void const *b, + void *stream); + +infiniopStatus_t musaDestroySwiGLUDescriptor(SwiGLUMusaDescriptor_t desc); + +void swiglu_mt_gpu_f16(SwiGLUMusaDescriptor_t desc, void *c, void const *a, void const *b, void *stream); + +#endif// __MT_GPU_SWIGLU_H__ diff --git a/src/ops/swiglu/operator.cc b/src/ops/swiglu/operator.cc index 3eb68a97..3ea0bedc 100644 --- a/src/ops/swiglu/operator.cc +++ b/src/ops/swiglu/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/swiglu_maca.h" #endif +#ifdef ENABLE_MTHREADS_GPU +#include "musa/swiglu_musa.h" +#endif __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, infiniopSwiGLUDescriptor_t *desc_ptr, @@ -57,6 +60,10 @@ __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, a_desc, b_desc); } +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaCreateSwiGLUDescriptor(handle, (SwiGLUMusaDescriptor_t *) desc_ptr, c_desc, a_desc, b_desc); #endif } return STATUS_BAD_DEVICE; @@ -88,6 +95,10 @@ __C infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, #ifdef ENABLE_METAX_GPU case DevMetaxGpu: return macaSwiGLU((SwiGLUMacaDescriptor_t) desc, c, a, b, stream); +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaSwiGLU((SwiGLUMusaDescriptor_t) desc, c, a, b, stream); #endif } return STATUS_BAD_DEVICE; @@ -115,6 +126,10 @@ __C infiniopStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t #ifdef ENABLE_METAX_GPU case DevMetaxGpu: return macaDestroySwiGLUDescriptor((SwiGLUMacaDescriptor_t) desc); +#endif +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: + return musaDestroySwiGLUDescriptor((SwiGLUMusaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/xmake.lua b/xmake.lua index ce8f065a..f9e6f3dc 100644 --- a/xmake.lua +++ b/xmake.lua @@ -48,6 +48,13 @@ option("metax-gpu") option_end() +option("mthreads-gpu") + set_default(false) + set_showmenu(true) + set_description("Enable or disable MThreads GPU kernel") + add_defines("ENABLE_MTHREADS_GPU") +option_end() + option("sugon-dcu") set_default(false) set_showmenu(true) @@ -172,6 +179,51 @@ if has_config("cambricon-mlu") then end +if has_config("mthreads-gpu") then + + add_defines("ENABLE_MTHREADS_GPU") + local musa_home = os.getenv("MUSA_INSTALL_PATH") + -- Add include dirs + add_includedirs(musa_home .. "/include") + -- Add shared lib + add_linkdirs(musa_home .. "/lib") + add_links("libmusa.so") + add_links("libmusart.so") + add_links("libmudnn.so") + add_links("libmublas.so") + + rule("mu") + set_extensions(".mu") + on_load(function (target) + target:add("includedirs", "include") + end) + + on_build_file(function (target, sourcefile) + local objectfile = target:objectfile(sourcefile) + os.mkdir(path.directory(objectfile)) + + local mcc = "/usr/local/musa/bin/mcc" + local includedirs = table.concat(target:get("includedirs"), " ") + local args = {"-c", sourcefile, "-o", objectfile, "-I/usr/local/musa/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} + for _, includedir in ipairs(target:get("includedirs")) do + table.insert(args, "-I" .. includedir) + end + + os.execv(mcc, args) + table.insert(target:objectfiles(), objectfile) + end) + rule_end() + + target("mthreads-gpu") + set_kind("static") + set_languages("cxx17") + add_files("src/devices/musa/*.cc", "src/ops/*/musa/*.cc") + add_files("src/ops/*/musa/*.mu", {rule = "mu"}) + add_cxflags("-lstdc++ -Wall -fPIC") + target_end() + +end + if has_config("ascend-npu") then add_defines("ENABLE_ASCEND_NPU") @@ -315,6 +367,9 @@ target("infiniop") if has_config("metax-gpu") then add_deps("metax-gpu") end + if has_config("mthreads-gpu") then + add_deps("mthreads-gpu") + end set_languages("cxx17") add_files("src/devices/handle.cc") add_files("src/ops/*/operator.cc")