Skip to content

Commit 51b87eb

Browse files
authored
correct geam API move beta before B (#151)
1 parent 996f5c3 commit 51b87eb

File tree

6 files changed

+69
-69
lines changed

6 files changed

+69
-69
lines changed

clients/common/rocblas_template_specialization.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -544,10 +544,10 @@
544544
rocblas_int m, rocblas_int n,
545545
const float *alpha,
546546
const float *A, rocblas_int lda,
547-
const float *B, rocblas_int ldb,
548547
const float *beta,
548+
const float *B, rocblas_int ldb,
549549
float *C, rocblas_int ldc){
550-
return rocblas_sgeam(handle, transA, transB, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
550+
return rocblas_sgeam(handle, transA, transB, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
551551
}
552552

553553
template<>
@@ -556,10 +556,10 @@
556556
rocblas_int m, rocblas_int n,
557557
const double *alpha,
558558
const double *A, rocblas_int lda,
559-
const double *B, rocblas_int ldb,
560559
const double *beta,
560+
const double *B, rocblas_int ldb,
561561
double *C, rocblas_int ldc){
562-
return rocblas_dgeam(handle, transA, transB, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
562+
return rocblas_dgeam(handle, transA, transB, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
563563
}
564564

565565

clients/include/rocblas.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@
126126
rocblas_int m, rocblas_int n,
127127
const T *alpha,
128128
const T *A, rocblas_int lda,
129-
const T *B, rocblas_int ldb,
130129
const T *beta,
130+
const T *B, rocblas_int ldb,
131131
T *C, rocblas_int ldc);
132132

133133
template<typename T>

clients/include/testing_geam.hpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ void testing_geam_bad_arg()
7979
status = rocblas_geam<T>(handle, transA, transB,
8080
M, N,
8181
&h_alpha, dA_null, lda,
82-
dB, ldb,
83-
&h_beta, dC, ldc);
82+
&h_beta, dB, ldb,
83+
dC, ldc);
8484

8585
verify_rocblas_status_invalid_pointer(status, "ERROR: A is nullptr");
8686
}
@@ -89,8 +89,8 @@ void testing_geam_bad_arg()
8989
status = rocblas_geam<T>(handle, transA, transB,
9090
M, N,
9191
&h_alpha, dA, lda,
92-
dB_null, ldb,
93-
&h_beta, dC, ldc);
92+
&h_beta, dB_null, ldb,
93+
dC, ldc);
9494

9595
verify_rocblas_status_invalid_pointer(status, "ERROR: B is nullptr");
9696
}
@@ -99,8 +99,8 @@ void testing_geam_bad_arg()
9999
status = rocblas_geam<T>(handle, transA, transB,
100100
M, N,
101101
&h_alpha, dA, lda,
102-
dB, ldb,
103-
&h_beta, dC_null, ldc);
102+
&h_beta, dB, ldb,
103+
dC_null, ldc);
104104

105105
verify_rocblas_status_invalid_pointer(status, "ERROR: C is nullptr");
106106
}
@@ -109,8 +109,8 @@ void testing_geam_bad_arg()
109109
status = rocblas_geam<T>(handle, transA, transB,
110110
M, N,
111111
h_alpha_null, dA, lda,
112-
dB, ldb,
113-
&h_beta, dC, ldc);
112+
&h_beta, dB, ldb,
113+
dC, ldc);
114114

115115
verify_rocblas_status_invalid_pointer(status, "ERROR: h_alpha is nullptr");
116116
}
@@ -119,8 +119,8 @@ void testing_geam_bad_arg()
119119
status = rocblas_geam<T>(handle, transA, transB,
120120
M, N,
121121
&h_alpha, dA, lda,
122-
dB, ldb,
123-
h_beta_null, dC, ldc);
122+
h_beta_null, dB, ldb,
123+
dC, ldc);
124124

125125
verify_rocblas_status_invalid_pointer(status, "ERROR: h_beta is nullptr");
126126
}
@@ -129,8 +129,8 @@ void testing_geam_bad_arg()
129129
status = rocblas_geam<T>(handle_null, transA, transB,
130130
M, N,
131131
&h_alpha, dA, lda,
132-
dB, ldb,
133-
&h_beta, dC, ldc);
132+
&h_beta, dB, ldb,
133+
dC, ldc);
134134

135135
verify_rocblas_status_invalid_handle(status);
136136
}
@@ -213,8 +213,8 @@ rocblas_status testing_geam(Arguments argus)
213213
status = rocblas_geam<T>(handle, transA, transB,
214214
M, N,
215215
&h_alpha, dA, lda,
216-
dB, ldb,
217-
&h_beta, dC, ldc);
216+
&h_beta, dB, ldb,
217+
dC, ldc);
218218

219219
geam_arg_check(status, M, N, lda, ldb, ldc);
220220

@@ -239,8 +239,8 @@ rocblas_status testing_geam(Arguments argus)
239239
status = rocblas_geam<T>(handle, transA, transB,
240240
M, N,
241241
&h_alpha, dA, lda,
242-
dB, ldb,
243-
&h_beta, dC, ldc);
242+
&h_beta, dB, ldb,
243+
dC, ldc);
244244

245245
verify_rocblas_status_invalid_pointer(status, "ERROR: A or B or C is nullptr");
246246

@@ -257,8 +257,8 @@ rocblas_status testing_geam(Arguments argus)
257257
status = rocblas_geam<T>(handle, transA, transB,
258258
M, N,
259259
&h_alpha, dA, lda,
260-
dB, ldb,
261-
&h_beta, dC, ldc);
260+
&h_beta, dB, ldb,
261+
dC, ldc);
262262

263263
verify_rocblas_status_invalid_handle(status);
264264

