Skip to content

Commit e28c0b8

Browse files
authored
cuda : implement bf16 cpy ops and enable bf16 cont (#14763)
* implement bf16 cpy ops and enable bf16 cont * deduplicate copy functions * deduplicate checks
1 parent 8e6f8bc commit e28c0b8

File tree

4 files changed

+49
-124
lines changed

4 files changed

+49
-124
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,13 @@
22

33
#include "ggml-common.h"
44

5-
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
6-
*dst = *src;
7-
}
8-
9-
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
10-
*dst = __float2half(*src);
11-
}
12-
13-
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
14-
*dst = *src;
15-
}
16-
17-
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
18-
*dst = *src;
19-
}
20-
21-
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
22-
*dst = *src;
5+
template<typename src_t, typename dst_t>
6+
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
7+
if constexpr (std::is_same_v<src_t, dst_t>) {
8+
*dst = *src;
9+
} else {
10+
*dst = float(*src);
11+
}
2312
}
2413

2514
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
@@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
230219
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
231220
}
232221

233-
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
234-
convert_f32_f32((const float *)cxi, (float *)cdsti);
235-
}
236-
237-
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
238-
convert_f32_f16((const float *)cxi, (half *)cdsti);
239-
}
240-
241-
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
242-
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
243-
}
244-
245-
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
246-
convert_f16_f16((const half *)cxi, (half *)cdsti);
247-
}
248-
249-
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
250-
convert_f16_f32((const half *)cxi, (float *)cdsti);
222+
template<typename src_t, typename dst_t>
223+
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
224+
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
251225
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

1010
template <cpy_kernel_t cpy_1>
11-
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
12-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14-
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
11+
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
12+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
1515
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
1616

1717
if (i >= ne) {
@@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139139
#endif
140140
}
141141

142-
static void ggml_cpy_f16_f32_cuda(
142+
template<typename src_t, typename dst_t>
143+
static void ggml_cpy_flt_cuda(
143144
const char * cx, char * cdst, const int ne,
144145
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145146
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146147

147148
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
148-
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150-
}
151-
152-
static void ggml_cpy_f32_f32_cuda(
153-
const char * cx, char * cdst, const int ne,
154-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156-
157-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
158-
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
159-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160-
}
161-
162-
static void ggml_cpy_f32_bf16_cuda(
163-
const char * cx, char * cdst, const int ne,
164-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166-
167-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
168-
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
169-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170-
}
171-
172-
static void ggml_cpy_f32_f16_cuda(
173-
const char * cx, char * cdst, const int ne,
174-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176-
177-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
178-
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
179150
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180151
}
181152

@@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
307278
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
308279
}
309280

310-
static void ggml_cpy_f16_f16_cuda(
311-
const char * cx, char * cdst, const int ne,
312-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
313-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
314-
315-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
316-
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
317-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318-
}
319-
320281
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321282
const int64_t ne = ggml_nelements(src0);
322283
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372333
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
373334
}
374335
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375-
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
336+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376337
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
377-
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
338+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
378339
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
379-
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
380341
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
381342
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
382343
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
403364
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
404365
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405366
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406-
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
368+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
369+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
407370
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408-
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
371+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
372+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
373+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
374+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
375+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
377+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
409378
} else {
410379
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
411380
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
430399
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
431400
return nullptr;
432401
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
433-
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
402+
return (void*) cpy_flt<cpy_1_flt<float, float>>;
434403
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
435-
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
404+
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
436405
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
437-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
406+
return (void*) cpy_flt<cpy_1_flt<float, half>>;
438407
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439408
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
440409
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458427
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459428
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460429
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
430+
return (void*) cpy_flt<cpy_1_flt<half, half>>;
431+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
432+
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
462433
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463-
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
434+
return (void*) cpy_flt<cpy_1_flt<half, float>>;
435+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
436+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
437+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
438+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
439+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
440+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
464441
} else {
465442
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
466443
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3242,13 +3242,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32423242
{
32433243
ggml_type src0_type = op->src[0]->type;
32443244
ggml_type src1_type = op->src[1]->type;
3245-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3246-
return true;
3247-
}
3248-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3249-
return true;
3250-
}
3251-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3245+
if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
3246+
(src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
3247+
) {
32523248
return true;
32533249
}
32543250
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3284,12 +3280,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32843280
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
32853281
return true;
32863282
}
3287-
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
3288-
return true;
3289-
}
3290-
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3291-
return true;
3292-
}
32933283
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
32943284
return true;
32953285
}
@@ -3370,7 +3360,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33703360
return op->src[0]->ne[1] % 128 == 0;
33713361
}
33723362
case GGML_OP_CONT:
3373-
return op->src[0]->type != GGML_TYPE_BF16;
3363+
return true;
33743364
case GGML_OP_DIAG_MASK_INF:
33753365
return true;
33763366
case GGML_OP_SOFT_MAX:

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,8 @@
44
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
55

66
template<typename src_t, typename dst_t>
7-
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8-
GGML_UNUSED(src_f);
9-
GGML_UNUSED(dst_f);
10-
}
11-
12-
template<>
13-
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
14-
convert_f32_f16(src_f, dst_h);
15-
}
16-
17-
template<>
18-
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
19-
convert_f32_bf16(src_f, dst_b);
20-
}
21-
22-
template<>
23-
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
24-
convert_f32_f32(src_f, dst_f);
7+
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8+
convert_flt(src_f, dst_f);
259
}
2610

2711
// Generic quantized set_rows kernel template

0 commit comments

Comments
 (0)