From 88b52ac0b425b03ddfc9be659957503879353675 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 26 Aug 2024 12:55:02 -0700 Subject: [PATCH] `TypePool` and proto conversion utilities PiperOrigin-RevId: 667682555 --- common/BUILD | 113 ++++++++ common/type.h | 14 + common/type_pool.cc | 62 +++++ common/type_pool.h | 115 ++++++++ common/type_pool_test.cc | 113 ++++++++ common/type_proto.cc | 353 ++++++++++++++++++++++++ common/type_proto.h | 37 +++ common/type_proto_test.cc | 388 +++++++++++++++++++++++++++ common/type_proto_v1alpha1.cc | 354 +++++++++++++++++++++++++ common/type_proto_v1alpha1.h | 40 +++ common/type_proto_v1alpha1_test.cc | 413 +++++++++++++++++++++++++++++ common/types/map_type_pool.h | 8 +- common/types/opaque_type_pool.h | 2 +- common/types/type_pool.cc | 96 ------- common/types/type_pool.h | 99 ------- common/types/type_pool_test.cc | 94 ------- 16 files changed, 2006 insertions(+), 295 deletions(-) create mode 100644 common/type_pool.cc create mode 100644 common/type_pool.h create mode 100644 common/type_pool_test.cc create mode 100644 common/type_proto.cc create mode 100644 common/type_proto.h create mode 100644 common/type_proto_test.cc create mode 100644 common/type_proto_v1alpha1.cc create mode 100644 common/type_proto_v1alpha1.h create mode 100644 common/type_proto_v1alpha1_test.cc delete mode 100644 common/types/type_pool.cc delete mode 100644 common/types/type_pool.h delete mode 100644 common/types/type_pool_test.cc diff --git a/common/BUILD b/common/BUILD index 8983d3b4c..4b29aa879 100644 --- a/common/BUILD +++ b/common/BUILD @@ -960,3 +960,116 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "type_pool", + srcs = ["type_pool.cc"], + hdrs = ["type_pool.h"], + deps = [ + ":arena_string", + ":arena_string_pool", + ":type", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_pool_test", + srcs = ["type_pool_test.cc"], + deps = [ + ":arena_string_pool", + ":type", + ":type_pool", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_proto", + srcs = ["type_proto.cc"], + hdrs = ["type_proto.h"], + deps = [ + ":type", + ":type_kind", + ":type_pool", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_proto_test", + srcs = ["type_proto_test.cc"], + deps = [ + ":arena_string_pool", + ":type", + ":type_pool", + ":type_proto", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_proto_v1alpha1", + srcs = ["type_proto_v1alpha1.cc"], + hdrs = ["type_proto_v1alpha1.h"], + deps = [ + ":type", + ":type_kind", + ":type_pool", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_proto_v1alpha1_test", + srcs = ["type_proto_v1alpha1_test.cc"], + deps = [ + ":arena_string_pool", + ":type", + ":type_pool", + ":type_proto_v1alpha1", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/common/type.h b/common/type.h index 9d849dc60..3652a89d4 100644 --- a/common/type.h +++ b/common/type.h @@ -860,6 +860,20 @@ class TypeParameters final { }; }; +inline bool operator==(const TypeParameters& lhs, const TypeParameters& rhs) { + return absl::c_equal(lhs, rhs); +} + +inline bool operator!=(const TypeParameters& lhs, const TypeParameters& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const TypeParameters& parameters) { + return H::combine_contiguous(std::move(state), parameters.data(), + parameters.size()); +} + // Now that TypeParameters is defined, we can define `GetParameters()` for most // types. diff --git a/common/type_pool.cc b/common/type_pool.cc new file mode 100644 index 000000000..df98d8bbd --- /dev/null +++ b/common/type_pool.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_pool.h" + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +ListType TypePool::MakeListType(const Type& element) { + return list_type_pool_.InternListType(element); +} + +MapType TypePool::MakeMapType(const Type& key, const Type& value) { + return map_type_pool_.InternMapType(key, value); +} + +StructType TypePool::MakeStructType(absl::string_view name) { + if (descriptor_pool_ != nullptr) { + const google::protobuf::Descriptor* descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor != nullptr) { + return MessageType(descriptor); + } + } + return common_internal::MakeBasicStructType(string_pool_->InternString(name)); +} + +FunctionType TypePool::MakeFunctionType(const Type& result, + absl::Span args) { + return function_type_pool_.InternFunctionType(result, args); +} + +OpaqueType TypePool::MakeOpaqueType(absl::string_view name, + absl::Span params) { + return opaque_type_pool_.InternOpaqueType(string_pool_->InternString(name), + params); +} + +TypeParamType TypePool::MakeTypeParamType(absl::string_view name) { + return TypeParamType(string_pool_->InternString(name)); +} + +TypeType TypePool::MakeTypeType(const Type& type) { + return type_type_pool_.InternTypeType(type); +} + +} // namespace cel diff --git a/common/type_pool.h b/common/type_pool.h new file mode 100644 index 000000000..e20915988 --- /dev/null +++ b/common/type_pool.h @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_POOL_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/arena_string.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "common/types/function_type_pool.h" +#include "common/types/list_type_pool.h" +#include "common/types/map_type_pool.h" +#include "common/types/opaque_type_pool.h" +#include "common/types/type_type_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class TypePool; + +absl::Nonnull> NewTypePool( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull string_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class TypePool final { + public: + TypePool(const TypePool&) = delete; + TypePool(TypePool&&) = delete; + TypePool& operator=(const TypePool&) = delete; + TypePool& operator=(TypePool&&) = delete; + + ListType MakeListType(const Type& element); + + MapType MakeMapType(const Type& key, const Type& value); + + StructType MakeStructType(absl::string_view name); + + StructType MakeStructType(ArenaString) = delete; + + FunctionType MakeFunctionType(const Type& result, + absl::Span args); + + OpaqueType MakeOpaqueType(absl::string_view name, + absl::Span params); + + OpaqueType MakeOpaqueType(ArenaString, absl::Span) = delete; + + OptionalType MakeOptionalType(const Type& param) { + return static_cast( + MakeOpaqueType(OptionalType::kName, absl::MakeConstSpan(¶m, 1))); + } + + TypeParamType MakeTypeParamType(absl::string_view name); + + TypeParamType MakeTypeParamType(ArenaString) = delete; + + TypeType MakeTypeType(const Type& type); + + private: + friend absl::Nonnull> NewTypePool( + absl::Nonnull, absl::Nonnull, + absl::Nullable); + + TypePool(absl::Nonnull arena, + absl::Nonnull string_pool, + absl::Nullable descriptor_pool) + : string_pool_(string_pool), + descriptor_pool_(descriptor_pool), + function_type_pool_(arena), + list_type_pool_(arena), + map_type_pool_(arena), + opaque_type_pool_(arena), + type_type_pool_(arena) {} + + absl::Nonnull const string_pool_; + absl::Nullable const descriptor_pool_; + common_internal::FunctionTypePool function_type_pool_; + common_internal::ListTypePool list_type_pool_; + common_internal::MapTypePool map_type_pool_; + common_internal::OpaqueTypePool opaque_type_pool_; + common_internal::TypeTypePool type_type_pool_; +}; + +inline absl::Nonnull> NewTypePool( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull string_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return std::unique_ptr( + new TypePool(arena, string_pool, descriptor_pool)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_POOL_H_ diff --git a/common/type_pool_test.cc b/common/type_pool_test.cc new file mode 100644 index 000000000..5830f427a --- /dev/null +++ b/common/type_pool_test.cc @@ -0,0 +1,113 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_pool.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/types/optional.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::_; +using ::testing::Test; + +class TypePoolTest : public Test { + public: + void SetUp() override { + arena_.emplace(); + string_pool_ = NewArenaStringPool(arena()); + type_pool_ = + NewTypePool(arena(), string_pool(), GetTestingDescriptorPool()); + } + + void TearDown() override { + type_pool_.reset(); + string_pool_.reset(); + arena_.reset(); + } + + absl::Nonnull arena() { return &*arena_; } + + absl::Nonnull string_pool() { return string_pool_.get(); } + + absl::Nonnull type_pool() { return type_pool_.get(); } + + private: + absl::optional arena_; + std::unique_ptr string_pool_; + std::unique_ptr type_pool_; +}; + +TEST_F(TypePoolTest, MakeStructType) { + EXPECT_EQ(type_pool()->MakeStructType("foo.Bar"), + common_internal::MakeBasicStructType("foo.Bar")); + EXPECT_TRUE( + type_pool() + ->MakeStructType("google.api.expr.test.v1.proto3.TestAllTypes") + .IsMessage()); + EXPECT_DEBUG_DEATH(static_cast(type_pool()->MakeStructType( + "google.protobuf.BoolValue")), + _); +} + +TEST_F(TypePoolTest, MakeFunctionType) { + EXPECT_EQ(type_pool()->MakeFunctionType(BoolType(), {IntType(), IntType()}), + FunctionType(arena(), BoolType(), {IntType(), IntType()})); +} + +TEST_F(TypePoolTest, MakeListType) { + EXPECT_EQ(type_pool()->MakeListType(DynType()), ListType()); + EXPECT_EQ(type_pool()->MakeListType(DynType()), JsonListType()); + EXPECT_EQ(type_pool()->MakeListType(StringType()), + ListType(arena(), StringType())); +} + +TEST_F(TypePoolTest, MakeMapType) { + EXPECT_EQ(type_pool()->MakeMapType(DynType(), DynType()), MapType()); + EXPECT_EQ(type_pool()->MakeMapType(StringType(), DynType()), JsonMapType()); + EXPECT_EQ(type_pool()->MakeMapType(StringType(), StringType()), + MapType(arena(), StringType(), StringType())); +} + +TEST_F(TypePoolTest, MakeOpaqueType) { + EXPECT_EQ(type_pool()->MakeOpaqueType("custom_type", {DynType(), DynType()}), + OpaqueType(arena(), "custom_type", {DynType(), DynType()})); +} + +TEST_F(TypePoolTest, MakeOptionalType) { + EXPECT_EQ(type_pool()->MakeOptionalType(DynType()), OptionalType()); + EXPECT_EQ(type_pool()->MakeOptionalType(StringType()), + OptionalType(arena(), StringType())); +} + +TEST_F(TypePoolTest, MakeTypeParamType) { + EXPECT_EQ(type_pool()->MakeTypeParamType("T"), TypeParamType("T")); +} + +TEST_F(TypePoolTest, MakeTypeType) { + EXPECT_EQ(type_pool()->MakeTypeType(BoolType()), + TypeType(arena(), BoolType())); +} + +} // namespace +} // namespace cel diff --git a/common/type_proto.cc b/common/type_proto.cc new file mode 100644 index 000000000..c6afe869c --- /dev/null +++ b/common/type_proto.cc @@ -0,0 +1,353 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "common/type_pool.h" + +namespace cel { + +namespace { + +using TypeProto = ::cel::expr::Type; +using ListTypeProto = typename TypeProto::ListType; +using MapTypeProto = typename TypeProto::MapType; +using FunctionTypeProto = typename TypeProto::FunctionType; +using OpaqueTypeProto = typename TypeProto::AbstractType; +using PrimitiveTypeProto = typename TypeProto::PrimitiveType; +using WellKnownTypeProto = typename TypeProto::WellKnownType; + +struct TypeFromProtoConverter final { + explicit TypeFromProtoConverter(absl::Nonnull type_pool) + : type_pool(type_pool) {} + + absl::optional FromType(const TypeProto& proto) { + switch (proto.type_kind_case()) { + case TypeProto::TYPE_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case TypeProto::kDyn: + return DynType(); + case TypeProto::kNull: + return NullType(); + case TypeProto::kPrimitive: + switch (proto.primitive()) { + case TypeProto::BOOL: + return BoolType(); + case TypeProto::INT64: + return IntType(); + case TypeProto::UINT64: + return UintType(); + case TypeProto::DOUBLE: + return DoubleType(); + case TypeProto::STRING: + return StringType(); + case TypeProto::BYTES: + return BytesType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected primitive type kind: ", proto.primitive())); + return absl::nullopt; + } + case TypeProto::kWrapper: + switch (proto.wrapper()) { + case TypeProto::BOOL: + return BoolWrapperType(); + case TypeProto::INT64: + return IntWrapperType(); + case TypeProto::UINT64: + return UintWrapperType(); + case TypeProto::DOUBLE: + return DoubleWrapperType(); + case TypeProto::STRING: + return StringWrapperType(); + case TypeProto::BYTES: + return BytesWrapperType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected wrapper type kind: ", proto.wrapper())); + return absl::nullopt; + } + case TypeProto::kWellKnown: + switch (proto.well_known()) { + case TypeProto::ANY: + return AnyType(); + case TypeProto::DURATION: + return DurationType(); + case TypeProto::TIMESTAMP: + return TimestampType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected well known type kind: ", proto.well_known())); + return absl::nullopt; + } + case TypeProto::kListType: { + auto elem = FromType(proto.list_type().elem_type()); + if (ABSL_PREDICT_FALSE(!elem.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeListType(*elem); + } + case TypeProto::kMapType: { + auto key = FromType(proto.map_type().key_type()); + if (ABSL_PREDICT_FALSE(!key.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + auto value = FromType(proto.map_type().value_type()); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeMapType(*key, *value); + } + case TypeProto::kFunction: { + auto result = FromType(proto.function().result_type()); + if (ABSL_PREDICT_FALSE(!result.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + absl::InlinedVector args; + args.reserve(static_cast(proto.function().arg_types().size())); + for (const auto& arg_proto : proto.function().arg_types()) { + auto arg = FromType(arg_proto); + if (ABSL_PREDICT_FALSE(!arg.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + args.push_back(*arg); + } + return type_pool->MakeFunctionType(*result, args); + } + case TypeProto::kMessageType: + if (ABSL_PREDICT_FALSE(proto.message_type().empty())) { + status = + absl::InvalidArgumentError("unexpected empty message type name"); + return absl::nullopt; + } + if (ABSL_PREDICT_FALSE(IsWellKnownMessageType(proto.message_type()))) { + status = absl::InvalidArgumentError( + absl::StrCat("well known type masquerading as message type: ", + proto.message_type())); + return absl::nullopt; + } + return type_pool->MakeStructType(proto.message_type()); + case TypeProto::kTypeParam: + if (ABSL_PREDICT_FALSE(proto.type_param().empty())) { + status = + absl::InvalidArgumentError("unexpected empty type param name"); + return absl::nullopt; + } + return type_pool->MakeTypeParamType(proto.type_param()); + case TypeProto::kType: { + auto type = FromType(proto.type()); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeTypeType(*type); + } + case TypeProto::kError: + return ErrorType(); + case TypeProto::kAbstractType: { + if (proto.abstract_type().name().empty()) { + status = + absl::InvalidArgumentError("unexpected empty opaque type name"); + return absl::nullopt; + } + absl::InlinedVector params; + params.reserve(static_cast( + proto.abstract_type().parameter_types().size())); + for (const auto& param_proto : + proto.abstract_type().parameter_types()) { + auto param = FromType(param_proto); + if (ABSL_PREDICT_FALSE(!param.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + params.push_back(*param); + } + return type_pool->MakeOpaqueType(proto.abstract_type().name(), params); + } + default: + status = absl::DataLossError(absl::StrCat("unexpected type kind case: ", + proto.type_kind_case())); + return absl::nullopt; + } + } + + absl::Nonnull const type_pool; + absl::Status status; +}; + +} // namespace + +absl::StatusOr TypeFromProto(absl::Nonnull type_pool, + const TypeProto& proto) { + TypeFromProtoConverter converter(type_pool); + auto type = converter.FromType(proto); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return *type; +} + +namespace { + +struct TypeToProtoConverter final { + bool FromType(const Type& type, absl::Nonnull proto) { + switch (type.kind()) { + case TypeKind::kDyn: + proto->mutable_dyn(); + return true; + case TypeKind::kNull: + proto->set_null(google::protobuf::NULL_VALUE); + return true; + case TypeKind::kBool: + proto->set_primitive(TypeProto::BOOL); + return true; + case TypeKind::kInt: + proto->set_primitive(TypeProto::INT64); + return true; + case TypeKind::kUint: + proto->set_primitive(TypeProto::UINT64); + return true; + case TypeKind::kDouble: + proto->set_primitive(TypeProto::DOUBLE); + return true; + case TypeKind::kBytes: + proto->set_primitive(TypeProto::BYTES); + return true; + case TypeKind::kString: + proto->set_primitive(TypeProto::STRING); + return true; + case TypeKind::kBoolWrapper: + proto->set_wrapper(TypeProto::BOOL); + return true; + case TypeKind::kIntWrapper: + proto->set_wrapper(TypeProto::INT64); + return true; + case TypeKind::kUintWrapper: + proto->set_wrapper(TypeProto::UINT64); + return true; + case TypeKind::kDoubleWrapper: + proto->set_wrapper(TypeProto::DOUBLE); + return true; + case TypeKind::kBytesWrapper: + proto->set_wrapper(TypeProto::BYTES); + return true; + case TypeKind::kStringWrapper: + proto->set_wrapper(TypeProto::STRING); + return true; + case TypeKind::kAny: + proto->set_well_known(TypeProto::ANY); + return true; + case TypeKind::kDuration: + proto->set_well_known(TypeProto::DURATION); + return true; + case TypeKind::kTimestamp: + proto->set_well_known(TypeProto::TIMESTAMP); + return true; + case TypeKind::kList: + return FromType(static_cast(type).GetElement(), + proto->mutable_list_type()->mutable_elem_type()); + case TypeKind::kMap: + return FromType(static_cast(type).GetKey(), + proto->mutable_map_type()->mutable_key_type()) && + FromType(static_cast(type).GetValue(), + proto->mutable_map_type()->mutable_value_type()); + case TypeKind::kStruct: + proto->set_message_type(static_cast(type).name()); + return true; + case TypeKind::kOpaque: { + auto opaque_type = static_cast(type); + auto* opaque_type_proto = proto->mutable_abstract_type(); + opaque_type_proto->set_name(opaque_type.name()); + auto opaque_type_params = opaque_type.GetParameters(); + opaque_type_proto->mutable_parameter_types()->Reserve( + static_cast(opaque_type_params.size())); + for (const auto& param : opaque_type_params) { + if (ABSL_PREDICT_FALSE( + !FromType(param, opaque_type_proto->add_parameter_types()))) { + return false; + } + } + return true; + } + case TypeKind::kTypeParam: + proto->set_type_param(static_cast(type).name()); + return true; + case TypeKind::kType: + return FromType(static_cast(type).GetType(), + proto->mutable_type()); + case TypeKind::kFunction: { + auto function_type = static_cast(type); + auto* function_type_proto = proto->mutable_function(); + if (ABSL_PREDICT_FALSE( + !FromType(function_type.result(), + function_type_proto->mutable_result_type()))) { + return false; + } + auto function_type_args = function_type.args(); + function_type_proto->mutable_arg_types()->Reserve( + static_cast(function_type_args.size())); + for (const auto& arg : function_type_args) { + if (ABSL_PREDICT_FALSE( + !FromType(arg, function_type_proto->add_arg_types()))) { + return false; + } + } + return true; + } + case TypeKind::kError: + proto->mutable_error(); + return true; + default: + status = absl::DataLossError( + absl::StrCat("unexpected type kind: ", type.kind())); + return false; + } + } + + absl::Status status; +}; + +} // namespace + +absl::Status TypeToProto(const Type& type, absl::Nonnull proto) { + TypeToProtoConverter converter; + if (ABSL_PREDICT_FALSE(!converter.FromType(type, proto))) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/type_proto.h b/common/type_proto.h new file mode 100644 index 000000000..8b32e64a2 --- /dev/null +++ b/common/type_proto.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "common/type_pool.h" + +namespace cel { + +// TypeFromProto converts `cel::expr::Type` to `cel::Type`. +absl::StatusOr TypeFromProto(absl::Nonnull type_pool, + const cel::expr::Type& proto); + +// TypeToProto converts `cel::Type` to `cel::expr::Type`. +absl::Status TypeToProto(const Type& type, + absl::Nonnull proto); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ diff --git a/common/type_proto_test.cc b/common/type_proto_test.cc new file mode 100644 index 000000000..260de4213 --- /dev/null +++ b/common/type_proto_test.cc @@ -0,0 +1,388 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/optional.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "common/type_pool.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::test::EqualsProto; +using ::testing::Eq; +using ::testing::Test; + +using TypeProto = ::cel::expr::Type; + +class TypeProtoTest : public Test { + public: + void SetUp() override { + arena_.emplace(); + string_pool_ = NewArenaStringPool(arena()); + type_pool_ = + NewTypePool(arena(), string_pool(), GetTestingDescriptorPool()); + } + + void TearDown() override { + type_pool_.reset(); + string_pool_.reset(); + arena_.reset(); + } + + absl::Nonnull arena() { return &*arena_; } + + absl::Nonnull string_pool() { return string_pool_.get(); } + + absl::Nonnull type_pool() { return type_pool_.get(); } + + private: + absl::optional arena_; + std::unique_ptr string_pool_; + std::unique_ptr type_pool_; +}; + +TEST_F(TypeProtoTest, Dyn) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(dyn: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DynType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Null) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(null: NULL_VALUE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(NullType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Bool) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Int) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Uint) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Double) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, String) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Bytes) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BoolWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, IntWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, UintWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, DoubleWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, StringWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BytesWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Any) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: ANY)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(AnyType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Duration) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: DURATION)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DurationType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Timestamp) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(well_known: TIMESTAMP)pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TimestampType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, List) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(list_type: { elem_type: { primitive: BOOL } })pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ListType(arena(), BoolType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Map) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(map_type: { + key_type: { primitive: INT64 } + value_type: { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(MapType(arena(), IntType(), StringType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Function) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(function: { + result_type: { primitive: INT64 } + arg_types { primitive: STRING } + arg_types { primitive: INT64 } + arg_types { primitive: UINT64 } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(FunctionType(arena(), IntType(), + {StringType(), IntType(), UintType()}))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Struct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(message_type: "google.protobuf.Empty")pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT( + got, Eq(common_internal::MakeBasicStructType("google.protobuf.Empty"))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BadStruct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(message_type: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProto(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoTest, TypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "T")pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeParamType("T"))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BadTypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProto(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoTest, Type) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type: { dyn: {} })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeType(arena(), DynType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Error) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(error: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ErrorType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, Opaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { + name: "optional_type" + parameter_types { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, TypeFromProto(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(OptionalType(arena(), StringType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProto(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoTest, BadOpaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { name: "" })pb", &expected_proto)); + EXPECT_THAT(TypeFromProto(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel diff --git a/common/type_proto_v1alpha1.cc b/common/type_proto_v1alpha1.cc new file mode 100644 index 000000000..39bec96e1 --- /dev/null +++ b/common/type_proto_v1alpha1.cc @@ -0,0 +1,354 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto_v1alpha1.h" + +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "common/type_pool.h" + +namespace cel { + +namespace { + +using TypeProto = ::google::api::expr::v1alpha1::Type; +using ListTypeProto = typename TypeProto::ListType; +using MapTypeProto = typename TypeProto::MapType; +using FunctionTypeProto = typename TypeProto::FunctionType; +using OpaqueTypeProto = typename TypeProto::AbstractType; +using PrimitiveTypeProto = typename TypeProto::PrimitiveType; +using WellKnownTypeProto = typename TypeProto::WellKnownType; + +struct TypeFromProtoConverter final { + explicit TypeFromProtoConverter(absl::Nonnull type_pool) + : type_pool(type_pool) {} + + absl::optional FromType(const TypeProto& proto) { + switch (proto.type_kind_case()) { + case TypeProto::TYPE_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case TypeProto::kDyn: + return DynType(); + case TypeProto::kNull: + return NullType(); + case TypeProto::kPrimitive: + switch (proto.primitive()) { + case TypeProto::BOOL: + return BoolType(); + case TypeProto::INT64: + return IntType(); + case TypeProto::UINT64: + return UintType(); + case TypeProto::DOUBLE: + return DoubleType(); + case TypeProto::STRING: + return StringType(); + case TypeProto::BYTES: + return BytesType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected primitive type kind: ", proto.primitive())); + return absl::nullopt; + } + case TypeProto::kWrapper: + switch (proto.wrapper()) { + case TypeProto::BOOL: + return BoolWrapperType(); + case TypeProto::INT64: + return IntWrapperType(); + case TypeProto::UINT64: + return UintWrapperType(); + case TypeProto::DOUBLE: + return DoubleWrapperType(); + case TypeProto::STRING: + return StringWrapperType(); + case TypeProto::BYTES: + return BytesWrapperType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected wrapper type kind: ", proto.wrapper())); + return absl::nullopt; + } + case TypeProto::kWellKnown: + switch (proto.well_known()) { + case TypeProto::ANY: + return AnyType(); + case TypeProto::DURATION: + return DurationType(); + case TypeProto::TIMESTAMP: + return TimestampType(); + default: + status = absl::DataLossError(absl::StrCat( + "unexpected well known type kind: ", proto.well_known())); + return absl::nullopt; + } + case TypeProto::kListType: { + auto elem = FromType(proto.list_type().elem_type()); + if (ABSL_PREDICT_FALSE(!elem.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeListType(*elem); + } + case TypeProto::kMapType: { + auto key = FromType(proto.map_type().key_type()); + if (ABSL_PREDICT_FALSE(!key.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + auto value = FromType(proto.map_type().value_type()); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeMapType(*key, *value); + } + case TypeProto::kFunction: { + auto result = FromType(proto.function().result_type()); + if (ABSL_PREDICT_FALSE(!result.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + absl::InlinedVector args; + args.reserve(static_cast(proto.function().arg_types().size())); + for (const auto& arg_proto : proto.function().arg_types()) { + auto arg = FromType(arg_proto); + if (ABSL_PREDICT_FALSE(!arg.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + args.push_back(*arg); + } + return type_pool->MakeFunctionType(*result, args); + } + case TypeProto::kMessageType: + if (ABSL_PREDICT_FALSE(proto.message_type().empty())) { + status = + absl::InvalidArgumentError("unexpected empty message type name"); + return absl::nullopt; + } + if (ABSL_PREDICT_FALSE(IsWellKnownMessageType(proto.message_type()))) { + status = absl::InvalidArgumentError( + absl::StrCat("well known type masquerading as message type: ", + proto.message_type())); + return absl::nullopt; + } + return type_pool->MakeStructType(proto.message_type()); + case TypeProto::kTypeParam: + if (ABSL_PREDICT_FALSE(proto.type_param().empty())) { + status = + absl::InvalidArgumentError("unexpected empty type param name"); + return absl::nullopt; + } + return type_pool->MakeTypeParamType(proto.type_param()); + case TypeProto::kType: { + auto type = FromType(proto.type()); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + return type_pool->MakeTypeType(*type); + } + case TypeProto::kError: + return ErrorType(); + case TypeProto::kAbstractType: { + if (proto.abstract_type().name().empty()) { + status = + absl::InvalidArgumentError("unexpected empty opaque type name"); + return absl::nullopt; + } + absl::InlinedVector params; + params.reserve(static_cast( + proto.abstract_type().parameter_types().size())); + for (const auto& param_proto : + proto.abstract_type().parameter_types()) { + auto param = FromType(param_proto); + if (ABSL_PREDICT_FALSE(!param.has_value())) { + ABSL_DCHECK(!status.ok()); + return absl::nullopt; + } + params.push_back(*param); + } + return type_pool->MakeOpaqueType(proto.abstract_type().name(), params); + } + default: + status = absl::DataLossError(absl::StrCat("unexpected type kind case: ", + proto.type_kind_case())); + return absl::nullopt; + } + } + + absl::Nonnull const type_pool; + absl::Status status; +}; + +} // namespace + +absl::StatusOr TypeFromProtoV1Alpha1(absl::Nonnull type_pool, + const TypeProto& proto) { + TypeFromProtoConverter converter(type_pool); + auto type = converter.FromType(proto); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return *type; +} + +namespace { + +struct TypeToProtoConverter final { + bool FromType(const Type& type, absl::Nonnull proto) { + switch (type.kind()) { + case TypeKind::kDyn: + proto->mutable_dyn(); + return true; + case TypeKind::kNull: + proto->set_null(google::protobuf::NULL_VALUE); + return true; + case TypeKind::kBool: + proto->set_primitive(TypeProto::BOOL); + return true; + case TypeKind::kInt: + proto->set_primitive(TypeProto::INT64); + return true; + case TypeKind::kUint: + proto->set_primitive(TypeProto::UINT64); + return true; + case TypeKind::kDouble: + proto->set_primitive(TypeProto::DOUBLE); + return true; + case TypeKind::kBytes: + proto->set_primitive(TypeProto::BYTES); + return true; + case TypeKind::kString: + proto->set_primitive(TypeProto::STRING); + return true; + case TypeKind::kBoolWrapper: + proto->set_wrapper(TypeProto::BOOL); + return true; + case TypeKind::kIntWrapper: + proto->set_wrapper(TypeProto::INT64); + return true; + case TypeKind::kUintWrapper: + proto->set_wrapper(TypeProto::UINT64); + return true; + case TypeKind::kDoubleWrapper: + proto->set_wrapper(TypeProto::DOUBLE); + return true; + case TypeKind::kBytesWrapper: + proto->set_wrapper(TypeProto::BYTES); + return true; + case TypeKind::kStringWrapper: + proto->set_wrapper(TypeProto::STRING); + return true; + case TypeKind::kAny: + proto->set_well_known(TypeProto::ANY); + return true; + case TypeKind::kDuration: + proto->set_well_known(TypeProto::DURATION); + return true; + case TypeKind::kTimestamp: + proto->set_well_known(TypeProto::TIMESTAMP); + return true; + case TypeKind::kList: + return FromType(static_cast(type).GetElement(), + proto->mutable_list_type()->mutable_elem_type()); + case TypeKind::kMap: + return FromType(static_cast(type).GetKey(), + proto->mutable_map_type()->mutable_key_type()) && + FromType(static_cast(type).GetValue(), + proto->mutable_map_type()->mutable_value_type()); + case TypeKind::kStruct: + proto->set_message_type(static_cast(type).name()); + return true; + case TypeKind::kOpaque: { + auto opaque_type = static_cast(type); + auto* opaque_type_proto = proto->mutable_abstract_type(); + opaque_type_proto->set_name(opaque_type.name()); + auto opaque_type_params = opaque_type.GetParameters(); + opaque_type_proto->mutable_parameter_types()->Reserve( + static_cast(opaque_type_params.size())); + for (const auto& param : opaque_type_params) { + if (ABSL_PREDICT_FALSE( + !FromType(param, opaque_type_proto->add_parameter_types()))) { + return false; + } + } + return true; + } + case TypeKind::kTypeParam: + proto->set_type_param(static_cast(type).name()); + return true; + case TypeKind::kType: + return FromType(static_cast(type).GetType(), + proto->mutable_type()); + case TypeKind::kFunction: { + auto function_type = static_cast(type); + auto* function_type_proto = proto->mutable_function(); + if (ABSL_PREDICT_FALSE( + !FromType(function_type.result(), + function_type_proto->mutable_result_type()))) { + return false; + } + auto function_type_args = function_type.args(); + function_type_proto->mutable_arg_types()->Reserve( + static_cast(function_type_args.size())); + for (const auto& arg : function_type_args) { + if (ABSL_PREDICT_FALSE( + !FromType(arg, function_type_proto->add_arg_types()))) { + return false; + } + } + return true; + } + case TypeKind::kError: + proto->mutable_error(); + return true; + default: + status = absl::DataLossError( + absl::StrCat("unexpected type kind: ", type.kind())); + return false; + } + } + + absl::Status status; +}; + +} // namespace + +absl::Status TypeToProtoV1Alpha1(const Type& type, + absl::Nonnull proto) { + TypeToProtoConverter converter; + if (ABSL_PREDICT_FALSE(!converter.FromType(type, proto))) { + ABSL_DCHECK(!converter.status.ok()); + return converter.status; + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/type_proto_v1alpha1.h b/common/type_proto_v1alpha1.h new file mode 100644 index 000000000..b9cc92600 --- /dev/null +++ b/common/type_proto_v1alpha1.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_V1ALPHA1_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_V1ALPHA1_H_ + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "common/type_pool.h" + +namespace cel { + +// TypeFromProtoV1Alpha1 converts `google::api::expr::v1alpha1::Type` to +// `cel::Type`. +absl::StatusOr TypeFromProtoV1Alpha1( + absl::Nonnull type_pool, + const google::api::expr::v1alpha1::Type& proto); + +// TypeToProtoV1Alpha1 converts `cel::Type` to +// `google::api::expr::v1alpha1::Type`. +absl::Status TypeToProtoV1Alpha1( + const Type& type, absl::Nonnull proto); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_V1ALPHA1_H_ diff --git a/common/type_proto_v1alpha1_test.cc b/common/type_proto_v1alpha1_test.cc new file mode 100644 index 000000000..958edcf4c --- /dev/null +++ b/common/type_proto_v1alpha1_test.cc @@ -0,0 +1,413 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto_v1alpha1.h" + +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/optional.h" +#include "common/arena_string_pool.h" +#include "common/type.h" +#include "common/type_pool.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::test::EqualsProto; +using ::testing::Eq; +using ::testing::Test; + +using TypeProto = ::google::api::expr::v1alpha1::Type; + +class TypeProtoV1Alpha1Test : public Test { + public: + void SetUp() override { + arena_.emplace(); + string_pool_ = NewArenaStringPool(arena()); + type_pool_ = + NewTypePool(arena(), string_pool(), GetTestingDescriptorPool()); + } + + void TearDown() override { + type_pool_.reset(); + string_pool_.reset(); + arena_.reset(); + } + + absl::Nonnull arena() { return &*arena_; } + + absl::Nonnull string_pool() { return string_pool_.get(); } + + absl::Nonnull type_pool() { return type_pool_.get(); } + + private: + absl::optional arena_; + std::unique_ptr string_pool_; + std::unique_ptr type_pool_; +}; + +TEST_F(TypeProtoV1Alpha1Test, Dyn) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(dyn: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DynType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Null) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(null: NULL_VALUE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(NullType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Bool) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Int) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Uint) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Double) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, String) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Bytes) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(primitive: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BoolWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BOOL)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BoolWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, IntWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: INT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(IntWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, UintWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: UINT64)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(UintWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, DoubleWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: DOUBLE)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DoubleWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, StringWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: STRING)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(StringWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BytesWrapper) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(wrapper: BYTES)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(BytesWrapperType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Any) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: ANY)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(AnyType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Duration) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(well_known: DURATION)pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(DurationType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Timestamp) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(well_known: TIMESTAMP)pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TimestampType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, List) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(list_type: { elem_type: { primitive: BOOL } })pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ListType(arena(), BoolType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Map) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(map_type: { + key_type: { primitive: INT64 } + value_type: { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(MapType(arena(), IntType(), StringType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Function) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(function: { + result_type: { primitive: INT64 } + arg_types { primitive: STRING } + arg_types { primitive: INT64 } + arg_types { primitive: UINT64 } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(FunctionType(arena(), IntType(), + {StringType(), IntType(), UintType()}))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Struct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(message_type: "google.protobuf.Empty")pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT( + got, Eq(common_internal::MakeBasicStructType("google.protobuf.Empty"))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BadStruct) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(message_type: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProtoV1Alpha1(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoV1Alpha1Test, TypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "T")pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeParamType("T"))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BadTypeParam) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type_param: "")pb", + &expected_proto)); + EXPECT_THAT(TypeFromProtoV1Alpha1(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(TypeProtoV1Alpha1Test, Type) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb(type: { dyn: {} })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(TypeType(arena(), DynType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Error) { + TypeProto expected_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb(error: {})pb", &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(ErrorType())); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, Opaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { + name: "optional_type" + parameter_types { primitive: STRING } + })pb", + &expected_proto)); + ASSERT_OK_AND_ASSIGN(auto got, + TypeFromProtoV1Alpha1(type_pool(), expected_proto)); + EXPECT_THAT(got, Eq(OptionalType(arena(), StringType()))); + TypeProto got_proto; + EXPECT_THAT(TypeToProtoV1Alpha1(got, &got_proto), IsOk()); + EXPECT_THAT(got_proto, EqualsProto(expected_proto)); +} + +TEST_F(TypeProtoV1Alpha1Test, BadOpaque) { + TypeProto expected_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb(abstract_type: { name: "" })pb", &expected_proto)); + EXPECT_THAT(TypeFromProtoV1Alpha1(type_pool(), expected_proto), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel diff --git a/common/types/map_type_pool.h b/common/types/map_type_pool.h index d86ddb2e9..29b21f154 100644 --- a/common/types/map_type_pool.h +++ b/common/types/map_type_pool.h @@ -18,7 +18,6 @@ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ #include -#include #include #include "absl/base/nullability.h" @@ -41,15 +40,14 @@ class MapTypePool final { MapType InternMapType(const Type& key, const Type& value); private: - using MapTypeTuple = std::tuple, - std::reference_wrapper>; + using MapTypeTuple = std::tuple; static MapTypeTuple AsTuple(const MapType& map_type) { - return AsTuple(map_type.key(), map_type.value()); + return AsTuple(map_type.GetKey(), map_type.GetValue()); } static MapTypeTuple AsTuple(const Type& key, const Type& value) { - return MapTypeTuple{std::cref(key), std::cref(value)}; + return MapTypeTuple{key, value}; } struct Hasher { diff --git a/common/types/opaque_type_pool.h b/common/types/opaque_type_pool.h index 60b2b3c39..fe079febd 100644 --- a/common/types/opaque_type_pool.h +++ b/common/types/opaque_type_pool.h @@ -45,7 +45,7 @@ class OpaqueTypePool final { absl::Span parameters); private: - using OpaqueTypeTuple = std::tuple>; + using OpaqueTypeTuple = std::tuple; static OpaqueTypeTuple AsTuple(const OpaqueType& opaque_type) { return AsTuple(opaque_type.name(), opaque_type.GetParameters()); diff --git a/common/types/type_pool.cc b/common/types/type_pool.cc deleted file mode 100644 index fdbae2418..000000000 --- a/common/types/type_pool.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "common/types/type_pool.h" - -#include "absl/base/optimization.h" -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "common/type.h" - -namespace cel::common_internal { - -StructType TypePool::MakeStructType(absl::string_view name) { - ABSL_DCHECK(!IsWellKnownMessageType(name)) << name; - if (ABSL_PREDICT_FALSE(name.empty())) { - return StructType(); - } - if (const auto* descriptor = descriptors_->FindMessageTypeByName(name); - descriptor != nullptr) { - return MessageType(descriptor); - } - return MakeBasicStructType(InternString(name)); -} - -FunctionType TypePool::MakeFunctionType(const Type& result, - absl::Span args) { - absl::MutexLock lock(&functions_mutex_); - return functions_.InternFunctionType(result, args); -} - -ListType TypePool::MakeListType(const Type& element) { - if (element.IsDyn()) { - return ListType(); - } - absl::MutexLock lock(&lists_mutex_); - return lists_.InternListType(element); -} - -MapType TypePool::MakeMapType(const Type& key, const Type& value) { - if (key.IsDyn() && value.IsDyn()) { - return MapType(); - } - if (key.IsString() && value.IsDyn()) { - return JsonMapType(); - } - absl::MutexLock lock(&maps_mutex_); - return maps_.InternMapType(key, value); -} - -OpaqueType TypePool::MakeOpaqueType(absl::string_view name, - absl::Span parameters) { - if (name == OptionalType::kName) { - if (parameters.size() == 1 && parameters.front().IsDyn()) { - return OptionalType(); - } - name = OptionalType::kName; - } else { - name = InternString(name); - } - absl::MutexLock lock(&opaques_mutex_); - return opaques_.InternOpaqueType(name, parameters); -} - -OptionalType TypePool::MakeOptionalType(const Type& parameter) { - return static_cast( - MakeOpaqueType(OptionalType::kName, absl::MakeConstSpan(¶meter, 1))); -} - -TypeParamType TypePool::MakeTypeParamType(absl::string_view name) { - return TypeParamType(InternString(name)); -} - -TypeType TypePool::MakeTypeType(const Type& type) { - absl::MutexLock lock(&types_mutex_); - return types_.InternTypeType(type); -} - -absl::string_view TypePool::InternString(absl::string_view string) { - absl::MutexLock lock(&strings_mutex_); - return strings_.InternString(string); -} - -} // namespace cel::common_internal diff --git a/common/types/type_pool.h b/common/types/type_pool.h deleted file mode 100644 index 37f3ff662..000000000 --- a/common/types/type_pool.h +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ - -#include "absl/base/attributes.h" -#include "absl/base/nullability.h" -#include "absl/base/thread_annotations.h" -#include "absl/log/die_if_null.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "common/type.h" -#include "common/types/function_type_pool.h" -#include "common/types/list_type_pool.h" -#include "common/types/map_type_pool.h" -#include "common/types/opaque_type_pool.h" -#include "common/types/type_type_pool.h" -#include "internal/string_pool.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" - -namespace cel::common_internal { - -// `TypePool` is a thread safe interning factory for complex types. All types -// are allocated using the provided `google::protobuf::Arena`. -class TypePool final { - public: - TypePool(absl::Nonnull descriptors - ABSL_ATTRIBUTE_LIFETIME_BOUND, - absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) - : descriptors_(ABSL_DIE_IF_NULL(descriptors)), // Crash OK - arena_(ABSL_DIE_IF_NULL(arena)), // Crash OK - strings_(arena_), - functions_(arena_), - lists_(arena_), - maps_(arena_), - opaques_(arena_), - types_(arena_) {} - - TypePool(const TypePool&) = delete; - TypePool(TypePool&&) = delete; - TypePool& operator=(const TypePool&) = delete; - TypePool& operator=(TypePool&&) = delete; - - StructType MakeStructType(absl::string_view name); - - FunctionType MakeFunctionType(const Type& result, - absl::Span args); - - ListType MakeListType(const Type& element); - - MapType MakeMapType(const Type& key, const Type& value); - - OpaqueType MakeOpaqueType(absl::string_view name, - absl::Span parameters); - - OptionalType MakeOptionalType(const Type& parameter); - - TypeParamType MakeTypeParamType(absl::string_view name); - - TypeType MakeTypeType(const Type& type); - - private: - absl::string_view InternString(absl::string_view string); - - absl::Nonnull const descriptors_; - absl::Nonnull const arena_; - absl::Mutex strings_mutex_; - internal::StringPool strings_ ABSL_GUARDED_BY(strings_mutex_); - absl::Mutex functions_mutex_; - FunctionTypePool functions_ ABSL_GUARDED_BY(functions_mutex_); - absl::Mutex lists_mutex_; - ListTypePool lists_ ABSL_GUARDED_BY(lists_mutex_); - absl::Mutex maps_mutex_; - MapTypePool maps_ ABSL_GUARDED_BY(maps_mutex_); - absl::Mutex opaques_mutex_; - OpaqueTypePool opaques_ ABSL_GUARDED_BY(opaques_mutex_); - absl::Mutex types_mutex_; - TypeTypePool types_ ABSL_GUARDED_BY(types_mutex_); -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ diff --git a/common/types/type_pool_test.cc b/common/types/type_pool_test.cc deleted file mode 100644 index 2f36121be..000000000 --- a/common/types/type_pool_test.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "common/types/type_pool.h" - -#include "common/type.h" -#include "internal/testing.h" -#include "internal/testing_descriptor_pool.h" -#include "google/protobuf/arena.h" - -namespace cel::common_internal { -namespace { - -using ::cel::internal::GetTestingDescriptorPool; -using ::testing::_; - -TEST(TypePool, MakeStructType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeStructType("foo.Bar"), - MakeBasicStructType("foo.Bar")); - EXPECT_TRUE( - type_pool.MakeStructType("google.api.expr.test.v1.proto3.TestAllTypes") - .IsMessage()); - EXPECT_DEBUG_DEATH( - static_cast(type_pool.MakeStructType("google.protobuf.BoolValue")), - _); -} - -TEST(TypePool, MakeFunctionType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeFunctionType(BoolType(), {IntType(), IntType()}), - FunctionType(&arena, BoolType(), {IntType(), IntType()})); -} - -TEST(TypePool, MakeListType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeListType(DynType()), ListType()); - EXPECT_EQ(type_pool.MakeListType(DynType()), JsonListType()); - EXPECT_EQ(type_pool.MakeListType(StringType()), - ListType(&arena, StringType())); -} - -TEST(TypePool, MakeMapType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeMapType(DynType(), DynType()), MapType()); - EXPECT_EQ(type_pool.MakeMapType(StringType(), DynType()), JsonMapType()); - EXPECT_EQ(type_pool.MakeMapType(StringType(), StringType()), - MapType(&arena, StringType(), StringType())); -} - -TEST(TypePool, MakeOpaqueType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeOpaqueType("custom_type", {DynType(), DynType()}), - OpaqueType(&arena, "custom_type", {DynType(), DynType()})); -} - -TEST(TypePool, MakeOptionalType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeOptionalType(DynType()), OptionalType()); - EXPECT_EQ(type_pool.MakeOptionalType(StringType()), - OptionalType(&arena, StringType())); -} - -TEST(TypePool, MakeTypeParamType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeTypeParamType("T"), TypeParamType("T")); -} - -TEST(TypePool, MakeTypeType) { - google::protobuf::Arena arena; - TypePool type_pool(GetTestingDescriptorPool(), &arena); - EXPECT_EQ(type_pool.MakeTypeType(BoolType()), TypeType(&arena, BoolType())); -} - -} // namespace -} // namespace cel::common_internal