8
8
typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
9
9
10
10
template <cpy_kernel_t cpy_1>
11
- static __global__ void cpy_f32_f16 (const char * cx, char * cdst_direct, const int ne,
12
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14
- const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
11
+ static __global__ void cpy_flt (const char * cx, char * cdst_direct, const int ne,
12
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14
+ const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
15
15
const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
16
16
17
17
if (i >= ne) {
@@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139
139
#endif
140
140
}
141
141
142
- static void ggml_cpy_f16_f32_cuda (
142
+ template <typename src_t , typename dst_t >
143
+ static void ggml_cpy_flt_cuda (
143
144
const char * cx, char * cdst, const int ne,
144
145
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145
146
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146
147
147
148
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
148
- cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
149
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150
- }
151
-
152
- static void ggml_cpy_f32_f32_cuda (
153
- const char * cx, char * cdst, const int ne,
154
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156
-
157
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
158
- cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
159
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160
- }
161
-
162
- static void ggml_cpy_f32_bf16_cuda (
163
- const char * cx, char * cdst, const int ne,
164
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166
-
167
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
168
- cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
169
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170
- }
171
-
172
- static void ggml_cpy_f32_f16_cuda (
173
- const char * cx, char * cdst, const int ne,
174
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176
-
177
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
178
- cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
149
+ cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
179
150
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180
151
}
181
152
@@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
307
278
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
308
279
}
309
280
310
- static void ggml_cpy_f16_f16_cuda (
311
- const char * cx, char * cdst, const int ne,
312
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
313
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
314
-
315
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
316
- cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
317
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318
- }
319
-
320
281
void ggml_cuda_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321
282
const int64_t ne = ggml_nelements (src0);
322
283
GGML_ASSERT (ne == ggml_nelements (src1));
@@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372
333
CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
373
334
}
374
335
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375
- ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
336
+ ggml_cpy_flt_cuda< float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376
337
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
377
- ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
338
+ ggml_cpy_flt_cuda< float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
378
339
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
379
- ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340
+ ggml_cpy_flt_cuda< float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
380
341
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
381
342
ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
382
343
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
403
364
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
404
365
ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405
366
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406
- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367
+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
368
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
369
+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
407
370
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408
- ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
371
+ ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
372
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
373
+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
374
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
375
+ ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
377
+ ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
409
378
} else {
410
379
GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
411
380
ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
430
399
if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
431
400
return nullptr ;
432
401
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
433
- return (void *) cpy_f32_f16<cpy_1_f32_f32 >;
402
+ return (void *) cpy_flt<cpy_1_flt< float , float > >;
434
403
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
435
- return (void *) cpy_f32_f16<cpy_1_f32_bf16 >;
404
+ return (void *) cpy_flt<cpy_1_flt< float , nv_bfloat16> >;
436
405
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
437
- return (void *) cpy_f32_f16<cpy_1_f32_f16 >;
406
+ return (void *) cpy_flt<cpy_1_flt< float , half> >;
438
407
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439
408
return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
440
409
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458
427
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459
428
return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460
429
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461
- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
430
+ return (void *) cpy_flt<cpy_1_flt<half, half>>;
431
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
432
+ return (void *) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
462
433
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463
- return (void *) cpy_f32_f16<cpy_1_f16_f32>;
434
+ return (void *) cpy_flt<cpy_1_flt<half, float >>;
435
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
436
+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
437
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
438
+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
439
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
440
+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, float >>;
464
441
} else {
465
442
GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
466
443
ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments