From 6a7872ae3efb4044f646187e72a6e92981b8df05 Mon Sep 17 00:00:00 2001 From: UebelAndre Date: Wed, 5 Jul 2023 09:00:33 -0700 Subject: [PATCH] Fix prost proto packages not sanitizing to valid module names (#2044) * Fix prost proto packages not sanitizing to lowercase * Add support for snake casing to match prost-build --- proto/prost/private/3rdparty/BUILD.heck.bazel | 20 +++++++++ proto/prost/private/BUILD.bazel | 5 ++- proto/prost/private/protoc_wrapper.rs | 43 ++++++++++++++----- .../tests/sanitized_modules/BUILD.bazel | 33 ++++++++++++++ .../private/tests/sanitized_modules/bar.proto | 17 ++++++++ .../private/tests/sanitized_modules/foo.proto | 11 +++++ .../sanitized_modules_test.rs | 26 +++++++++++ proto/prost/repositories.bzl | 10 +++++ 8 files changed, 154 insertions(+), 11 deletions(-) create mode 100644 proto/prost/private/3rdparty/BUILD.heck.bazel create mode 100644 proto/prost/private/tests/sanitized_modules/BUILD.bazel create mode 100644 proto/prost/private/tests/sanitized_modules/bar.proto create mode 100644 proto/prost/private/tests/sanitized_modules/foo.proto create mode 100644 proto/prost/private/tests/sanitized_modules/sanitized_modules_test.rs diff --git a/proto/prost/private/3rdparty/BUILD.heck.bazel b/proto/prost/private/3rdparty/BUILD.heck.bazel new file mode 100644 index 0000000000..c6723799c5 --- /dev/null +++ b/proto/prost/private/3rdparty/BUILD.heck.bazel @@ -0,0 +1,20 @@ +load("@rules_rust//rust:defs.bzl", "rust_library") + +package(default_visibility = ["//visibility:public"]) + +rust_library( + name = "heck", + srcs = glob(["**/*.rs"]), + crate_features = [ + "default", + ], + crate_root = "src/lib.rs", + edition = "2018", + rustc_flags = ["--cap-lints=allow"], + tags = [ + "manual", + "noclippy", + "norustfmt", + ], + version = "0.4.1", +) diff --git a/proto/prost/private/BUILD.bazel b/proto/prost/private/BUILD.bazel index eea28b7dcb..4e261c95a3 100644 --- a/proto/prost/private/BUILD.bazel +++ b/proto/prost/private/BUILD.bazel @@ -13,7 +13,10 @@ rust_binary( srcs = ["protoc_wrapper.rs"], edition = RUST_EDITION, visibility = ["//visibility:public"], - deps = [":current_prost_runtime"], + deps = [ + ":current_prost_runtime", + "@rules_rust_prost__heck//:heck", + ], ) rust_test( diff --git a/proto/prost/private/protoc_wrapper.rs b/proto/prost/private/protoc_wrapper.rs index abb1e5fad5..2a0b5553a4 100644 --- a/proto/prost/private/protoc_wrapper.rs +++ b/proto/prost/private/protoc_wrapper.rs @@ -9,6 +9,7 @@ use std::path::PathBuf; use std::process; use std::{env, fmt}; +use heck::ToSnakeCase; use prost::Message; use prost_types::{ DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet, @@ -44,6 +45,18 @@ fn find_generated_rust_files(out_dir: &Path) -> BTreeSet { all_rs_files } +fn snake_cased_package_name(package: &str) -> String { + if package == "_" { + return package.to_owned(); + } + + package + .split('.') + .map(|s| s.to_snake_case()) + .collect::>() + .join(".") +} + /// Rust module definition. #[derive(Debug, Default)] struct Module { @@ -121,20 +134,25 @@ fn generate_lib_rs(prost_outputs: &BTreeSet, is_tonic: bool) -> String .to_string() }; - let module_name = package.to_lowercase().to_string(); - - if module_name.is_empty() { + if package.is_empty() { continue; } - let mut name = module_name.clone(); - if module_name.contains('.') { - name = module_name + let name = if package == "_" { + package.clone() + } else if package.contains('.') { + package .rsplit_once('.') .expect("Failed to split on '.'") .1 - .to_string(); - } + .to_snake_case() + .to_string() + } else { + package.to_snake_case() + }; + + // Avoid a stack overflow by skipping a known bad package name + let module_name = snake_cased_package_name(&package); module_info.insert( module_name.clone(), @@ -145,7 +163,7 @@ fn generate_lib_rs(prost_outputs: &BTreeSet, is_tonic: bool) -> String }, ); - let module_parts = module_name.split('.').collect::>(); + let module_parts = module_name.split('.').collect::>(); for parent_module_index in 0..module_parts.len() { let child_module_index = parent_module_index + 1; if child_module_index >= module_parts.len() { @@ -307,7 +325,7 @@ fn descriptor_set_file_to_extern_paths( file: &FileDescriptorProto, ) { let package = file.package.clone().unwrap_or_default(); - let rust_path = rust_path.join(&package.replace('.', "::")); + let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::")); let proto_path = ProtoPath(package); for message_type in file.message_type.iter() { @@ -1043,6 +1061,11 @@ mod test { assert_eq!(proto_path.to_string(), "foo.bar"); assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz")); } + { + let proto_path = ProtoPath::from("Foo.baR"); + assert_eq!(proto_path.to_string(), "Foo.baR"); + assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz")); + } } #[test] diff --git a/proto/prost/private/tests/sanitized_modules/BUILD.bazel b/proto/prost/private/tests/sanitized_modules/BUILD.bazel new file mode 100644 index 0000000000..1df12d4021 --- /dev/null +++ b/proto/prost/private/tests/sanitized_modules/BUILD.bazel @@ -0,0 +1,33 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_rust//rust:defs.bzl", "rust_test") +load("//proto/prost:defs.bzl", "rust_prost_library") + +proto_library( + name = "foo_proto", + srcs = [ + "foo.proto", + ], + strip_import_prefix = "/proto/prost/private/tests/sanitized_modules", +) + +proto_library( + name = "bar_proto", + srcs = [ + "bar.proto", + ], + deps = [ + "foo_proto", + ], +) + +rust_prost_library( + name = "bar_proto_rs", + proto = ":bar_proto", +) + +rust_test( + name = "sanitized_modules_test", + srcs = ["sanitized_modules_test.rs"], + edition = "2021", + deps = [":bar_proto_rs"], +) diff --git a/proto/prost/private/tests/sanitized_modules/bar.proto b/proto/prost/private/tests/sanitized_modules/bar.proto new file mode 100644 index 0000000000..e90e00e7fb --- /dev/null +++ b/proto/prost/private/tests/sanitized_modules/bar.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +import "foo.proto"; + +package bAR.bAz.QAZ.QuX; + +message Bar { + string name = 1; + + Foo.QuuX.CoRgE.GRAULT.gaRply.Foo foo = 2; + + Foo.QuuX.CoRgE.GRAULT.gaRply.Foo.NestedFoo nested_foo = 3; + + message Baz { + string name = 4; + } +} diff --git a/proto/prost/private/tests/sanitized_modules/foo.proto b/proto/prost/private/tests/sanitized_modules/foo.proto new file mode 100644 index 0000000000..d0cc78537d --- /dev/null +++ b/proto/prost/private/tests/sanitized_modules/foo.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package Foo.QuuX.CoRgE.GRAULT.gaRply; + +message Foo { + string name = 1; + + message NestedFoo { + string name = 2; + } +} diff --git a/proto/prost/private/tests/sanitized_modules/sanitized_modules_test.rs b/proto/prost/private/tests/sanitized_modules/sanitized_modules_test.rs new file mode 100644 index 0000000000..2b852a37e0 --- /dev/null +++ b/proto/prost/private/tests/sanitized_modules/sanitized_modules_test.rs @@ -0,0 +1,26 @@ +//! Tests protos with various capitalizations in their package names are +//! consumable in an expected way. + +use bar_proto::b_ar::b_az::qaz::qu_x::bar::Baz as BazMessage; +use bar_proto::b_ar::b_az::qaz::qu_x::Bar as BarMessage; +use foo_proto::foo::quu_x::co_rg_e::grault::ga_rply::foo::NestedFoo as NestedFooMessage; +use foo_proto::foo::quu_x::co_rg_e::grault::ga_rply::Foo as FooMessage; + +#[test] +fn test_packages() { + let bar_message = BarMessage { + name: "bar".to_string(), + foo: Some(FooMessage { + name: "foo".to_string(), + }), + nested_foo: Some(NestedFooMessage { + name: "nested_foo".to_string(), + }), + }; + let baz_message = BazMessage { + name: "baz".to_string(), + }; + + assert_eq!(bar_message.name, "bar"); + assert_eq!(baz_message.name, "baz"); +} diff --git a/proto/prost/repositories.bzl b/proto/prost/repositories.bzl index 821d9eab7f..d0ae717d48 100644 --- a/proto/prost/repositories.bzl +++ b/proto/prost/repositories.bzl @@ -22,3 +22,13 @@ def rust_prost_dependencies(): strip_prefix = "protobuf-3.18.0", urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v3.18.0/protobuf-all-3.18.0.tar.gz"], ) + + maybe( + http_archive, + name = "rules_rust_prost__heck", + sha256 = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8", + type = "tar.gz", + urls = ["https://crates.io/api/v1/crates/heck/0.4.1/download"], + strip_prefix = "heck-0.4.1", + build_file = Label("@rules_rust//proto/prost/private/3rdparty/crates:BUILD.heck-0.4.1.bazel"), + )