Skip to content

Commit f018c2c

Browse files
committed
merge conflict
2 parents 732f981 + 925a91f commit f018c2c

File tree

16 files changed

+3039
-72
lines changed

16 files changed

+3039
-72
lines changed

behavior_tests/src/query-api-mapping/do_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ def migrate_test():
6262
[],
6363
[" *(f) = 0;"],
6464
],
65+
[
66+
"cublasSgemm",
67+
[
68+
" cublasSgemm(handle /*cublasHandle_t*/, transa /*cublasOperation_t*/,",
69+
" transb /*cublasOperation_t*/, m /*int*/, n /*int*/, k /*int*/,",
70+
" alpha /*const float **/, a /*const float **/, lda /*int*/,",
71+
" b /*const float **/, ldb /*int*/, beta /*const float **/,",
72+
" c /*float **/, ldc /*int*/);",
73+
],
74+
[],
75+
[" oneapi::mkl::blas::column_major::gemm(*handle, transa, transb, m, n, k, dpct::get_value(alpha, *handle), a, lda, b, ldb, dpct::get_value(beta, *handle), c, ldc);"],
76+
],
6577
]
6678

6779
res = True
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// ====------ cub_device_no_trivial_runs.cu ---------------- *- CUDA -* ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//
8+
// ===---------------------------------------------------------------------===//
9+
10+
#include <cub/cub.cuh>
11+
12+
#include <algorithm>
13+
#include <cstdarg>
14+
#include <cstddef>
15+
#include <cstdint>
16+
#include <cstdio>
17+
#include <initializer_list>
18+
#include <iostream>
19+
#include <sstream>
20+
21+
void CanFail(cudaError_t E, const char *Fmt, ...) {
22+
if (E == cudaSuccess)
23+
return;
24+
va_list Ap;
25+
va_start(Ap, Fmt);
26+
vfprintf(stderr, Fmt, Ap);
27+
va_end(Ap);
28+
std::terminate();
29+
}
30+
31+
#define CANTFAIL(E) \
32+
{ CanFail(E, #E " Failed\n"); }
33+
34+
template <typename T> T *Init(std::initializer_list<T> List) {
35+
T *Ptr = nullptr;
36+
size_t Bytes = sizeof(T) * List.size();
37+
CANTFAIL(cudaMallocManaged(&Ptr, Bytes));
38+
CANTFAIL(cudaMemcpy(Ptr, List.begin(), Bytes, cudaMemcpyHostToDevice));
39+
return Ptr;
40+
}
41+
42+
template <typename T> std::string Join(T *Begin, T *End) {
43+
std::stringstream OS;
44+
OS << "[";
45+
for (auto I = Begin; I != End; ++I) {
46+
OS << *I << (I == End - 1 ? "" : ", ");
47+
}
48+
OS << "]";
49+
return OS.str();
50+
}
51+
52+
template <typename T, size_t N> std::string Join(T (&Arr)[N]) {
53+
return Join(std::begin(Arr), std::end(Arr));
54+
}
55+
56+
int main() {
57+
58+
int num_items = 8;
59+
int *d_in = Init({0, 2, 2, 9, 5, 5, 5, 8});
60+
int *d_offsets_out = Init({0, 0, 0, 0, 0, 0, 0, 0});
61+
int *d_lengths_out = Init({0, 0, 0, 0, 0, 0, 0, 0});
62+
int *d_num_runs_out = Init({0});
63+
64+
int offsets[] = {1, 4};
65+
int lengths[] = {2, 3};
66+
int runs_out = 2;
67+
68+
void *d_temp_storage = nullptr;
69+
size_t temp_storage_bytes = 0;
70+
cub::DeviceRunLengthEncode::NonTrivialRuns(d_temp_storage, temp_storage_bytes,
71+
d_in, d_offsets_out, d_lengths_out,
72+
d_num_runs_out, num_items);
73+
cudaMalloc(&d_temp_storage, temp_storage_bytes);
74+
cub::DeviceRunLengthEncode::NonTrivialRuns(d_temp_storage, temp_storage_bytes,
75+
d_in, d_offsets_out, d_lengths_out,
76+
d_num_runs_out, num_items);
77+
cudaDeviceSynchronize();
78+
79+
if (*d_num_runs_out != runs_out) {
80+
std::cerr << "Expected d_num_runs_out = 2, but got " << *d_num_runs_out
81+
<< "\n";
82+
return 1;
83+
}
84+
85+
if (!std::equal(offsets, offsets + runs_out, d_offsets_out)) {
86+
std::cerr << "Expected d_offsets_out = " << Join(offsets) << ", but got "
87+
<< Join(d_offsets_out, d_offsets_out + runs_out) << "\n";
88+
return 1;
89+
}
90+
91+
if (!std::equal(lengths, lengths + runs_out, d_lengths_out)) {
92+
std::cerr << "Expected d_lengths_out = " << Join(lengths) << ", but got "
93+
<< Join(d_lengths_out, d_lengths_out + runs_out) << "\n";
94+
return 1;
95+
}
96+
97+
std::cout << "cub::DeviceRunLengthEncode::NonTrivialRuns PASS\n";
98+
99+
return 0;
100+
}

0 commit comments

Comments
 (0)