diff --git a/BUILD.bazel b/BUILD.bazel index 5467e221b..e281ea890 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -9,8 +9,10 @@ load("//:def.bzl", "gazelle", "gazelle_binary") # gazelle:exclude .bazelci # gazelle:exclude .bcr # gazelle:exclude .idea +# gazelle:exclude .ijwb # gazelle:exclude .github # gazelle:exclude .vscode +# gazelle:exclude internal/module/testdata # gazelle:go_naming_convention import_alias gazelle( name = "gazelle", diff --git a/internal/module/BUILD.bazel b/internal/module/BUILD.bazel index f5ae5e583..0adffa699 100644 --- a/internal/module/BUILD.bazel +++ b/internal/module/BUILD.bazel @@ -1,11 +1,14 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "module", srcs = ["module.go"], importpath = "github.com/bazelbuild/bazel-gazelle/internal/module", visibility = ["//:__subpackages__"], - deps = ["@com_github_bazelbuild_buildtools//build"], + deps = [ + "//label", + "@com_github_bazelbuild_buildtools//build", + ], ) filegroup( @@ -14,6 +17,7 @@ filegroup( srcs = [ "BUILD.bazel", "module.go", + "module_test.go", ], visibility = ["//visibility:public"], ) @@ -23,3 +27,17 @@ alias( actual = ":module", visibility = ["//:__subpackages__"], ) + +go_test( + name = "module_test", + srcs = ["module_test.go"], + data = glob( + ["testdata/**"], + allow_empty = True, + ), + embed = [":module"], + deps = [ + "@com_github_google_go_cmp//cmp", + "@io_bazel_rules_go//go/runfiles:go_default_library", + ], +) diff --git a/internal/module/module.go b/internal/module/module.go index c4411e0cd..04e2e612a 100644 --- a/internal/module/module.go +++ b/internal/module/module.go @@ -18,9 +18,11 @@ limitations under the License. package module import ( + "fmt" "os" "path/filepath" + "github.com/bazelbuild/bazel-gazelle/label" "github.com/bazelbuild/buildtools/build" ) @@ -29,30 +31,20 @@ import ( // See https://bazel.build/external/module#repository_names_and_strict_deps for more information on // apparent names. func ExtractModuleToApparentNameMapping(repoRoot string) (func(string) string, error) { - moduleFile, err := parseModuleFile(repoRoot) + moduleToApparentName, err := collectApparentNames(repoRoot, "MODULE.bazel") if err != nil { return nil, err } - var moduleToApparentName map[string]string - if moduleFile != nil { - moduleToApparentName = collectApparentNames(moduleFile) - } else { - // If there is no MODULE.bazel file, return a function that always returns the empty string. - // Languages will know to fall back to the WORKSPACE names of repos. - moduleToApparentName = make(map[string]string) - } return func(moduleName string) string { return moduleToApparentName[moduleName] }, nil } -func parseModuleFile(repoRoot string) (*build.File, error) { - path := filepath.Join(repoRoot, "MODULE.bazel") +func parseModuleSegment(repoRoot, relPath string) (*build.File, error) { + path := filepath.Join(repoRoot, relPath) bytes, err := os.ReadFile(path) - if os.IsNotExist(err) { - return nil, nil - } else if err != nil { + if err != nil { return nil, err } return build.ParseModule(path, bytes) @@ -61,11 +53,59 @@ func parseModuleFile(repoRoot string) (*build.File, error) { // Collects the mapping of module names (e.g. "rules_go") to user-configured apparent names (e.g. // "my_rules_go"). See https://bazel.build/external/module#repository_names_and_strict_deps for more // information on apparent names. -func collectApparentNames(m *build.File) map[string]string { +func collectApparentNames(repoRoot, relPath string) (map[string]string, error) { + apparentNames := make(map[string]string) + seenFiles := make(map[string]struct{}) + filesToProcess := []string{relPath} + + for len(filesToProcess) > 0 { + f := filesToProcess[0] + filesToProcess = filesToProcess[1:] + if _, seen := seenFiles[f]; seen { + continue + } + seenFiles[f] = struct{}{} + bf, err := parseModuleSegment(repoRoot, f) + if err != nil { + if f == "MODULE.bazel" && os.IsNotExist(err) { + // If there is no MODULE.bazel file, return an empty map but no error. + // Languages will know to fall back to the WORKSPACE names of repos. + return nil, nil + } + return nil, err + } + names, includeLabels := collectApparentNamesAndIncludes(bf) + for name, apparentName := range names { + apparentNames[name] = apparentName + } + for _, includeLabel := range includeLabels { + l, err := label.Parse(includeLabel) + if err != nil { + return nil, fmt.Errorf("failed to parse include label %q: %v", includeLabel, err) + } + p := filepath.Join(filepath.FromSlash(l.Pkg), filepath.FromSlash(l.Name)) + filesToProcess = append(filesToProcess, p) + } + } + + return apparentNames, nil +} + +func collectApparentNamesAndIncludes(f *build.File) (map[string]string, []string) { apparentNames := make(map[string]string) + var includeLabels []string - for _, dep := range m.Rules("") { - if dep.Name() == "" { + for _, dep := range f.Rules("") { + if dep.ExplicitName() == "" { + if ident, ok := dep.Call.X.(*build.Ident); !ok || ident.Name != "include" { + continue + } + if len(dep.Call.List) != 1 { + continue + } + if str, ok := dep.Call.List[0].(*build.StringExpr); ok { + includeLabels = append(includeLabels, str.Value) + } continue } if dep.Kind() != "module" && dep.Kind() != "bazel_dep" { @@ -82,5 +122,5 @@ func collectApparentNames(m *build.File) map[string]string { } } - return apparentNames + return apparentNames, includeLabels } diff --git a/internal/module/module_test.go b/internal/module/module_test.go new file mode 100644 index 000000000..a9cee370b --- /dev/null +++ b/internal/module/module_test.go @@ -0,0 +1,44 @@ +package module + +import ( + "path/filepath" + "testing" + + "github.com/bazelbuild/rules_go/go/runfiles" + "github.com/google/go-cmp/cmp" +) + +func TestCollectApparent(t *testing.T) { + moduleFile, err := runfiles.Rlocation("bazel_gazelle/internal/module/testdata/MODULE.bazel") + if err != nil { + t.Fatal(err) + } + + apparentNames, err := collectApparentNames(filepath.Dir(moduleFile), "MODULE.bazel") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := map[string]string{ + "rules_bar": "rules_bar", + "rules_baz": "rules_baz", + "rules_foo": "my_rules_foo", + "rules_lang": "my_rules_lang", + "rules_quz": "rules_quz", + "test_module": "my_test_module", + } + if diff := cmp.Diff(expected, apparentNames); diff != "" { + t.Errorf("unexpected apparent names (-want +got):\n%s", diff) + } +} + +func TestCollectApparent_fileDoesNotExist(t *testing.T) { + _, err := collectApparentNames(t.TempDir(), "MODULE.bazel") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = collectApparentNames(t.TempDir(), "segment.MODULE.bazel") + if err == nil { + t.Fatalf("expected error, got nil") + } +} diff --git a/internal/module/testdata/MODULE.bazel b/internal/module/testdata/MODULE.bazel new file mode 100644 index 000000000..fd5273061 --- /dev/null +++ b/internal/module/testdata/MODULE.bazel @@ -0,0 +1,10 @@ +module( + name = "test_module", + repo_name = "my_test_module", +) + +bazel_dep(name = "rules_bar", version = "1.2.3") + +include("//bazel:lang.MODULE.bazel") + +bazel_dep(name = "rules_foo", version = "1.2.3", repo_name = "my_rules_foo") diff --git a/internal/module/testdata/bazel/dir/deps.MODULE.bazel b/internal/module/testdata/bazel/dir/deps.MODULE.bazel new file mode 100644 index 000000000..ccafd9582 --- /dev/null +++ b/internal/module/testdata/bazel/dir/deps.MODULE.bazel @@ -0,0 +1 @@ +bazel_dep(name = "rules_quz", version = "0.0.1") diff --git a/internal/module/testdata/bazel/lang.MODULE.bazel b/internal/module/testdata/bazel/lang.MODULE.bazel new file mode 100644 index 000000000..6d8489e4d --- /dev/null +++ b/internal/module/testdata/bazel/lang.MODULE.bazel @@ -0,0 +1,4 @@ +bazel_dep(name = "rules_lang", version = "1.2.3", repo_name = "my_rules_lang") + +include("//:deps.MODULE.bazel") +include("//bazel:dir/deps.MODULE.bazel") diff --git a/internal/module/testdata/deps.MODULE.bazel b/internal/module/testdata/deps.MODULE.bazel new file mode 100644 index 000000000..3c23d33e4 --- /dev/null +++ b/internal/module/testdata/deps.MODULE.bazel @@ -0,0 +1 @@ +bazel_dep(name = "rules_baz", version = "0.0.1")