Skip to content

Commit f856213

Browse files
committed
add an array3d example
1 parent 631e6cb commit f856213

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

examples/array3d.spy

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from operator import OpImpl, OpArg
2+
from unsafe import gc_alloc, ptr
3+
4+
@blue.generic
5+
def array3d(DTYPE):
6+
7+
@struct
8+
class ArrayData:
9+
h: i32
10+
w: i32
11+
d: i32
12+
items: ptr[DTYPE]
13+
14+
@typelift
15+
class ndarray:
16+
__ll__: ptr[ArrayData]
17+
18+
# use this __new__ to allocate a buffer
19+
## def __new__(h: i32, w: i32, d: i32) -> ndarray:
20+
## data = gc_alloc(ArrayData)(1)
21+
## data.h = h
22+
## data.w = w
23+
## data.d = d
24+
## data.items = gc_alloc(DTYPE)(h*w*d)
25+
## i = 0
26+
## while i < h*w*d:
27+
## data.items[i] = 0
28+
## i = i + 1
29+
## return ndarray.__lift__(data)
30+
31+
# use this __new__ to create an array out of an existing buffer
32+
def __new__(buf: ptr[DTYPE], h: i32, w: i32, d: i32) -> ndarray:
33+
data = gc_alloc(ArrayData)(1)
34+
data.h = h
35+
data.w = w
36+
data.d = d
37+
data.items = buf
38+
return ndarray.__lift__(data)
39+
40+
@blue
41+
def __GETITEM__(v_arr: OpArg, v_i: OpArg, v_j: OpArg, v_k: OpArg) -> OpImpl:
42+
def getitem(arr: ndarray, i: i32, j: i32, k: i32) -> DTYPE:
43+
ll = arr.__ll__
44+
if i >= ll.h:
45+
raise IndexError
46+
if j >= ll.w:
47+
raise IndexError
48+
if k >= ll.d:
49+
raise IndexError
50+
idx = (i * ll.w * ll.d) + (j * ll.d) + k
51+
return ll.items[idx]
52+
return OpImpl(getitem)
53+
54+
@blue
55+
def __SETITEM__(v_arr: OpArg, v_i: OpArg, v_j: OpArg, v_k: OpArg,
56+
v_v: OpArg) -> OpImpl:
57+
def setitem(arr: ndarray, i: i32, j: i32, k: i32, v: DTYPE) -> void:
58+
ll = arr.__ll__
59+
if i >= ll.h:
60+
raise IndexError
61+
if j >= ll.w:
62+
raise IndexError
63+
if k >= ll.d:
64+
raise IndexError
65+
idx = (i * ll.w * ll.d) + (j * ll.d) + k
66+
ll.items[idx] = v
67+
return OpImpl(setitem)
68+
69+
def print_flatten(self: ndarray) -> void:
70+
ll = self.__ll__
71+
i = 0
72+
while i < ll.h * ll.w * ll.d:
73+
print(ll.items[i])
74+
i = i + 1
75+
76+
return ndarray
77+
78+
def main() -> void:
79+
buf = gc_alloc(i32)(4*3*2)
80+
i = 0
81+
while i < 4*3*2:
82+
buf[i] = i
83+
i = i + 1
84+
a = array3d[i32](buf, 4, 3, 2)
85+
#a.print_flatten()
86+
print(a[2, 1, 0])

0 commit comments

Comments
 (0)