@@ -182,7 +182,7 @@ Status DecoderQkvToContext(
182
182
const hipDeviceProp_t& prop,
183
183
RocmTuningContext* tuning_ctx,
184
184
Stream* ort_stream,
185
- rocblas_handle& rocblas ,
185
+ hipblasHandle_t& hipblas ,
186
186
const size_t element_size,
187
187
const int batch_size,
188
188
const int sequence_length,
@@ -284,7 +284,7 @@ Status DecoderQkvToContext(
284
284
const int strideB = sequence_length * head_size;
285
285
if (use_past && static_kv) {
286
286
ORT_RETURN_IF_ERROR (blas::column_major::StridedBatchedGemm (
287
- tuning_ctx, ort_stream, rocblas ,
287
+ tuning_ctx, ort_stream, hipblas ,
288
288
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
289
289
kv_sequence_length, sequence_length, head_size,
290
290
/* alpha=*/ rsqrt_head_size,
@@ -295,7 +295,7 @@ Status DecoderQkvToContext(
295
295
BN));
296
296
} else {
297
297
ORT_RETURN_IF_ERROR (blas::column_major::StridedBatchedGemm (
298
- tuning_ctx, ort_stream, rocblas ,
298
+ tuning_ctx, ort_stream, hipblas ,
299
299
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
300
300
kv_sequence_length, sequence_length, head_size,
301
301
/* alpha=*/ rsqrt_head_size,
@@ -320,7 +320,7 @@ Status DecoderQkvToContext(
320
320
// compute P*V (as V*P), and store in scratch3: BxNxSxH
321
321
if (use_past && static_kv) {
322
322
ORT_RETURN_IF_ERROR (blas::column_major::StridedBatchedGemm (
323
- tuning_ctx, ort_stream, rocblas ,
323
+ tuning_ctx, ort_stream, hipblas ,
324
324
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
325
325
head_size, sequence_length, kv_sequence_length,
326
326
/* alpha=*/ 1 .0f ,
@@ -331,7 +331,7 @@ Status DecoderQkvToContext(
331
331
BN));
332
332
} else {
333
333
ORT_RETURN_IF_ERROR (blas::column_major::StridedBatchedGemm (
334
- tuning_ctx, ort_stream, rocblas ,
334
+ tuning_ctx, ort_stream, hipblas ,
335
335
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
336
336
head_size, sequence_length, kv_sequence_length,
337
337
/* alpha=*/ 1 .0f ,
@@ -351,7 +351,7 @@ Status LaunchDecoderAttentionKernel(
351
351
const hipDeviceProp_t& prop,
352
352
RocmTuningContext* tuning_ctx,
353
353
Stream* stream,
354
- rocblas_handle& rocblas ,
354
+ hipblasHandle_t& hipblas ,
355
355
const size_t element_size,
356
356
const int batch_size,
357
357
const int sequence_length,
@@ -378,7 +378,7 @@ Status LaunchDecoderAttentionKernel(
378
378
prop,
379
379
tuning_ctx,
380
380
stream,
381
- rocblas ,
381
+ hipblas ,
382
382
element_size,
383
383
batch_size,
384
384
sequence_length,
@@ -405,7 +405,7 @@ Status LaunchDecoderAttentionKernel(
405
405
prop,
406
406
tuning_ctx,
407
407
stream,
408
- rocblas ,
408
+ hipblas ,
409
409
element_size,
410
410
batch_size,
411
411
sequence_length,
0 commit comments