Skip to content

Commit f4405a3

Browse files
committed
增加对mudnn的支持
1 parent 7b9a469 commit f4405a3

File tree

4 files changed

+164
-2
lines changed

4 files changed

+164
-2
lines changed

src/devices/musa/musa_handle.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,37 @@ infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) {
2020
musaDeviceProp prop;
2121
musaGetDeviceProperties(&prop, device_id);
2222

23-
23+
// create a mublas handle pool
2424
auto mublas_pool = std::make_shared<Pool<mublasHandle_t>>();
2525
mublasHandle_t *mublas_handle = new mublasHandle_t;
2626
mublasCreate(mublas_handle);
2727
mublas_pool->push(mublas_handle);
2828

29-
*handle_ptr = new MusaContext{DevMtGpu, device_id, std::move(mublas_pool), std::move(prop)};
29+
// create a mudnn handle pool
30+
auto mudnn_pool = std::make_shared<Pool<musa::dnn::Handle>>();
31+
musa::dnn::Handle *mudnn_handle = new musa::dnn::Handle;
32+
mudnn_pool->push(mudnn_handle);
33+
34+
int capability_major;
35+
int capability_minor;
36+
musaDeviceGetAttribute(&capability_major, musaDevAttrComputeCapabilityMajor, device_id);
37+
musaDeviceGetAttribute(&capability_minor, musaDevAttrComputeCapabilityMinor, device_id);
38+
39+
*handle_ptr = new MusaContext{
40+
DevMtGpu,
41+
device_id,
42+
std::move(mublas_pool),
43+
std::move(mudnn_pool),
44+
std::move(prop),
45+
capability_major,
46+
capability_minor,};
3047

3148
return STATUS_SUCCESS;
3249
}
3350

