diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java index 2028d43c9fa96e..089c9fe3bc9031 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java +++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java @@ -17,7 +17,7 @@ /** Constants used in Proto rules. */ public final class ProtoConstants { /** Default label for proto compiler. */ - static final String DEFAULT_PROTOC_LABEL = "@bazel_tools//tools/proto:protoc"; + public static final String DEFAULT_PROTOC_LABEL = "@bazel_tools//tools/proto:protoc"; /** Default label for java proto toolchains. */ static final String DEFAULT_JAVA_PROTO_LABEL = "@bazel_tools//tools/proto:java_toolchain"; diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl index 872e7af0641c67..72eaafda28e0af 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl @@ -32,6 +32,13 @@ def _rule_impl(ctx): if ctx.attr.plugin != None: plugin = ctx.attr.plugin[DefaultInfo].files_to_run + if semantics.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN_TYPE].proto.proto_compiler + protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN_TYPE].proto.protoc_opts + else: + proto_compiler = ctx.attr._proto_compiler.files_to_run + protoc_opts = ctx.fragments.proto.experimental_protoc_opts + return [ DefaultInfo( files = depset(), @@ -44,8 +51,8 @@ def _rule_impl(ctx): plugin = plugin, runtime = ctx.attr.runtime, provided_proto_sources = provided_proto_sources, - proto_compiler = ctx.attr._proto_compiler.files_to_run, - protoc_opts = ctx.fragments.proto.experimental_protoc_opts, + proto_compiler = proto_compiler, + protoc_opts = protoc_opts, progress_message = ctx.attr.progress_message, mnemonic = ctx.attr.mnemonic, allowlist_different_package = ctx.attr.allowlist_different_package, @@ -74,13 +81,15 @@ proto_lang_toolchain = rule( cfg = "exec", providers = [PackageSpecificationInfo], ), + } | ({} if semantics.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { "_proto_compiler": attr.label( cfg = "exec", executable = True, allow_files = True, default = configuration_field("proto", "proto_compiler"), ), - }, + }), provides = [ProtoLangToolchainInfo], fragments = ["proto"], + toolchains = semantics.PROTO_TOOLCHAIN, # Used to obtain protoc ) diff --git a/src/test/java/com/google/devtools/build/lib/packages/BUILD b/src/test/java/com/google/devtools/build/lib/packages/BUILD index b265ae2788c1fb..6b4a1feb953f74 100644 --- a/src/test/java/com/google/devtools/build/lib/packages/BUILD +++ b/src/test/java/com/google/devtools/build/lib/packages/BUILD @@ -163,6 +163,7 @@ java_library( "//src/main/java/com/google/devtools/build/lib/pkgcache", "//src/main/java/com/google/devtools/build/lib/rules:repository/repository_function", "//src/main/java/com/google/devtools/build/lib/rules/cpp", + "//src/main/java/com/google/devtools/build/lib/rules/proto", "//src/main/java/com/google/devtools/build/lib/rules/python", "//src/main/java/com/google/devtools/build/lib/skyframe:precomputed_value", "//src/main/java/com/google/devtools/build/lib/skyframe:skyframe_cluster", diff --git a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java index bd3d56de0e12b1..71344ac45f0222 100644 --- a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java +++ b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java @@ -14,6 +14,7 @@ package com.google.devtools.build.lib.packages.util; +import com.google.devtools.build.lib.rules.proto.ProtoConstants; import com.google.devtools.build.lib.testutil.TestConstants; import java.io.IOException; @@ -43,7 +44,9 @@ private static void registerProtoToolchain(MockToolsConfig config) throws IOExce "tools/proto/toolchains/BUILD", TestConstants.LOAD_PROTO_TOOLCHAIN, "proto_toolchain(name = 'protoc_sources'," - + "proto_compiler = '//net/proto2/compiler/public:protocol_compiler')"); + + "proto_compiler = '" + + ProtoConstants.DEFAULT_PROTOC_LABEL + + "')"); } /** Create a dummy "net/proto2 compiler and proto APIs for all languages and versions. */ diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java index 6536328afe1e67..3d7b8620290cbc 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java +++ b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java @@ -99,6 +99,39 @@ public void protoToolchain() throws Exception { validateProtoCompiler(toolchain, ProtoConstants.DEFAULT_PROTOC_LABEL); } + @Test + public void protoToolchainResolution_enabled() throws Exception { + setBuildLanguageOptions("--incompatible_enable_proto_toolchain_resolution"); + scratch.file( + "third_party/x/BUILD", + "licenses(['unencumbered'])", + "cc_binary(name = 'plugin', srcs = ['plugin.cc'])", + "cc_library(name = 'runtime', srcs = ['runtime.cc'])", + "filegroup(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])", + "filegroup(name = 'any', srcs = ['any.proto'])", + "proto_library(name = 'denied', srcs = [':descriptors', ':any'])"); + scratch.file( + "foo/BUILD", + TestConstants.LOAD_PROTO_LANG_TOOLCHAIN, + "licenses(['unencumbered'])", + "proto_lang_toolchain(", + " name = 'toolchain',", + " command_line = 'cmd-line:$(OUT)',", + " plugin_format_flag = '--plugin=%s',", + " plugin = '//third_party/x:plugin',", + " runtime = '//third_party/x:runtime',", + " progress_message = 'Progress Message %{label}',", + " mnemonic = 'MyMnemonic',", + ")"); + + update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus()); + ProtoLangToolchainProvider toolchain = + ProtoLangToolchainProvider.get(getConfiguredTarget("//foo:toolchain")); + + validateProtoLangToolchain(toolchain); + validateProtoCompiler(toolchain, ProtoConstants.DEFAULT_PROTOC_LABEL); + } + @Test public void protoToolchainBlacklistProtoLibraries() throws Exception { scratch.file(