Skip to content

Commit

Permalink
Simplify some code
Browse files Browse the repository at this point in the history
  • Loading branch information
upamanyus committed Jul 10, 2024
1 parent 1a57515 commit fd5f78f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 195 deletions.
154 changes: 30 additions & 124 deletions goose.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ func (ctx Ctx) field(f *ast.Field) glang.FieldDecl {
}
return glang.FieldDecl{
Name: f.Names[0].Name,
Type: ctx.coqType(f.Type),
Type: ctx.glangTypeFromExpr(f.Type),
}
}

func (ctx Ctx) paramList(fs *ast.FieldList) []glang.FieldDecl {
var decls []glang.FieldDecl
for _, f := range fs.List {
ty := ctx.coqType(f.Type)
ty := ctx.glangTypeFromExpr(f.Type)
for _, name := range f.Names {
decls = append(decls, glang.FieldDecl{
Name: name.Name,
Expand Down Expand Up @@ -199,7 +199,7 @@ func (ctx Ctx) structFields(fs *ast.FieldList) []glang.FieldDecl {
ctx.unsupported(f, "unnamed (embedded) field")
return nil
}
ty := ctx.coqType(f.Type)
ty := ctx.glangTypeFromExpr(f.Type)
decls = append(decls, glang.FieldDecl{
Name: f.Names[0].Name,
Type: ty,
Expand Down Expand Up @@ -247,7 +247,7 @@ func (ctx Ctx) typeDecl(spec *ast.TypeSpec) glang.Decl {
default:
return glang.TypeDecl{
Name: spec.Name.Name,
Body: ctx.coqType(spec.Type),
Body: ctx.glangType(spec.Type, ctx.typeOf(spec.Type)),
}
}
}
Expand Down Expand Up @@ -344,7 +344,7 @@ func (ctx Ctx) makeSliceExpr(elt glang.Type, args []ast.Expr) glang.CallExpr {
func (ctx Ctx) makeExpr(args []ast.Expr) glang.Expr {
switch ty := ctx.typeOf(args[0]).Underlying().(type) {
case *types.Slice:
elt := ctx.coqTypeOfType(args[0], ty.Elem())
elt := ctx.glangType(args[0], ty.Elem())
if len(args) == 2 {
return glang.NewCallExpr(glang.GallinaIdent("slice.make2"), elt, ctx.expr(args[1]))
} else if len(args) == 3 {
Expand All @@ -355,8 +355,8 @@ func (ctx Ctx) makeExpr(args []ast.Expr) glang.Expr {
}
case *types.Map:
return glang.NewCallExpr(glang.GallinaIdent("map.make"),
ctx.coqTypeOfType(args[0], ty.Key()),
ctx.coqTypeOfType(args[0], ty.Elem()),
ctx.glangType(args[0], ty.Key()),
ctx.glangType(args[0], ty.Elem()),
glang.UnitLiteral{})
default:
ctx.unsupported(args[0],
Expand All @@ -368,8 +368,8 @@ func (ctx Ctx) makeExpr(args []ast.Expr) glang.Expr {
// newExpr parses a call to new() into an appropriate allocation
func (ctx Ctx) newExpr(ty ast.Expr) glang.Expr {
return glang.RefExpr{
X: glang.NewCallExpr(glang.GallinaIdent("zero_val"), ctx.coqType(ty)),
Ty: ctx.coqType(ty),
X: glang.NewCallExpr(glang.GallinaIdent("zero_val"), ctx.glangTypeFromExpr(ty)),
Ty: ctx.glangTypeFromExpr(ty),
}
}

Expand All @@ -396,7 +396,7 @@ func (ctx Ctx) integerConversion(s ast.Node, x ast.Expr, width int) glang.Expr {
func (ctx Ctx) copyExpr(n ast.Node, dst ast.Expr, src ast.Expr) glang.Expr {
e := sliceElem(ctx.typeOf(dst))
return glang.NewCallExpr(glang.GallinaIdent("slice.copy"),
ctx.coqTypeOfType(n, e),
ctx.glangType(n, e),
ctx.expr(dst), ctx.expr(src))
}

Expand Down Expand Up @@ -460,15 +460,15 @@ func (ctx Ctx) builtinCallExpr(s *ast.CallExpr) glang.Expr {
}
xExpr = glang.NewCallExpr(glang.GallinaIdent("slice.literal"),
// FIXME: get the type of the vararg
ctx.coqTypeOfType(s.Args[1], ctx.typeOf(s.Args[1])),
ctx.glangType(s.Args[1], ctx.typeOf(s.Args[1])),
glang.ListExpr(exprs))
}
} else {
// append(s1, s2...)
xExpr = ctx.expr(s.Args[1])
}
return glang.NewCallExpr(glang.GallinaIdent("slice.append"),
ctx.coqTypeOfType(s, elemTy),
ctx.glangType(s, elemTy),
ctx.expr(s.Args[0]),
xExpr,
)
Expand Down Expand Up @@ -543,7 +543,7 @@ func (ctx Ctx) selectorExpr(e *ast.SelectorExpr) glang.Expr {
if isField {
return glang.DerefExpr{
X: ctx.exprAddr(e),
Ty: ctx.coqTypeOfType(e, ctx.typeOf(e)),
Ty: ctx.glangType(e, ctx.typeOf(e)),
}
}
}
Expand All @@ -560,7 +560,7 @@ func (ctx Ctx) compositeLiteral(e *ast.CompositeLit) glang.Expr {
args = append(args, ctx.expr(e))
}
return glang.NewCallExpr(glang.GallinaIdent("slice.literal"),
ctx.coqTypeOfType(e, t.Elem()),
ctx.glangType(e, t.Elem()),
args)
}
info, ok := ctx.getStructInfo(ctx.typeOf(e))
Expand Down Expand Up @@ -730,7 +730,7 @@ func (ctx Ctx) sliceExpr(e *ast.SliceExpr) glang.Expr {
Names: []string{"$s"},
ValExpr: x,
Cont: glang.NewCallExpr(glang.GallinaIdent("slice.slice"),
ctx.coqTypeOfType(e, sliceElem(ctx.typeOf(e.X))),
ctx.glangType(e, sliceElem(ctx.typeOf(e.X))),
glang.IdentExpr("$s"), lowExpr, highExpr),
}
}
Expand Down Expand Up @@ -766,7 +766,7 @@ func (ctx Ctx) unaryExpr(e *ast.UnaryExpr) glang.Expr {
// e is &a[b] where x is a.b
if xTy, ok := ctx.typeOf(x.X).(*types.Slice); ok {
return glang.NewCallExpr(glang.GallinaIdent("SliceRef"),
ctx.coqTypeOfType(e, xTy.Elem()),
ctx.glangType(e, xTy.Elem()),
ctx.expr(x.X), ctx.expr(x.Index))
}
}
Expand All @@ -777,7 +777,7 @@ func (ctx Ctx) unaryExpr(e *ast.UnaryExpr) glang.Expr {
sl := ctx.structLiteral(info, structLit)
return glang.RefExpr{
X: sl,
Ty: ctx.coqTypeOfType(e.X, ctx.typeOf(e.X)),
Ty: ctx.glangType(e.X, ctx.typeOf(e.X)),
}
}
}
Expand All @@ -793,7 +793,7 @@ func (ctx Ctx) variable(s *ast.Ident) glang.Expr {
ctx.dep.addDep(s.Name)
return glang.GallinaIdent(s.Name)
}
return glang.DerefExpr{X: glang.IdentExpr(s.Name), Ty: ctx.coqTypeOfType(s, ctx.typeOf(s))}
return glang.DerefExpr{X: glang.IdentExpr(s.Name), Ty: ctx.glangType(s, ctx.typeOf(s))}
}

func (ctx Ctx) function(s *ast.Ident) glang.Expr {
Expand Down Expand Up @@ -856,7 +856,7 @@ func (ctx Ctx) indexExpr(e *ast.IndexExpr, isSpecial bool) glang.Expr {
case *types.Slice:
return glang.DerefExpr{
X: ctx.exprAddr(e),
Ty: ctx.coqTypeOfType(e, ctx.typeOf(e)),
Ty: ctx.glangType(e, ctx.typeOf(e)),
}
case *types.Signature:
ctx.unsupported(e, "generic function %v", xTy)
Expand All @@ -868,7 +868,7 @@ func (ctx Ctx) indexExpr(e *ast.IndexExpr, isSpecial bool) glang.Expr {
func (ctx Ctx) derefExpr(e ast.Expr) glang.Expr {
return glang.DerefExpr{
X: ctx.expr(e),
Ty: ctx.coqTypeOfType(e, ptrElem(ctx.typeOf(e))),
Ty: ctx.glangType(e, ptrElem(ctx.typeOf(e))),
}
}

Expand All @@ -889,8 +889,6 @@ func (ctx Ctx) exprSpecial(e ast.Expr, isSpecial bool) glang.Expr {
switch e := e.(type) {
case *ast.CallExpr:
return ctx.callExpr(e)
case *ast.MapType:
return ctx.mapType(e)
case *ast.Ident:
return ctx.identExpr(e)
case *ast.SelectorExpr:
Expand Down Expand Up @@ -1051,7 +1049,7 @@ func (ctx Ctx) sliceRangeStmt(s *ast.RangeStmt) glang.Expr {
Key: ctx.identBinder(key),
Val: valExpr,
Slice: glang.IdentExpr("$range"),
Ty: ctx.coqTypeOfType(s.X, sliceElem(ctx.typeOf(s.X))),
Ty: ctx.glangType(s.X, sliceElem(ctx.typeOf(s.X))),
Body: ctx.blockStmt(s.Body),
}
return glang.LetExpr{
Expand All @@ -1078,7 +1076,7 @@ func (ctx Ctx) rangeStmt(s *ast.RangeStmt) glang.Expr {
func (ctx Ctx) referenceTo(rhs ast.Expr) glang.Expr {
return glang.RefExpr{
X: ctx.expr(rhs),
Ty: ctx.coqTypeOfType(rhs, ctx.typeOf(rhs)),
Ty: ctx.glangType(rhs, ctx.typeOf(rhs)),
}
}

Expand All @@ -1089,7 +1087,7 @@ func (ctx Ctx) defineStmt(s *ast.AssignStmt, cont glang.Expr) glang.Expr {
for _, lhsExpr := range s.Lhs {
if ident, ok := lhsExpr.(*ast.Ident); ok {
if _, ok := ctx.info.Defs[ident]; ok { // if this identifier is defining something
t := ctx.coqTypeOfType(ident, ctx.info.TypeOf(ident))
t := ctx.glangType(ident, ctx.info.TypeOf(ident))
e = glang.LetExpr{
Names: []string{ident.Name},
ValExpr: glang.RefExpr{
Expand All @@ -1114,7 +1112,7 @@ func (ctx Ctx) varSpec(s *ast.ValueSpec, cont glang.Expr) glang.Expr {
lhs := s.Names[0]
var rhs glang.Expr
if len(s.Values) == 0 {
ty := ctx.coqType(lhs)
ty := ctx.glangTypeFromExpr(lhs)
rhs = glang.NewCallExpr(glang.GallinaIdent("ref_ty"), ty,
glang.NewCallExpr(glang.GallinaIdent("zero_val"), ty))
} else {
Expand Down Expand Up @@ -1159,7 +1157,7 @@ func (ctx Ctx) exprAddr(e ast.Expr) glang.Expr {
switch targetTy := targetTy.(type) {
case *types.Slice:
return glang.NewCallExpr(glang.GallinaIdent("slice.elem_ref"),
ctx.coqTypeOfType(e, targetTy.Elem()),
ctx.glangType(e, targetTy.Elem()),
ctx.expr(e.X),
ctx.expr(e.Index))
case *types.Map:
Expand Down Expand Up @@ -1215,7 +1213,7 @@ func (ctx Ctx) assignFromTo(s ast.Node, lhs ast.Expr, rhs glang.Expr, cont glang
return glang.NewDoSeq(glang.StoreStmt{
Dst: ctx.exprAddr(lhs),
X: rhs,
Ty: ctx.coqTypeOfType(lhs, ctx.typeOf(lhs)),
Ty: ctx.glangType(lhs, ctx.typeOf(lhs)),
}, cont)
}

Expand Down Expand Up @@ -1405,7 +1403,7 @@ func (ctx Ctx) returnType(results *ast.FieldList) glang.Type {
ctx.unsupported(r, "named returned value")
return glang.TypeIdent("<invalid>")
}
ts = append(ts, ctx.coqType(r.Type))
ts = append(ts, ctx.glangTypeFromExpr(r.Type))
}
return glang.NewTupleType(ts)
}
Expand Down Expand Up @@ -1471,9 +1469,9 @@ func (ctx Ctx) constSpec(spec *ast.ValueSpec) glang.ConstDecl {
val := spec.Values[0]
cd.Val = ctx.expr(val)
if spec.Type == nil {
cd.Type = ctx.coqTypeOfType(spec, ctx.typeOf(val))
cd.Type = ctx.glangType(spec, ctx.typeOf(val))
} else {
cd.Type = ctx.coqType(spec.Type)
cd.Type = ctx.glangTypeFromExpr(spec.Type)
}
cd.Val = ctx.expr(spec.Values[0])
return cd
Expand Down Expand Up @@ -1537,60 +1535,6 @@ func (ctx Ctx) imports(d []ast.Spec) []glang.Decl {
return decls
}

func (ctx Ctx) exprInterface(cvs []glang.Decl, expr ast.Expr, d *ast.FuncDecl) []glang.Decl {
switch f := expr.(type) {
case *ast.UnaryExpr:
if left, ok := f.X.(*ast.BinaryExpr); ok {
if call, ok := left.X.(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
}
case *ast.BinaryExpr:
if left, ok := f.X.(*ast.BinaryExpr); ok {
if call, ok := left.X.(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
}
if right, ok := f.Y.(*ast.BinaryExpr); ok {
if call, ok := right.X.(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
}
case *ast.CallExpr:
cvs = ctx.callExprInterface(cvs, f, d)
}
return cvs
}

func (ctx Ctx) stmtInterface(cvs []glang.Decl, stmt ast.Stmt, d *ast.FuncDecl) []glang.Decl {
switch f := stmt.(type) {
case *ast.ReturnStmt:
for _, result := range f.Results {
cvs = ctx.exprInterface(cvs, result, d)
}
if len(f.Results) > 0 {
if results, ok := f.Results[0].(*ast.BinaryExpr); ok {
if call, ok := results.X.(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
}
}
case *ast.IfStmt:
if call, ok := f.Cond.(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
case *ast.ExprStmt:
if call, ok := f.X.(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
case *ast.AssignStmt:
if call, ok := f.Rhs[0].(*ast.CallExpr); ok {
cvs = ctx.callExprInterface(cvs, call, d)
}
}
return cvs
}

// TODO: this is a hack, should have a better scheme for putting
// interface/implementation types into the conversion name
func unqualifyName(name string) string {
Expand All @@ -1599,49 +1543,11 @@ func unqualifyName(name string) string {
return components[len(components)-1]
}

func (ctx Ctx) callExprInterface(cvs []glang.Decl, r *ast.CallExpr, d *ast.FuncDecl) []glang.Decl {
interfaceName := ""
methods := []string{}
if signature, ok := ctx.typeOf(r.Fun).(*types.Signature); ok {
params := signature.Params()
for j := 0; j < params.Len(); j++ {
interfaceName = params.At(j).Type().String()
interfaceName = unqualifyName(interfaceName)
if v, ok := params.At(j).Type().Underlying().(*types.Interface); ok {
for m := 0; m < v.NumMethods(); m++ {
methods = append(methods, v.Method(m).Name())
}
}
}
for _, arg := range r.Args {
structName := ctx.typeOf(arg).String()
structName = unqualifyName(structName)
if _, ok := ctx.typeOf(arg).Underlying().(*types.Struct); ok {
cv := glang.StructToInterface{Struct: structName, Interface: interfaceName, Methods: methods}
if len(cv.Coq(true)) > 1 && len(cv.MethodList()) > 0 {
cvs = append(cvs, cv)
}
}
}
}
return cvs
}

func (ctx Ctx) maybeDecls(d ast.Decl) []glang.Decl {
switch d := d.(type) {
case *ast.FuncDecl:
cvs := []glang.Decl{}
for _, stmt := range d.Body.List {
cvs = ctx.stmtInterface(cvs, stmt, d)
}
fd := ctx.funcDecl(d)
results := []glang.Decl{}
if len(cvs) > 0 {
results = append(cvs, fd)
} else {
results = []glang.Decl{fd}
}
return results
return []glang.Decl{fd}
case *ast.GenDecl:
switch d.Tok {
case token.IMPORT:
Expand Down
Loading

0 comments on commit fd5f78f

Please sign in to comment.