diff --git a/arguments/files.go b/arguments/files.go new file mode 100644 index 0000000..5566837 --- /dev/null +++ b/arguments/files.go @@ -0,0 +1,7 @@ +package arguments + +import "os" + +type CurrentWorkingDir func() string +type SymlinkEvaler func(string) (string, error) +type FileStatReader func(string) (os.FileInfo, error) diff --git a/arguments/flags.go b/arguments/flags.go new file mode 100644 index 0000000..1297f06 --- /dev/null +++ b/arguments/flags.go @@ -0,0 +1,17 @@ +package arguments + +import "flag" + +var ( + fakeNameFlag = flag.String( + "fake-name", + "", + "The name of the fake struct", + ) + + outputPathFlag = flag.String( + "o", + "", + "The file or directory to which the generated fake will be written", + ) +) diff --git a/arguments/parser.go b/arguments/parser.go new file mode 100644 index 0000000..25ba127 --- /dev/null +++ b/arguments/parser.go @@ -0,0 +1,100 @@ +package arguments + +import ( + "path/filepath" + "regexp" + "strings" +) + +type ArgumentParser interface { + ParseArguments(...string) ParsedArguments +} + +func NewArgumentParser( + failHandler FailHandler, + currentWorkingDir CurrentWorkingDir, + symlinkEvaler SymlinkEvaler, + fileStatReader FileStatReader, +) ArgumentParser { + return argumentParser{ + failHandler: failHandler, + currentWorkingDir: currentWorkingDir, + symlinkEvaler: symlinkEvaler, + fileStatReader: fileStatReader, + } +} + +func (argParser argumentParser) ParseArguments(args ...string) ParsedArguments { + sourcePackageDir := argParser.getSourceDir(args[0]) + fakeImplName := getFakeName(args[1], *fakeNameFlag) + outputPath := argParser.getOutputPath(sourcePackageDir, fakeImplName, *outputPathFlag) + + return ParsedArguments{ + SourcePackageDir: sourcePackageDir, + OutputPath: outputPath, + + InterfaceName: args[1], + FakeImplName: fakeImplName, + + PrintToStdOut: len(args) == 3 && args[2] == "-", + } +} + +type argumentParser struct { + failHandler FailHandler + currentWorkingDir CurrentWorkingDir + symlinkEvaler SymlinkEvaler + fileStatReader FileStatReader +} + +type ParsedArguments struct { + SourcePackageDir string // abs path to the dir containing the interface to fake + OutputPath string // path to write the fake file to + + InterfaceName string // the interface to counterfeit + FakeImplName string // the name of the struct implementing the given interface + + PrintToStdOut bool +} + +func getFakeName(interfaceName, arg string) string { + if arg == "" { + return "Fake" + interfaceName + } else { + return arg + } +} + +var camelRegexp = regexp.MustCompile("([a-z])([A-Z])") + +func (argParser argumentParser) getOutputPath(sourceDir, fakeName, arg string) string { + if arg == "" { + snakeCaseName := strings.ToLower(camelRegexp.ReplaceAllString(fakeName, "${1}_${2}")) + return filepath.Join(sourceDir, "fakes", snakeCaseName+".go") + } else { + if !filepath.IsAbs(arg) { + arg = filepath.Join(argParser.currentWorkingDir(), arg) + } + return arg + } +} + +func (argParser argumentParser) getSourceDir(arg string) string { + if !filepath.IsAbs(arg) { + arg = filepath.Join(argParser.currentWorkingDir(), arg) + } + + arg, _ = argParser.symlinkEvaler(arg) + stat, err := argParser.fileStatReader(arg) + if err != nil { + argParser.failHandler("No such file or directory '%s'", arg) + } + + if !stat.IsDir() { + return filepath.Dir(arg) + } else { + return arg + } +} + +type FailHandler func(string, ...interface{}) diff --git a/arguments/parser_test.go b/arguments/parser_test.go new file mode 100644 index 0000000..db0f0e0 --- /dev/null +++ b/arguments/parser_test.go @@ -0,0 +1,200 @@ +package arguments_test + +import ( + "errors" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/maxbrunsfeld/counterfeiter/arguments" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("parsing arguments", func() { + var subject ArgumentParser + var parsedArgs ParsedArguments + var args []string + + var fail FailHandler + var cwd CurrentWorkingDir + var symlinkEvaler SymlinkEvaler + var fileStatReader FileStatReader + + JustBeforeEach(func() { + subject = NewArgumentParser(fail, cwd, symlinkEvaler, fileStatReader) + parsedArgs = subject.ParseArguments(args...) + }) + + BeforeEach(func() { + fail = func(_ string, _ ...interface{}) {} + cwd = func() string { + return "/home/test-user/workspace" + } + + symlinkEvaler = func(input string) (string, error) { + return input, nil + } + fileStatReader = func(filename string) (os.FileInfo, error) { + return fakeFileInfo(filename, true), nil + } + }) + + Describe("when two arguments are provided", func() { + BeforeEach(func() { + args = []string{"some/path", "MySpecialInterface"} + }) + + It("indicates to not print to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + }) + + It("provides a name for the fake implementing the interface", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) + }) + + It("treats the second argument as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("MySpecialInterface")) + }) + + It("snake cases the filename for the output directory", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + parsedArgs.SourcePackageDir, + "fakes", + "fake_my_special_interface.go", + ), + )) + }) + + Describe("the source directory", func() { + It("should be an absolute path", func() { + Expect(filepath.IsAbs(parsedArgs.SourcePackageDir)).To(BeTrue()) + }) + + Context("when the first arg is a path to a file", func() { + BeforeEach(func() { + fileStatReader = func(filename string) (os.FileInfo, error) { + return fakeFileInfo(filename, false), nil + } + }) + + It("should be the directory containing the file", func() { + Expect(parsedArgs.SourcePackageDir).ToNot(ContainSubstring("something.go")) + }) + }) + + Context("when the file stat cannot be read", func() { + var failWasCalled bool + + BeforeEach(func() { + fail = func(_ string, _ ...interface{}) { failWasCalled = true } + fileStatReader = func(_ string) (os.FileInfo, error) { + return fakeFileInfo("", false), errors.New("submarine-shoutout") + } + }) + + It("should call its fail handler", func() { + Expect(failWasCalled).To(BeTrue()) + }) + }) + }) + }) + + Describe("when three arguments are provided", func() { + Context("and the third one is '-'", func() { + BeforeEach(func() { + args = []string{"some/path", "MySpecialInterface", "-"} + }) + + It("treats the second argument as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("MySpecialInterface")) + }) + + It("provides a name for the fake implementing the interface", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) + }) + + It("indicates that the fake should be printed to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeTrue()) + }) + + It("snake cases the filename for the output directory", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + parsedArgs.SourcePackageDir, + "fakes", + "fake_my_special_interface.go", + ), + )) + }) + + Describe("the source directory", func() { + It("should be an absolute path", func() { + Expect(filepath.IsAbs(parsedArgs.SourcePackageDir)).To(BeTrue()) + }) + + Context("when the first arg is a path to a file", func() { + BeforeEach(func() { + fileStatReader = func(filename string) (os.FileInfo, error) { + return fakeFileInfo(filename, false), nil + } + }) + + It("should be the directory containing the file", func() { + Expect(parsedArgs.SourcePackageDir).ToNot(ContainSubstring("something.go")) + }) + }) + }) + }) + + Context("and the third one is some random input", func() { + BeforeEach(func() { + args = []string{"some/path", "MySpecialInterface", "WHOOPS"} + }) + + It("indicates to not print to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + }) + }) + }) +}) + +func TestCounterfeiterCLI(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Argument Parser Suite") +} + +func fakeFileInfo(filename string, isDir bool) os.FileInfo { + return testFileInfo{name: filename, isDir: isDir} +} + +type testFileInfo struct { + name string + isDir bool +} + +func (testFileInfo testFileInfo) Name() string { + return testFileInfo.name +} + +func (testFileInfo testFileInfo) IsDir() bool { + return testFileInfo.isDir +} + +func (testFileInfo testFileInfo) Size() int64 { + return 0 +} + +func (testFileInfo testFileInfo) Mode() os.FileMode { + return 0 +} + +func (testFileInfo testFileInfo) ModTime() time.Time { + return time.Now() +} + +func (testFileInfo testFileInfo) Sys() interface{} { + return nil +} diff --git a/main.go b/main.go index 0fca180..2fd61a5 100644 --- a/main.go +++ b/main.go @@ -5,9 +5,8 @@ import ( "fmt" "os" "path/filepath" - "regexp" - "strings" + "github.com/maxbrunsfeld/counterfeiter/arguments" "github.com/maxbrunsfeld/counterfeiter/generator" "github.com/maxbrunsfeld/counterfeiter/locator" ) @@ -40,33 +39,30 @@ OPTIONS be prepended to the name of the original interface. ` -var outputPathFlag = flag.String( - "o", - "", - "The file or directory to which the generated fake will be written", -) - -var fakeNameFlag = flag.String( - "fake-name", - "", - "The name of the fake struct", -) - func main() { flag.Parse() args := flag.Args() if len(args) < 2 { fail("%s", usage) + return } - sourceDir := getSourceDir(args[0]) - interfaceName := args[1] - fakeName := getFakeName(interfaceName, *fakeNameFlag) - outputPath := getOutputPath(sourceDir, fakeName, *outputPathFlag) + argumentParser := arguments.NewArgumentParser( + fail, + cwd, + filepath.EvalSymlinks, + os.Stat, + ) + parsedArgs := argumentParser.ParseArguments(args...) + + interfaceName := parsedArgs.InterfaceName + fakeName := parsedArgs.FakeImplName + sourceDir := parsedArgs.SourcePackageDir + outputPath := parsedArgs.OutputPath + outputDir := filepath.Dir(outputPath) fakePackageName := filepath.Base(outputDir) - shouldPrintToStdout := len(args) >= 3 && args[2] == "-" iface, err := locator.GetInterfaceFromFilePath(interfaceName, sourceDir) if err != nil { @@ -83,7 +79,7 @@ func main() { fail("%v", err) } - if shouldPrintToStdout { + if parsedArgs.PrintToStdOut { fmt.Println(code) } else { os.MkdirAll(outputDir, 0777) @@ -106,49 +102,6 @@ func main() { } } -func getSourceDir(arg string) string { - if !filepath.IsAbs(arg) { - arg = filepath.Join(cwd(), arg) - } - - arg, err := filepath.EvalSymlinks(arg) - - stat, err := os.Stat(arg) - if err != nil { - fail("No such file or directory '%s'", arg) - } - - if !stat.IsDir() { - return filepath.Dir(arg) - } else { - return arg - } -} - -func getOutputPath(sourceDir, fakeName, arg string) string { - if arg == "" { - return filepath.Join(sourceDir, "fakes", snakeCase(fakeName)+".go") - } else { - if !filepath.IsAbs(arg) { - arg = filepath.Join(cwd(), arg) - } - return arg - } -} - -func getFakeName(interfaceName, arg string) string { - if arg == "" { - return "Fake" + interfaceName - } else { - return arg - } -} - -func snakeCase(input string) string { - camelRegexp := regexp.MustCompile("([a-z])([A-Z])") - return strings.ToLower(camelRegexp.ReplaceAllString(input, "${1}_${2}")) -} - func cwd() string { dir, err := os.Getwd() if err != nil {