Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting wrapper targets based on annotations #79

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions java/gazelle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ go_library(
"//java/gazelle/private/javaparser",
"//java/gazelle/private/logconfig",
"//java/gazelle/private/maven",
"//java/gazelle/private/sorted_multiset",
"//java/gazelle/private/sorted_set",
"//java/gazelle/private/types",
"@bazel_gazelle//config:go_default_library",
Expand Down
50 changes: 50 additions & 0 deletions java/gazelle/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ import (
type Configurer struct {
lang *javaLang
annotationToAttribute annotationToAttribute
annotationToWrapper annotationToWrapper
mavenInstallFile string
}

func NewConfigurer(lang *javaLang) *Configurer {
return &Configurer{
lang: lang,
annotationToAttribute: make(annotationToAttribute),
annotationToWrapper: make(annotationToWrapper),
}
}

func (jc *Configurer) RegisterFlags(fs *flag.FlagSet, cmd string, c *config.Config) {
fs.Var(&jc.annotationToAttribute, "java-annotation-to-attribute", "Mapping of annotations (on test classes) to attributes which should be set for that test rule. Examples: com.example.annotations.FlakyTest=flaky=True com.example.annotations.SlowTest=timeout=\"long\"")
fs.Var(&jc.annotationToWrapper, "java-annotation-to-wrapper", "Mapping of annotations (on test classes) to wrapper rules which should be used around the test rule. Example: com.example.annotations.RequiresNetwork=@some//wrapper:file.bzl=requires_network")
fs.StringVar(&jc.mavenInstallFile, "java-maven-install-file", "", "Path of the maven_install.json file. Defaults to \"maven_install.json\".")
}

Expand All @@ -42,6 +45,9 @@ func (jc *Configurer) CheckFlags(fs *flag.FlagSet, c *config.Config) error {
cfgs[""].MapAnnotationToAttribute(annotation, k, v)
}
}
for annotation, wrapper := range jc.annotationToWrapper {
cfgs[""].MapAnnotationToWrapper(annotation, wrapper.symbol)
}
if jc.mavenInstallFile != "" {
cfgs[""].SetMavenInstallFile(jc.mavenInstallFile)
}
Expand Down Expand Up @@ -192,3 +198,47 @@ func (f *annotationToAttribute) Set(value string) error {
(*f)[annotationClassName][key] = parsedValue
return nil
}

type loadInfo struct {
from string
symbol string
}

type annotationToWrapper map[string]loadInfo

func (f *annotationToWrapper) String() string {
s := "annotationToWrapper{"
for a, li := range *f {
s += a + ": "
s += fmt.Sprintf(`load("%s", "%s")`, li.from, li.symbol)
}
s += "}"
return s
}

func (f *annotationToWrapper) Set(value string) error {
parts := strings.Split(value, "=")
if len(parts) != 2 {
return fmt.Errorf("want --java-annotation-to-wrapper to have format com.example.RequiresNetwork=@some_repo//has:wrapper.bzl,wrapper_rule but didn't see exactly one equals sign")
}
annotation := parts[0]

if _, ok := (*f)[annotation]; ok {
return fmt.Errorf("saw conflicting values for --java-annotation-to-wrapper flag for annotation %v", annotation)
}

vParts := strings.Split(parts[1], ",")
if len(vParts) != 2 {
return fmt.Errorf("want --java-annotation-to-wrapper to have format com.example.RequiresNetwork=@some_repo//has:wrapper.bzl,wrapper_rule but didn't see exactly one comma after equals sign")
}

from := vParts[0]
symbol := vParts[1]

(*f)[annotation] = loadInfo{
from: from,
symbol: symbol,
}

return nil
}
48 changes: 37 additions & 11 deletions java/gazelle/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ func javaFileLess(l, r javaFile) bool {
return l.pathRelativeToBazelWorkspaceRoot < r.pathRelativeToBazelWorkspaceRoot
}

type separateJavaTestReasons struct {
attributes map[string]bzl.Expr
wrapper string
}

// GenerateRules extracts build metadata from source files in a directory.
//
// See language.GenerateRules for more information.
Expand Down Expand Up @@ -114,7 +119,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
testJavaImports := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)

// Java Test files which need to be generated separately from any others because they have explicit attribute overrides.
separateTestJavaFiles := make(map[javaFile]map[string]bzl.Expr)
separateTestJavaFiles := make(map[javaFile]separateJavaTestReasons)