@@ -301,8 +301,8 @@ rocblas_status testing_geam(Arguments argus)
301301
status_h = rocblas_geam<T>(handle, transA, transB,
302302
M, N,
303303
&h_alpha, dA, lda,
304-
dB, ldb,
305-
&h_beta, dC, ldc);
304+
&h_beta, dB, ldb,
305+
dC, ldc);
306306

307307
CHECK_HIP_ERROR(hipMemcpy(hC_h.data(), dC, sizeof(T) * C_size, hipMemcpyDeviceToHost));
308308

@@ -311,8 +311,8 @@ rocblas_status testing_geam(Arguments argus)
311311
status_d = rocblas_geam<T>(handle, transA, transB,
312312
M, N,
313313
d_alpha, dA, lda,
314-
dB, ldb,
315-
d_beta, dC, ldc);
314+
d_beta, dB, ldb,
315+
dC, ldc);
316316

317317
CHECK_HIP_ERROR(hipMemcpy(hC_d.data(), dC, sizeof(T) * C_size, hipMemcpyDeviceToHost));
318318

@@ -373,8 +373,8 @@ rocblas_status testing_geam(Arguments argus)
373373
status_h = rocblas_geam<T>(handle, transA, transB,
374374
M, N,
375375
&h_alpha, dA, lda,
376-
dB, ldb,
377-
&h_beta, dC_in_place, ldc);
376+
&h_beta, dB, ldb,
377+
dC_in_place, ldc);
378378

379379
if (lda != ldc || transA != rocblas_operation_none)
380380
{
@@ -428,8 +428,8 @@ rocblas_status testing_geam(Arguments argus)
428428
status_h = rocblas_geam<T>(handle, transA, transB,
429429
M, N,
430430
&h_alpha, dA, lda,
431-
dB, ldb,
432-
&h_beta, dC_in_place, ldc);
431+
&h_beta, dB, ldb,
432+
dC_in_place, ldc);
433433

434434
if (ldb != ldc || transB != rocblas_operation_none)
435435
{
@@ -489,8 +489,8 @@ rocblas_status testing_geam(Arguments argus)
489489
status = rocblas_geam<T>(handle, transA, transB,
490490
M, N,
491491
&h_alpha, dA, lda,
492-
dB, ldb,
493-
&h_beta, dC, ldc);
492+
&h_beta, dB, ldb,
493+
dC, ldc);
494494
}
495495

496496
gpu_time_used = get_time_us(); // in microseconds
@@ -499,8 +499,8 @@ rocblas_status testing_geam(Arguments argus)
499499
status = rocblas_geam<T>(handle, transA, transB,
500500
M, N,
501501
&h_alpha, dA, lda,
502-
dB, ldb,
503-
&h_beta, dC, ldc);
502+
&h_beta, dB, ldb,
503+
dC, ldc);
504504
}
505505
gpu_time_used = get_time_us() - gpu_time_used;
506506
rocblas_gflops = geam_gflop_count<T> (M, N) * number_hot_calls / gpu_time_used * 1e6;

library/include/rocblas-functions.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,12 +1207,12 @@ rocblas_zgemm_strided_batched(
12071207
lda rocblas_int
12081208
specifies the leading dimension of A.
12091209
@param[in]
1210+
beta specifies the scalar beta.
1211+
@param[in]
12101212
B pointer storing matrix B on the GPU.
12111213
@param[in]
12121214
ldb rocblas_int
12131215
specifies the leading dimension of B.
1214-
@param[in]
1215-
beta specifies the scalar beta.
12161216
@param[in, out]
12171217
C pointer storing matrix C on the GPU.
12181218
@param[in]
@@ -1228,8 +1228,8 @@ rocblas_sgeam(
12281228
rocblas_int m, rocblas_int n,
12291229
const float *alpha,
12301230
const float *A, rocblas_int lda,
1231-
const float *B, rocblas_int ldb,
12321231
const float *beta,
1232+
const float *B, rocblas_int ldb,
12331233
float *C, rocblas_int ldc);
12341234

12351235
ROCBLAS_EXPORT rocblas_status
@@ -1239,8 +1239,8 @@ rocblas_dgeam(
12391239
rocblas_int m, rocblas_int n,
12401240
const double *alpha,
12411241
const double *A, rocblas_int lda,
1242-
const double *B, rocblas_int ldb,
12431242
const double *beta,
1243+
const double *B, rocblas_int ldb,
12441244
double *C, rocblas_int ldc);
12451245

12461246

library/src/blas3/geam_device.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ geam_device(
1010
rocblas_int m, rocblas_int n,
1111
T alpha,
1212
const T * __restrict__ A, rocblas_int lda,
13-
const T * __restrict__ B, rocblas_int ldb,
1413
T beta,
14+
const T * __restrict__ B, rocblas_int ldb,
1515
T * C, rocblas_int ldc)
1616
{
1717
rocblas_int tx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
@@ -92,8 +92,8 @@ geam_1D_device(
9292
rocblas_int size,
9393
T alpha,
9494
const T * __restrict__ A,
95-
const T * __restrict__ B,
9695
T beta,
96+
const T * __restrict__ B,
9797
T * C)
9898
{
9999
rocblas_int tx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
@@ -144,8 +144,8 @@ geam_inplace_device(
144144
rocblas_operation transB,
145145
rocblas_int m, rocblas_int n,
146146
T alpha,
147-
const T * __restrict__ B, rocblas_int ldb,
148147
T beta,
148+
const T * __restrict__ B, rocblas_int ldb,
149149
T * C, rocblas_int ldc)
150150
{
151151
rocblas_int tx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;

0 commit comments

Comments
 (0)