-
Notifications
You must be signed in to change notification settings - Fork 0
/
gemm_metal.py
132 lines (119 loc) · 5.43 KB
/
gemm_metal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv, flat_mv
from tinygrad import Device
from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, compile_metal
N = getenv("N", 2048)
LID = 2
device = MetalDevice("METAL")
metalalloc = MetalAllocator(device)
a = metalalloc.alloc(N*N*4)
b = metalalloc.alloc(N*N*4)
c = metalalloc.alloc(N*N*4)
na = np.zeros((N,N),dtype=np.float32)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)N
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
metalalloc.copyin(b,nb.tobytes())
metalalloc.copyin(c,nc.tobytes())
FLOPS = N*N*N*2
BW = N*N*3*4
prog = MetalProgram(device,"test", compile_metal(f"""
#include <metal_stdlib>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal;
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
a += gid.x * 32 * {N} + (gid.y * {LID} + lid.y) * 32;
data1 += gid.x * 32 * {N};
data2 += (gid.y * {LID} + lid.y) * 32;
simdgroup_float8x8 acc[4][4];
for (uint i = 0; i < 4; i++) {{
for (uint j = 0; j < 4; j++) {{
acc[i][j] = simdgroup_float8x8(0);
}}
}}
simdgroup_float8x8 A[4];
simdgroup_float8x8 B[4];
for (uint k = 0; k < {N}; k+=8) {{
threadgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_load(A[0], data1+k+{0*N}, {N}, ulong2(0, 0));
simdgroup_load(A[1], data1+k+{8*N}, {N}, ulong2(0, 0));
simdgroup_load(A[2], data1+k+{16*N}, {N}, ulong2(0, 0));
simdgroup_load(A[3], data1+k+{24*N}, {N}, ulong2(0, 0));
simdgroup_load(B[0], data2+0+k*{N}, {N}, ulong2(0, 0));
simdgroup_load(B[1], data2+8+k*{N}, {N}, ulong2(0, 0));
simdgroup_load(B[2], data2+16+k*{N}, {N}, ulong2(0, 0));
simdgroup_load(B[3], data2+24+k*{N}, {N}, ulong2(0, 0));
simdgroup_multiply_accumulate(acc[0][0], A[0], B[0], acc[0][0]);
simdgroup_multiply_accumulate(acc[0][1], A[1], B[0], acc[0][1]);
simdgroup_multiply_accumulate(acc[0][2], A[2], B[0], acc[0][2]);
simdgroup_multiply_accumulate(acc[0][3], A[3], B[0], acc[0][3]);
simdgroup_multiply_accumulate(acc[1][0], A[0], B[1], acc[1][0]);
simdgroup_multiply_accumulate(acc[1][1], A[1], B[1], acc[1][1]);
simdgroup_multiply_accumulate(acc[1][2], A[2], B[1], acc[1][2]);
simdgroup_multiply_accumulate(acc[1][3], A[3], B[1], acc[1][3]);
simdgroup_multiply_accumulate(acc[2][0], A[0], B[2], acc[2][0]);
simdgroup_multiply_accumulate(acc[2][1], A[1], B[2], acc[2][1]);
simdgroup_multiply_accumulate(acc[2][2], A[2], B[2], acc[2][2]);
simdgroup_multiply_accumulate(acc[2][3], A[3], B[2], acc[2][3]);
simdgroup_multiply_accumulate(acc[3][0], A[0], B[3], acc[3][0]);
simdgroup_multiply_accumulate(acc[3][1], A[1], B[3], acc[3][1]);
simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]);
simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]);
}}
simdgroup_store(acc[0][0], a+{0+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][0], a+{8+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][0], a+{16+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][0], a+{24+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[0][1], a+{0+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][1], a+{8+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][1], a+{16+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][1], a+{24+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[0][2], a+{0+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][2], a+{8+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][2], a+{16+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][2], a+{24+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[0][3], a+{0+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
}}"""))
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
comp = nb@nc
metalalloc.copyout(flat_mv(na.data), a)
if N <= 32:
print(na)
print(comp)
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na, comp, atol=1e-3)
import torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
def torch_prog(b, c):
st = time.perf_counter()
a = b@c
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return (b@c).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
Device["METAL"].synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")