// Files which are used by non-test classes in test java packages.
testHelperJavaFiles := sorted_set.NewSortedSetFn([]javaFile{}, javaFileLess)
Expand Down Expand Up @@ -234,8 +239,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
switch cfg.TestMode() {
case "file":
for _, tf := range testJavaFiles.SortedSlice() {
extraAttributes := separateTestJavaFiles[tf]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, extraAttributes, &res)
separateJavaTestReasons := separateTestJavaFiles[tf]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}

case "suite":
Expand Down Expand Up @@ -278,7 +283,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
if testHelperJavaFiles.Len() > 0 {
testHelperDep = ptr(testHelperLibname(suiteName))
}
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateTestJavaFiles[src], &res)
separateJavaTestReasons := separateTestJavaFiles[src]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}
}
}
Expand Down Expand Up @@ -407,10 +413,11 @@ func addFilteringOutOwnPackage(to *sorted_set.SortedSet[types.PackageName], from
}
}

func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]map[string]bzl.Expr, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]separateJavaTestReasons, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
if cfg.IsJavaTestFile(filepath.Base(file.pathRelativeToBazelWorkspaceRoot)) {
annotationClassNames := perClassMetadata[file.ClassName().FullyQualifiedClassName()].AnnotationClassNames
perFileAttrs := make(map[string]bzl.Expr)
wrapper := ""
for _, annotationClassName := range annotationClassNames.SortedSlice() {
if attrs, ok := cfg.AttributesForAnnotation(annotationClassName); ok {
for k, v := range attrs {
Expand All @@ -420,10 +427,20 @@ func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFil
perFileAttrs[k] = v
}
}
newWrapper, ok := cfg.WrapperForAnnotation(annotationClassName)
if ok {
if wrapper != "" {
log.Error().Str("file", file.pathRelativeToBazelWorkspaceRoot).Msgf("Saw conflicting wrappers from annotations: %v and %v. Picking one at random.", wrapper, newWrapper)
}
wrapper = newWrapper
}
}
testJavaFiles.Add(file)
if len(perFileAttrs) > 0 {
separateTestJavaFiles[file] = perFileAttrs
if len(perFileAttrs) > 0 || wrapper != "" {
separateTestJavaFiles[file] = separateJavaTestReasons{
attributes: perFileAttrs,
wrapper: wrapper,
}
}
} else {
testHelperJavaFiles.Add(file)
Expand Down Expand Up @@ -488,7 +505,7 @@ func (l javaLang) generateJavaBinary(file *rule.File, m types.ClassName, libName
})
}

func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
className := f.ClassName()
fullyQualifiedTestClass := className.FullyQualifiedClassName()
var testName string
Expand All @@ -498,12 +515,12 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
testName = className.BareOuterClassName()
}

ruleKind := "java_test"
javaRuleKind := "java_test"
if importsJunit5(imports) {
ruleKind = "java_junit5_test"
javaRuleKind = "java_junit5_test"
}

runtimeDeps := l.collectRuntimeDeps(ruleKind, testName, file)
runtimeDeps := l.collectRuntimeDeps(javaRuleKind, testName, file)
if importsJunit5(imports) {
// This should probably register imports here, and then allow the
// resolver to resolve this to an artifact, but we don't currently wire
Expand All @@ -514,7 +531,16 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
}
}

ruleKind := javaRuleKind
if wrapper != "" {
ruleKind = wrapper
}

r := rule.NewRule(ruleKind, testName)
if wrapper != "" {
r.AddArg(&bzl.Ident{Name: javaRuleKind})
}

