Skip to content

Commit

Permalink
Merge pull request #83 from chsasank/sasank/prelisp
Browse files Browse the repository at this point in the history
Python is our macroprocessor
  • Loading branch information
chsasank authored Jul 22, 2024
2 parents 473274a + 91ceb33 commit 8e94ecc
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 0 deletions.
70 changes: 70 additions & 0 deletions src/backend/prelisp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import sys
import json
import importlib
import os

sys.path.append(os.getcwd())


def prelisp(expr, module_name):
assert isinstance(module_name, str) and module_name.endswith(".py")
module = importlib.import_module(module_name[: -len(".py")])
return preprocess(expr, module)[0]


def preprocess(expr, env):
if isinstance(expr, list):
if expr[0] == "unquote":
assert len(expr) == 2
return expand_macro(expr[1], env), "append"
elif expr[0] == 'unquote-splicing':
assert len(expr) == 2
out = expand_macro(expr[1], env)
assert isinstance(out, list)
return out, "extend"
else:
res = []
for x in expr:
out, mode = preprocess(x, env)
if mode == 'append':
res.append(out)
else:
res.extend(out)

return res, "append"
else:
return expr, "append"


def expand_macro(expr, env):
if isinstance(expr, list):
fn_name, fn_args = expr[0], expr[1:]
fn = getattr(env, fn_name)
expr_out = fn(*fn_args)
return clean_python_output(expr_out)
elif isinstance(expr, str):
var = expr
expr_out = getattr(env, var)
return clean_python_output(expr_out)
else:
raise Exception("Unknown macro")


def clean_python_output(x):
return json.loads(json.dumps(x))