3451
infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr) {
3552
handle_ptr->mublas_handles_t = nullptr;
53+
handle_ptr->mudnn_handles_t = nullptr;
3654
delete handle_ptr;
3755

3856
return STATUS_SUCCESS;

src/devices/musa/musa_handle.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ struct MusaContext {
1515
Device device;
1616
int device_id;
1717
std::shared_ptr<Pool<mublasHandle_t>> mublas_handles_t;
18+
std::shared_ptr<Pool<musa::dnn::Handle>> mudnn_handles_t;
1819
musaDeviceProp prop;
20+
int compute_capability_major;
21+
int compute_capability_minor;
1922
};
2023
typedef struct MusaContext *MusaHandle_t;
2124

@@ -40,4 +43,22 @@ void use_mublas(std::shared_ptr<Pool<mublasHandle_t>> mublas_handles_t, int devi
4043
mublas_handles_t->push(handle);
4144
}
4245

46+
template<typename T>
47+
void use_mudnn(std::shared_ptr<Pool<musa::dnn::Handle>> mudnn_handles_t, int device_id, musaStream_t stream, T const &f) {
48+
musa::dnn::Handle* handle = mudnn_handles_t->pop();
49+
if (!handle) {
50+
int current_device;
51+
musaGetDevice(&current_device);
52+
if (current_device != device_id) {
53+
musaSetDevice(device_id);
54+
}
55+
handle = new musa::dnn::Handle(device_id);
56+
// mudnnCreate(handle);
57+
}
58+
// mudnnSetStream(*handle, (MUstream) stream);
59+
handle->SetStream(stream);
60+
f(handle);
61+
mudnn_handles_t->push(handle);
62+
}
63+
4364
#endif // __MUSA_HANDLE_H__

src/devices/musa/tensor_desc.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
2+
#include "tensor_desc.h"
3+
#include <iostream>
4+
#include <vector>
5+
6+
// void mudnnSqueezeTensorDim(mudnnTensorDesc_t &ldesc, mudnnTensorDesc_t &rdesc, mudnnTensorDesc_t &outdesc) {
7+
// if (outdesc->ndims > 2) {
8+
// if (ldesc->ndims > 2 && *ldesc->dim == 1) {
9+
// ldesc->ndims -= 1;
10+
// ldesc->dim = ldesc->dim+1;
11+
// }
12+
// if (rdesc->ndims > 2 && *rdesc->dim == 1) {
13+
// rdesc->ndims -= 1;
14+
// rdesc->dim = rdesc->dim+1;
15+
// }
16+
// }
17+
// }
18+
19+
// void mudnnCreateTensorDescriptor(mudnnTensorDesc_t *desc) {
20+
// *desc = new mudnnTensorDesc;
21+
// (*desc)->type = Type::FLOAT;
22+
// (*desc)->format = Format::UNKNOWN;
23+
// (*desc)->ndims = 0;
24+
// (*desc)->dim = nullptr;
25+
// (*desc)->stride = nullptr;
26+
// (*desc)->scales = nullptr;
27+
// (*desc)->addr = nullptr;
28+
// }
29+
30+
31+
// void mudnnSetTensorDescriptor(mudnnTensorDesc_t &desc, int64_t *shape, int64_t *stride, int64_t ndim,
32+
// int64_t offset, Type type, Format format) {
33+
// desc->type = type;
34+
// desc->format = format;
35+
// desc->ndims = ndim;
36+
// desc->dim = shape;
37+
// if (stride) {
38+
// desc->stride = stride;
39+
// } else {
40+
// std::vector<int64_t> stride_v(ndim, 1);
41+
// for (int64_t i = ndim - 2; i >= 0; i--) {
42+
// stride_v[i] = shape[i + 1] * stride_v[i + 1];
43+
// }
44+
// desc->stride = stride_v.data();
45+
// }
46+
// }
47+
48+
// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout) {
49+
// auto dims = new int64_t(layout->ndim);
50+
// for (uint64_t i = 0; i < layout->ndim; i++) {
51+
// dims[i] = static_cast<int64_t>(layout->shape[i]);
52+
// }
53+
// // Cast bytes stride to element stride
54+
// auto strides = new int64_t(layout->ndim);
55+
// for (uint64_t i = 0; i < layout->ndim; i++) {
56+
// strides[i] = layout->strides[i] / (layout->dt).size;
57+
// }
58+
59+
// Type type = Type::HALF;
60+
// Format format = Format::NCHW;
61+
62+
// mudnnSetTensorDescriptor(desc, dims, strides, layout->ndim, 0, type, format);
63+
// }
64+
65+
// void mudnnDestroyTensorDescriptor(mudnnTensorDesc_t &desc) {
66+
// if (desc) {
67+
// delete desc;
68+
// desc = nullptr;
69+
// }
70+
// }
71+
72+
// int mudnnCreateTensor(TensorDescriptor desc, void *data, musa::dnn::Tensor **tensor) {
73+
// *tensor = new musa::dnn::Tensor();
74+
75+
// (*tensor)->SetAddr(data);
76+
// // (*tensor)->SetType(musa::dnn::Tensor::Type(desc->type));
77+
// (*tensor)->SetFormat(musa::dnn::Tensor::Format(desc->format));
78+
// // (*tensor)->SetNdInfo(desc->ndims, desc->dim, desc->stride);
79+
// (*tensor)->SetNdInfo(desc->ndims, desc->dim);
80+
// return 0;
81+
// }

src/devices/musa/tensor_desc.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef __TENSOR_DESC_H__
2+
#define __TENSOR_DESC_H__
3+
4+
#include "tensor.h"
5+
#include "common_musa.h"
6+
#include <musa.h>
7+
#include <musa_runtime.h>
8+
#include <mudnn.h>
9+
#include <mudnn_base.h>
10+
11+
// using namespace musa::dnn;
12+
13+
// struct mudnnTensorDesc {
14+
// Type type;
15+
// Format format;
16+
// int64_t ndims;
17+
// int64_t *dim;
18+
// int64_t *stride;
19+
// int64_t *scales;
20+
// int64_t *addr;
21+
// };
22+
23+
// typedef mudnnTensorDesc *mudnnTensorDesc_t;
24+
25+
// void mudnnCreateTensorDescriptor(mudnnTensorDesc_t *desc);
26+
27+
// void mudnnSetTensorDescriptor(mudnnTensorDesc_t &desc, int64_t *shape,
28+
// int64_t *stride, int64_t ndim, int64_t offset,
29+
// Type type, Format format);
30+
31+
// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout);
32+
33+
// void mudnnDestroyTensorDescriptor(mudnnTensorDesc_t &desc);
34+
35+
int mudnnCreateTensor(TensorDescriptor desc, void *data, musa::dnn::Tensor **tensor);
36+
37+
// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout);
38+
39+
// void mudnnSqueezeTensorDim(mudnnTensorDesc_t &ldesc, mudnnTensorDesc_t &rdesc, mudnnTensorDesc_t &outdesc);
40+
41+
42+
#endif // __TENSOR_DESC_H__

0 commit comments

Comments
 (0)