diff --git a/src/backend/prelisp.py b/src/backend/prelisp.py new file mode 100644 index 0000000..f67358b --- /dev/null +++ b/src/backend/prelisp.py @@ -0,0 +1,70 @@ +import sys +import json +import importlib +import os + +sys.path.append(os.getcwd()) + + +def prelisp(expr, module_name): + assert isinstance(module_name, str) and module_name.endswith(".py") + module = importlib.import_module(module_name[: -len(".py")]) + return preprocess(expr, module)[0] + + +def preprocess(expr, env): + if isinstance(expr, list): + if expr[0] == "unquote": + assert len(expr) == 2 + return expand_macro(expr[1], env), "append" + elif expr[0] == 'unquote-splicing': + assert len(expr) == 2 + out = expand_macro(expr[1], env) + assert isinstance(out, list) + return out, "extend" + else: + res = [] + for x in expr: + out, mode = preprocess(x, env) + if mode == 'append': + res.append(out) + else: + res.extend(out) + + return res, "append" + else: + return expr, "append" + + +def expand_macro(expr, env): + if isinstance(expr, list): + fn_name, fn_args = expr[0], expr[1:] + fn = getattr(env, fn_name) + expr_out = fn(*fn_args) + return clean_python_output(expr_out) + elif isinstance(expr, str): + var = expr + expr_out = getattr(env, var) + return clean_python_output(expr_out) + else: + raise Exception("Unknown macro") + + +def clean_python_output(x): + return json.loads(json.dumps(x)) + + +def main(mod_name): + expr = json.load(sys.stdin) + print(json.dumps(prelisp(expr, mod_name))) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="What the program does") + parser.add_argument( + "macro_module", help="python file with entry for macro functions" + ) + args = parser.parse_args() + main(args.macro_module) diff --git a/src/backend/tests/prelisp/first.out b/src/backend/tests/prelisp/first.out new file mode 100644 index 0000000..618f927 --- /dev/null +++ b/src/backend/tests/prelisp/first.out @@ -0,0 +1,9 @@ +(c-lisp + (define ((error_check void) (res int)) + (if (gt res 0) (call exit 1)) + (ret res)) + (define ((main int)) + (declare res int) + (declare res float) + (set res (call error (call print 3))) + (ret 0))) diff --git a/src/backend/tests/prelisp/first.sexp b/src/backend/tests/prelisp/first.sexp new file mode 100644 index 0000000..a81ff88 --- /dev/null +++ b/src/backend/tests/prelisp/first.sexp @@ -0,0 +1,11 @@ +(c-lisp + (define ((error_check void) (res int)) + (if (gt res 0) + (call exit 1)) + (ret res)) + + (define ((main int)) + (declare res ,cudevice) + (declare res ,(cudevice2)) + (set res ,(myError (call print 5))) + (ret 0))) diff --git a/src/backend/tests/prelisp/includes.out b/src/backend/tests/prelisp/includes.out new file mode 100644 index 0000000..4017b49 --- /dev/null +++ b/src/backend/tests/prelisp/includes.out @@ -0,0 +1,7 @@ +(c-lisp + (define ((print int) (n int))) + (define ((fprint float) (n float))) + (define ((main int)) + (if #t ((call print 1) (call print 2))) + (if #f ((call print 0)) ((call print 3))) + (ret 0))) diff --git a/src/backend/tests/prelisp/includes.sexp b/src/backend/tests/prelisp/includes.sexp new file mode 100644 index 0000000..bd7a037 --- /dev/null +++ b/src/backend/tests/prelisp/includes.sexp @@ -0,0 +1,12 @@ +(c-lisp + ,@(add_print_includes) + + (define ((main int)) + (if #t + ((call print 1) + (call print 2))) + + (if #f + ((call print 0)) + ((call print 3))) + (ret 0))) diff --git a/src/backend/tests/prelisp/macros.py b/src/backend/tests/prelisp/macros.py new file mode 100644 index 0000000..814b306 --- /dev/null +++ b/src/backend/tests/prelisp/macros.py @@ -0,0 +1,40 @@ +cudevice = "int" +cuptr = ["ptr", "int"] + + +def cudevice2(): + return "float" + + +def myError(expr): + expr[-1] = 3 + return ("call", "error", expr) + + +def add_print_includes(): + return [ + ("define", (("print", "int"), ("n", "int"))), + ("define", (("fprint", "float"), ("n", "float"))), + ] + + +def init_zeros(c, ldc): + out = [] + for i in range(4): + for j in range(4): + ln = ("store", ("ptradd", c, ("add", i, ("mul", j, ldc))), 0.0) + out.append(ln) + return out + + +def store_vals(p, a, b, c, lda, ldb, ldc): + out = [] + for i in range(4): + for j in range(4): + c_ij = ("ptradd", "c", ("add", i, ("mul", j, "ldc"))) + a_ip = ("ptradd", "a", ("add", i, ("mul", p, "lda"))) + b_pj = ("ptradd", "b", ("add", p, ("mul", j, "ldb"))) + ln = ("store", c_ij, ("fadd", ("load", c_ij), ("fmul", ("load", a_ip), ("load", b_pj)))) + out.append(ln) + + return out diff --git a/src/backend/tests/prelisp/mmult.out b/src/backend/tests/prelisp/mmult.out new file mode 100644 index 0000000..9b73086 --- /dev/null +++ b/src/backend/tests/prelisp/mmult.out @@ -0,0 +1,117 @@ +(c-lisp + (define ((__add_dot4x4 void) + (k int) + (a (ptr float)) + (lda int) + (b (ptr float)) + (ldb int) + (c (ptr float)) + (ldc int)) + (store (ptradd c (add 0 (mul 0 ldc))) 0) + (store (ptradd c (add 0 (mul 1 ldc))) 0) + (store (ptradd c (add 0 (mul 2 ldc))) 0) + (store (ptradd c (add 0 (mul 3 ldc))) 0) + (store (ptradd c (add 1 (mul 0 ldc))) 0) + (store (ptradd c (add 1 (mul 1 ldc))) 0) + (store (ptradd c (add 1 (mul 2 ldc))) 0) + (store (ptradd c (add 1 (mul 3 ldc))) 0) + (store (ptradd c (add 2 (mul 0 ldc))) 0) + (store (ptradd c (add 2 (mul 1 ldc))) 0) + (store (ptradd c (add 2 (mul 2 ldc))) 0) + (store (ptradd c (add 2 (mul 3 ldc))) 0) + (store (ptradd c (add 3 (mul 0 ldc))) 0) + (store (ptradd c (add 3 (mul 1 ldc))) 0) + (store (ptradd c (add 3 (mul 2 ldc))) 0) + (store (ptradd c (add 3 (mul 3 ldc))) 0) + (for ((set p 0) (lt p k) (set p (add p 1))) + (store (ptradd c (add 0 (mul 0 ldc))) + (fadd (load (ptradd c (add 0 (mul 0 ldc)))) + (fmul (load (ptradd a (add 0 (mul p lda)))) + (load (ptradd b (add p (mul 0 ldb))))))) + (store (ptradd c (add 0 (mul 1 ldc))) + (fadd (load (ptradd c (add 0 (mul 1 ldc)))) + (fmul (load (ptradd a (add 0 (mul p lda)))) + (load (ptradd b (add p (mul 1 ldb))))))) + (store (ptradd c (add 0 (mul 2 ldc))) + (fadd (load (ptradd c (add 0 (mul 2 ldc)))) + (fmul (load (ptradd a (add 0 (mul p lda)))) + (load (ptradd b (add p (mul 2 ldb))))))) + (store (ptradd c (add 0 (mul 3 ldc))) + (fadd (load (ptradd c (add 0 (mul 3 ldc)))) + (fmul (load (ptradd a (add 0 (mul p lda)))) + (load (ptradd b (add p (mul 3 ldb))))))) + (store (ptradd c (add 1 (mul 0 ldc))) + (fadd (load (ptradd c (add 1 (mul 0 ldc)))) + (fmul (load (ptradd a (add 1 (mul p lda)))) + (load (ptradd b (add p (mul 0 ldb))))))) + (store (ptradd c (add 1 (mul 1 ldc))) + (fadd (load (ptradd c (add 1 (mul 1 ldc)))) + (fmul (load (ptradd a (add 1 (mul p lda)))) + (load (ptradd b (add p (mul 1 ldb))))))) + (store (ptradd c (add 1 (mul 2 ldc))) + (fadd (load (ptradd c (add 1 (mul 2 ldc)))) + (fmul (load (ptradd a (add 1 (mul p lda)))) + (load (ptradd b (add p (mul 2 ldb))))))) + (store (ptradd c (add 1 (mul 3 ldc))) + (fadd (load (ptradd c (add 1 (mul 3 ldc)))) + (fmul (load (ptradd a (add 1 (mul p lda)))) + (load (ptradd b (add p (mul 3 ldb))))))) + (store (ptradd c (add 2 (mul 0 ldc))) + (fadd (load (ptradd c (add 2 (mul 0 ldc)))) + (fmul (load (ptradd a (add 2 (mul p lda)))) + (load (ptradd b (add p (mul 0 ldb))))))) + (store (ptradd c (add 2 (mul 1 ldc))) + (fadd (load (ptradd c (add 2 (mul 1 ldc)))) + (fmul (load (ptradd a (add 2 (mul p lda)))) + (load (ptradd b (add p (mul 1 ldb))))))) + (store (ptradd c (add 2 (mul 2 ldc))) + (fadd (load (ptradd c (add 2 (mul 2 ldc)))) + (fmul (load (ptradd a (add 2 (mul p lda)))) + (load (ptradd b (add p (mul 2 ldb))))))) + (store (ptradd c (add 2 (mul 3 ldc))) + (fadd (load (ptradd c (add 2 (mul 3 ldc)))) + (fmul (load (ptradd a (add 2 (mul p lda)))) + (load (ptradd b (add p (mul 3 ldb))))))) + (store (ptradd c (add 3 (mul 0 ldc))) + (fadd (load (ptradd c (add 3 (mul 0 ldc)))) + (fmul (load (ptradd a (add 3 (mul p lda)))) + (load (ptradd b (add p (mul 0 ldb))))))) + (store (ptradd c (add 3 (mul 1 ldc))) + (fadd (load (ptradd c (add 3 (mul 1 ldc)))) + (fmul (load (ptradd a (add 3 (mul p lda)))) + (load (ptradd b (add p (mul 1 ldb))))))) + (store (ptradd c (add 3 (mul 2 ldc))) + (fadd (load (ptradd c (add 3 (mul 2 ldc)))) + (fmul (load (ptradd a (add 3 (mul p lda)))) + (load (ptradd b (add p (mul 2 ldb))))))) + (store (ptradd c (add 3 (mul 3 ldc))) + (fadd (load (ptradd c (add 3 (mul 3 ldc)))) + (fmul (load (ptradd a (add 3 (mul p lda)))) + (load (ptradd b (add p (mul 3 ldb)))))))) + (ret)) + (define ((__kernel void) + (a (ptr float)) + (b (ptr float)) + (c (ptr float)) + (m int) + (n int) + (k int)) + (declare i int) + (declare j int) + (declare lda int) + (declare ldb int) + (declare ldc int) + (set lda m) + (set ldb n) + (set ldc k) + (for ((set j 0) (lt j n) (set j (add j 4))) + (for ((set i 0) (lt i m) (set i (add i 4))) + (call __add_dot4x4 + k + (ptradd a (add i (mul 0 lda))) + lda + (ptradd b (mul j ldb)) + ldb + (ptradd c (add i (mul j ldc))) + ldc))) + (ret))) diff --git a/src/backend/tests/prelisp/mmult.sexp b/src/backend/tests/prelisp/mmult.sexp new file mode 100644 index 0000000..e77ed81 --- /dev/null +++ b/src/backend/tests/prelisp/mmult.sexp @@ -0,0 +1,45 @@ +(c-lisp + ;; Function to add dot products for a 4x4 block + (define ((__add_dot4x4 void) + (k int) + (a (ptr float)) + (lda int) + (b (ptr float)) + (ldb int) + (c (ptr float)) + (ldc int)) + + ,@(init_zeros c ldc) + + (for ((set p 0) (lt p k) (set p (add p 1))) + ,@(store_vals p a b c lda ldb ldc)) + + (ret)) + + + (define ((__kernel void) + (a (ptr float)) + (b (ptr float)) + (c (ptr float)) + (m int) + (n int) + (k int)) + + (declare i int) + (declare j int) + + (declare lda int) + (declare ldb int) + (declare ldc int) + + (set lda m) + (set ldb n) + (set ldc k) + + (for ((set j 0) (lt j n) (set j (add j 4))) + (for ((set i 0) (lt i m) (set i (add i 4))) + (call __add_dot4x4 k + (ptradd a (add i (mul 0 lda))) lda + (ptradd b (mul j ldb)) ldb + (ptradd c (add i (mul j ldc))) ldc))) + (ret))) diff --git a/src/backend/tests/prelisp/turnt.toml b/src/backend/tests/prelisp/turnt.toml new file mode 100644 index 0000000..a9a42c9 --- /dev/null +++ b/src/backend/tests/prelisp/turnt.toml @@ -0,0 +1 @@ +command = "guile ../../utils/sexp-json.scm < {filename} | python ../../prelisp.py macros.py | guile ../../utils/json-sexp.scm" \ No newline at end of file