Skip to content

Commit

Permalink
Allow setting wrapper targets based on annotations
Browse files Browse the repository at this point in the history
This allows for registering that, e.g.

```java
@RequiresNetwork
class SomeTest {}
```

should be generated as:

```starlark
requires_network(
    java_test,
    ...
)
```

instead of:

```starlark
java_test(
    ...
)
```

In suite mode, separate targets will be generated for these special
targets.
  • Loading branch information
illicitonion committed Jan 22, 2024
1 parent 8b833d8 commit bbb48f6
Show file tree
Hide file tree
Showing 15 changed files with 580 additions and 19 deletions.
1 change: 1 addition & 0 deletions java/gazelle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ go_library(
"//java/gazelle/private/javaparser",
"//java/gazelle/private/logconfig",
"//java/gazelle/private/maven",
"//java/gazelle/private/sorted_multiset",
"//java/gazelle/private/sorted_set",
"//java/gazelle/private/types",
"@bazel_gazelle//config:go_default_library",
Expand Down
50 changes: 50 additions & 0 deletions java/gazelle/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ import (
type Configurer struct {
lang *javaLang
annotationToAttribute annotationToAttribute
annotationToWrapper annotationToWrapper
mavenInstallFile string
}

func NewConfigurer(lang *javaLang) *Configurer {
return &Configurer{
lang: lang,
annotationToAttribute: make(annotationToAttribute),
annotationToWrapper: make(annotationToWrapper),
}
}

func (jc *Configurer) RegisterFlags(fs *flag.FlagSet, cmd string, c *config.Config) {
fs.Var(&jc.annotationToAttribute, "java-annotation-to-attribute", "Mapping of annotations (on test classes) to attributes which should be set for that test rule. Examples: com.example.annotations.FlakyTest=flaky=True com.example.annotations.SlowTest=timeout=\"long\"")
fs.Var(&jc.annotationToWrapper, "java-annotation-to-wrapper", "Mapping of annotations (on test classes) to wrapper rules which should be used around the test rule. Example: com.example.annotations.RequiresNetwork=@some//wrapper:file.bzl=requires_network")
fs.StringVar(&jc.mavenInstallFile, "java-maven-install-file", "", "Path of the maven_install.json file. Defaults to \"maven_install.json\".")
}

Expand All @@ -42,6 +45,9 @@ func (jc *Configurer) CheckFlags(fs *flag.FlagSet, c *config.Config) error {
cfgs[""].MapAnnotationToAttribute(annotation, k, v)
}
}
for annotation, wrapper := range jc.annotationToWrapper {
cfgs[""].MapAnnotationToWrapper(annotation, wrapper.symbol)
}
if jc.mavenInstallFile != "" {
cfgs[""].SetMavenInstallFile(jc.mavenInstallFile)
}
Expand Down Expand Up @@ -192,3 +198,47 @@ func (f *annotationToAttribute) Set(value string) error {
(*f)[annotationClassName][key] = parsedValue
return nil
}

type loadInfo struct {
from string
symbol string
}

type annotationToWrapper map[string]loadInfo

func (f *annotationToWrapper) String() string {
s := "annotationToWrapper{"
for a, li := range *f {
s += a + ": "
s += fmt.Sprintf(`load("%s", "%s")`, li.from, li.symbol)
}
s += "}"
return s
}

func (f *annotationToWrapper) Set(value string) error {
parts := strings.Split(value, "=")
if len(parts) != 2 {
return fmt.Errorf("want --java-annotation-to-wrapper to have format com.example.RequiresNetwork=@some_repo//has:wrapper.bzl,wrapper_rule but didn't see exactly one equals sign")
}
annotation := parts[0]

if _, ok := (*f)[annotation]; ok {
return fmt.Errorf("saw conflicting values for --java-annotation-to-wrapper flag for annotation %v", annotation)
}

vParts := strings.Split(parts[1], ",")
if len(vParts) != 2 {
return fmt.Errorf("want --java-annotation-to-wrapper to have format com.example.RequiresNetwork=@some_repo//has:wrapper.bzl,wrapper_rule but didn't see exactly one comma after equals sign")
}

from := vParts[0]
symbol := vParts[1]

(*f)[annotation] = loadInfo{
from: from,
symbol: symbol,
}

return nil
}
48 changes: 37 additions & 11 deletions java/gazelle/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ func javaFileLess(l, r javaFile) bool {
return l.pathRelativeToBazelWorkspaceRoot < r.pathRelativeToBazelWorkspaceRoot
}

type separateJavaTestReasons struct {
attributes map[string]bzl.Expr
wrapper string
}

// GenerateRules extracts build metadata from source files in a directory.
//
// See language.GenerateRules for more information.
Expand Down Expand Up @@ -114,7 +119,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
testJavaImports := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)