def main(mod_name):
expr = json.load(sys.stdin)
print(json.dumps(prelisp(expr, mod_name)))


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="What the program does")
parser.add_argument(
"macro_module", help="python file with entry for macro functions"
)
args = parser.parse_args()
main(args.macro_module)
9 changes: 9 additions & 0 deletions src/backend/tests/prelisp/first.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
(c-lisp
(define ((error_check void) (res int))
(if (gt res 0) (call exit 1))
(ret res))
(define ((main int))
(declare res int)
(declare res float)
(set res (call error (call print 3)))
(ret 0)))
11 changes: 11 additions & 0 deletions src/backend/tests/prelisp/first.sexp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(c-lisp
(define ((error_check void) (res int))
(if (gt res 0)
(call exit 1))
(ret res))

(define ((main int))
(declare res ,cudevice)
(declare res ,(cudevice2))
(set res ,(myError (call print 5)))
(ret 0)))
7 changes: 7 additions & 0 deletions src/backend/tests/prelisp/includes.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(c-lisp
(define ((print int) (n int)))
(define ((fprint float) (n float)))
(define ((main int))
(if #t ((call print 1) (call print 2)))
(if #f ((call print 0)) ((call print 3)))
(ret 0)))
12 changes: 12 additions & 0 deletions src/backend/tests/prelisp/includes.sexp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(c-lisp
,@(add_print_includes)

(define ((main int))
(if #t
((call print 1)
(call print 2)))

(if #f
((call print 0))
((call print 3)))
(ret 0)))
40 changes: 40 additions & 0 deletions src/backend/tests/prelisp/macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
cudevice = "int"
cuptr = ["ptr", "int"]


def cudevice2():
return "float"


def myError(expr):
expr[-1] = 3
return ("call", "error", expr)


def add_print_includes():
return [
("define", (("print", "int"), ("n", "int"))),
("define", (("fprint", "float"), ("n", "float"))),
]


def init_zeros(c, ldc):
out = []
for i in range(4):
for j in range(4):
ln = ("store", ("ptradd", c, ("add", i, ("mul", j, ldc))), 0.0)
out.append(ln)
return out


def store_vals(p, a, b, c, lda, ldb, ldc):
out = []
for i in range(4):
for j in range(4):
c_ij = ("ptradd", "c", ("add", i, ("mul", j, "ldc")))
a_ip = ("ptradd", "a", ("add", i, ("mul", p, "lda")))
b_pj = ("ptradd", "b", ("add", p, ("mul", j, "ldb")))
ln = ("store", c_ij, ("fadd", ("load", c_ij), ("fmul", ("load", a_ip), ("load", b_pj))))
out.append(ln)

return out
117 changes: 117 additions & 0 deletions src/backend/tests/prelisp/mmult.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
(c-lisp
(define ((__add_dot4x4 void)
(k int)
(a (ptr float))
(lda int)
(b (ptr float))
(ldb int)
(c (ptr float))
(ldc int))
(store (ptradd c (add 0 (mul 0 ldc))) 0)
(store (ptradd c (add 0 (mul 1 ldc))) 0)
(store (ptradd c (add 0 (mul 2 ldc))) 0)
(store (ptradd c (add 0 (mul 3 ldc))) 0)
(store (ptradd c (add 1 (mul 0 ldc))) 0)
(store (ptradd c (add 1 (mul 1 ldc))) 0)
(store (ptradd c (add 1 (mul 2 ldc))) 0)
(store (ptradd c (add 1 (mul 3 ldc))) 0)
(store (ptradd c (add 2 (mul 0 ldc))) 0)
(store (ptradd c (add 2 (mul 1 ldc))) 0)
(store (ptradd c (add 2 (mul 2 ldc))) 0)
(store (ptradd c (add 2 (mul 3 ldc))) 0)
(store (ptradd c (add 3 (mul 0 ldc))) 0)
(store (ptradd c (add 3 (mul 1 ldc))) 0)
(store (ptradd c (add 3 (mul 2 ldc))) 0)
(store (ptradd c (add 3 (mul 3 ldc))) 0)
(for ((set p 0) (lt p k) (set p (add p 1)))
(store (ptradd c (add 0 (mul 0 ldc)))
(fadd (load (ptradd c (add 0 (mul 0 ldc))))
(fmul (load (ptradd a (add 0 (mul p lda))))
(load (ptradd b (add p (mul 0 ldb)))))))
(store (ptradd c (add 0 (mul 1 ldc)))
(fadd (load (ptradd c (add 0 (mul 1 ldc))))
(fmul (load (ptradd a (add 0 (mul p lda))))
(load (ptradd b (add p (mul 1 ldb)))))))
(store (ptradd c (add 0 (mul 2 ldc)))
(fadd (load (ptradd c (add 0 (mul 2 ldc))))
(fmul (load (ptradd a (add 0 (mul p lda))))
(load (ptradd b (add p (mul 2 ldb)))))))
(store (ptradd c (add 0 (mul 3 ldc)))
(fadd (load (ptradd c (add 0 (mul 3 ldc))))
(fmul (load (ptradd a (add 0 (mul p lda))))
(load (ptradd b (add p (mul 3 ldb)))))))
(store (ptradd c (add 1 (mul 0 ldc)))
(fadd (load (ptradd c (add 1 (mul 0 ldc))))
(fmul (load (ptradd a (add 1 (mul p lda))))
(load (ptradd b (add p (mul 0 ldb)))))))
(store (ptradd c (add 1 (mul 1 ldc)))
(fadd (load (ptradd c (add 1 (mul 1 ldc))))
(fmul (load (ptradd a (add 1 (mul p lda))))
(load (ptradd b (add p (mul 1 ldb)))))))
(store (ptradd c (add 1 (mul 2 ldc)))
(fadd (load (ptradd c (add 1 (mul 2 ldc))))
(fmul (load (ptradd a (add 1 (mul p lda))))
(load (ptradd b (add p (mul 2 ldb)))))))
(store (ptradd c (add 1 (mul 3 ldc)))
(fadd (load (ptradd c (add 1 (mul 3 ldc))))
(fmul (load (ptradd a (add 1 (mul p lda))))
(load (ptradd b (add p (mul 3 ldb)))))))
(store (ptradd c (add 2 (mul 0 ldc)))
(fadd (load (ptradd c (add 2 (mul 0 ldc))))
(fmul (load (ptradd a (add 2 (mul p lda))))
(load (ptradd b (add p (mul 0 ldb)))))))
(store (ptradd c (add 2 (mul 1 ldc)))
(fadd (load (ptradd c (add 2 (mul 1 ldc))))
(fmul (load (ptradd a (add 2 (mul p lda))))
(load (ptradd b (add p (mul 1 ldb)))))))
(store (ptradd c (add 2 (mul 2 ldc)))
(fadd (load (ptradd c (add 2 (mul 2 ldc))))
(fmul (load (ptradd a (add 2 (mul p lda))))
(load (ptradd b (add p (mul 2 ldb)))))))
(store (ptradd c (add 2 (mul 3 ldc)))
(fadd (load (ptradd c (add 2 (mul 3 ldc))))
(fmul (load (ptradd a (add 2 (mul p lda))))
(load (ptradd b (add p (mul 3 ldb)))))))
(store (ptradd c (add 3 (mul 0 ldc)))
(fadd (load (ptradd c (add 3 (mul 0 ldc))))
(fmul (load (ptradd a (add 3 (mul p lda))))
(load (ptradd b (add p (mul 0 ldb)))))))
(store (ptradd c (add 3 (mul 1 ldc)))
(fadd (load (ptradd c (add 3 (mul 1 ldc))))
(fmul (load (ptradd a (add 3 (mul p lda))))
(load (ptradd b (add p (mul 1 ldb)))))))
(store (ptradd c (add 3 (mul 2 ldc)))
(fadd (load (ptradd c (add 3 (mul 2 ldc))))
(fmul (load (ptradd a (add 3 (mul p lda))))
(load (ptradd b (add p (mul 2 ldb)))))))
(store (ptradd c (add 3 (mul 3 ldc)))
(fadd (load (ptradd c (add 3 (mul 3 ldc))))
(fmul (load (ptradd a (add 3 (mul p lda))))
(load (ptradd b (add p (mul 3 ldb))))))))
(ret))
(define ((__kernel void)
(a (ptr float))
(b (ptr float))
(c (ptr float))
(m int)
(n int)
(k int))
(declare i int)
(declare j int)
(declare lda int)
(declare ldb int)
(declare ldc int)
(set lda m)
(set ldb n)
(set ldc k)
(for ((set j 0) (lt j n) (set j (add j 4)))
(for ((set i 0) (lt i m) (set i (add i 4)))
(call __add_dot4x4
k
(ptradd a (add i (mul 0 lda)))
lda
(ptradd b (mul j ldb))
ldb
(ptradd c (add i (mul j ldc)))
ldc)))
(ret)))
45 changes: 45 additions & 0 deletions src/backend/tests/prelisp/mmult.sexp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
(c-lisp
;; Function to add dot products for a 4x4 block
(define ((__add_dot4x4 void)
(k int)
(a (ptr float))
(lda int)
(b (ptr float))
(ldb int)
(c (ptr float))
(ldc int))

,@(init_zeros c ldc)

(for ((set p 0) (lt p k) (set p (add p 1)))
,@(store_vals p a b c lda ldb ldc))

(ret))


(define ((__kernel void)
(a (ptr float))
(b (ptr float))
(c (ptr float))
(m int)
(n int)
(k int))

(declare i int)
(declare j int)

(declare lda int)
(declare ldb int)
(declare ldc int)

(set lda m)
(set ldb n)
(set ldc k)

(for ((set j 0) (lt j n) (set j (add j 4)))
(for ((set i 0) (lt i m) (set i (add i 4)))
(call __add_dot4x4 k
(ptradd a (add i (mul 0 lda))) lda
(ptradd b (mul j ldb)) ldb
(ptradd c (add i (mul j ldc))) ldc)))
(ret)))
1 change: 1 addition & 0 deletions src/backend/tests/prelisp/turnt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
command = "guile ../../utils/sexp-json.scm < {filename} | python ../../prelisp.py macros.py | guile ../../utils/json-sexp.scm"

0 comments on commit 8e94ecc

Please sign in to comment.