Skip to content

Commit

Permalink
Kitchen sink
Browse files Browse the repository at this point in the history
  • Loading branch information
myshkin5 committed Feb 29, 2024
1 parent 91815f4 commit 68c7d68
Show file tree
Hide file tree
Showing 24 changed files with 21,866 additions and 7,913 deletions.
13 changes: 13 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ linters:
# Usually disabled but useful for checking everything has godoc
- golint

run:
skip-dirs:
- ^testit.*

linters-settings:
gci:
sections:
Expand Down Expand Up @@ -70,6 +74,15 @@ issues:
- revive
- stylecheck
- unused
- linters:
- unused
path: generator/testmoqs/fnadaptors_test.go
- linters:
- unused
path: generator/testmoqs/usualadaptors_test.go
- linters:
- inamedparam
path: .*/testmoqs/.*
include:
# disable excluding of issues about comments from golint.
- EXC0002
194 changes: 170 additions & 24 deletions ast/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package ast
import (
"errors"
"fmt"
"go/token"
"io/fs"
"os"
"path/filepath"
Expand All @@ -23,9 +24,11 @@ import (
const (
builtinPkg = "builtin"

genTypeSuffix = "_genType"
starGenTypeSuffix = "_starGenType"
testPkgSuffix = "_test"
genTypeSuffix = "_genType"
starGenTypeSuffix = "_starGenType"
indexGenTypeSuffix = "_indexGenType"
indexListGenTypeSuffix = "_indexListGenType"
testPkgSuffix = "_test"
)

//go:generate moqueries LoadFn
Expand Down Expand Up @@ -197,7 +200,7 @@ func (c *Cache) Type(id dst.Ident, contextPkg string, testImport bool) (TypeInfo
// IsComparable determines if an expression is comparable. The optional
// parentType can be used to supply type parameters.
func (c *Cache) IsComparable(expr dst.Expr, parentType TypeInfo) (bool, error) {
return c.isDefaultComparable(expr, &parentType, true)
return c.isDefaultComparable(expr, &parentType, true, false)
}

// IsDefaultComparable determines if an expression is comparable. Returns the
Expand All @@ -206,7 +209,7 @@ func (c *Cache) IsComparable(expr dst.Expr, parentType TypeInfo) (bool, error) {
// map key will panic at runtime and by default pointers use a deep hash to be
// comparable).
func (c *Cache) IsDefaultComparable(expr dst.Expr, parentType TypeInfo) (bool, error) {
return c.isDefaultComparable(expr, &parentType, false)
return c.isDefaultComparable(expr, &parentType, false, false)
}

// FindPackage finds the package for a given directory
Expand Down Expand Up @@ -368,32 +371,75 @@ func (c *Cache) isDefaultComparable(
expr dst.Expr,
parentType *TypeInfo,
interfacePointerDefault bool,
genericType bool,
) (bool, error) {
subInterfaceDefault := interfacePointerDefault
if genericType {
subInterfaceDefault = false
}
switch e := expr.(type) {
case *dst.ArrayType:
if e.Len == nil {
return false, nil
}
return c.isDefaultComparable(e.Elt, parentType, interfacePointerDefault)

return c.isDefaultComparable(e.Elt, parentType, interfacePointerDefault, genericType)
case *dst.BinaryExpr:
comp, err := c.isDefaultComparable(e.X, parentType, interfacePointerDefault, genericType)
if err != nil || !comp {
return comp, err
}

return c.isDefaultComparable(e.Y, parentType, interfacePointerDefault, genericType)
case *dst.Ellipsis:
return false, nil
case *dst.FuncType:
return false, nil
case *dst.InterfaceType:
return interfacePointerDefault, nil
case *dst.Ident:
if e.Obj != nil {
typ, ok := e.Obj.Decl.(*dst.TypeSpec)
if !ok {
return false, fmt.Errorf("%q: %w", e.String(), ErrInvalidType)
if e.Methods == nil || len(e.Methods.List) == 0 {
// Basically an "any" interface
return subInterfaceDefault, nil
}
hasTypeConstraints := false
for _, m := range e.Methods.List {
if _, ok := m.Type.(*dst.FuncType); ok {
// Skip methods as they don't change whether something is
// comparable
continue
}

if typ.Name.Name == "string" && typ.Name.Path == "" {
return true, nil
hasTypeConstraints = true

comp, err := c.isDefaultComparable(m.Type, parentType, subInterfaceDefault, genericType)
if err != nil || !comp {
return comp, err
}
}

return c.isDefaultComparable(typ.Type, parentType, interfacePointerDefault)
if hasTypeConstraints {
// If an interface has type constraints and none of them were not
// comparable (none were because we would have returned early
// above), then it is always comparable
return true, nil
}

return subInterfaceDefault, nil
case *dst.Ident:
// if e.Obj != nil {
// var tExpr dst.Expr
// switch typ := e.Obj.Decl.(type) {
// case *dst.TypeSpec:
// tExpr = typ.Type
// case *dst.Field:
// tExpr = typ.Type
// default:
// return false, fmt.Errorf("identity expression %q: %w", e.String(), ErrInvalidType)
// }
//
// return c.isDefaultComparable(tExpr, parentType, "", interfacePointerDefault, false)
// }
// TODO: Generic type parameters should trump types in the cache (call
// findGenericType first)
pkgPath := e.Path
typ, ok := c.typesByIdent[e.String()]
if !ok && e.Path == "" && parentType != nil {
Expand All @@ -407,15 +453,27 @@ func (c *Cache) isDefaultComparable(
Exported: isExported(e.Name, pkgPath),
Fabricated: false,
}
return c.isDefaultComparable(typ.typ.Type, tInfo, interfacePointerDefault)
return c.isDefaultComparable(
typ.typ.Type, tInfo, interfacePointerDefault, genericType)
}

// Builtin type?
if e.Path == "" {
// error is the one builtin type that may not be comparable (it's
// Builtin or generic type?
if e.Path == "" || (parentType != nil && parentType.Type != nil && e.Path == parentType.Type.Name.Path) {
// Precedence is given to a generic type
gType := c.findGenericType(parentType, e.Name)
if gType != nil {
return c.isDefaultComparable(gType, parentType, interfacePointerDefault, true)
}

// error is a builtin type that may not be comparable (it's
// an interface so return the same result as an interface)
if e.Name == "error" {
return interfacePointerDefault, nil
return subInterfaceDefault, nil
}

// any is an alias for interface{}, so again the default
if e.Name == "any" {
return subInterfaceDefault, nil
}

return true, nil
Expand All @@ -434,7 +492,7 @@ func (c *Cache) isDefaultComparable(
Exported: isExported(e.Name, e.Path),
Fabricated: false,
}
return c.isDefaultComparable(typ.typ.Type, tInfo, interfacePointerDefault)
return c.isDefaultComparable(typ.typ.Type, tInfo, interfacePointerDefault, genericType)
}

return true, nil
Expand All @@ -443,7 +501,7 @@ func (c *Cache) isDefaultComparable(
case *dst.SelectorExpr:
ex, ok := e.X.(*dst.Ident)
if !ok {
return false, fmt.Errorf("%q: %w", e.X, ErrInvalidType)
return false, fmt.Errorf("selector expression %q: %w", e.X, ErrInvalidType)
}
path := ex.Name
_, err := c.loadPackage(path, false)
Expand All @@ -453,7 +511,7 @@ func (c *Cache) isDefaultComparable(

typ, ok := c.typesByIdent[IdPath(e.Sel.Name, path).String()]
if ok {
return c.isDefaultComparable(typ.typ.Type, parentType, interfacePointerDefault)
return c.isDefaultComparable(typ.typ.Type, nil, interfacePointerDefault, genericType)
}

// Builtin type?
Expand All @@ -462,16 +520,96 @@ func (c *Cache) isDefaultComparable(
return interfacePointerDefault, nil
case *dst.StructType:
for _, f := range e.Fields.List {
comp, err := c.isDefaultComparable(f.Type, parentType, interfacePointerDefault)
comp, err := c.isDefaultComparable(f.Type, parentType, interfacePointerDefault, genericType)
if err != nil || !comp {
return false, err
}
}
case *dst.UnaryExpr:
if e.Op != token.TILDE {
return false, fmt.Errorf(
"unexpected unary operator %s: %w", e.Op.String(), ErrInvalidType)
}
// This is a type constraint and for determining comparability, we
// don't care if the constraint is for a type or underlying types
return c.isDefaultComparable(e.X, parentType, interfacePointerDefault, genericType)
}

return true, nil
}

func (c *Cache) findGenericType(parentType *TypeInfo, paramTypeName string) dst.Expr {
if parentType == nil || parentType.Type == nil || parentType.Type.TypeParams == nil {
return nil
}

for _, p := range parentType.Type.TypeParams.List {
for _, n := range p.Names {
if n.Name == paramTypeName {
return p.Type
}
}
}

return nil
}

// func (c *Cache) findMethodGenericType(fn *dst.FuncDecl, paramTypeName string) (dst.Expr, error) {
// // Only handle methods here. Functions and structs have their Obj's intact
// // and don't need to be looked up in another declaration
// for _, r := range fn.Recv.List {
// switch idxType := r.Type.(type) {
// case *dst.IndexListExpr:
// for n, iExpr := range idxType.Indices {
// xId, ok := idxType.X.(*dst.Ident)
// if !ok {
// return nil, fmt.Errorf(
// "expecting *dst.Ident in IndexListExpr.X: %w", ErrInvalidType)
// }
// gType, err := c.findIndexedGenericType(iExpr, paramTypeName, xId, n)
// if err != nil || gType != nil {
// return gType, err
// }
// }
// case *dst.IndexExpr:
// xId, ok := idxType.X.(*dst.Ident)
// if !ok {
// return nil, fmt.Errorf(
// "expecting *dst.Ident in IndexExpr.X: %w", ErrInvalidType)
// }
// return c.findIndexedGenericType(idxType.Index, paramTypeName, xId, 0)
// default:
// return nil, fmt.Errorf(
// "unexpected index type %#v: %w", idxType, ErrInvalidType)
// }
// }
//
// return nil, nil
// }

// func (c *Cache) findIndexedGenericType(
// iExpr dst.Expr, paramTypeName string, xId *dst.Ident, idx int,
// ) (dst.Expr, error) {
// if id, ok := iExpr.(*dst.Ident); ok && id.Name != paramTypeName {
// return nil, nil
// }
//
// if xId.Obj == nil {
// return nil, fmt.Errorf(
// "expecting Obj: %w", ErrInvalidType)
// }
// tSpec, ok := xId.Obj.Decl.(*dst.TypeSpec)
// if !ok {
// return nil, fmt.Errorf(
// "expecting *dst.TypeSpec: %w", ErrInvalidType)
// }
// if tSpec.TypeParams == nil || len(tSpec.TypeParams.List) <= idx {
// return nil, fmt.Errorf(
// "base type to method type param mismatch: %w", ErrInvalidType)
// }
// return tSpec.TypeParams.List[idx].Type, nil
// }

func (c *Cache) loadPackage(path string, testImport bool) (string, error) {
indexPath := path
if strings.HasPrefix(path, ".") {
Expand Down Expand Up @@ -706,6 +844,14 @@ func (c *Cache) storeFuncDecl(decl *dst.FuncDecl, pkg *pkgInfo) {
suffix = starGenTypeSuffix
expr = sExpr.X
}
if iExpr, ok := expr.(*dst.IndexExpr); ok {
suffix = indexGenTypeSuffix
expr = iExpr.X
}
if ilExpr, ok := expr.(*dst.IndexListExpr); ok {
suffix = indexListGenTypeSuffix
expr = ilExpr.X
}
exprId, ok := expr.(*dst.Ident)
if !ok {
logs.Panicf("%s has a non-Ident (or StarExpr/Ident) receiver: %#v",
Expand Down
Loading

0 comments on commit 68c7d68

Please sign in to comment.