// Java Test files which need to be generated separately from any others because they have explicit attribute overrides.
separateTestJavaFiles := make(map[javaFile]map[string]bzl.Expr)
separateTestJavaFiles := make(map[javaFile]separateJavaTestReasons)

// Files which are used by non-test classes in test java packages.
testHelperJavaFiles := sorted_set.NewSortedSetFn([]javaFile{}, javaFileLess)
Expand Down Expand Up @@ -234,8 +239,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
switch cfg.TestMode() {
case "file":
for _, tf := range testJavaFiles.SortedSlice() {
extraAttributes := separateTestJavaFiles[tf]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, extraAttributes, &res)
separateJavaTestReasons := separateTestJavaFiles[tf]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}

case "suite":
Expand Down Expand Up @@ -278,7 +283,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
if testHelperJavaFiles.Len() > 0 {
testHelperDep = ptr(testHelperLibname(suiteName))
}
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateTestJavaFiles[src], &res)
separateJavaTestReasons := separateTestJavaFiles[src]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}
}
}
Expand Down Expand Up @@ -407,10 +413,11 @@ func addFilteringOutOwnPackage(to *sorted_set.SortedSet[types.PackageName], from
}
}

func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]map[string]bzl.Expr, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]separateJavaTestReasons, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
if cfg.IsJavaTestFile(filepath.Base(file.pathRelativeToBazelWorkspaceRoot)) {
annotationClassNames := perClassMetadata[file.ClassName().FullyQualifiedClassName()].AnnotationClassNames
perFileAttrs := make(map[string]bzl.Expr)
wrapper := ""
for _, annotationClassName := range annotationClassNames.SortedSlice() {
if attrs, ok := cfg.AttributesForAnnotation(annotationClassName); ok {
for k, v := range attrs {
Expand All @@ -420,10 +427,20 @@ func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFil
perFileAttrs[k] = v
}
}
newWrapper, ok := cfg.WrapperForAnnotation(annotationClassName)
if ok {
if wrapper != "" {
log.Error().Str("file", file.pathRelativeToBazelWorkspaceRoot).Msgf("Saw conflicting wrappers from annotations: %v and %v. Picking one at random.", wrapper, newWrapper)
}
wrapper = newWrapper
}
}
testJavaFiles.Add(file)
if len(perFileAttrs) > 0 {
separateTestJavaFiles[file] = perFileAttrs
if len(perFileAttrs) > 0 || wrapper != "" {
separateTestJavaFiles[file] = separateJavaTestReasons{
attributes: perFileAttrs,
wrapper: wrapper,
}
}
} else {
testHelperJavaFiles.Add(file)
Expand Down Expand Up @@ -488,7 +505,7 @@ func (l javaLang) generateJavaBinary(file *rule.File, m types.ClassName, libName
})
}

func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
className := f.ClassName()
fullyQualifiedTestClass := className.FullyQualifiedClassName()
var testName string
Expand All @@ -498,12 +515,12 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
testName = className.BareOuterClassName()
}

ruleKind := "java_test"
javaRuleKind := "java_test"
if importsJunit5(imports) {
ruleKind = "java_junit5_test"
javaRuleKind = "java_junit5_test"
}

runtimeDeps := l.collectRuntimeDeps(ruleKind, testName, file)
runtimeDeps := l.collectRuntimeDeps(javaRuleKind, testName, file)
if importsJunit5(imports) {
// This should probably register imports here, and then allow the
// resolver to resolve this to an artifact, but we don't currently wire
Expand All @@ -514,7 +531,16 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
}
}

ruleKind := javaRuleKind
if wrapper != "" {
ruleKind = wrapper
}

r := rule.NewRule(ruleKind, testName)
if wrapper != "" {
r.AddArg(&bzl.Ident{Name: javaRuleKind})
}

