Skip to content

Commit

Permalink
wip: generate kernel on IR generation for CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
hidetatz committed Jan 24, 2025
1 parent adbb2c4 commit 4759b35
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 64 deletions.
148 changes: 106 additions & 42 deletions tensor2/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,46 +51,63 @@ func generateIR(t *Tensor) (irs []*instruction, err error) {
}
}()

irs = []*instruction{}
inst := func(m mnemonic) *instruction {
return &instruction{id: newInstid(), mnemonic: m}
}

kernels := []*instruction{}
pushK := func(ir *instruction) instid {
kernels = append(kernels, ir)
return ir.id
}

push := func(ir *instruction) instid {
irs = append(irs, ir)
entries := []*instruction{}
pushE := func(ir *instruction) instid {
entries = append(entries, ir)
return ir.id
}

pushE(inst(&mnEntry{}))

var dfs func(t *Tensor) instid
dfs = func(t *Tensor) instid {
switch t.op {
case op_const:
return push(inst(&mnParam{typ: t_floats, val: t.data}))
return pushE(inst(&mnParam{typ: t_floats, val: t.data}))

case op_recip:
input := t.inputs[0]
size := input.Size()

inputid := dfs(input)

// define result to store
result := push(inst(&mnDecl{typ: t_floats, length: size}))

// start loop
loop := push(inst(&mnLoop{countImm: size}))

load := push(inst(&mnInit{from: inputid, idx: loop}))

/*
* define kernel
*/
paramIdx := pushK(inst(&mnKernParam{typ: t_int}))
paramx := pushK(inst(&mnKernParam{typ: t_floats}))
paramresult := pushK(inst(&mnKernParam{typ: t_floats}))
kern := pushK(inst(&mnKernel{params: []instid{paramIdx, paramx, paramresult}}))
target := pushK(inst(&mnInit{from: paramx, idx: paramIdx}))
var op alu1op
if t.op == op_recip {
op = alu1_recip
}
alu1 := pushK(inst(&mnALU1{val: target, op: op}))
pushK(inst(&mnAssign{left: paramresult, lidx: paramIdx, right: alu1}))
pushK(inst(&mnEndKernel{}))

// do compute
alu1 := push(inst(&mnALU1{val: load, op: op}))
/*
* call kernel from entry
*/

// assign computed to result
push(inst(&mnAssign{left: result, lidx: loop, right: alu1}))
// define result to store
result := pushE(inst(&mnDecl{typ: t_floats, length: size}))

// finish loop
push(inst(&mnEndLoop{}))
// start loop and invokes kernel
loop := pushE(inst(&mnLoop{countImm: size}))
pushE(inst(&mnInvokeKernel{kernel: kern, args: []instid{loop, inputid, result}}))
pushE(inst(&mnEndLoop{}))

return result

Expand All @@ -107,25 +124,25 @@ func generateIR(t *Tensor) (irs []*instruction, err error) {
lid, rid := dfs(l), dfs(r)

// define result to store
result := push(inst(&mnDecl{typ: t_floats, length: sizel}))
result := pushE(inst(&mnDecl{typ: t_floats, length: sizel}))

// start loop
loop := push(inst(&mnLoop{countImm: sizel}))
loop := pushE(inst(&mnLoop{countImm: sizel}))

// assume vector
// todo: support 2 or more dimensions

// compute stride, considering broadcast
lstride := push(inst(&mnInitImm{typ: t_int, val: l.dim.strides[0]}))
rstride := push(inst(&mnInitImm{typ: t_int, val: r.dim.strides[0]}))
lstride := pushE(inst(&mnInitImm{typ: t_int, val: l.dim.strides[0]}))
rstride := pushE(inst(&mnInitImm{typ: t_int, val: r.dim.strides[0]}))

// define index
lidx := push(inst(&mnALU2{left: loop, op: alu2_mul, right: lstride}))
ridx := push(inst(&mnALU2{left: loop, op: alu2_mul, right: rstride}))
lidx := pushE(inst(&mnALU2{left: loop, op: alu2_mul, right: lstride}))
ridx := pushE(inst(&mnALU2{left: loop, op: alu2_mul, right: rstride}))

// load value to be computed from left and right
loadl := push(inst(&mnInit{from: lid, idx: lidx}))
loadr := push(inst(&mnInit{from: rid, idx: ridx}))
loadl := pushE(inst(&mnInit{from: lid, idx: lidx}))
loadr := pushE(inst(&mnInit{from: rid, idx: ridx}))

var op alu2op

Expand All @@ -136,13 +153,13 @@ func generateIR(t *Tensor) (irs []*instruction, err error) {
}

// do compute
alu2 := push(inst(&mnALU2{left: loadl, op: op, right: loadr}))
alu2 := pushE(inst(&mnALU2{left: loadl, op: op, right: loadr}))

// assign computed to result
push(inst(&mnAssign{left: result, lidx: loop, right: alu2}))
pushE(inst(&mnAssign{left: result, lidx: loop, right: alu2}))

// finish loop
push(inst(&mnEndLoop{}))
pushE(inst(&mnEndLoop{}))

return result

Expand All @@ -155,9 +172,10 @@ func generateIR(t *Tensor) (irs []*instruction, err error) {
}

result := dfs(t)
push(inst(&mnReturn{val: result}))
pushE(inst(&mnReturn{val: result}))
pushE(inst(&mnEndEntry{}))

return irs, nil
return slices.Concat(kernels, entries), nil
}

/*
Expand Down Expand Up @@ -193,10 +211,15 @@ type cLikeLangRenderer interface {

varname(id int) string

entrypoint() string
kernelName(id int) string
kernel(kernname string, params []string, typs []typ) string
endKernel() string

kernel(entry string) string
endkernel() string
invokeKernel(kernname string, args []string) string

entrypointName() string
entrypoint(entryname string) string
endEntrypoint() string

// ir
global(varname string, typ typ) string
Expand Down Expand Up @@ -238,6 +261,11 @@ func (r *cLikeRenderer) render(irs []*instruction) (*dll, error) {
return ok
})

irmap := map[instid]*instruction{}
for _, ir := range irs {
irmap[ir.id] = ir
}

/*
* start rendering
*/
Expand Down Expand Up @@ -268,12 +296,6 @@ func (r *cLikeRenderer) render(irs []*instruction) (*dll, error) {

write("")

// render main kernel

entry := r.lang.entrypoint()
write(r.lang.kernel(entry))
r.depth++

toidx := func(id instid) *idx {
if !id.valid() {
return nil
Expand All @@ -282,10 +304,53 @@ func (r *cLikeRenderer) render(irs []*instruction) (*dll, error) {
return &idx{val: varname(id)}
}

var entry string

for _, ir := range irs {
v := varname(ir.id)

switch mn := ir.mnemonic.(type) {
case *mnEntry:
entry = r.lang.entrypointName()
write(r.lang.entrypoint(entry))
r.depth++

case *mnEndEntry:
r.depth--
write(r.lang.endEntrypoint())
write("")

case *mnKernParam:
// do nothing

case *mnKernel:
kernname := r.lang.kernelName(int(ir.id))

// extrace kernel parameter
params := make([]string, len(mn.params))
typs := make([]typ, len(mn.params))
for i, paramID := range mn.params {
param := irmap[paramID]
params[i] = varname(param.id)
typs[i] = param.mnemonic.(*mnKernParam).typ // assume param type is *mnKernParam
}

write(r.lang.kernel(kernname, params, typs))
r.depth++

case *mnEndKernel:
r.depth--
write(r.lang.endKernel())
write("")

case *mnInvokeKernel:
kernname := r.lang.kernelName(int(mn.kernel))
args := make([]string, len(mn.args))
for i, arg := range mn.args {
args[i] = varname(arg)
}
write(r.lang.invokeKernel(kernname, args))

case *mnReturn:
write(r.lang.return_(varname(mn.val)))

Expand All @@ -308,6 +373,7 @@ func (r *cLikeRenderer) render(irs []*instruction) (*dll, error) {
case *mnEndLoop:
r.depth--
write(r.lang.endloop())
write("")

case *mnALU1:
write(r.lang.alu1(v, varname(mn.val), mn.op))
Expand All @@ -320,8 +386,6 @@ func (r *cLikeRenderer) render(irs []*instruction) (*dll, error) {
}
}

r.depth--
write(r.lang.endkernel())
write("")

write(r.lang.footer())
Expand Down
28 changes: 24 additions & 4 deletions tensor2/backend_golang.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,35 @@ func (r *gorenderer) varname(id int) string {
return fmt.Sprintf("D%v", id)
}

func (r *gorenderer) entrypoint() string {
return "F" // must be exported
func (r *gorenderer) kernelName(id int) string {
return fmt.Sprintf("kern%d", id)
}

func (r *gorenderer) kernel(entry string) string {
func (r *gorenderer) kernel(kernname string, _params []string, typs []typ) string {
params := make([]string, len(_params))
for i, p := range _params {
params[i] = fmt.Sprintf("%v %v", p, r.encodeType(typs[i]))
}
return fmt.Sprintf("func %v(%v) {", kernname, strings.Join(params, ", "))
}

func (r *gorenderer) endKernel() string {
return "}"
}

func (r *gorenderer) invokeKernel(kernname string, args []string) string {
return fmt.Sprintf("%v(%v)", kernname, strings.Join(args, ", "))
}

func (r *gorenderer) entrypointName() string {
return "Entry" // must be exported
}

func (r *gorenderer) entrypoint(entry string) string {
return fmt.Sprintf("func %v() []float32 {", entry)
}

func (r *gorenderer) endkernel() string {
func (r *gorenderer) endEntrypoint() string {
return "}"
}

Expand Down
Loading

0 comments on commit 4759b35

Please sign in to comment.