path := strings.TrimPrefix(f.pathRelativeToBazelWorkspaceRoot, pathToPackageRelativeToBazelWorkspace+"/")
r.SetAttr("srcs", []string{path})
r.SetAttr("test_class", fullyQualifiedTestClass)
Expand Down
27 changes: 26 additions & 1 deletion java/gazelle/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/sorted_set"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
"github.com/bazelbuild/bazel-gazelle/language"
bzl "github.com/bazelbuild/buildtools/build"
"github.com/google/go-cmp/cmp"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
Expand All @@ -19,10 +20,12 @@ func TestSingleJavaTestFile(t *testing.T) {
type testCase struct {
includePackageInName bool
importedPackages []string
wrapper string
wantRuleKind string
wantImports []string
wantDeps []string
wantRuntimeDeps []string
wantArgs []bzl.Expr
}

for name, tc := range map[string]testCase{
Expand Down Expand Up @@ -86,6 +89,14 @@ func TestSingleJavaTestFile(t *testing.T) {
wantRuleKind: "java_test",
wantImports: []string{"com.example", "org.junit"},
},
"wrapper junit4": {
includePackageInName: false,
importedPackages: []string{"org.junit"},
wrapper: "some_wrapper",
wantRuleKind: "some_wrapper",
wantImports: []string{"com.example", "org.junit"},
wantArgs: []bzl.Expr{&bzl.Ident{Name: "java_test"}},
},
"explicit junit5": {
includePackageInName: false,
importedPackages: []string{"org.junit.jupiter.api"},
Expand Down Expand Up @@ -119,6 +130,19 @@ func TestSingleJavaTestFile(t *testing.T) {
"@maven//:org_junit_platform_junit_platform_reporting",
},
},
"wrapper junit5": {
includePackageInName: false,
importedPackages: []string{"org.junit.jupiter.api"},
wrapper: "some_wrapper",
wantRuleKind: "some_wrapper",
wantImports: []string{"com.example", "org.junit.jupiter.api"},
wantRuntimeDeps: []string{
"@maven//:org_junit_jupiter_junit_jupiter_engine",
"@maven//:org_junit_platform_junit_platform_launcher",
"@maven//:org_junit_platform_junit_platform_reporting",
},
wantArgs: []bzl.Expr{&bzl.Ident{Name: "java_junit5_test"}},
},
"explicit both junit4 and junit5": {
includePackageInName: false,
importedPackages: []string{"org.junit", "org.junit.jupiter.api"},
Expand All @@ -135,7 +159,7 @@ func TestSingleJavaTestFile(t *testing.T) {
var res language.GenerateResult

l := newTestJavaLang(t)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, nil, &res)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, tc.wrapper, nil, &res)

require.Len(t, res.Gen, 1, "want 1 generated rule")

Expand All @@ -154,6 +178,7 @@ func TestSingleJavaTestFile(t *testing.T) {
wantAttrs = append(wantAttrs, "runtime_deps")
}
require.ElementsMatch(t, wantAttrs, rule.AttrKeys())
require.ElementsMatch(t, tc.wantArgs, rule.Args())

require.Len(t, res.Imports, 1, "want 1 generated importedPackages")
wantImports := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)
Expand Down
24 changes: 24 additions & 0 deletions java/gazelle/javaconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (c *Config) NewChild() *Config {
testMode: c.testMode,
customTestFileSuffixes: c.customTestFileSuffixes,
annotationToAttribute: c.annotationToAttribute,
annotationToWrapper: c.annotationToWrapper,
excludedArtifacts: clonedExcludedArtifacts,
mavenRepositoryName: c.mavenRepositoryName,
}
Expand Down Expand Up @@ -99,6 +100,7 @@ type Config struct {
customTestFileSuffixes *[]string
excludedArtifacts map[string]struct{}
annotationToAttribute map[string]map[string]bzl.Expr
annotationToWrapper map[string]string
mavenRepositoryName string
}

Expand All @@ -120,6 +122,7 @@ func New(repoRoot string) *Config {
customTestFileSuffixes: nil,
excludedArtifacts: make(map[string]struct{}),
annotationToAttribute: make(map[string]map[string]bzl.Expr),
annotationToWrapper: make(map[string]string),
mavenRepositoryName: "maven",
}
}
Expand Down Expand Up @@ -244,6 +247,27 @@ func (c *Config) AttributesForAnnotation(annotation string) (map[string]bzl.Expr
return m, ok
}

func (c *Config) MapAnnotationToWrapper(annotation string, wrapper string) {
c.annotationToWrapper[annotation] = wrapper
}

func (c *Config) WrapperForAnnotation(annotation string) (string, bool) {
s, ok := c.annotationToWrapper[annotation]
return s, ok
}

func (c *Config) IsTestRule(ruleKind string) bool {
if ruleKind == "java_junit5_test" || ruleKind == "java_test" || ruleKind == "java_test_suite" {
return true
}
for _, wrapper := range c.annotationToWrapper {
if ruleKind == wrapper {
return true
}
}
return false
}

func equalStringSlices(l, r []string) bool {
if len(l) != len(r) {
return false
Expand Down
Loading
Loading