path := strings.TrimPrefix(f.pathRelativeToBazelWorkspaceRoot, pathToPackageRelativeToBazelWorkspace+"/")
r.SetAttr("srcs", []string{path})
r.SetAttr("test_class", fullyQualifiedTestClass)
Expand Down
27 changes: 26 additions & 1 deletion java/gazelle/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/sorted_set"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
"github.com/bazelbuild/bazel-gazelle/language"
bzl "github.com/bazelbuild/buildtools/build"
"github.com/google/go-cmp/cmp"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
Expand All @@ -19,10 +20,12 @@ func TestSingleJavaTestFile(t *testing.T) {
type testCase struct {
includePackageInName bool
importedPackages []string
wrapper string
wantRuleKind string
wantImports []string
wantDeps []string
wantRuntimeDeps []string
wantArgs []bzl.Expr
}

for name, tc := range map[string]testCase{
Expand Down Expand Up @@ -86,6 +89,14 @@ func TestSingleJavaTestFile(t *testing.T) {
wantRuleKind: "java_test",
wantImports: []string{"com.example", "org.junit"},
},
"wrapper junit4": {
includePackageInName: false,
importedPackages: []string{"org.junit"},
wrapper: "some_wrapper",
wantRuleKind: "some_wrapper",
wantImports: []string{"com.example", "org.junit"},
wantArgs: []bzl.Expr{&bzl.Ident{Name: "java_test"}},
},
"explicit junit5": {
includePackageInName: false,
importedPackages: []string{"org.junit.jupiter.api"},
Expand Down Expand Up @@ -119,6 +130,19 @@ func TestSingleJavaTestFile(t *testing.T) {
"@maven//:org_junit_platform_junit_platform_reporting",
},
},
"wrapper junit5": {
includePackageInName: false,
importedPackages: []string{"org.junit.jupiter.api"},
wrapper: "some_wrapper",
wantRuleKind: "some_wrapper",
wantImports: []string{"com.example", "org.junit.jupiter.api"},
wantRuntimeDeps: []string{
"@maven//:org_junit_jupiter_junit_jupiter_engine",
"@maven//:org_junit_platform_junit_platform_launcher",
"@maven//:org_junit_platform_junit_platform_reporting",
},
wantArgs: []bzl.Expr{&bzl.Ident{Name: "java_junit5_test"}},
},
"explicit both junit4 and junit5": {
includePackageInName: false,
importedPackages: []string{"org.junit", "org.junit.jupiter.api"},
Expand All @@ -135,7 +159,7 @@ func TestSingleJavaTestFile(t *testing.T) {
var res language.GenerateResult

l := newTestJavaLang(t)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, nil, &res)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, tc.wrapper, nil, &res)

require.Len(t, res.Gen, 1, "want 1 generated rule")

Expand All @@ -154,6 +178,7 @@ func TestSingleJavaTestFile(t *testing.T) {
wantAttrs = append(wantAttrs, "runtime_deps")
}
require.ElementsMatch(t, wantAttrs, rule.AttrKeys())
require.ElementsMatch(t, tc.wantArgs, rule.Args())

require.Len(t, res.Imports, 1, "want 1 generated importedPackages")
wantImports := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)
Expand Down
24 changes: 24 additions & 0 deletions java/gazelle/javaconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (c *Config) NewChild() *Config {
testMode: c.testMode,
customTestFileSuffixes: c.customTestFileSuffixes,
annotationToAttribute: c.annotationToAttribute,
annotationToWrapper: c.annotationToWrapper,
excludedArtifacts: clonedExcludedArtifacts,
mavenRepositoryName: c.mavenRepositoryName,
}
Expand Down Expand Up @@ -99,6 +100,7 @@ type Config struct {
customTestFileSuffixes *[]string
excludedArtifacts map[string]struct{}
annotationToAttribute map[string]map[string]bzl.Expr
annotationToWrapper map[string]string
mavenRepositoryName string
}

Expand All @@ -120,6 +122,7 @@ func New(repoRoot string) *Config {
customTestFileSuffixes: nil,
excludedArtifacts: make(map[string]struct{}),
annotationToAttribute: make(map[string]map[string]bzl.Expr),
annotationToWrapper: make(map[string]string),
mavenRepositoryName: "maven",
}
}
Expand Down Expand Up @@ -244,6 +247,27 @@ func (c *Config) AttributesForAnnotation(annotation string) (map[string]bzl.Expr
return m, ok
}

func (c *Config) MapAnnotationToWrapper(annotation string, wrapper string) {
c.annotationToWrapper[annotation] = wrapper
}

func (c *Config) WrapperForAnnotation(annotation string) (string, bool) {
s, ok := c.annotationToWrapper[annotation]
return s, ok
}

func (c *Config) IsTestRule(ruleKind string) bool {
if ruleKind == "java_junit5_test" || ruleKind == "java_test" || ruleKind == "java_test_suite" {
return true
}
for _, wrapper := range c.annotationToWrapper {
if ruleKind == wrapper {
return true
}
}
return false
}

func equalStringSlices(l, r []string) bool {
if len(l) != len(r) {
return false
Expand Down
Loading

0 comments on commit bbb48f6

Please sign in to comment.