This repository has been archived by the owner on Oct 28, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpMSpM-Multiply-COO.cl
105 lines (84 loc) · 2.52 KB
/
SpMSpM-Multiply-COO.cl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// To be replaced before kernel compiling
#define MAXROW %%AROW%%
#define MAXCOL %%BCOL%%
__kernel void spmm_coo_kernel_naive(
__global const uint * restrict ArowPtr, __global const uint * restrict Acols,
__global const float * restrict Adata,
__global const uint * restrict BrowPtr, __global const uint * restrict Bcols,
__global const float * restrict Bdata,
__global int * counter, __global float * cooVal)
{
int currRow = get_global_id(0);
int currCol = get_global_id(1);
if( !((currRow < MAXROW) && (currCol < MAXCOL)) )
{
return;
}
int ArowCur = ArowPtr[currRow];
int ArowEnd = ArowPtr[currRow+1];
int BrowCur = BrowPtr[currCol];
int BrowEnd = BrowPtr[currCol+1];
int AcurIdx = -1;
int BcurIdx = -1;
float localSum = 0;
while ((ArowCur < ArowEnd) && (BrowCur < BrowEnd)) {
AcurIdx = Acols[ArowCur];
BcurIdx = Bcols[BrowCur];
if (AcurIdx == BcurIdx) {
localSum += Adata[ArowCur] * Bdata[BrowCur];
ArowCur++;
BrowCur++;
} else if ( AcurIdx < BcurIdx) {
ArowCur++;
} else {
BrowCur++;
}
}
if (localSum > 0) {
int localIndex = atomic_add(counter,1);
cooVal[localIndex*3 + 0] = (float)currRow;
cooVal[localIndex*3 + 1] = (float)currCol;
cooVal[localIndex*3 + 2] = localSum;
// printf("(%d,%d)[%d]: raS: %f\n", currRow, currCol, localIndex, localSum);
}
}
__kernel void spmm_coo_binary_kernel_naive(
__global const uint * restrict ArowPtr, __global const uint * restrict Acols,
__global const uint * restrict BrowPtr, __global const uint * restrict Bcols,
__global int * counter, __global float * cooVal)
{
int currRow = get_global_id(0);
int currCol = get_global_id(1);
if( !((currRow < MAXROW) && (currCol < MAXCOL)) )
{
return;
}
int ArowCur = ArowPtr[currRow];
int ArowEnd = ArowPtr[currRow+1];
int BrowCur = BrowPtr[currCol];
int BrowEnd = BrowPtr[currCol+1];
int AcurIdx = -1;
int BcurIdx = -1;
// printf("(%d,%d): raS: %d raE: %d rbS: %d rbE: %d\n", currRow, currCol, ArowCur, ArowEnd, BrowCur, BrowEnd);
float localSum = 0;
while ((ArowCur < ArowEnd) && (BrowCur < BrowEnd)) {
AcurIdx = Acols[ArowCur];
BcurIdx = Bcols[BrowCur];
if (AcurIdx == BcurIdx) {
localSum += 1;
ArowCur++;
BrowCur++;
} else if ( AcurIdx < BcurIdx) {
ArowCur++;
} else {
BrowCur++;
}
}
if (localSum > 0) {
int localIndex = atomic_add(counter,1);
cooVal[localIndex*3 + 0] = (float)currRow;
cooVal[localIndex*3 + 1] = (float)currCol;
cooVal[localIndex*3 + 2] = localSum;
// printf("(%d,%d)[%d]: raS: %f\n", currRow, currCol, localIndex, localSum);
}
}