Skip to content

Commit

Permalink
Fix prost proto packages not sanitizing to valid module names (#2044)
Browse files Browse the repository at this point in the history
* Fix prost proto packages not sanitizing to lowercase

* Add support for snake casing to match prost-build
  • Loading branch information
UebelAndre authored Jul 5, 2023
1 parent c080d7b commit 6a7872a
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 11 deletions.
20 changes: 20 additions & 0 deletions proto/prost/private/3rdparty/BUILD.heck.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@rules_rust//rust:defs.bzl", "rust_library")

package(default_visibility = ["//visibility:public"])

rust_library(
name = "heck",
srcs = glob(["**/*.rs"]),
crate_features = [
"default",
],
crate_root = "src/lib.rs",
edition = "2018",
rustc_flags = ["--cap-lints=allow"],
tags = [
"manual",
"noclippy",
"norustfmt",
],
version = "0.4.1",
)
5 changes: 4 additions & 1 deletion proto/prost/private/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ rust_binary(
srcs = ["protoc_wrapper.rs"],
edition = RUST_EDITION,
visibility = ["//visibility:public"],
deps = [":current_prost_runtime"],
deps = [
":current_prost_runtime",
"@rules_rust_prost__heck//:heck",
],
)

rust_test(
Expand Down
43 changes: 33 additions & 10 deletions proto/prost/private/protoc_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::path::PathBuf;
use std::process;
use std::{env, fmt};

use heck::ToSnakeCase;
use prost::Message;
use prost_types::{
DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet,
Expand Down Expand Up @@ -44,6 +45,18 @@ fn find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf> {
all_rs_files
}

fn snake_cased_package_name(package: &str) -> String {
if package == "_" {
return package.to_owned();
}

package
.split('.')
.map(|s| s.to_snake_case())
.collect::<Vec<_>>()
.join(".")
}

/// Rust module definition.
#[derive(Debug, Default)]
struct Module {
Expand Down Expand Up @@ -121,20 +134,25 @@ fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String
.to_string()
};

let module_name = package.to_lowercase().to_string();

if module_name.is_empty() {
if package.is_empty() {
continue;
}

let mut name = module_name.clone();
if module_name.contains('.') {
name = module_name
let name = if package == "_" {
package.clone()
} else if package.contains('.') {
package
.rsplit_once('.')
.expect("Failed to split on '.'")
.1
.to_string();
}
.to_snake_case()
.to_string()
} else {
package.to_snake_case()
};

// Avoid a stack overflow by skipping a known bad package name
let module_name = snake_cased_package_name(&package);

module_info.insert(
module_name.clone(),
Expand All @@ -145,7 +163,7 @@ fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String
},
);

let module_parts = module_name.split('.').collect::<Vec<&str>>();
let module_parts = module_name.split('.').collect::<Vec<_>>();
for parent_module_index in 0..module_parts.len() {
let child_module_index = parent_module_index + 1;
if child_module_index >= module_parts.len() {
Expand Down Expand Up @@ -307,7 +325,7 @@ fn descriptor_set_file_to_extern_paths(
file: &FileDescriptorProto,
) {
let package = file.package.clone().unwrap_or_default();
let rust_path = rust_path.join(&package.replace('.', "::"));
let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::"));
let proto_path = ProtoPath(package);

for message_type in file.message_type.iter() {
Expand Down Expand Up @@ -1043,6 +1061,11 @@ mod test {
assert_eq!(proto_path.to_string(), "foo.bar");
assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz"));
}
{
let proto_path = ProtoPath::from("Foo.baR");
assert_eq!(proto_path.to_string(), "Foo.baR");
assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz"));
}
}

#[test]
Expand Down
33 changes: 33 additions & 0 deletions proto/prost/private/tests/sanitized_modules/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_rust//rust:defs.bzl", "rust_test")
load("//proto/prost:defs.bzl", "rust_prost_library")

proto_library(
name = "foo_proto",
srcs = [
"foo.proto",
],
strip_import_prefix = "/proto/prost/private/tests/sanitized_modules",
)

proto_library(
name = "bar_proto",
srcs = [
"bar.proto",
],
deps = [
"foo_proto",
],
)

rust_prost_library(
name = "bar_proto_rs",
proto = ":bar_proto",
)

rust_test(
name = "sanitized_modules_test",
srcs = ["sanitized_modules_test.rs"],
edition = "2021",
deps = [":bar_proto_rs"],
)
17 changes: 17 additions & 0 deletions proto/prost/private/tests/sanitized_modules/bar.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
syntax = "proto3";

import "foo.proto";

package bAR.bAz.QAZ.QuX;

message Bar {
string name = 1;

Foo.QuuX.CoRgE.GRAULT.gaRply.Foo foo = 2;

Foo.QuuX.CoRgE.GRAULT.gaRply.Foo.NestedFoo nested_foo = 3;

message Baz {
string name = 4;
}
}
11 changes: 11 additions & 0 deletions proto/prost/private/tests/sanitized_modules/foo.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

package Foo.QuuX.CoRgE.GRAULT.gaRply;

message Foo {
string name = 1;

message NestedFoo {
string name = 2;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//! Tests protos with various capitalizations in their package names are
//! consumable in an expected way.

use bar_proto::b_ar::b_az::qaz::qu_x::bar::Baz as BazMessage;
use bar_proto::b_ar::b_az::qaz::qu_x::Bar as BarMessage;
use foo_proto::foo::quu_x::co_rg_e::grault::ga_rply::foo::NestedFoo as NestedFooMessage;
use foo_proto::foo::quu_x::co_rg_e::grault::ga_rply::Foo as FooMessage;

#[test]
fn test_packages() {
let bar_message = BarMessage {
name: "bar".to_string(),
foo: Some(FooMessage {
name: "foo".to_string(),
}),
nested_foo: Some(NestedFooMessage {
name: "nested_foo".to_string(),
}),
};
let baz_message = BazMessage {
name: "baz".to_string(),
};

assert_eq!(bar_message.name, "bar");
assert_eq!(baz_message.name, "baz");
}
10 changes: 10 additions & 0 deletions proto/prost/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ def rust_prost_dependencies():
strip_prefix = "protobuf-3.18.0",
urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v3.18.0/protobuf-all-3.18.0.tar.gz"],
)

maybe(
http_archive,
name = "rules_rust_prost__heck",
sha256 = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8",
type = "tar.gz",
urls = ["https://crates.io/api/v1/crates/heck/0.4.1/download"],
strip_prefix = "heck-0.4.1",
build_file = Label("@rules_rust//proto/prost/private/3rdparty/crates:BUILD.heck-0.4.1.bazel"),
)

0 comments on commit 6a7872a

Please sign in to comment.