diff --git a/context/context_test.go b/context/context_test.go index 8f11aa2..943b984 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -431,6 +431,51 @@ func TestNoDep(t *testing.T) { verifyChecksum(g, c, "new") } +func TestMainTest(t *testing.T) { + g := gt.New(t) + defer g.Clean() + + // This test relies on the file list being sorted. Normally not required. + testNeedsSortOrder = true + + g.Setup("co1/pk1", + gt.File("a.go", "strings"), + ) + // Include some non-go files that sort before and after the go file. + // Triggers an edge case where the package wasn't added. + g.Setup("co2/pk1", + gt.FilePkgBuild("a.go", "main_test", "", "testing"), + gt.FilePkgBuild("b.go", "main", "", "fmt"), + ) + g.In("co1") + c := ctx(g) + list(g, c, "before", ` + l co1/pk1 < [] + s strings < ["co1/pk1"] +`) + g.Check(c.ModifyImport(pkg("co2/pk1"), AddUpdate)) + g.Check(c.Alter()) + g.Check(c.WriteVendorFile()) + + list(g, c, "after", ` +pv co1/vendor/co2/pk1 [co2/pk1] < [] + l co1/pk1 < [] + s fmt < ["co1/vendor/co2/pk1"] + s strings < ["co1/pk1"] + s testing < ["co1/vendor/co2/pk1"] +`) + + c.IgnoreBuild("test") + + list(g, c, "ignore test", ` +pv co1/vendor/co2/pk1 [co2/pk1] < [] + l co1/pk1 < [] + s fmt < ["co1/vendor/co2/pk1"] + s strings < ["co1/pk1"] +`) + +} + func TestUpdate(t *testing.T) { g := gt.New(t) defer g.Clean() diff --git a/context/resolve.go b/context/resolve.go index f484a61..8b2ebd9 100644 --- a/context/resolve.go +++ b/context/resolve.go @@ -72,15 +72,20 @@ func (ctx *Context) getFileTags(pathname string, f *ast.File) (tags, imports []s return nil, nil, nil } } + tags = make([]string, 0, 6) + if strings.HasSuffix(f.Name.Name, "_test") { + tags = append(tags, "test") + } + pkgNameNormalized := strings.TrimSuffix(f.Name.Name, "_test") + // Files with package name "documentation" should be ignored, per go build tool. - if f.Name.Name == "documentation" { + if pkgNameNormalized == "documentation" { return nil, nil, nil } filename := filenameExt[:len(filenameExt)-3] l := strings.Split(filename, "_") - tags = make([]string, 0, 6) if n := len(l); n > 1 && l[n-1] == "test" { l = l[:n-1] @@ -156,9 +161,10 @@ func (ctx *Context) addFileImports(pathname, gopath string) (*Package, error) { if f == nil { return nil, nil } + pkgNameNormalized := strings.TrimSuffix(f.Name.Name, "_test") // Files with package name "documentation" should be ignored, per go build tool. - if f.Name.Name == "documentation" { + if pkgNameNormalized == "documentation" { return nil, nil } @@ -185,7 +191,7 @@ func (ctx *Context) addFileImports(pathname, gopath string) (*Package, error) { Location: LocationUnknown, Presence: PresenceFound, } - if f.Name.Name == "main" { + if pkgNameNormalized == "main" { status.Type = TypeProgram } pkg = ctx.setPackage(dir, importPath, importPath, gopath, status) diff --git a/context/rewrite.go b/context/rewrite.go index d58879a..3ace9f2 100644 --- a/context/rewrite.go +++ b/context/rewrite.go @@ -111,8 +111,9 @@ func (ctx *Context) rewrite() error { if f == nil { return nil } + pkgNameNormalized := strings.TrimSuffix(f.Name.Name, "_test") // Files with package name "documentation" should be ignored, per go build tool. - if f.Name.Name == "documentation" { + if pkgNameNormalized == "documentation" { return nil } diff --git a/internal/gt/gopath.go b/internal/gt/gopath.go index 736168a..f0cf92b 100644 --- a/internal/gt/gopath.go +++ b/internal/gt/gopath.go @@ -116,23 +116,20 @@ func (g *GopathTest) Check(err error) { type FileSpec struct { Pkg string + PkgName string Name string Imports []string Build string } var fileSpecFile = template.Must(template.New("").Funcs(map[string]interface{}{ - "pkg": func(s string) string { - _, pkg := path.Split(s) - return pkg - }, "imp": func(s string) string { return "`" + s + "`" }, }).Parse(` {{if .Build}} // +build {{.Build}} {{end}} -package {{.Pkg|pkg}} +package {{.PkgName}} import ( {{range .Imports}} {{.|imp}} @@ -151,12 +148,18 @@ func File(name string, imports ...string) FileSpec { func FileBuild(name, build string, imports ...string) FileSpec { return FileSpec{Name: name, Build: build, Imports: imports} } +func FilePkgBuild(name, pkgName, build string, imports ...string) FileSpec { + return FileSpec{Name: name, PkgName: pkgName, Build: build, Imports: imports} +} func (g *GopathTest) Setup(at string, files ...FileSpec) { var err error pkg := g.mksrc(at) for _, f := range files { f.Pkg = at + if len(f.PkgName) == 0 { + _, f.PkgName = path.Split(f.Pkg) + } p := filepath.Join(pkg, f.Name) err = ioutil.WriteFile(p, f.Bytes(), 0600) if err != nil {