diff --git a/README.md b/README.md index eba715b..a680c35 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,23 @@ go install go.uber.org/mock/mockgen@latest ## Running mockgen -`mockgen` has two modes of operation: source and reflect. +`mockgen` has three modes of operation: archive, source and reflect. + +### Archive mode + +Archive mode generates mock interfaces from a package archive +file (.a). It is enabled by using the -archive flag. An import +path and a comma-separated list of symbols should be provided +as a non-flag argument to the command. + +Example: + +```bash +# Build the package to a archive. +go build -o pkg.a database/sql/driver + +mockgen -archive=pkg.a database/sql/driver Conn,Driver +``` ### Source mode @@ -66,6 +82,8 @@ The `mockgen` command is used to generate source code for a mock class given a Go source file containing interfaces to be mocked. It supports the following flags: +- `-archive`: A package archive file containing interfaces to be mocked. + - `-source`: A file containing interfaces to be mocked. - `-destination`: A file to which to write the resulting source code. If you diff --git a/mockgen/archive.go b/mockgen/archive.go new file mode 100644 index 0000000..bc80ad1 --- /dev/null +++ b/mockgen/archive.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "go/token" + "go/types" + "os" + + "go.uber.org/mock/mockgen/model" + + "golang.org/x/tools/go/gcexportdata" +) + +func archiveMode(importPath string, symbols []string, archive string) (*model.Package, error) { + f, err := os.Open(archive) + if err != nil { + return nil, err + } + defer f.Close() + r, err := gcexportdata.NewReader(f) + if err != nil { + return nil, fmt.Errorf("read export data %q: %v", archive, err) + } + + fset := token.NewFileSet() + imports := make(map[string]*types.Package) + tp, err := gcexportdata.Read(r, fset, imports, importPath) + if err != nil { + return nil, err + } + + pkg := &model.Package{ + Name: tp.Name(), + PkgPath: tp.Path(), + Interfaces: make([]*model.Interface, 0, len(symbols)), + } + for _, name := range symbols { + m := tp.Scope().Lookup(name) + tn, ok := m.(*types.TypeName) + if !ok { + continue + } + ti, ok := tn.Type().Underlying().(*types.Interface) + if !ok { + continue + } + it, err := model.InterfaceFromGoTypesType(ti) + if err != nil { + return nil, err + } + it.Name = m.Name() + pkg.Interfaces = append(pkg.Interfaces, it) + } + return pkg, nil +} diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index feb747f..d17fea3 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -53,6 +53,7 @@ var ( ) var ( + archive = flag.String("archive", "", "(archive mode) Input Go archive file; enables archive mode.") source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") destination = flag.String("destination", "", "Output file; defaults to stdout.") mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") @@ -66,9 +67,8 @@ var ( imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") excludeInterfaces = flag.String("exclude_interfaces", "", "Comma-separated names of interfaces to be excluded") - - debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") - showVersion = flag.Bool("version", false, "Print version.") + debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") + showVersion = flag.Bool("version", false, "Print version.") ) func main() { @@ -83,15 +83,22 @@ func main() { var pkg *model.Package var err error var packageName string - if *source != "" { + + // Switch between modes + switch { + case *source != "": // source mode pkg, err = sourceMode(*source) - } else { - if flag.NArg() != 2 { - usage() - log.Fatal("Expected exactly two arguments") - } + case *archive != "": // archive mode + checkArgs() + packageName = flag.Arg(0) + interfaces := strings.Split(flag.Arg(1), ",") + pkg, err = archiveMode(packageName, interfaces, *archive) + + default: // reflect mode + checkArgs() packageName = flag.Arg(0) interfaces := strings.Split(flag.Arg(1), ",") + if packageName == "." { dir, err := os.Getwd() if err != nil { @@ -104,6 +111,7 @@ func main() { } pkg, err = reflectMode(packageName, interfaces) } + if err != nil { log.Fatalf("Loading input failed: %v", err) } @@ -144,6 +152,8 @@ func main() { g := new(generator) if *source != "" { g.filename = *source + } else if *archive != "" { + g.filename = *archive } else { g.srcPackage = packageName g.srcInterfaces = flag.Arg(1) @@ -219,12 +229,19 @@ func parseExcludeInterfaces(names string) map[string]struct{} { return namesSet } +func checkArgs() { + if flag.NArg() != 2 { + usage() + log.Fatal("Expected exactly two arguments") + } +} + func usage() { _, _ = io.WriteString(os.Stderr, usageText) flag.PrintDefaults() } -const usageText = `mockgen has two modes of operation: source and reflect. +const usageText = `mockgen has three modes of operation: archive, source and reflect. Source mode generates mock interfaces from a source file. It is enabled by using the -source flag. Other flags that @@ -239,6 +256,13 @@ comma-separated list of symbols. Example: mockgen database/sql/driver Conn,Driver +Archive mode generates mock interfaces from a package archive +file (.a). It is enabled by using the -archive flag and two +non-flag arguments: an import path, and a comma-separated +list of symbols. +Example: + mockgen -archive=pkg.a database/sql/driver Conn,Driver + ` type generator struct { diff --git a/mockgen/model/model_gotypes.go b/mockgen/model/model_gotypes.go new file mode 100644 index 0000000..4596c3d --- /dev/null +++ b/mockgen/model/model_gotypes.go @@ -0,0 +1,160 @@ +package model + +import ( + "fmt" + "go/types" +) + +// InterfaceFromGoTypesType returns a pointer to an interface for the +// given interface type loaded from archive. +func InterfaceFromGoTypesType(it *types.Interface) (*Interface, error) { + intf := &Interface{} + + for i := 0; i < it.NumMethods(); i++ { + mt := it.Method(i) + // Skip unexported methods. + if !mt.Exported() { + continue + } + m := &Method{ + Name: mt.Name(), + } + + var err error + m.In, m.Variadic, m.Out, err = funcArgsFromGoTypesType(mt.Type().(*types.Signature)) + if err != nil { + return nil, fmt.Errorf("method %q: %w", mt.Name(), err) + } + + intf.AddMethod(m) + } + + return intf, nil +} + +func funcArgsFromGoTypesType(t *types.Signature) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) { + nin := t.Params().Len() + if t.Variadic() { + nin-- + } + for i := 0; i < nin; i++ { + p, err := parameterFromGoTypesType(t.Params().At(i), false) + if err != nil { + return nil, nil, nil, err + } + in = append(in, p) + } + if t.Variadic() { + p, err := parameterFromGoTypesType(t.Params().At(nin), true) + if err != nil { + return nil, nil, nil, err + } + variadic = p + } + for i := 0; i < t.Results().Len(); i++ { + p, err := parameterFromGoTypesType(t.Results().At(i), false) + if err != nil { + return nil, nil, nil, err + } + out = append(out, p) + } + return +} + +func parameterFromGoTypesType(v *types.Var, variadic bool) (*Parameter, error) { + t := v.Type() + if variadic { + t = t.(*types.Slice).Elem() + } + tt, err := typeFromGoTypesType(t) + if err != nil { + return nil, err + } + return &Parameter{Name: v.Name(), Type: tt}, nil +} + +func typeFromGoTypesType(t types.Type) (Type, error) { + if t, ok := t.(*types.Named); ok { + tn := t.Obj() + if tn.Pkg() == nil { + return PredeclaredType(tn.Name()), nil + } + return &NamedType{ + Package: tn.Pkg().Path(), + Type: tn.Name(), + }, nil + } + + // only unnamed or predeclared types after here + + // Lots of types have element types. Let's do the parsing and error checking for all of them. + var elemType Type + if t, ok := t.(interface{ Elem() types.Type }); ok { + var err error + elemType, err = typeFromGoTypesType(t.Elem()) + if err != nil { + return nil, err + } + } + + switch t := t.(type) { + case *types.Array: + return &ArrayType{ + Len: int(t.Len()), + Type: elemType, + }, nil + case *types.Basic: + return PredeclaredType(t.String()), nil + case *types.Chan: + var dir ChanDir + switch t.Dir() { + case types.RecvOnly: + dir = RecvDir + case types.SendOnly: + dir = SendDir + } + return &ChanType{ + Dir: dir, + Type: elemType, + }, nil + case *types.Signature: + in, variadic, out, err := funcArgsFromGoTypesType(t) + if err != nil { + return nil, err + } + return &FuncType{ + In: in, + Out: out, + Variadic: variadic, + }, nil + case *types.Interface: + if t.NumMethods() == 0 { + return PredeclaredType("interface{}"), nil + } + case *types.Map: + kt, err := typeFromGoTypesType(t.Key()) + if err != nil { + return nil, err + } + return &MapType{ + Key: kt, + Value: elemType, + }, nil + case *types.Pointer: + return &PointerType{ + Type: elemType, + }, nil + case *types.Slice: + return &ArrayType{ + Len: -1, + Type: elemType, + }, nil + case *types.Struct: + if t.NumFields() == 0 { + return PredeclaredType("struct{}"), nil + } + // TODO: UnsafePointer + } + + return nil, fmt.Errorf("can't yet turn %v (%T) into a model.Type", t.String(), t) +}