diff --git a/arguments/parser.go b/arguments/parser.go index 25ba1271..042bfe70 100644 --- a/arguments/parser.go +++ b/arguments/parser.go @@ -1,9 +1,14 @@ package arguments import ( + "fmt" "path/filepath" "regexp" + "strconv" "strings" + + "github.com/maxbrunsfeld/counterfeiter/locator" + "github.com/maxbrunsfeld/counterfeiter/terminal" ) type ArgumentParser interface { @@ -15,32 +20,78 @@ func NewArgumentParser( currentWorkingDir CurrentWorkingDir, symlinkEvaler SymlinkEvaler, fileStatReader FileStatReader, + ui terminal.UI, + interfaceLocator locator.InterfaceLocator, ) ArgumentParser { - return argumentParser{ + return &argumentParser{ + ui: ui, failHandler: failHandler, currentWorkingDir: currentWorkingDir, symlinkEvaler: symlinkEvaler, fileStatReader: fileStatReader, + interfaceLocator: interfaceLocator, } } -func (argParser argumentParser) ParseArguments(args ...string) ParsedArguments { +func (argParser *argumentParser) ParseArguments(args ...string) ParsedArguments { sourcePackageDir := argParser.getSourceDir(args[0]) - fakeImplName := getFakeName(args[1], *fakeNameFlag) - outputPath := argParser.getOutputPath(sourcePackageDir, fakeImplName, *outputPathFlag) + + var interfaceName string + + if len(args) > 1 { + interfaceName = args[1] + } else { + interfaceName = argParser.PromptUserForInterfaceName(sourcePackageDir) + } + + fakeImplName := getFakeName(interfaceName, *fakeNameFlag) + + outputPath := argParser.getOutputPath( + sourcePackageDir, + fakeImplName, + *outputPathFlag, + ) return ParsedArguments{ SourcePackageDir: sourcePackageDir, OutputPath: outputPath, - InterfaceName: args[1], + InterfaceName: interfaceName, FakeImplName: fakeImplName, - PrintToStdOut: len(args) == 3 && args[2] == "-", + PrintToStdOut: any(args, "-"), } } +func (parser *argumentParser) PromptUserForInterfaceName(filepath string) string { + parser.ui.WriteLine("Which interface to counterfeit?") + + interfacesInPackage := parser.interfaceLocator.GetInterfacesFromFilePath(filepath) + + for i, interfaceName := range interfacesInPackage { + parser.ui.WriteLine(fmt.Sprintf("%d. %s", i+1, interfaceName)) + } + parser.ui.WriteLine("") + + response := parser.ui.ReadLineFromStdin() + parsedResponse, err := strconv.ParseInt(response, 10, 64) + if err != nil { + parser.failHandler("Unknown option '%s'", response) + return "" + } + + option := int(parsedResponse - 1) + if option < 0 || option >= len(interfacesInPackage) { + parser.failHandler("Unknown option '%s'", response) + return "" + } + + return interfacesInPackage[option] +} + type argumentParser struct { + ui terminal.UI + interfaceLocator locator.InterfaceLocator failHandler FailHandler currentWorkingDir CurrentWorkingDir symlinkEvaler SymlinkEvaler @@ -67,7 +118,7 @@ func getFakeName(interfaceName, arg string) string { var camelRegexp = regexp.MustCompile("([a-z])([A-Z])") -func (argParser argumentParser) getOutputPath(sourceDir, fakeName, arg string) string { +func (argParser *argumentParser) getOutputPath(sourceDir, fakeName, arg string) string { if arg == "" { snakeCaseName := strings.ToLower(camelRegexp.ReplaceAllString(fakeName, "${1}_${2}")) return filepath.Join(sourceDir, "fakes", snakeCaseName+".go") @@ -79,7 +130,7 @@ func (argParser argumentParser) getOutputPath(sourceDir, fakeName, arg string) s } } -func (argParser argumentParser) getSourceDir(arg string) string { +func (argParser *argumentParser) getSourceDir(arg string) string { if !filepath.IsAbs(arg) { arg = filepath.Join(argParser.currentWorkingDir(), arg) } @@ -97,4 +148,14 @@ func (argParser argumentParser) getSourceDir(arg string) string { } } +func any(slice []string, needle string) bool { + for _, str := range slice { + if str == needle { + return true + } + } + + return false +} + type FailHandler func(string, ...interface{}) diff --git a/arguments/parser_test.go b/arguments/parser_test.go index db0f0e0e..0ec0b39e 100644 --- a/arguments/parser_test.go +++ b/arguments/parser_test.go @@ -7,6 +7,9 @@ import ( "testing" "time" + locatorFakes "github.com/maxbrunsfeld/counterfeiter/locator/fakes" + terminalFakes "github.com/maxbrunsfeld/counterfeiter/terminal/fakes" + . "github.com/maxbrunsfeld/counterfeiter/arguments" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -22,17 +25,42 @@ var _ = Describe("parsing arguments", func() { var symlinkEvaler SymlinkEvaler var fileStatReader FileStatReader + var ui *terminalFakes.FakeUI + var interfaceLocator *locatorFakes.FakeInterfaceLocator + + var failWasCalled bool + // fake UI helper + + var fakeUIBuffer = func() string { + var output string + for i := 0; i < ui.WriteLineCallCount(); i++ { + output = output + ui.WriteLineArgsForCall(i) + } + return output + } + JustBeforeEach(func() { - subject = NewArgumentParser(fail, cwd, symlinkEvaler, fileStatReader) + subject = NewArgumentParser( + fail, + cwd, + symlinkEvaler, + fileStatReader, + ui, + interfaceLocator, + ) parsedArgs = subject.ParseArguments(args...) }) BeforeEach(func() { - fail = func(_ string, _ ...interface{}) {} + failWasCalled = false + fail = func(_ string, _ ...interface{}) { failWasCalled = true } cwd = func() string { return "/home/test-user/workspace" } + ui = new(terminalFakes.FakeUI) + interfaceLocator = new(locatorFakes.FakeInterfaceLocator) + symlinkEvaler = func(input string) (string, error) { return input, nil } @@ -41,6 +69,43 @@ var _ = Describe("parsing arguments", func() { } }) + Describe("when a single argument is provided", func() { + BeforeEach(func() { + args = []string{"some/path"} + + interfaceLocator.GetInterfacesFromFilePathReturns([]string{"Foo", "Bar"}) + ui.ReadLineFromStdinReturns("1") + }) + + It("prompts the user for which interface they want", func() { + Expect(fakeUIBuffer()).To(ContainSubstring("Which interface to counterfeit?")) + }) + + It("shows the user each interface found in the given filepath", func() { + Expect(fakeUIBuffer()).To(ContainSubstring("1. Foo")) + Expect(fakeUIBuffer()).To(ContainSubstring("2. Bar")) + }) + + It("asks its interface locator for valid interfaces", func() { + Expect(interfaceLocator.GetInterfacesFromFilePathCallCount()).To(Equal(1)) + Expect(interfaceLocator.GetInterfacesFromFilePathArgsForCall(0)).To(Equal("/home/test-user/workspace/some/path")) + }) + + It("yields the interface name the user chose", func() { + Expect(parsedArgs.InterfaceName).To(Equal("Foo")) + }) + + Describe("when the user types an invalid option", func() { + BeforeEach(func() { + ui.ReadLineFromStdinReturns("garbage") + }) + + It("invokes its fail handler", func() { + Expect(failWasCalled).To(BeTrue()) + }) + }) + }) + Describe("when two arguments are provided", func() { BeforeEach(func() { args = []string{"some/path", "MySpecialInterface"} @@ -86,10 +151,7 @@ var _ = Describe("parsing arguments", func() { }) Context("when the file stat cannot be read", func() { - var failWasCalled bool - BeforeEach(func() { - fail = func(_ string, _ ...interface{}) { failWasCalled = true } fileStatReader = func(_ string) (os.FileInfo, error) { return fakeFileInfo("", false), errors.New("submarine-shoutout") } diff --git a/fixtures/multiple_interfaces.go b/fixtures/multiple_interfaces.go new file mode 100644 index 00000000..87894963 --- /dev/null +++ b/fixtures/multiple_interfaces.go @@ -0,0 +1,11 @@ +package fixtures + +type FirstInterface interface { + DoThings() +} + +type SecondInterface interface { + EmbeddedMethod() string +} + +type unexportedInterface interface{} diff --git a/integration/integration_test.go b/integration/integration_test.go index 4c358150..c78e3d73 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -18,10 +18,28 @@ var _ = Describe("The counterfeiter CLI", func() { BeforeEach(func() { pathToCLI = tmpPath("counterfeiter") - copyIn("something.go", pathToCLI) + }) + + Describe("when given a single argument", func() { + It("interactively prompts the user for the interface they want to counterfeit", func() { + reader, writer := io.Pipe() + + copyIn("multiple_interfaces.go", pathToCLI) + session := startCounterfeiterWithStdinPipe(pathToCLI, reader, "multiple_interfaces.go") + + writer.Write([]byte("1\n")) + writer.Close() + + Eventually(session).Should(gexec.Exit(0)) + Expect(string(session.Out.Contents())).To(ContainSubstring("Wrote `FakeFirstInterface`")) + }) }) Describe("when given two arguments", func() { + BeforeEach(func() { + copyIn("something.go", pathToCLI) + }) + It("writes a fake for the given interface from the provided file", func() { session := startCounterfeiter(pathToCLI, "something.go", "Something") @@ -33,6 +51,10 @@ var _ = Describe("The counterfeiter CLI", func() { }) Describe("when provided three arguments", func() { + BeforeEach(func() { + copyIn("something.go", pathToCLI) + }) + It("writes the fake to stdout", func() { session := startCounterfeiter(pathToCLI, "something.go", "Something", "-") @@ -87,6 +109,21 @@ func startCounterfeiter(workingDir string, args ...string) *gexec.Session { return session } +func startCounterfeiterWithStdinPipe(workingDir string, stdin io.Reader, args ...string) *gexec.Session { + fakeGoPathDir := filepath.Dir(filepath.Dir(workingDir)) + absPath, _ := filepath.Abs(fakeGoPathDir) + absPathWithSymlinks, _ := filepath.EvalSymlinks(absPath) + + cmd := exec.Command(pathToCounterfeiter, args...) + cmd.Stdin = stdin + cmd.Dir = workingDir + cmd.Env = []string{"GOPATH=" + absPathWithSymlinks} + + session, err := gexec.Start(cmd, GinkgoWriter, GinkgoWriter) + Expect(err).ToNot(HaveOccurred()) + return session +} + // gexec setup var tmpDir string diff --git a/integration/suite_test.go b/integration/suite_test.go index a39e117d..4f64192f 100644 --- a/integration/suite_test.go +++ b/integration/suite_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -func TestCounterfeiterCLI(t *testing.T) { +func TestCounterfeiterCLIIntegration(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "Counterfeiter CLI Suite") + RunSpecs(t, "Counterfeiter CLI Integration Suite") } diff --git a/locator/fakes/fake_interface_locator.go b/locator/fakes/fake_interface_locator.go new file mode 100644 index 00000000..9f7381fd --- /dev/null +++ b/locator/fakes/fake_interface_locator.go @@ -0,0 +1,53 @@ +// This file was generated by counterfeiter +package fakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/locator" +) + +type FakeInterfaceLocator struct { + GetInterfacesFromFilePathStub func(string) []string + getInterfacesFromFilePathMutex sync.RWMutex + getInterfacesFromFilePathArgsForCall []struct { + arg1 string + } + getInterfacesFromFilePathReturns struct { + result1 []string + } +} + +func (fake *FakeInterfaceLocator) GetInterfacesFromFilePath(arg1 string) []string { + fake.getInterfacesFromFilePathMutex.Lock() + fake.getInterfacesFromFilePathArgsForCall = append(fake.getInterfacesFromFilePathArgsForCall, struct { + arg1 string + }{arg1}) + fake.getInterfacesFromFilePathMutex.Unlock() + if fake.GetInterfacesFromFilePathStub != nil { + return fake.GetInterfacesFromFilePathStub(arg1) + } else { + return fake.getInterfacesFromFilePathReturns.result1 + } +} + +func (fake *FakeInterfaceLocator) GetInterfacesFromFilePathCallCount() int { + fake.getInterfacesFromFilePathMutex.RLock() + defer fake.getInterfacesFromFilePathMutex.RUnlock() + return len(fake.getInterfacesFromFilePathArgsForCall) +} + +func (fake *FakeInterfaceLocator) GetInterfacesFromFilePathArgsForCall(i int) string { + fake.getInterfacesFromFilePathMutex.RLock() + defer fake.getInterfacesFromFilePathMutex.RUnlock() + return fake.getInterfacesFromFilePathArgsForCall[i].arg1 +} + +func (fake *FakeInterfaceLocator) GetInterfacesFromFilePathReturns(result1 []string) { + fake.GetInterfacesFromFilePathStub = nil + fake.getInterfacesFromFilePathReturns = struct { + result1 []string + }{result1} +} + +var _ locator.InterfaceLocator = new(FakeInterfaceLocator) diff --git a/locator/locator.go b/locator/locator.go index 05e63688..7c9cdbc5 100644 --- a/locator/locator.go +++ b/locator/locator.go @@ -10,10 +10,67 @@ import ( "path/filepath" "runtime" "strings" + "unicode" "github.com/maxbrunsfeld/counterfeiter/model" ) +type InterfaceLocator interface { + GetInterfacesFromFilePath(string) []string +} + +func NewInterfaceLocator() InterfaceLocator { + return interfaceLocator{} +} + +type interfaceLocator struct{} + +func (locator interfaceLocator) GetInterfacesFromFilePath(path string) []string { + dir, err := getDir(path) + if err != nil { + panic(err) + } + + importPath, err := importPathForDirPath(dir) + if err != nil { + panic(err) + } + + dirPath, err := dirPathForImportPath(importPath) + if err != nil { + panic(err) + } + + packages, err := packagesForDirPath(dirPath) + if err != nil { + panic(err) + } + + interfacesInPackage := []string{} + for _, pkg := range packages { + + for _, f := range pkg.Files { + ast.Inspect(f, func(node ast.Node) bool { + if typeSpec, ok := node.(*ast.TypeSpec); ok { + if _, ok := typeSpec.Type.(*ast.InterfaceType); ok { + firstRune := rune(typeSpec.Name.Name[0]) + + if !unicode.IsUpper(firstRune) { + return true + } + + interfacesInPackage = append(interfacesInPackage, typeSpec.Name.Name) + } + } + + return true + }) + } + } + + return interfacesInPackage +} + func GetInterfaceFromFilePath(interfaceName, filePath string) (*model.InterfaceToFake, error) { dirPath, err := getDir(filePath) if err != nil { diff --git a/main.go b/main.go index 2fd61a59..f83fa1dc 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "github.com/maxbrunsfeld/counterfeiter/arguments" "github.com/maxbrunsfeld/counterfeiter/generator" "github.com/maxbrunsfeld/counterfeiter/locator" + "github.com/maxbrunsfeld/counterfeiter/terminal" ) var usage = ` @@ -43,7 +44,7 @@ func main() { flag.Parse() args := flag.Args() - if len(args) < 2 { + if len(args) < 1 { fail("%s", usage) return } @@ -53,6 +54,8 @@ func main() { cwd, filepath.EvalSymlinks, os.Stat, + terminal.NewUI(), + locator.NewInterfaceLocator(), ) parsedArgs := argumentParser.ParseArguments(args...) diff --git a/scripts/test.sh b/scripts/test.sh index 77b3cea5..33777e99 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -8,18 +8,19 @@ ln -fs $(pwd)/fixtures /tmp/symlinked_fixtures go build -o $counterfeiter -$counterfeiter fixtures Something -$counterfeiter fixtures HasVarArgs -$counterfeiter fixtures HasVarArgsWithLocalTypes -$counterfeiter fixtures HasImports -$counterfeiter fixtures HasOtherTypes -$counterfeiter fixtures ReusesArgTypes -$counterfeiter fixtures EmbedsInterfaces -$counterfeiter fixtures/aliased_package InAliasedPackage -$counterfeiter /tmp/symlinked_fixtures Something +$counterfeiter fixtures Something >/dev/null +$counterfeiter fixtures HasVarArgs >/dev/null +$counterfeiter fixtures HasVarArgsWithLocalTypes >/dev/null +$counterfeiter fixtures HasImports >/dev/null +$counterfeiter fixtures HasOtherTypes >/dev/null +$counterfeiter fixtures ReusesArgTypes >/dev/null +$counterfeiter fixtures EmbedsInterfaces >/dev/null +$counterfeiter fixtures/aliased_package InAliasedPackage >/dev/null +$counterfeiter /tmp/symlinked_fixtures Something >/dev/null + go build ./fixtures/... -go test -race -v . +go test -race -v ./... rm /tmp/symlinked_fixtures diff --git a/terminal/fakes/fake_ui.go b/terminal/fakes/fake_ui.go new file mode 100644 index 00000000..96586769 --- /dev/null +++ b/terminal/fakes/fake_ui.go @@ -0,0 +1,129 @@ +// This file was generated by counterfeiter +package fakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/terminal" +) + +type FakeUI struct { + TerminalIsTTYStub func() bool + terminalIsTTYMutex sync.RWMutex + terminalIsTTYArgsForCall []struct{} + terminalIsTTYReturns struct { + result1 bool + } + ReadLineFromStdinStub func() string + readLineFromStdinMutex sync.RWMutex + readLineFromStdinArgsForCall []struct{} + readLineFromStdinReturns struct { + result1 string + } + WriteStub func(string) + writeMutex sync.RWMutex + writeArgsForCall []struct { + arg1 string + } + WriteLineStub func(string) + writeLineMutex sync.RWMutex + writeLineArgsForCall []struct { + arg1 string + } +} + +func (fake *FakeUI) TerminalIsTTY() bool { + fake.terminalIsTTYMutex.Lock() + fake.terminalIsTTYArgsForCall = append(fake.terminalIsTTYArgsForCall, struct{}{}) + fake.terminalIsTTYMutex.Unlock() + if fake.TerminalIsTTYStub != nil { + return fake.TerminalIsTTYStub() + } else { + return fake.terminalIsTTYReturns.result1 + } +} + +func (fake *FakeUI) TerminalIsTTYCallCount() int { + fake.terminalIsTTYMutex.RLock() + defer fake.terminalIsTTYMutex.RUnlock() + return len(fake.terminalIsTTYArgsForCall) +} + +func (fake *FakeUI) TerminalIsTTYReturns(result1 bool) { + fake.TerminalIsTTYStub = nil + fake.terminalIsTTYReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeUI) ReadLineFromStdin() string { + fake.readLineFromStdinMutex.Lock() + fake.readLineFromStdinArgsForCall = append(fake.readLineFromStdinArgsForCall, struct{}{}) + fake.readLineFromStdinMutex.Unlock() + if fake.ReadLineFromStdinStub != nil { + return fake.ReadLineFromStdinStub() + } else { + return fake.readLineFromStdinReturns.result1 + } +} + +func (fake *FakeUI) ReadLineFromStdinCallCount() int { + fake.readLineFromStdinMutex.RLock() + defer fake.readLineFromStdinMutex.RUnlock() + return len(fake.readLineFromStdinArgsForCall) +} + +func (fake *FakeUI) ReadLineFromStdinReturns(result1 string) { + fake.ReadLineFromStdinStub = nil + fake.readLineFromStdinReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeUI) Write(arg1 string) { + fake.writeMutex.Lock() + fake.writeArgsForCall = append(fake.writeArgsForCall, struct { + arg1 string + }{arg1}) + fake.writeMutex.Unlock() + if fake.WriteStub != nil { + fake.WriteStub(arg1) + } +} + +func (fake *FakeUI) WriteCallCount() int { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return len(fake.writeArgsForCall) +} + +func (fake *FakeUI) WriteArgsForCall(i int) string { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return fake.writeArgsForCall[i].arg1 +} + +func (fake *FakeUI) WriteLine(arg1 string) { + fake.writeLineMutex.Lock() + fake.writeLineArgsForCall = append(fake.writeLineArgsForCall, struct { + arg1 string + }{arg1}) + fake.writeLineMutex.Unlock() + if fake.WriteLineStub != nil { + fake.WriteLineStub(arg1) + } +} + +func (fake *FakeUI) WriteLineCallCount() int { + fake.writeLineMutex.RLock() + defer fake.writeLineMutex.RUnlock() + return len(fake.writeLineArgsForCall) +} + +func (fake *FakeUI) WriteLineArgsForCall(i int) string { + fake.writeLineMutex.RLock() + defer fake.writeLineMutex.RUnlock() + return fake.writeLineArgsForCall[i].arg1 +} + +var _ terminal.UI = new(FakeUI) diff --git a/terminal/ui.go b/terminal/ui.go new file mode 100644 index 00000000..edec9de4 --- /dev/null +++ b/terminal/ui.go @@ -0,0 +1,48 @@ +package terminal + +import ( + "bufio" + "os" + + "golang.org/x/crypto/ssh/terminal" +) + +type UI interface { + TerminalIsTTY() bool + ReadLineFromStdin() string + + Write(string) + WriteLine(string) +} + +func NewUI() UI { + return &ui{} +} + +type ui struct{} + +func (ui *ui) TerminalIsTTY() bool { + return terminal.IsTerminal(int(os.Stdin.Fd())) +} + +func (ui *ui) ReadLineFromStdin() string { + bio := bufio.NewReader(os.Stdin) + bytes, hasMoreInLine, _ := bio.ReadLine() + line := string(bytes) + + var continuation []byte + for hasMoreInLine { + continuation, hasMoreInLine, _ = bio.ReadLine() + line = line + string(continuation) + } + + return line +} + +func (ui *ui) WriteLine(line string) { + println(line) +} + +func (ui *ui) Write(output string) { + print(output) +}