diff --git a/common/internal/BUILD b/common/internal/BUILD index 10084b685..ec39e689d 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -143,14 +143,16 @@ cc_library( srcs = ["signature.cc"], hdrs = ["signature.h"], deps = [ + "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) @@ -159,11 +161,14 @@ cc_test( srcs = ["signature_test.cc"], deps = [ ":signature", + "//common:ast", "//common:type", + "//common:type_kind", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/common/internal/signature.cc b/common/internal/signature.cc index f63049878..5cd0d5f5f 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -15,20 +15,29 @@ #include "common/internal/signature.h" #include +#include +#include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { +// Signature generator helper functions. namespace { void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { @@ -208,4 +217,327 @@ absl::StatusOr MakeOverloadSignature( return result; } + +// Signature parser helper functions. +namespace { + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + // NOMUTANTS -- equivalent mutant + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + // Scanning str for delimiter boundaries while ensuring + // brackets are balanced and escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == target && nesting == 0) { + if (find_last || found_idx == std::string_view::npos) { + found_idx = i; + } + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + // Scanning str for delimiter while ensuring brackets are balanced and + // escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + result.push_back(input_.substr(start)); + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + std::string_view param_name = signature.substr(1); + if (param_name.empty()) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(param_name) + .FindTopLevelChar('<', /*find_last=*/false)); + CEL_ASSIGN_OR_RETURN(size_t comma_idx, + SignatureScanner(param_name) + .FindTopLevelChar(',', /*find_last=*/false)); + if (less_idx != std::string_view::npos || + comma_idx != std::string_view::npos) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + return TypeSpec(ParamTypeSpec(Unescape(param_name))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + auto read_param_or_dyn = [&](size_t index) { + auto spec = std::make_unique(DynTypeSpec()); + if (params.size() > index) { + *spec = std::move(params[index]); + } + return spec; + }; + + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any") return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp") return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration") return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn") return TypeSpec(DynTypeSpec()); + + // Handle standard Protobuf well-known wrapper types to preserve + // backward compatibility for users migrating YAML configuration files. + if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" || + name_str == "google.protobuf.Int32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" || + name_str == "google.protobuf.UInt32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper" || + name_str == "google.protobuf.DoubleValue" || + name_str == "google.protobuf.FloatValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "type") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: type expects at most 1 parameter"); + } + return TypeSpec(read_param_or_dyn(0)); + } + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + return TypeSpec(ListTypeSpec(read_param_or_dyn(0))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = read_param_or_dyn(0); + auto value = read_param_or_dyn(1); + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = read_param_or_dyn(0); + std::vector arg_types; + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + if (name_str.empty() || absl::StrContains(name_str, "..")) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid identifier"); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN(size_t paren_idx, + SignatureScanner(signature, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || signature.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = signature.substr(0, paren_idx); + std::string_view args_str = + signature.substr(paren_idx + 1, signature.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(signature)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + } // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h index 3f31d8fd1..5b91f0b5c 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -20,7 +20,10 @@ #include #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { @@ -56,6 +59,29 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + TypeSpec signature_type; + + bool operator==(const ParsedFunctionOverload& other) const { + return function_name == other.function_name && + is_member == other.is_member && + signature_type == other.signature_type; + } +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 8e41c70fb..754f15de7 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -13,13 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -38,6 +42,100 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } +void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { + switch (original.kind()) { + case TypeKind::kDyn: + EXPECT_TRUE(parsed.has_dyn()); + break; + case TypeKind::kNull: + EXPECT_TRUE(parsed.has_null()); + break; + case TypeKind::kBool: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); + break; + case TypeKind::kInt: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); + break; + case TypeKind::kUint: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); + break; + case TypeKind::kDouble: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); + break; + case TypeKind::kString: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); + break; + case TypeKind::kBytes: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); + break; + case TypeKind::kAny: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); + break; + case TypeKind::kTimestamp: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); + break; + case TypeKind::kDuration: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); + break; + case TypeKind::kList: + EXPECT_TRUE(parsed.has_list_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.list_type().elem_type(), + original.GetParameters()[0]); + } + break; + case TypeKind::kMap: + EXPECT_TRUE(parsed.has_map_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.map_type().key_type(), + original.GetParameters()[0]); + } + if (original.GetParameters().size() > 1) { + VerifyParsedMatchesType(parsed.map_type().value_type(), + original.GetParameters()[1]); + } + break; + case TypeKind::kBoolWrapper: + case TypeKind::kIntWrapper: + case TypeKind::kUintWrapper: + case TypeKind::kDoubleWrapper: + case TypeKind::kStringWrapper: + case TypeKind::kBytesWrapper: + EXPECT_TRUE(parsed.has_wrapper()); + break; + case TypeKind::kType: + EXPECT_TRUE(parsed.has_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); + } + break; + case TypeKind::kTypeParam: + EXPECT_TRUE(parsed.has_type_param()); + break; + default: + EXPECT_TRUE(parsed.has_abstract_type()); + break; + } +} + +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + struct TypeSignatureTestCase { Type type; std::string expected_signature; @@ -73,10 +171,18 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), StringType{}), .expected_signature = "list", }, + { + .type = TypeType(GetTestArena(), IntType{}), + .expected_signature = "type", + }, { .type = ListType(GetTestArena(), TypeParamType("A")), .expected_signature = "list<~A>", }, + { + .type = ListType(GetTestArena(), TypeParamType("A GetTypeSignatureTestCases() { .expected_signature = "map<~B,~C>", }, { - .type = OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})}), - .expected_signature = "bar>", + .type = OpaqueType(GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), + {StringType{}, BoolType{}})}), + .expected_signature = "bar>", }, { .type = AnyType{}, @@ -104,10 +210,18 @@ std::vector GetTypeSignatureTestCases() { .type = TimestampType{}, .expected_signature = "timestamp", }, + { + .type = BoolWrapperType{}, + .expected_signature = "bool_wrapper", + }, { .type = IntWrapperType{}, .expected_signature = "int_wrapper", }, + { + .type = UintWrapperType{}, + .expected_signature = "uint_wrapper", + }, { .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")), @@ -117,22 +231,32 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", }, - { - .type = UnknownType{}, - .expected_error = - "Type kind: *unknown* is not supported in CEL declarations", - }, - { - .type = ErrorType{}, - .expected_error = - "Type kind: *error* is not supported in CEL declarations", - }, }; } +TEST(TypeSignatureTest, UnsupportedTypes) { + EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *unknown* is not supported"))); + + EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *error* is not supported"))); +} + INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + VerifyTypesEqual(*parsed, param.type); + } +} + struct OverloadSignatureTestCase { std::string function_name = "hello"; std::vector args; @@ -202,10 +326,18 @@ std::vector GetOverloadSignatureTestCases() { .args = {TimestampType{}}, .expected_signature = "hello(timestamp)", }, + { + .args = {BoolWrapperType{}}, + .expected_signature = "hello(bool_wrapper)", + }, { .args = {IntWrapperType{}}, .expected_signature = "hello(int_wrapper)", }, + { + .args = {UintWrapperType{}}, + .expected_signature = "hello(uint_wrapper)", + }, { .args = {MessageType( GetTestingDescriptorPool()->FindMessageTypeByName( @@ -213,9 +345,6 @@ std::vector GetOverloadSignatureTestCases() { .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, - {.args = {}, - .is_member = true, - .expected_error = "Member function with no receiver"}, { .args = {StringType{}}, .is_member = true, @@ -231,6 +360,18 @@ std::vector GetOverloadSignatureTestCases() { .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, + { + .function_name = "hello", + .args = {OpaqueType(GetTestArena(), "bar", + {TypeParamType("dummy.type")})}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, + { + .function_name = "inspect", + .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .expected_signature = "inspect(type)", + }, { .function_name = R"(h.(e),l\o)", .args = {StringType{}, @@ -242,8 +383,263 @@ std::vector GetOverloadSignatureTestCases() { }; } +TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { + auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + EXPECT_THAT(signature, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Member function with no receiver"))); +} + INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, ValuesIn(GetOverloadSignatureTestCases())); +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + +TEST(ParseSignatureTest, ProtoParsing) { + ASSERT_OK_AND_ASSIGN( + auto t1, ParseType("int", GetTestArena(), *GetTestingDescriptorPool())); + EXPECT_TRUE(t1.IsInt()); + + ASSERT_OK_AND_ASSIGN(auto t2, ParseType("list<~A>", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t2.IsList()); + + ASSERT_OK_AND_ASSIGN(auto t3, ParseType(R"(~abc\)", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t3.IsTypeParam()); + EXPECT_EQ(t3.GetTypeParam().name(), R"(abc\)"); + + ASSERT_OK_AND_ASSIGN(auto w1, + ParseType("google.protobuf.BoolValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w1.IsBoolWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w2, + ParseType("google.protobuf.Int64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w2.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w3, + ParseType("google.protobuf.Int32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w3.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w4, + ParseType("google.protobuf.UInt64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w4.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w5, + ParseType("google.protobuf.UInt32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w5.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w6, + ParseType("google.protobuf.DoubleValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w6.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w7, + ParseType("google.protobuf.FloatValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w7.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w8, + ParseType("google.protobuf.StringValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w8.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w9, + ParseType("google.protobuf.BytesValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w9.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w10, ParseType("string_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w10.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w11, ParseType("bytes_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w11.IsBytesWrapper()); +} + +TEST(ParseSignatureTest, FunctionParsing) { + ASSERT_OK_AND_ASSIGN(auto f1, ParseFunctionSignature("hello(string)")); + EXPECT_TRUE(f1.signature_type.has_function()); + EXPECT_EQ(f1.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto f2, ParseFunctionSignature("a.b.c()")); + EXPECT_TRUE(f2.is_member); + EXPECT_EQ(f2.function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + // Mismatched template brackets and parentheses. + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + + // Enforcing valid function and identifier names. + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + // Missing closing operators and boundary checks. + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map int, string>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + EXPECT_THAT(ParseFunctionSignature("a..b.c()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + + EXPECT_THAT( + ParseType("~list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, OverlyComplexSignatures) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); + + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + } // namespace } // namespace cel::common_internal