From 8a772df19bd57cbf3dd5e4151eb295887b3ac6da Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 14 Mar 2024 23:23:06 -0700 Subject: [PATCH] Implement unwrapping utilities (cel::Value -> protobuf type). PiperOrigin-RevId: 616024509 --- extensions/protobuf/BUILD | 46 +- extensions/protobuf/internal/BUILD | 57 + extensions/protobuf/internal/message.cc | 2383 ++++++++++++++++++ extensions/protobuf/internal/message.h | 164 ++ extensions/protobuf/internal/message_test.cc | 43 + extensions/protobuf/value.h | 186 ++ extensions/protobuf/value_end_to_end_test.cc | 943 +++++++ extensions/protobuf/value_test.cc | 422 +++- 8 files changed, 4234 insertions(+), 10 deletions(-) create mode 100644 extensions/protobuf/internal/message.cc create mode 100644 extensions/protobuf/internal/message.h create mode 100644 extensions/protobuf/internal/message_test.cc create mode 100644 extensions/protobuf/value_end_to_end_test.cc diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index d9040ccf1..1dccd0a16 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -182,15 +182,25 @@ cc_test( cc_library( name = "value", - srcs = ["value.cc"], - hdrs = ["value.h"], + srcs = [ + "value.cc", + ], + hdrs = [ + "value.h", + ], deps = [ "//common:casting", "//common:value", + "//extensions/protobuf/internal:duration", "//extensions/protobuf/internal:enum", + "//extensions/protobuf/internal:message", + "//extensions/protobuf/internal:struct", + "//extensions/protobuf/internal:timestamp", + "//extensions/protobuf/internal:wrappers", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -202,13 +212,45 @@ cc_test( name = "value_test", srcs = ["value_test.cc"], deps = [ + ":memory_manager", ":value", "//common:casting", + "//common:memory", "//common:value", + "//common:value_kind", "//common:value_testing", + "//internal:proto_matchers", "//internal:testing", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "value_end_to_end_test", + srcs = ["value_end_to_end_test.cc"], + deps = [ + ":runtime_adapter", + ":value", + "//common:memory", + "//common:value", + "//common:value_testing", + "//internal:testing", + "//parser", + "//runtime", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD index 41c5bb409..6b235c3f8 100644 --- a/extensions/protobuf/internal/BUILD +++ b/extensions/protobuf/internal/BUILD @@ -119,6 +119,63 @@ cc_library( ], ) +cc_library( + name = "message", + srcs = ["message.cc"], + hdrs = ["message.h"], + deps = [ + ":any", + ":duration", + ":json", + ":map_reflection", + ":qualify", + ":struct", + ":timestamp", + ":wrappers", + "//base:attributes", + "//base/internal:message_wrapper", + "//common:any", + "//common:casting", + "//common:json", + "//common:memory", + "//common:native_type", + "//common:type", + "//common:value", + "//common:value_kind", + "//common/internal:reference_count", + "//extensions/protobuf:json", + "//extensions/protobuf:memory_manager", + "//internal:align", + "//internal:casts", + "//internal:new", + "//internal:status_macros", + "//runtime:runtime_options", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "message_test", + srcs = ["message_test.cc"], + deps = [ + ":message", + "//internal:testing", + "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", + ], +) + cc_library( name = "reflection", srcs = ["reflection.cc"], diff --git a/extensions/protobuf/internal/message.cc b/extensions/protobuf/internal/message.cc new file mode 100644 index 000000000..d4697896b --- /dev/null +++ b/extensions/protobuf/internal/message.cc @@ -0,0 +1,2383 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/internal/message.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/attribute.h" +#include "base/internal/message_wrapper.h" +#include "common/any.h" +#include "common/casting.h" +#include "common/internal/reference_count.h" +#include "common/json.h" +#include "common/memory.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_manager.h" +#include "extensions/protobuf/internal/any.h" +#include "extensions/protobuf/internal/duration.h" +#include "extensions/protobuf/internal/json.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "extensions/protobuf/internal/qualify.h" +#include "extensions/protobuf/internal/struct.h" +#include "extensions/protobuf/internal/timestamp.h" +#include "extensions/protobuf/internal/wrappers.h" +#include "extensions/protobuf/json.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/align.h" +#include "internal/casts.h" +#include "internal/new.h" +#include "internal/status_macros.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel { + +// Forward declare Value interfaces for implementing type traits. +namespace extensions::protobuf_internal { +namespace { +class PooledParsedProtoStructValueInterface; +class ParsedProtoListValueInterface; +class ParsedProtoMapValueInterface; +} // namespace +} // namespace extensions::protobuf_internal + +template <> +struct NativeTypeTraits< + extensions::protobuf_internal::PooledParsedProtoStructValueInterface> { + static bool SkipDestructor(const extensions::protobuf_internal:: + PooledParsedProtoStructValueInterface&) { + return true; + } +}; + +template <> +struct NativeTypeTraits< + extensions::protobuf_internal::ParsedProtoListValueInterface> { + static bool SkipDestructor( + const extensions::protobuf_internal::ParsedProtoListValueInterface&) { + return true; + } +}; + +template <> +struct NativeTypeTraits< + extensions::protobuf_internal::ParsedProtoMapValueInterface> { + static bool SkipDestructor( + const extensions::protobuf_internal::ParsedProtoMapValueInterface&) { + return true; + } +}; + +namespace extensions::protobuf_internal { + +namespace { + +struct DefaultArenaDeleter { + template + void operator()(T* message) const { + if (arena == nullptr) { + delete message; + } + } + + google::protobuf::Arena* arena = nullptr; +}; + +template +using ArenaUniquePtr = std::unique_ptr; + +absl::StatusOr>> NewProtoMessage( + absl::Nonnull pool, + absl::Nonnull factory, absl::string_view name, + google::protobuf::Arena* arena) { + const auto* desc = pool->FindMessageTypeByName(name); + if (ABSL_PREDICT_FALSE(desc == nullptr)) { + return absl::NotFoundError( + absl::StrCat("descriptor missing: `", name, "`")); + } + const auto* proto = factory->GetPrototype(desc); + if (ABSL_PREDICT_FALSE(proto == nullptr)) { + return absl::NotFoundError(absl::StrCat("prototype missing: `", name, "`")); + } + return ArenaUniquePtr(proto->New(arena), + DefaultArenaDeleter{arena}); +} + +absl::Status ProtoMapKeyTypeMismatch(google::protobuf::FieldDescriptor::CppType expected, + google::protobuf::FieldDescriptor::CppType got) { + if (ABSL_PREDICT_FALSE(got != expected)) { + return absl::InternalError( + absl::StrCat("protocol buffer map key type mismatch: ", + google::protobuf::FieldDescriptor::CppTypeName(expected), " vs ", + google::protobuf::FieldDescriptor::CppTypeName(got))); + } + return absl::OkStatus(); +} + +template +class AliasingValue : public T { + public: + template + explicit AliasingValue(Shared alias, Args&&... args) + : T(std::forward(args)...), alias_(std::move(alias)) {} + + private: + Shared alias_; +}; + +// ----------------------------------------------------------------------------- +// cel::Value -> google::protobuf::MapKey + +using ProtoMapKeyFromValueConverter = absl::Status (*)(ValueView, + google::protobuf::MapKey&); + +absl::Status ProtoBoolMapKeyFromValueConverter(ValueView value, + google::protobuf::MapKey& key) { + if (auto bool_value = As(value); bool_value) { + key.SetBoolValue(bool_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "bool").NativeValue(); +} + +absl::Status ProtoInt32MapKeyFromValueConverter(ValueView value, + google::protobuf::MapKey& key) { + if (auto int_value = As(value); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("int64 to int32_t overflow"); + } + key.SetInt32Value(static_cast(int_value->NativeValue())); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "int").NativeValue(); +} + +absl::Status ProtoInt64MapKeyFromValueConverter(ValueView value, + google::protobuf::MapKey& key) { + if (auto int_value = As(value); int_value) { + key.SetInt64Value(int_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "int").NativeValue(); +} + +absl::Status ProtoUInt32MapKeyFromValueConverter(ValueView value, + google::protobuf::MapKey& key) { + if (auto uint_value = As(value); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("uint64 to uint32_t overflow"); + } + key.SetUInt32Value(static_cast(uint_value->NativeValue())); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "uint").NativeValue(); +} + +absl::Status ProtoUInt64MapKeyFromValueConverter(ValueView value, + google::protobuf::MapKey& key) { + if (auto uint_value = As(value); uint_value) { + key.SetUInt64Value(uint_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "uint").NativeValue(); +} + +absl::Status ProtoStringMapKeyFromValueConverter(ValueView value, + google::protobuf::MapKey& key) { + if (auto string_value = As(value); string_value) { + key.SetStringValue(string_value->NativeString()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "string").NativeValue(); +} + +absl::StatusOr GetProtoMapKeyFromValueConverter( + google::protobuf::FieldDescriptor::CppType cpp_type) { + switch (cpp_type) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolMapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return ProtoStringMapKeyFromValueConverter; + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected protocol buffer map key type: ", + google::protobuf::FieldDescriptor::CppTypeName(cpp_type))); + } +} + +// ----------------------------------------------------------------------------- +// google::protobuf::MapKey -> cel::Value + +using ProtoMapKeyToValueConverter = + absl::StatusOr (*)(const google::protobuf::MapKey&, ValueManager&, Value&); + +absl::StatusOr ProtoBoolMapKeyToValueConverter( + const google::protobuf::MapKey& key, ValueManager&, Value&) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_BOOL, key.type())); + return BoolValueView{key.GetBoolValue()}; +} + +absl::StatusOr ProtoInt32MapKeyToValueConverter( + const google::protobuf::MapKey& key, ValueManager&, Value&) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_INT32, key.type())); + return IntValueView{key.GetInt32Value()}; +} + +absl::StatusOr ProtoInt64MapKeyToValueConverter( + const google::protobuf::MapKey& key, ValueManager&, Value&) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_INT64, key.type())); + return IntValueView{key.GetInt64Value()}; +} + +absl::StatusOr ProtoUInt32MapKeyToValueConverter( + const google::protobuf::MapKey& key, ValueManager&, Value&) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_UINT32, key.type())); + return UintValueView{key.GetUInt32Value()}; +} + +absl::StatusOr ProtoUInt64MapKeyToValueConverter( + const google::protobuf::MapKey& key, ValueManager&, Value&) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_UINT64, key.type())); + return UintValueView{key.GetUInt64Value()}; +} + +absl::StatusOr ProtoStringMapKeyToValueConverter( + const google::protobuf::MapKey& key, ValueManager& value_manager, Value&) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_STRING, key.type())); + return StringValueView{key.GetStringValue()}; +} + +absl::StatusOr GetProtoMapKeyToValueConverter( + absl::Nonnull field) { + ABSL_DCHECK(field->is_map()); + const auto* key_field = field->message_type()->map_key(); + switch (key_field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolMapKeyToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32MapKeyToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64MapKeyToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32MapKeyToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64MapKeyToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return ProtoStringMapKeyToValueConverter; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer map key type: ", + google::protobuf::FieldDescriptor::CppTypeName(key_field->cpp_type()))); + } +} + +// ----------------------------------------------------------------------------- +// cel::Value -> google::protobuf::MapValueRef + +using ProtoMapValueFromValueConverter = + absl::Status (*)(ValueView, absl::Nonnull, + google::protobuf::MapValueRef&); + +absl::Status ProtoBoolMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_BOOL, value_ref.type())); + if (auto bool_value = As(value); bool_value) { + value_ref.SetBoolValue(bool_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "bool").NativeValue(); +} + +absl::Status ProtoInt32MapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_INT32, value_ref.type())); + if (auto int_value = As(value); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("int64 to int32_t overflow"); + } + value_ref.SetInt32Value(static_cast(int_value->NativeValue())); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "int").NativeValue(); +} + +absl::Status ProtoInt64MapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_INT64, value_ref.type())); + if (auto int_value = As(value); int_value) { + value_ref.SetInt64Value(int_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "int").NativeValue(); +} + +absl::Status ProtoUInt32MapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_UINT32, value_ref.type())); + if (auto uint_value = As(value); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("uint64 to uint32_t overflow"); + } + value_ref.SetUInt32Value(static_cast(uint_value->NativeValue())); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "uint").NativeValue(); +} + +absl::Status ProtoUInt64MapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_UINT64, value_ref.type())); + if (auto uint_value = As(value); uint_value) { + value_ref.SetUInt64Value(uint_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "uint").NativeValue(); +} + +absl::Status ProtoFloatMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_FLOAT, value_ref.type())); + if (auto double_value = As(value); double_value) { + value_ref.SetFloatValue(double_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "double").NativeValue(); +} + +absl::Status ProtoDoubleMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE, value_ref.type())); + if (auto double_value = As(value); double_value) { + value_ref.SetDoubleValue(double_value->NativeValue()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "double").NativeValue(); +} + +absl::Status ProtoBytesMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_STRING, value_ref.type())); + if (auto bytes_value = As(value); bytes_value) { + value_ref.SetStringValue(bytes_value->NativeString()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "bytes").NativeValue(); +} + +absl::Status ProtoStringMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_STRING, value_ref.type())); + if (auto string_value = As(value); string_value) { + value_ref.SetStringValue(string_value->NativeString()); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "string").NativeValue(); +} + +absl::Status ProtoNullMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_ENUM, value_ref.type())); + if (InstanceOf(value) || InstanceOf(value)) { + value_ref.SetEnumValue(0); + return absl::OkStatus(); + } + return TypeConversionError(value.GetTypeName(), "google.protobuf.NullValue") + .NativeValue(); +} + +absl::Status ProtoEnumMapValueFromValueConverter( + ValueView value, absl::Nonnull field, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_ENUM, value_ref.type())); + if (auto int_value = As(value); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("int64 to int32_t overflow"); + } + const auto* enum_type = field->enum_type(); + if (const auto* enum_value = enum_type->FindValueByNumber( + static_cast(int_value->NativeValue())); + enum_value != nullptr) { + value_ref.SetEnumValue(enum_value->number()); + return absl::OkStatus(); + } + return absl::NotFoundError(absl::StrCat("enum `", enum_type->full_name(), + "` has no value with number ", + int_value->NativeValue())); + } + return TypeConversionError(value.GetTypeName(), "enum").NativeValue(); +} + +absl::Status ProtoMessageMapValueFromValueConverter( + ValueView value, absl::Nonnull, + google::protobuf::MapValueRef& value_ref) { + CEL_RETURN_IF_ERROR(ProtoMapKeyTypeMismatch( + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE, value_ref.type())); + return ProtoMessageFromValueImpl(value, value_ref.MutableMessageValue()); +} + +absl::StatusOr +GetProtoMapValueFromValueConverter( + absl::Nonnull field) { + ABSL_DCHECK(field->is_map()); + const auto* value_field = field->message_type()->map_value(); + switch (value_field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return ProtoFloatMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return ProtoDoubleMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + if (value_field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return ProtoBytesMapValueFromValueConverter; + } + return ProtoStringMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + if (value_field->enum_type()->full_name() == + "google.protobuf.NullValue") { + return ProtoNullMapValueFromValueConverter; + } + return ProtoEnumMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return ProtoMessageMapValueFromValueConverter; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer map value type: ", + google::protobuf::FieldDescriptor::CppTypeName(value_field->cpp_type()))); + } +} + +// ----------------------------------------------------------------------------- +// google::protobuf::MapValueConstRef -> cel::Value + +using ProtoMapValueToValueConverter = absl::StatusOr (*)( + SharedView, absl::Nonnull, + const google::protobuf::MapValueConstRef&, ValueManager&, Value&); + +absl::StatusOr ProtoBoolMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return BoolValueView{value_ref.GetBoolValue()}; +} + +absl::StatusOr ProtoInt32MapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return IntValueView{value_ref.GetInt32Value()}; +} + +absl::StatusOr ProtoInt64MapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + + return IntValueView{value_ref.GetInt64Value()}; +} + +absl::StatusOr ProtoUInt32MapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return UintValueView{value_ref.GetUInt32Value()}; +} + +absl::StatusOr ProtoUInt64MapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return UintValueView{value_ref.GetUInt64Value()}; +} + +absl::StatusOr ProtoFloatMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return DoubleValueView{value_ref.GetFloatValue()}; +} + +absl::StatusOr ProtoDoubleMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return DoubleValueView{value_ref.GetDoubleValue()}; +} + +absl::StatusOr ProtoBytesMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return BytesValueView{value_ref.GetStringValue()}; +} + +absl::StatusOr ProtoStringMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return StringValueView{value_ref.GetStringValue()}; +} + +absl::StatusOr ProtoNullMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return NullValueView{}; +} + +absl::StatusOr ProtoEnumMapValueToValueConverter( + SharedView, absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + return IntValueView{value_ref.GetEnumValue()}; +} + +absl::StatusOr ProtoMessageMapValueToValueConverter( + SharedView alias, + absl::Nonnull field, + const google::protobuf::MapValueConstRef& value_ref, ValueManager& value_manager, + Value& value) { + // Caller validates that the field type is correct. + CEL_ASSIGN_OR_RETURN( + value, ProtoMessageToValueImpl(value_manager, Shared(alias), + &value_ref.GetMessageValue())); + return value; +} + +absl::StatusOr GetProtoMapValueToValueConverter( + absl::Nonnull field) { + ABSL_DCHECK(field->is_map()); + const auto* value_field = field->message_type()->map_value(); + switch (value_field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolMapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32MapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64MapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32MapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64MapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return ProtoFloatMapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return ProtoDoubleMapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + if (value_field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return ProtoBytesMapValueToValueConverter; + } + return ProtoStringMapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + if (value_field->enum_type()->full_name() == + "google.protobuf.NullValue") { + return ProtoNullMapValueToValueConverter; + } + return ProtoEnumMapValueToValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return ProtoMessageMapValueToValueConverter; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer map value type: ", + google::protobuf::FieldDescriptor::CppTypeName(value_field->cpp_type()))); + } +} + +// ----------------------------------------------------------------------------- +// repeated field -> Value + +using ProtoRepeatedFieldToValueAccessor = absl::StatusOr (*)( + SharedView, absl::Nonnull, + absl::Nonnull, + absl::Nonnull, int, ValueManager&, Value&); + +absl::StatusOr ProtoBoolRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return BoolValueView{reflection->GetRepeatedBool(*message, field, index)}; +} + +absl::StatusOr ProtoInt32RepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return IntValueView{reflection->GetRepeatedInt32(*message, field, index)}; +} + +absl::StatusOr ProtoInt64RepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return IntValueView{reflection->GetRepeatedInt64(*message, field, index)}; +} + +absl::StatusOr ProtoUInt32RepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return UintValueView{reflection->GetRepeatedUInt32(*message, field, index)}; +} + +absl::StatusOr ProtoUInt64RepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return UintValueView{reflection->GetRepeatedUInt64(*message, field, index)}; +} + +absl::StatusOr ProtoFloatRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return DoubleValueView{reflection->GetRepeatedFloat(*message, field, index)}; +} + +absl::StatusOr ProtoDoubleRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager&, Value&) { + return DoubleValueView{reflection->GetRepeatedDouble(*message, field, index)}; +} + +absl::StatusOr ProtoBytesRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager& value_manager, Value& value) { + value = BytesValue{reflection->GetRepeatedString(*message, field, index)}; + return value; +} + +absl::StatusOr ProtoStringRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager& value_manager, Value& value) { + value = value_manager.CreateUncheckedStringValue( + reflection->GetRepeatedString(*message, field, index)); + return value; +} + +absl::StatusOr ProtoNullRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull, + absl::Nonnull, + absl::Nonnull, int, ValueManager&, Value&) { + return NullValueView{}; +} + +absl::StatusOr ProtoEnumRepeatedFieldToValueAccessor( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager& value_manager, Value& value) { + return IntValueView{reflection->GetRepeatedEnumValue(*message, field, index)}; +} + +absl::StatusOr ProtoMessageRepeatedFieldToValueAccessor( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, int index, + ValueManager& value_manager, Value& value) { + const auto& field_value = + reflection->GetRepeatedMessage(*message, field, index); + CEL_ASSIGN_OR_RETURN( + value, ProtoMessageToValueImpl(value_manager, Shared(aliased), + &field_value)); + return value; +} + +absl::StatusOr +GetProtoRepeatedFieldToValueAccessor( + absl::Nonnull field) { + ABSL_DCHECK(!field->is_map()); + ABSL_DCHECK(field->is_repeated()); + switch (field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolRepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32RepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64RepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32RepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64RepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return ProtoFloatRepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return ProtoDoubleRepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + if (field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return ProtoBytesRepeatedFieldToValueAccessor; + } + return ProtoStringRepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return ProtoNullRepeatedFieldToValueAccessor; + } + return ProtoEnumRepeatedFieldToValueAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return ProtoMessageRepeatedFieldToValueAccessor; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer repeated field type: ", + google::protobuf::FieldDescriptor::CppTypeName(field->cpp_type()))); + } +} + +// ----------------------------------------------------------------------------- +// field -> Value + +absl::StatusOr ProtoBoolFieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return BoolValueView{reflection->GetBool(*message, field)}; +} + +absl::StatusOr ProtoInt32FieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return IntValueView{reflection->GetInt32(*message, field)}; +} + +absl::StatusOr ProtoInt64FieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return IntValueView{reflection->GetInt64(*message, field)}; +} + +absl::StatusOr ProtoUInt32FieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return UintValueView{reflection->GetUInt32(*message, field)}; +} + +absl::StatusOr ProtoUInt64FieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return UintValueView{reflection->GetUInt64(*message, field)}; +} + +absl::StatusOr ProtoFloatFieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return DoubleValueView{reflection->GetFloat(*message, field)}; +} + +absl::StatusOr ProtoDoubleFieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return DoubleValueView{reflection->GetDouble(*message, field)}; +} + +absl::StatusOr ProtoBytesFieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value) { + value = BytesValue{reflection->GetString(*message, field)}; + return value; +} + +absl::StatusOr ProtoStringFieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value) { + value = StringValue{reflection->GetString(*message, field)}; + return value; +} + +absl::StatusOr ProtoNullFieldToValue( + SharedView, absl::Nonnull, + absl::Nonnull, + absl::Nonnull, ValueManager&, Value&) { + return NullValueView{}; +} + +absl::StatusOr ProtoEnumFieldToValue( + SharedView, absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, ValueManager&, + Value&) { + return IntValueView{reflection->GetEnumValue(*message, field)}; +} + +bool IsWrapperType(absl::Nonnull descriptor) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return true; + default: + return false; + } +} + +absl::StatusOr ProtoMessageFieldToValue( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value, + ProtoWrapperTypeOptions wrapper_type_options) { + if (wrapper_type_options == ProtoWrapperTypeOptions::kUnsetNull && + IsWrapperType(field->message_type()) && + !reflection->HasField(*message, field)) { + return NullValueView{}; + } + const auto& field_value = reflection->GetMessage(*message, field); + CEL_ASSIGN_OR_RETURN( + value, ProtoMessageToValueImpl(value_manager, Shared(aliased), + &field_value)); + return value; +} + +absl::StatusOr ProtoMapFieldToValue( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value); + +absl::StatusOr ProtoRepeatedFieldToValue( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value); + +absl::StatusOr ProtoFieldToValue( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value, + ProtoWrapperTypeOptions wrapper_type_options) { + if (field->is_map()) { + return ProtoMapFieldToValue(aliased, message, reflection, field, + value_manager, value); + } + if (field->is_repeated()) { + return ProtoRepeatedFieldToValue(aliased, message, reflection, field, + value_manager, value); + } + switch (field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolFieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32FieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64FieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32FieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64FieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return ProtoFloatFieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return ProtoDoubleFieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + if (field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return ProtoBytesFieldToValue(aliased, message, reflection, field, + value_manager, value); + } + return ProtoStringFieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return ProtoNullFieldToValue(aliased, message, reflection, field, + value_manager, value); + } + return ProtoEnumFieldToValue(aliased, message, reflection, field, + value_manager, value); + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return ProtoMessageFieldToValue(aliased, message, reflection, field, + value_manager, value, + wrapper_type_options); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer repeated field type: ", + google::protobuf::FieldDescriptor::CppTypeName(field->cpp_type()))); + } +} + +absl::Status ProtoMessageCopyUsingSerialization( + google::protobuf::MessageLite* to, const google::protobuf::MessageLite* from) { + ABSL_DCHECK_EQ(to->GetTypeName(), from->GetTypeName()); + absl::Cord serialized; + if (!from->SerializePartialToCord(&serialized)) { + return absl::UnknownError( + absl::StrCat("failed to serialize `", from->GetTypeName(), "`")); + } + if (!to->ParsePartialFromCord(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parse `", to->GetTypeName(), "`")); + } + return absl::OkStatus(); +} + +bool IsValidFieldNumber(int64_t number) { + return ABSL_PREDICT_TRUE(number > 0 || + number < std::numeric_limits::max()); +} + +class ParsedProtoListElementIterator final : public ValueIterator { + public: + ParsedProtoListElementIterator( + Shared aliasing, const google::protobuf::Message& message, + absl::Nonnull field, + ProtoRepeatedFieldToValueAccessor field_to_value_accessor) + : aliasing_(std::move(aliasing)), + message_(message), + field_(field), + field_to_value_accessor_(field_to_value_accessor), + size_(GetReflectionOrDie(message_)->FieldSize(message, field_)) {} + + bool HasNext() override { return index_ < size_; } + + absl::StatusOr Next(ValueManager& value_manager, + Value& scratch) override { + CEL_ASSIGN_OR_RETURN(auto element, + field_to_value_accessor_( + aliasing_, &message_, GetReflectionOrDie(message_), + field_, index_, value_manager, scratch)); + ++index_; + return element; + } + + private: + Shared aliasing_; + const google::protobuf::Message& message_; + absl::Nonnull field_; + ProtoRepeatedFieldToValueAccessor field_to_value_accessor_; + const int size_; + int index_ = 0; +}; + +class ParsedProtoListValueInterface + : public ParsedListValueInterface, + public EnableSharedFromThis { + public: + ParsedProtoListValueInterface( + const google::protobuf::Message& message, + absl::Nonnull field, + ProtoRepeatedFieldToValueAccessor field_to_value_accessor) + : message_(message), + field_(field), + field_to_value_accessor_(field_to_value_accessor) {} + + std::string DebugString() const final { + google::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.SetUseUtf8StringEscaping(true); + std::string buffer; + buffer.push_back('['); + std::string output; + const int count = GetReflectionOrDie(message_)->FieldSize(message_, field_); + for (int index = 0; index < count; ++index) { + if (index != 0) { + buffer.append(", "); + } + printer.PrintFieldValueToString(message_, field_, index, &output); + buffer.append(output); + } + buffer.push_back(']'); + return buffer; + } + + absl::StatusOr ConvertToJsonArray( + AnyToJsonConverter& converter) const final { + return ProtoRepeatedFieldToJsonArray( + converter, GetReflectionOrDie(message_), message_, field_); + } + + size_t Size() const final { + return static_cast( + GetReflectionOrDie(message_)->FieldSize(message_, field_)); + } + + absl::Status ForEach(ValueManager& value_manager, + ForEachCallback callback) const final { + const auto size = Size(); + Value element_scratch; + for (size_t index = 0; index < size; ++index) { + CEL_ASSIGN_OR_RETURN( + auto element, + field_to_value_accessor_( + shared_from_this(), &message_, GetReflectionOrDie(message_), + field_, static_cast(index), value_manager, element_scratch)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(element)); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::Status ForEach(ValueManager& value_manager, + ForEachWithIndexCallback callback) const final { + const auto size = Size(); + Value element_scratch; + for (size_t index = 0; index < size; ++index) { + CEL_ASSIGN_OR_RETURN( + auto element, + field_to_value_accessor_( + shared_from_this(), &message_, GetReflectionOrDie(message_), + field_, static_cast(index), value_manager, element_scratch)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(index, element)); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::StatusOr> NewIterator( + ValueManager& value_manager) const final { + return std::make_unique( + shared_from_this(), message_, field_, field_to_value_accessor_); + } + + private: + absl::StatusOr GetImpl(ValueManager& value_manager, size_t index, + Value& scratch) const final { + return field_to_value_accessor_( + shared_from_this(), &message_, GetReflectionOrDie(message_), field_, + static_cast(index), value_manager, scratch); + } + + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } + + const google::protobuf::Message& message_; + absl::Nonnull field_; + ProtoRepeatedFieldToValueAccessor field_to_value_accessor_; +}; + +class ParsedProtoMapKeyIterator final : public ValueIterator { + public: + ParsedProtoMapKeyIterator(const google::protobuf::Message& message, + absl::Nonnull field, + ProtoMapKeyToValueConverter map_key_to_value) + : begin_(MapBegin(*GetReflectionOrDie(message), message, *field)), + end_(MapEnd(*GetReflectionOrDie(message), message, *field)), + map_key_to_value_(map_key_to_value) {} + + bool HasNext() override { return begin_ != end_; } + + absl::StatusOr Next(ValueManager& value_manager, + Value& scratch) override { + CEL_ASSIGN_OR_RETURN( + auto key, map_key_to_value_(begin_.GetKey(), value_manager, scratch)); + ++begin_; + return key; + } + + private: + google::protobuf::MapIterator begin_; + google::protobuf::MapIterator end_; + ProtoMapKeyToValueConverter map_key_to_value_; +}; + +class ParsedProtoMapValueInterface + : public ParsedMapValueInterface, + public EnableSharedFromThis { + public: + ParsedProtoMapValueInterface( + const google::protobuf::Message& message, + absl::Nonnull field, + ProtoMapKeyFromValueConverter map_key_from_value, + ProtoMapKeyToValueConverter map_key_to_value, + ProtoMapValueToValueConverter map_value_to_value) + : message_(message), + field_(field), + map_key_from_value_(map_key_from_value), + map_key_to_value_(map_key_to_value), + map_value_to_value_(map_value_to_value) {} + + std::string DebugString() const final { + google::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.SetUseUtf8StringEscaping(true); + std::string buffer; + buffer.push_back('{'); + std::string output; + const auto* reflection = GetReflectionOrDie(message_); + const auto* map_key = field_->message_type()->map_key(); + const auto* map_value = field_->message_type()->map_value(); + const int count = reflection->FieldSize(message_, field_); + for (int index = 0; index < count; ++index) { + if (index != 0) { + buffer.append(", "); + } + const auto& entry = + reflection->GetRepeatedMessage(message_, field_, index); + printer.PrintFieldValueToString(entry, map_key, -1, &output); + buffer.append(output); + buffer.append(": "); + printer.PrintFieldValueToString(entry, map_value, -1, &output); + buffer.append(output); + } + buffer.push_back('}'); + return buffer; + } + + absl::StatusOr ConvertToJsonObject( + AnyToJsonConverter& converter) const final { + return ProtoMapFieldToJsonObject(converter, GetReflectionOrDie(message_), + message_, field_); + } + + size_t Size() const final { + return static_cast(protobuf_internal::MapSize( + *GetReflectionOrDie(message_), message_, *field_)); + } + + absl::StatusOr ListKeys(ValueManager& value_manager, + ListValue& scratch) const final { + CEL_ASSIGN_OR_RETURN(auto builder, + value_manager.NewListValueBuilder(ListTypeView{})); + builder->Reserve(Size()); + auto begin = MapBegin(*GetReflectionOrDie(message_), message_, *field_); + auto end = MapEnd(*GetReflectionOrDie(message_), message_, *field_); + Value key_scratch; + while (begin != end) { + CEL_ASSIGN_OR_RETURN( + auto key, + map_key_to_value_(begin.GetKey(), value_manager, key_scratch)); + CEL_RETURN_IF_ERROR(builder->Add(Value{key})); + ++begin; + } + return std::move(*builder).Build(); + } + + absl::Status ForEach(ValueManager& value_manager, + ForEachCallback callback) const final { + auto begin = MapBegin(*GetReflectionOrDie(message_), message_, *field_); + auto end = MapEnd(*GetReflectionOrDie(message_), message_, *field_); + Value key_scratch; + Value value_scratch; + while (begin != end) { + CEL_ASSIGN_OR_RETURN( + auto key, + map_key_to_value_(begin.GetKey(), value_manager, key_scratch)); + CEL_ASSIGN_OR_RETURN( + auto value, + map_value_to_value_(shared_from_this(), field_, begin.GetValueRef(), + value_manager, value_scratch)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); + if (!ok) { + break; + } + ++begin; + } + return absl::OkStatus(); + } + + absl::StatusOr> NewIterator( + ValueManager& value_manager) const final { + return std::make_unique(message_, field_, + map_key_to_value_); + } + + private: + absl::StatusOr> FindImpl( + ValueManager& value_manager, ValueView key, Value& scratch) const final { + google::protobuf::MapKey map_key; + CEL_RETURN_IF_ERROR(map_key_from_value_(key, map_key)); + google::protobuf::MapValueConstRef map_value; + if (!LookupMapValue(*GetReflectionOrDie(message_), message_, *field_, + map_key, &map_value)) { + return absl::nullopt; + } + CEL_ASSIGN_OR_RETURN( + auto value, map_value_to_value_(shared_from_this(), + field_->message_type()->map_value(), + map_value, value_manager, scratch)); + return value; + } + + absl::StatusOr HasImpl(ValueManager& value_manager, + ValueView key) const final { + google::protobuf::MapKey map_key; + CEL_RETURN_IF_ERROR(map_key_from_value_(key, map_key)); + return ContainsMapKey(*GetReflectionOrDie(message_), message_, *field_, + map_key); + } + + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } + + const google::protobuf::Message& message_; + absl::Nonnull field_; + ProtoMapKeyFromValueConverter map_key_from_value_; + ProtoMapKeyToValueConverter map_key_to_value_; + ProtoMapValueToValueConverter map_value_to_value_; +}; + +class ParsedProtoQualifyState final : public ProtoQualifyState { + public: + ParsedProtoQualifyState(const google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + const google::protobuf::Reflection* reflection, + Shared alias, ValueManager& value_manager) + : ProtoQualifyState(message, descriptor, reflection), + alias_(std::move(alias)), + value_manager_(value_manager) {} + + absl::optional& result() { return result_; } + + private: + void SetResultFromError(absl::Status status, + cel::MemoryManagerRef memory_manager) override { + result_ = ErrorValue{std::move(status)}; + } + + void SetResultFromBool(bool value) override { result_ = BoolValue{value}; } + + absl::Status SetResultFromField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) override { + Value scratch; + CEL_ASSIGN_OR_RETURN( + auto result, + ProtoFieldToValue(alias_, message, message->GetReflection(), field, + value_manager_, scratch, unboxing_option)); + result_ = Value{result}; + return absl::OkStatus(); + } + + absl::Status SetResultFromRepeatedField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + int index, cel::MemoryManagerRef memory_manager) override { + CEL_ASSIGN_OR_RETURN(auto accessor, + GetProtoRepeatedFieldToValueAccessor(field)); + Value scratch; + CEL_ASSIGN_OR_RETURN(auto result, + (*accessor)(alias_, message, message->GetReflection(), + field, index, value_manager_, scratch)); + result_ = Value{result}; + return absl::OkStatus(); + } + + absl::Status SetResultFromMapField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + const google::protobuf::MapValueConstRef& value, + cel::MemoryManagerRef memory_manager) override { + CEL_ASSIGN_OR_RETURN(auto converter, + GetProtoMapValueToValueConverter(field)); + Value scratch; + CEL_ASSIGN_OR_RETURN(auto result, (*converter)(alias_, field, value, + value_manager_, scratch)); + result_ = Value{result}; + return absl::OkStatus(); + } + + Shared alias_; + ValueManager& value_manager_; + absl::optional result_; +}; + +class ParsedProtoStructValueInterface; + +const ParsedProtoStructValueInterface* AsParsedProtoStructValue( + ParsedStructValueView value); + +const ParsedProtoStructValueInterface* AsParsedProtoStructValue( + ValueView value) { + if (auto parsed_struct_value = As(value); + parsed_struct_value) { + return AsParsedProtoStructValue(*parsed_struct_value); + } + return nullptr; +} + +class ParsedProtoStructValueInterface + : public ParsedStructValueInterface, + public EnableSharedFromThis { + public: + absl::string_view GetTypeName() const final { + return message().GetDescriptor()->full_name(); + } + + std::string DebugString() const final { return message().DebugString(); } + + // `GetSerializedSize` determines the serialized byte size that would result + // from serialization, without performing the serialization. If this value + // does not support serialization, `FAILED_PRECONDITION` is returned. + absl::StatusOr GetSerializedSize(AnyToJsonConverter&) const final { + return message().ByteSizeLong(); + } + + // `SerializeTo` serializes this value and appends it to `value`. If this + // value does not support serialization, `FAILED_PRECONDITION` is returned. + absl::Status SerializeTo(AnyToJsonConverter&, absl::Cord& value) const final { + if (!message().SerializePartialToCord(&value)) { + return absl::InternalError( + absl::StrCat("failed to serialize ", GetTypeName())); + } + return absl::OkStatus(); + } + + absl::StatusOr GetTypeUrl(absl::string_view prefix) const final { + return MakeTypeUrlWithPrefix(prefix, GetTypeName()); + } + + absl::StatusOr ConvertToJson( + AnyToJsonConverter& value_manager) const final { + return ProtoMessageToJson(value_manager, message()); + } + + bool IsZeroValue() const final { return message().ByteSizeLong() == 0; } + + absl::StatusOr GetFieldByName( + ValueManager& value_manager, absl::string_view name, Value& scratch, + ProtoWrapperTypeOptions unboxing_options) const final { + const auto* desc = message().GetDescriptor(); + const auto* field_desc = desc->FindFieldByName(name); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + scratch = NoSuchFieldError(name); + return scratch; + } + return GetField(value_manager, field_desc, scratch, unboxing_options); + } + + absl::StatusOr GetFieldByNumber( + ValueManager& value_manager, int64_t number, Value& scratch, + ProtoWrapperTypeOptions unboxing_options) const final { + if (!IsValidFieldNumber(number)) { + scratch = NoSuchFieldError(absl::StrCat(number)); + return scratch; + } + const auto* desc = message().GetDescriptor(); + const auto* field_desc = desc->FindFieldByNumber(static_cast(number)); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + scratch = NoSuchFieldError(absl::StrCat(number)); + return scratch; + } + return GetField(value_manager, field_desc, scratch, unboxing_options); + } + + absl::StatusOr HasFieldByName(absl::string_view name) const final { + const auto* desc = message().GetDescriptor(); + const auto* field_desc = desc->FindFieldByName(name); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + return NoSuchFieldError(name).NativeValue(); + } + return HasField(field_desc); + } + + absl::StatusOr HasFieldByNumber(int64_t number) const final { + if (!IsValidFieldNumber(number)) { + return NoSuchFieldError(absl::StrCat(number)).NativeValue(); + } + const auto* desc = message().GetDescriptor(); + const auto* field_desc = desc->FindFieldByNumber(static_cast(number)); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + return NoSuchFieldError(absl::StrCat(number)).NativeValue(); + } + return HasField(field_desc); + } + + absl::Status ForEachField(ValueManager& value_manager, + ForEachFieldCallback callback) const final { + std::vector fields; + const auto* reflection = message().GetReflection(); + reflection->ListFields(message(), &fields); + Value value_scratch; + for (const auto* field : fields) { + CEL_ASSIGN_OR_RETURN( + auto value, + ProtoFieldToValue(shared_from_this(), &message(), reflection, field, + value_manager, value_scratch, + ProtoWrapperTypeOptions::kUnsetProtoDefault)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(field->name(), value)); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::StatusOr> Qualify( + ValueManager& value_manager, absl::Span qualifiers, + bool presence_test, Value& scratch) const final { + if (ABSL_PREDICT_FALSE(qualifiers.empty())) { + return absl::InvalidArgumentError("invalid select qualifier path."); + } + auto memory_manager = value_manager.GetMemoryManager(); + ParsedProtoQualifyState qualify_state(&message(), message().GetDescriptor(), + message().GetReflection(), + shared_from_this(), value_manager); + for (int i = 0; i < qualifiers.size() - 1; i++) { + const auto& qualifier = qualifiers[i]; + CEL_RETURN_IF_ERROR( + qualify_state.ApplySelectQualifier(qualifier, memory_manager)); + if (qualify_state.result().has_value()) { + scratch = std::move(qualify_state.result()).value(); + return std::pair{ValueView{scratch}, + scratch.Is() ? -1 : i + 1}; + } + } + const auto& last_qualifier = qualifiers.back(); + if (presence_test) { + CEL_RETURN_IF_ERROR( + qualify_state.ApplyLastQualifierHas(last_qualifier, memory_manager)); + } else { + CEL_RETURN_IF_ERROR( + qualify_state.ApplyLastQualifierGet(last_qualifier, memory_manager)); + } + scratch = std::move(qualify_state.result()).value(); + return std::pair{ValueView{scratch}, -1}; + } + + virtual const google::protobuf::Message& message() const = 0; + + protected: + Type GetTypeImpl(TypeManager& type_manager) const final { + return type_manager.CreateStructType(message().GetTypeName()); + } + + private: + absl::StatusOr EqualImpl(ValueManager& value_manager, + ParsedStructValueView other, + Value& scratch) const final { + if (const auto* parsed_proto_struct_value = AsParsedProtoStructValue(other); + parsed_proto_struct_value) { + const auto& lhs_message = message(); + const auto& rhs_message = parsed_proto_struct_value->message(); + if (lhs_message.GetDescriptor() == rhs_message.GetDescriptor()) { + return BoolValueView{ + google::protobuf::util::MessageDifferencer::Equals(lhs_message, rhs_message)}; + } + } + return ParsedStructValueInterface::EqualImpl(value_manager, other, scratch); + } + + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } + + absl::StatusOr HasField( + absl::Nonnull field_desc) const { + const auto* reflect = message().GetReflection(); + if (field_desc->is_map() || field_desc->is_repeated()) { + return reflect->FieldSize(message(), field_desc) > 0; + } + return reflect->HasField(message(), field_desc); + } + + absl::StatusOr GetField( + ValueManager& value_manager, + absl::Nonnull field_desc, Value& scratch, + ProtoWrapperTypeOptions unboxing_options) const { + return ProtoFieldToValue(shared_from_this(), &message(), + message().GetReflection(), field_desc, + value_manager, scratch, unboxing_options); + } +}; + +const ParsedProtoStructValueInterface* AsParsedProtoStructValue( + ParsedStructValueView value) { + return NativeTypeId::Of(value) == + NativeTypeId::For() + ? cel::internal::down_cast( + value.operator->()) + : nullptr; +} + +class PooledParsedProtoStructValueInterface final + : public ParsedProtoStructValueInterface { + public: + explicit PooledParsedProtoStructValueInterface( + absl::Nonnull message) + : message_(message) {} + + const google::protobuf::Message& message() const override { return *message_; } + + private: + absl::Nonnull message_; +}; + +class AliasingParsedProtoStructValueInterface final + : public ParsedProtoStructValueInterface { + public: + explicit AliasingParsedProtoStructValueInterface( + absl::Nonnull message, Shared alias) + : message_(message), alias_(std::move(alias)) {} + + const google::protobuf::Message& message() const override { return *message_; } + + private: + absl::Nonnull message_; + Shared alias_; +}; + +// Reference counted `ParsedProtoStructValueInterface`. Used when we know the +// concrete message type. +class ReffedStaticParsedProtoStructValueInterface final + : public ParsedProtoStructValueInterface, + public common_internal::ReferenceCount { + public: + explicit ReffedStaticParsedProtoStructValueInterface(size_t size) + : size_(size) {} + + const google::protobuf::Message& message() const override { + return *static_cast( + reinterpret_cast( + reinterpret_cast(this) + MessageOffset())); + } + + static size_t MessageOffset() { + return internal::AlignUp( + sizeof(ReffedStaticParsedProtoStructValueInterface), + __STDCPP_DEFAULT_NEW_ALIGNMENT__); + } + + private: + void Finalize() noexcept override { + reinterpret_cast(reinterpret_cast(this) + + MessageOffset()) + ->~MessageLite(); + } + + void Delete() noexcept override { + void* address = this; + const auto size = MessageOffset() + size_; + this->~ReffedStaticParsedProtoStructValueInterface(); + internal::SizedDelete(address, size); + } + + const size_t size_; +}; + +// Reference counted `ParsedProtoStructValueInterface`. Used when we do not know +// the concrete message type. +class ReffedDynamicParsedProtoStructValueInterface final + : public ParsedProtoStructValueInterface, + public common_internal::ReferenceCount { + public: + explicit ReffedDynamicParsedProtoStructValueInterface( + absl::Nonnull message) + : message_(message) {} + + const google::protobuf::Message& message() const override { return *message_; } + + private: + void Finalize() noexcept override { delete message_; } + + void Delete() noexcept override { delete this; } + + absl::Nonnull message_; +}; + +void ProtoMessageDestruct(void* object) { + static_cast(object)->~Message(); +} + +absl::StatusOr ProtoMapFieldToValue( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value) { + ABSL_DCHECK(field->is_map()); + CEL_ASSIGN_OR_RETURN(auto map_key_from_value, + GetProtoMapKeyFromValueConverter( + field->message_type()->map_key()->cpp_type())); + CEL_ASSIGN_OR_RETURN(auto map_key_to_value, + GetProtoMapKeyToValueConverter(field)); + CEL_ASSIGN_OR_RETURN(auto map_value_to_value, + GetProtoMapValueToValueConverter(field)); + if (!aliased) { + value = ParsedMapValue{value_manager.GetMemoryManager() + .MakeShared( + *message, field, map_key_from_value, + map_key_to_value, map_value_to_value)}; + } else { + value = ParsedMapValue{ + value_manager.GetMemoryManager() + .MakeShared>( + Shared(aliased), *message, field, + map_key_from_value, map_key_to_value, map_value_to_value)}; + } + return value; +} + +absl::StatusOr ProtoRepeatedFieldToValue( + SharedView aliased, + absl::Nonnull message, + absl::Nonnull reflection, + absl::Nonnull field, + ValueManager& value_manager, Value& value) { + ABSL_DCHECK(!field->is_map()); + ABSL_DCHECK(field->is_repeated()); + CEL_ASSIGN_OR_RETURN(auto repeated_field_to_value, + GetProtoRepeatedFieldToValueAccessor(field)); + if (!aliased) { + value = ParsedListValue{value_manager.GetMemoryManager() + .MakeShared( + *message, field, repeated_field_to_value)}; + } else { + value = ParsedListValue{ + value_manager.GetMemoryManager() + .MakeShared>( + Shared(aliased), *message, field, + repeated_field_to_value)}; + } + return value; +} + +absl::StatusOr> WellKnownProtoMessageToValue( + ValueFactory& value_factory, const TypeReflector& type_reflector, + absl::Nonnull message) { + const auto* desc = message->GetDescriptor(); + if (ABSL_PREDICT_FALSE(desc == nullptr)) { + return absl::nullopt; + } + switch (desc->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicFloatValueProto(*message)); + return DoubleValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicDoubleValueProto(*message)); + return DoubleValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicInt32ValueProto(*message)); + return IntValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicInt64ValueProto(*message)); + return IntValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicUInt32ValueProto(*message)); + return UintValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicUInt64ValueProto(*message)); + return UintValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicStringValueProto(*message)); + return StringValue{std::move(value)}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicBytesValueProto(*message)); + return BytesValue{std::move(value)}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicBoolValueProto(*message)); + return BoolValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: { + CEL_ASSIGN_OR_RETURN(auto any, UnwrapDynamicAnyProto(*message)); + CEL_ASSIGN_OR_RETURN(auto value, + type_reflector.DeserializeValue( + value_factory, any.type_url(), any.value())); + if (!value) { + return absl::NotFoundError( + absl::StrCat("unable to find deserializer for ", any.type_url())); + } + return std::move(value).value(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicDurationProto(*message)); + return DurationValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + CEL_ASSIGN_OR_RETURN(auto value, UnwrapDynamicTimestampProto(*message)); + return TimestampValue{value}; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: { + CEL_ASSIGN_OR_RETURN(auto value, DynamicValueProtoToJson(*message)); + return value_factory.CreateValueFromJson(std::move(value)); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: { + CEL_ASSIGN_OR_RETURN(auto value, DynamicListValueProtoToJson(*message)); + return value_factory.CreateValueFromJson(std::move(value)); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: { + CEL_ASSIGN_OR_RETURN(auto value, DynamicStructProtoToJson(*message)); + return value_factory.CreateValueFromJson(std::move(value)); + } + default: + return absl::nullopt; + } +} + +absl::StatusOr> WellKnownProtoMessageToValue( + ValueManager& value_manager, + absl::Nonnull message) { + return WellKnownProtoMessageToValue(value_manager, + value_manager.type_provider(), message); +} + +absl::Status ProtoMessageCopy( + absl::Nonnull to_message, + absl::Nonnull to_descriptor, + absl::Nonnull from_message) { + CEL_ASSIGN_OR_RETURN(const auto* from_descriptor, + GetDescriptor(*from_message)); + if (to_descriptor == from_descriptor) { + // Same. + to_message->CopyFrom(*from_message); + return absl::OkStatus(); + } + if (to_descriptor->full_name() == from_descriptor->full_name()) { + // Same type, different descriptors. + return ProtoMessageCopyUsingSerialization(to_message, from_message); + } + return TypeConversionError(from_descriptor->full_name(), + to_descriptor->full_name()) + .NativeValue(); +} + +absl::Status ProtoMessageCopy( + absl::Nonnull to_message, + absl::Nonnull to_descriptor, + absl::Nonnull from_message) { + const auto& from_type_name = from_message->GetTypeName(); + if (from_type_name == to_descriptor->full_name()) { + return ProtoMessageCopyUsingSerialization(to_message, from_message); + } + return TypeConversionError(from_type_name, to_descriptor->full_name()) + .NativeValue(); +} + +} // namespace + +absl::StatusOr> GetDescriptor( + const google::protobuf::Message& message) { + const auto* desc = message.GetDescriptor(); + if (ABSL_PREDICT_FALSE(desc == nullptr)) { + return absl::InvalidArgumentError( + absl::StrCat(message.GetTypeName(), " is missing descriptor")); + } + return desc; +} + +absl::StatusOr> GetReflection( + const google::protobuf::Message& message) { + const auto* reflect = message.GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return absl::InvalidArgumentError( + absl::StrCat(message.GetTypeName(), " is missing reflection")); + } + return reflect; +} + +absl::Nonnull GetReflectionOrDie( + const google::protobuf::Message& message) { + const auto* reflection = message.GetReflection(); + ABSL_CHECK(reflection != nullptr) // Crash OK + << message.GetTypeName() << " is missing reflection"; + return reflection; +} + +absl::StatusOr ProtoMessageToValueImpl( + ValueManager& value_manager, absl::Nonnull message, + size_t size, size_t align, + absl::Nonnull arena_copy_construct, + absl::Nonnull copy_construct) { + ABSL_DCHECK_GT(size, 0); + ABSL_DCHECK(absl::has_single_bit(align)); + { + CEL_ASSIGN_OR_RETURN(auto well_known, + WellKnownProtoMessageToValue(value_manager, message)); + if (well_known) { + return std::move(well_known).value(); + } + } + auto memory_manager = value_manager.GetMemoryManager(); + if (auto* arena = ProtoMemoryManagerArena(memory_manager); arena != nullptr) { + auto* copied_message = (*arena_copy_construct)(arena, message); + return ParsedStructValue{ + memory_manager.MakeShared( + copied_message)}; + } + switch (memory_manager.memory_management()) { + case MemoryManagement::kPooling: { + auto* copied_message = + (*copy_construct)(memory_manager.Allocate(size, align), message); + memory_manager.OwnCustomDestructor(copied_message, &ProtoMessageDestruct); + return ParsedStructValue{ + memory_manager.MakeShared( + copied_message)}; + } + case MemoryManagement::kReferenceCounting: { + auto* block = static_cast(memory_manager.Allocate( + ReffedStaticParsedProtoStructValueInterface::MessageOffset() + size, + __STDCPP_DEFAULT_NEW_ALIGNMENT__)); + auto* message_address = + block + ReffedStaticParsedProtoStructValueInterface::MessageOffset(); + auto* copied_message = (*copy_construct)(message_address, message); + ABSL_DCHECK_EQ(reinterpret_cast(message_address), + reinterpret_cast(copied_message)); + auto* message_value = ::new (static_cast(block)) + ReffedStaticParsedProtoStructValueInterface(size); + common_internal::SetReferenceCountForThat(*message_value, message_value); + return ParsedStructValue{common_internal::MakeShared( + common_internal::kAdoptRef, message_value, message_value)}; + } + } +} + +absl::StatusOr ProtoMessageToValueImpl( + ValueManager& value_manager, absl::Nonnull message, + size_t size, size_t align, + absl::Nonnull arena_move_construct, + absl::Nonnull move_construct) { + ABSL_DCHECK_GT(size, 0); + ABSL_DCHECK(absl::has_single_bit(align)); + { + CEL_ASSIGN_OR_RETURN(auto well_known, + WellKnownProtoMessageToValue(value_manager, message)); + if (well_known) { + return std::move(well_known).value(); + } + } + auto memory_manager = value_manager.GetMemoryManager(); + if (auto* arena = ProtoMemoryManagerArena(memory_manager); arena != nullptr) { + auto* moved_message = (*arena_move_construct)(arena, message); + return ParsedStructValue{ + memory_manager.MakeShared( + moved_message)}; + } + switch (memory_manager.memory_management()) { + case MemoryManagement::kPooling: { + auto* moved_message = + (*move_construct)(memory_manager.Allocate(size, align), message); + memory_manager.OwnCustomDestructor(moved_message, &ProtoMessageDestruct); + return ParsedStructValue{ + memory_manager.MakeShared( + moved_message)}; + } + case MemoryManagement::kReferenceCounting: { + auto* block = static_cast(memory_manager.Allocate( + ReffedStaticParsedProtoStructValueInterface::MessageOffset() + size, + __STDCPP_DEFAULT_NEW_ALIGNMENT__)); + auto* message_address = + block + ReffedStaticParsedProtoStructValueInterface::MessageOffset(); + auto* moved_message = (*move_construct)(message_address, message); + ABSL_DCHECK_EQ(reinterpret_cast(message_address), + reinterpret_cast(moved_message)); + auto* message_value = ::new (static_cast(block)) + ReffedStaticParsedProtoStructValueInterface(size); + common_internal::SetReferenceCountForThat(*message_value, message_value); + return ParsedStructValue{common_internal::MakeShared( + common_internal::kAdoptRef, message_value, message_value)}; + } + } +} + +absl::StatusOr ProtoMessageToValueImpl( + ValueManager& value_manager, Shared aliased, + absl::Nonnull message) { + { + CEL_ASSIGN_OR_RETURN(auto well_known, + WellKnownProtoMessageToValue(value_manager, message)); + if (well_known) { + return std::move(well_known).value(); + } + } + auto memory_manager = value_manager.GetMemoryManager(); + switch (memory_manager.memory_management()) { + case MemoryManagement::kPooling: { + if (!aliased) { + // `message` is indirectly owned by something on an arena. The user is + // responsible for ensuring they are the same arena or that `message` + // outlives the resulting value. + return ParsedStructValue{ + memory_manager.MakeShared( + message)}; + } + // `message` is indirectly owned by something reference counted. The + // destructor of the implementation will decrement the reference count. + return ParsedStructValue{ + memory_manager.MakeShared( + message, std::move(aliased))}; + } + case MemoryManagement::kReferenceCounting: { + if (!aliased) { + // `message` is indirectly owned by something on an arena, and we want + // to create a reference counted value. Unfortunately we have no way of + // ensuring the arena outlives the resulting reference counted value. So + // we need to perform a copy. + auto* copied_message = message->New(); + copied_message->CopyFrom(*message); + return ParsedStructValue{ + memory_manager + .MakeShared( + copied_message)}; + } + return ParsedStructValue{ + memory_manager.MakeShared( + message, std::move(aliased))}; + } + } +} + +absl::StatusOr> ProtoMessageFromValueImpl( + ValueView value, absl::Nonnull pool, + absl::Nonnull factory, google::protobuf::Arena* arena) { + switch (value.kind()) { + case ValueKind::kNull: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.Value", arena)); + CEL_RETURN_IF_ERROR(DynamicValueProtoFromJson(kJsonNull, *message)); + return message.release(); + } + case ValueKind::kBool: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.BoolValue", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicBoolValueProto( + Cast(value).NativeValue(), *message)); + return message.release(); + } + case ValueKind::kInt: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.Int64Value", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicInt64ValueProto( + Cast(value).NativeValue(), *message)); + return message.release(); + } + case ValueKind::kUint: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.UInt64Value", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicUInt64ValueProto( + Cast(value).NativeValue(), *message)); + return message.release(); + } + case ValueKind::kDouble: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.DoubleValue", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicDoubleValueProto( + Cast(value).NativeValue(), *message)); + return message.release(); + } + case ValueKind::kString: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.StringValue", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicStringValueProto( + Cast(value).NativeCord(), *message)); + return message.release(); + } + case ValueKind::kBytes: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.BytesValue", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicBytesValueProto( + Cast(value).NativeCord(), *message)); + return message.release(); + } + case ValueKind::kStruct: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, value.GetTypeName(), arena)); + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN(auto serialized, value.Serialize(converter)); + if (!message->ParsePartialFromCord(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parse `", message->GetTypeName(), "`")); + } + return message.release(); + } + case ValueKind::kDuration: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.Duration", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicDurationProto( + Cast(value).NativeValue(), *message)); + return message.release(); + } + case ValueKind::kTimestamp: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.Timestamp", arena)); + CEL_RETURN_IF_ERROR(WrapDynamicTimestampProto( + Cast(value).NativeValue(), *message)); + return message.release(); + } + case ValueKind::kList: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.ListValue", arena)); + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN( + auto json, Cast(value).ConvertToJsonArray(converter)); + CEL_RETURN_IF_ERROR(DynamicListValueProtoFromJson(json, *message)); + return message.release(); + } + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN( + auto message, + NewProtoMessage(pool, factory, "google.protobuf.Struct", arena)); + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN( + auto json, Cast(value).ConvertToJsonObject(converter)); + CEL_RETURN_IF_ERROR(DynamicStructProtoFromJson(json, *message)); + return message.release(); + } + default: + break; + } + return TypeConversionError(value.GetTypeName(), "*message*").NativeValue(); +} + +absl::Status ProtoMessageFromValueImpl( + ValueView value, absl::Nonnull pool, + absl::Nonnull factory, + absl::Nonnull message) { + CEL_ASSIGN_OR_RETURN(const auto* to_desc, GetDescriptor(*message)); + switch (to_desc->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + if (auto double_value = As(value); double_value) { + return WrapDynamicFloatValueProto( + static_cast(double_value->NativeValue()), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + if (auto double_value = As(value); double_value) { + return WrapDynamicDoubleValueProto( + static_cast(double_value->NativeValue()), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { + if (auto int_value = As(value); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("int64 to int32_t overflow"); + } + return WrapDynamicInt32ValueProto( + static_cast(int_value->NativeValue()), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { + if (auto int_value = As(value); int_value) { + return WrapDynamicInt64ValueProto(int_value->NativeValue(), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + if (auto uint_value = As(value); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return absl::OutOfRangeError("uint64 to uint32_t overflow"); + } + return WrapDynamicUInt32ValueProto( + static_cast(uint_value->NativeValue()), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + if (auto uint_value = As(value); uint_value) { + return WrapDynamicUInt64ValueProto(uint_value->NativeValue(), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + if (auto string_value = As(value); string_value) { + return WrapDynamicStringValueProto(string_value->NativeCord(), + *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + if (auto bytes_value = As(value); bytes_value) { + return WrapDynamicBytesValueProto(bytes_value->NativeCord(), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + if (auto bool_value = As(value); bool_value) { + return WrapDynamicBoolValueProto(bool_value->NativeValue(), *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: { + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN(auto any, value.ConvertToAny(converter)); + return WrapDynamicAnyProto(any, *message); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { + if (auto duration_value = As(value); duration_value) { + return WrapDynamicDurationProto(duration_value->NativeValue(), + *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + if (auto timestamp_value = As(value); + timestamp_value) { + return WrapDynamicTimestampProto(timestamp_value->NativeValue(), + *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: { + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN(auto json, value.ConvertToJson(converter)); + return DynamicValueProtoFromJson(json, *message); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: { + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN(auto json, value.ConvertToJson(converter)); + if (absl::holds_alternative(json)) { + return DynamicListValueProtoFromJson(absl::get(json), + *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: { + ProtoAnyToJsonConverter converter(pool, factory); + CEL_ASSIGN_OR_RETURN(auto json, value.ConvertToJson(converter)); + if (absl::holds_alternative(json)) { + return DynamicStructProtoFromJson(absl::get(json), + *message); + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()) + .NativeValue(); + } + default: + break; + } + + // Not a well known type. + + // Deal with legacy values. + if (auto legacy_value = As(value); + legacy_value) { + if ((legacy_value->message_ptr() & base_internal::kMessageWrapperTagMask) == + base_internal::kMessageWrapperTagMessageValue) { + // Full. + const auto* from_message = reinterpret_cast( + legacy_value->message_ptr() & base_internal::kMessageWrapperPtrMask); + return ProtoMessageCopy(message, to_desc, from_message); + } else { + // Lite. + // Only thing we can do is check type names, which is gross because proto + // returns `std::string`. + const auto* from_message = reinterpret_cast( + legacy_value->message_ptr() & base_internal::kMessageWrapperPtrMask); + return ProtoMessageCopy(message, to_desc, from_message); + } + } + + // Deal with modern values. + if (const auto* parsed_proto_struct_value = AsParsedProtoStructValue(value); + parsed_proto_struct_value) { + return ProtoMessageCopy(message, to_desc, + &parsed_proto_struct_value->message()); + } + + return TypeConversionError(value.GetTypeName(), message->GetTypeName()) + .NativeValue(); +} + +absl::StatusOr ProtoMessageToValueImpl( + ValueFactory& value_factory, const TypeReflector& type_reflector, + absl::Nonnull prototype, + const absl::Cord& serialized) { + auto memory_manager = value_factory.GetMemoryManager(); + auto* arena = ProtoMemoryManagerArena(value_factory.GetMemoryManager()); + auto message = ArenaUniquePtr(prototype->New(arena), + DefaultArenaDeleter{arena}); + if (!message->ParsePartialFromCord(serialized)) { + return absl::InvalidArgumentError( + absl::StrCat("failed to parse `", prototype->GetTypeName(), "`")); + } + { + CEL_ASSIGN_OR_RETURN(auto well_known, + WellKnownProtoMessageToValue( + value_factory, type_reflector, message.get())); + if (well_known) { + return std::move(well_known).value(); + } + } + switch (memory_manager.memory_management()) { + case MemoryManagement::kPooling: + return ParsedStructValue{ + memory_manager.MakeShared( + message.release())}; + case MemoryManagement::kReferenceCounting: + return ParsedStructValue{ + memory_manager + .MakeShared( + message.release())}; + } +} + +} // namespace extensions::protobuf_internal + +} // namespace cel diff --git a/extensions/protobuf/internal/message.h b/extensions/protobuf/internal/message.h new file mode 100644 index 000000000..1de3ee0d9 --- /dev/null +++ b/extensions/protobuf/internal/message.h @@ -0,0 +1,164 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MESSAGE_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MESSAGE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "common/memory.h" +#include "common/type_reflector.h" +#include "common/value.h" +#include "common/value_factory.h" +#include "common/value_manager.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions::protobuf_internal { + +template +inline constexpr bool IsProtoMessage = + std::conjunction_v>, + std::negation>>; + +template +struct ProtoMessageTraits { + static_assert(IsProtoMessage); + + static absl::Nonnull ArenaCopyConstruct( + absl::Nonnull arena, + absl::Nonnull from) { + if constexpr (google::protobuf::Arena::is_arena_constructable::value) { + return google::protobuf::Arena::Create(arena, + *google::protobuf::DynamicCastToGenerated(from)); + } else { + auto* to = google::protobuf::Arena::Create(arena); + *to = *google::protobuf::DynamicCastToGenerated(from); + return to; + } + } + + static absl::Nonnull CopyConstruct( + absl::Nonnull address, + absl::Nonnull from) { + return ::new (address) T(*google::protobuf::DynamicCastToGenerated(from)); + } + + static absl::Nonnull ArenaMoveConstruct( + absl::Nonnull arena, + absl::Nonnull from) { + if constexpr (google::protobuf::Arena::is_arena_constructable::value) { + return google::protobuf::Arena::Create( + arena, std::move(*google::protobuf::DynamicCastToGenerated(from))); + } else { + auto* to = google::protobuf::Arena::Create(arena); + *to = std::move(*google::protobuf::DynamicCastToGenerated(from)); + return to; + } + } + + static absl::Nonnull MoveConstruct( + absl::Nonnull address, absl::Nonnull from) { + return ::new (address) + T(std::move(*google::protobuf::DynamicCastToGenerated(from))); + } +}; + +// Get the `google::protobuf::Descriptor` from `google::protobuf::Message`, or return an error if it +// is `nullptr`. This should be extremely rare. +absl::StatusOr> GetDescriptor( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Get the `google::protobuf::Reflection` from `google::protobuf::Message`, or return an error if it +// is `nullptr`. This should be extremely rare. +absl::StatusOr> GetReflection( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Get the `google::protobuf::Reflection` from `google::protobuf::Message`, or abort. +// Should only be used when it is guaranteed `google::protobuf::Message` has reflection. +ABSL_ATTRIBUTE_PURE_FUNCTION absl::Nonnull +GetReflectionOrDie( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND); + +using ProtoMessageArenaCopyConstructor = absl::Nonnull (*)( + absl::Nonnull, absl::Nonnull); + +using ProtoMessageCopyConstructor = absl::Nonnull (*)( + absl::Nonnull, absl::Nonnull); + +using ProtoMessageArenaMoveConstructor = absl::Nonnull (*)( + absl::Nonnull, absl::Nonnull); + +using ProtoMessageMoveConstructor = absl::Nonnull (*)( + absl::Nonnull, absl::Nonnull); + +// Adapts a protocol buffer message to a value, copying it. +absl::StatusOr ProtoMessageToValueImpl( + ValueManager& value_manager, absl::Nonnull message, + size_t size, size_t align, + absl::Nonnull arena_copy_construct, + absl::Nonnull copy_construct); + +// Adapts a protocol buffer message to a value, moving it if possible. +absl::StatusOr ProtoMessageToValueImpl( + ValueManager& value_manager, absl::Nonnull message, + size_t size, size_t align, + absl::Nonnull arena_move_construct, + absl::Nonnull move_construct); + +// Aliasing conversion. Assumes `aliased` is the owner of `message`. +absl::StatusOr ProtoMessageToValueImpl( + ValueManager& value_manager, Shared aliased, + absl::Nonnull message); + +// Adapts a serialized protocol buffer message to a value. `prototype` should be +// the prototype message returned from the message factory. +absl::StatusOr ProtoMessageToValueImpl( + ValueFactory& value_factory, const TypeReflector& type_reflector, + absl::Nonnull prototype, + const absl::Cord& serialized); + +// Converts a value to a protocol buffer message. +absl::StatusOr> ProtoMessageFromValueImpl( + ValueView value, absl::Nonnull pool, + absl::Nonnull factory, google::protobuf::Arena* arena); +inline absl::StatusOr> +ProtoMessageFromValueImpl(ValueView value, google::protobuf::Arena* arena) { + return ProtoMessageFromValueImpl( + value, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), arena); +} + +// Converts a value to a specific protocol buffer message. +absl::Status ProtoMessageFromValueImpl( + ValueView value, absl::Nonnull pool, + absl::Nonnull factory, + absl::Nonnull message); +inline absl::Status ProtoMessageFromValueImpl( + ValueView value, absl::Nonnull message) { + return ProtoMessageFromValueImpl( + value, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), message); +} + +} // namespace cel::extensions::protobuf_internal + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MESSAGE_H_ diff --git a/extensions/protobuf/internal/message_test.cc b/extensions/protobuf/internal/message_test.cc new file mode 100644 index 000000000..b217b4f0c --- /dev/null +++ b/extensions/protobuf/internal/message_test.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/internal/message.h" + +#include "internal/testing.h" +#include "proto/test/v1/proto2/test_all_types.pb.h" + +namespace cel::extensions::protobuf_internal { +namespace { + +using ::google::api::expr::test::v1::proto2::TestAllTypes; +using testing::NotNull; +using cel::internal::IsOkAndHolds; + +TEST(GetDescriptor, NotNull) { + TestAllTypes message; + EXPECT_THAT(GetDescriptor(message), IsOkAndHolds(NotNull())); +} + +TEST(GetReflection, NotNull) { + TestAllTypes message; + EXPECT_THAT(GetReflection(message), IsOkAndHolds(NotNull())); +} + +TEST(GetReflectionOrDie, DoesNotDie) { + TestAllTypes message; + EXPECT_THAT(GetReflectionOrDie(message), NotNull()); +} + +} // namespace +} // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/value.h b/extensions/protobuf/value.h index 2321087c6..9672214ea 100644 --- a/extensions/protobuf/value.h +++ b/extensions/protobuf/value.h @@ -11,25 +11,47 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// Utilities for wrapping and unwrapping cel::Values representing protobuf +// message types. +// +// Handles adapting well-known types to their corresponding CEL representation +// (see https://protobuf.dev/reference/protobuf/google.protobuf/). #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ +#include #include +#include +#include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/value.h" #include "common/value_factory.h" +#include "common/value_manager.h" +#include "extensions/protobuf/internal/duration.h" #include "extensions/protobuf/internal/enum.h" +#include "extensions/protobuf/internal/message.h" +#include "extensions/protobuf/internal/struct.h" +#include "extensions/protobuf/internal/timestamp.h" +#include "extensions/protobuf/internal/wrappers.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/generated_enum_reflection.h" +#include "google/protobuf/message.h" namespace cel::extensions { +// Adapt a protobuf enum value to cel:Value. template std::enable_if_t, absl::StatusOr> ProtoEnumToValue(ValueFactory&, T value, @@ -41,6 +63,7 @@ ProtoEnumToValue(ValueFactory&, T value, return IntValueView{static_cast(value)}; } +// Adapt a protobuf enum value to cel:Value. template std::enable_if_t, absl::StatusOr> ProtoEnumToValue(ValueFactory&, T value) { @@ -50,9 +73,12 @@ ProtoEnumToValue(ValueFactory&, T value) { return IntValue{static_cast(value)}; } +// Adapt a cel::Value representing a protobuf enum to the normalized enum value, +// given the enum descriptor. absl::StatusOr ProtoEnumFromValue( ValueView value, absl::Nonnull desc); +// Adapt a cel::Value representing a protobuf enum to the normalized enum value. template std::enable_if_t, absl::StatusOr> ProtoEnumFromValue(ValueView value) { @@ -62,6 +88,166 @@ ProtoEnumFromValue(ValueView value) { return static_cast(enum_value); } +// Adapt a protobuf message to a cel::Value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +template +std::enable_if_t, absl::StatusOr> +ProtoMessageToValue(ValueManager& value_manager, T&& value) { + using Tp = std::decay_t; + if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedBoolValueProto( + std::forward(value))); + return BoolValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedInt32ValueProto( + std::forward(value))); + return IntValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedInt64ValueProto( + std::forward(value))); + return IntValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedUInt32ValueProto( + std::forward(value))); + return UintValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedUInt64ValueProto( + std::forward(value))); + return UintValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedFloatValueProto( + std::forward(value))); + return DoubleValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedDoubleValueProto( + std::forward(value))); + return DoubleValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedBytesValueProto( + std::forward(value))); + return BytesValue{std::move(result)}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedStringValueProto( + std::forward(value))); + return StringValue{std::move(result)}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedDurationProto( + std::forward(value))); + return DurationValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::UnwrapGeneratedTimestampProto( + std::forward(value))); + return TimestampValue{result}; + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN( + auto result, + protobuf_internal::GeneratedValueProtoToJson(std::forward(value))); + return value_manager.CreateValueFromJson(std::move(result)); + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN(auto result, + protobuf_internal::GeneratedListValueProtoToJson( + std::forward(value))); + return value_manager.CreateListValueFromJsonArray(std::move(result)); + } else if constexpr (std::is_same_v) { + CEL_ASSIGN_OR_RETURN( + auto result, + protobuf_internal::GeneratedStructProtoToJson(std::forward(value))); + return value_manager.CreateMapValueFromJsonObject(std::move(result)); + } else { + auto dispatcher = absl::Overload( + [&](Tp&& m) { + return protobuf_internal::ProtoMessageToValueImpl( + value_manager, &m, sizeof(T), alignof(T), + &protobuf_internal::ProtoMessageTraits::ArenaMoveConstruct, + &protobuf_internal::ProtoMessageTraits::MoveConstruct); + }, + [&](const Tp& m) { + return protobuf_internal::ProtoMessageToValueImpl( + value_manager, &m, sizeof(T), alignof(T), + &protobuf_internal::ProtoMessageTraits::ArenaCopyConstruct, + &protobuf_internal::ProtoMessageTraits::CopyConstruct); + }); + return dispatcher(std::forward(value)); + } +} + +// Adapt a protobuf message to a cel::Value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +template +std::enable_if_t, + absl::StatusOr> +ProtoMessageToValue(ValueManager& value_manager, T&& value, + Value& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN( + scratch, ProtoMessageToValue(value_manager, std::forward(value))); + return scratch; +} + +// Extract a protobuf message from a CEL value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +template +std::enable_if_t, absl::Status> +ProtoMessageFromValue(ValueView value, T& message, + absl::Nonnull pool, + absl::Nonnull factory) { + return protobuf_internal::ProtoMessageFromValueImpl(value, pool, factory, + &message); +} + +// Extract a protobuf message from a CEL value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +template +std::enable_if_t, absl::Status> +ProtoMessageFromValue(ValueView value, T& message) { + return protobuf_internal::ProtoMessageFromValueImpl(value, &message); +} + +// Extract a protobuf message from a CEL value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +inline absl::StatusOr> ProtoMessageFromValue( + ValueView value, absl::Nullable arena) { + return protobuf_internal::ProtoMessageFromValueImpl(value, arena); +} + +// Extract a protobuf message from a CEL value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +inline absl::StatusOr> ProtoMessageFromValue( + ValueView value, absl::Nullable arena, + absl::Nonnull pool, + absl::Nonnull factory) { + return protobuf_internal::ProtoMessageFromValueImpl(value, pool, factory, + arena); +} + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ diff --git a/extensions/protobuf/value_end_to_end_test.cc b/extensions/protobuf/value_end_to_end_test.cc new file mode 100644 index 000000000..4cc642295 --- /dev/null +++ b/extensions/protobuf/value_end_to_end_test.cc @@ -0,0 +1,943 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Functional tests for protobuf backed CEL structs in the default runtime. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/memory.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "extensions/protobuf/value.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "google/protobuf/text_format.h" + +namespace cel::extensions { +namespace { + +using ::cel::test::BoolValueIs; +using ::cel::test::BytesValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::ListValueIs; +using ::cel::test::MapValueIs; +using ::cel::test::StringValueIs; +using ::cel::test::StructValueIs; +using ::cel::test::TimestampValueIs; +using ::cel::test::UintValueIs; +using ::cel::test::ValueMatcher; +using ::google::api::expr::v1alpha1::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::test::v1::proto3::TestAllTypes; +using testing::_; +using testing::AnyOf; +using testing::HasSubstr; +using cel::internal::StatusIs; + +struct TestCase { + std::string name; + std::string expr; + std::string msg_textproto; + ValueMatcher matcher; +}; + +std::ostream& operator<<(std::ostream& out, const TestCase& tc) { + return out << tc.name; +} + +class ProtobufValueEndToEndTest + : public common_internal::ThreadCompatibleValueTest { + using Base = common_internal::ThreadCompatibleValueTest; + + public: + ProtobufValueEndToEndTest() = default; + + using Base::ToString; + + protected: + const TestCase& test_case() const { return std::get<1>(GetParam()); } +}; + +TEST_P(ProtobufValueEndToEndTest, Runner) { + auto tc_name = + ::testing::UnitTest::GetInstance()->current_test_info()->name(); + + if (memory_management() == MemoryManagement::kReferenceCounting && + absl::StrContains(tc_name, "map_") && + absl::StrContains(tc_name, "compre")) { + GTEST_SKIP() + << " TODO(uncreated-issue/66): key view use after free for comprehension"; + } + TestAllTypes message; + + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(test_case().msg_textproto, &message)); + + ASSERT_OK_AND_ASSIGN(Value value, + ProtoMessageToValue(value_manager(), message)); + + Activation activation; + activation.InsertOrAssignValue("msg", std::move(value)); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder(opts)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(test_case().expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + ASSERT_OK_AND_ASSIGN(Value result, + program->Evaluate(activation, value_manager())); + + EXPECT_THAT(result, test_case().matcher); +} + +INSTANTIATE_TEST_SUITE_P( + Scalars, ProtobufValueEndToEndTest, + testing::Combine(testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + testing::ValuesIn(std::vector{ + {"single_int64", "msg.single_int64", + R"pb( + single_int64: 42 + )pb", + IntValueIs(42)}, + {"single_int32", "msg.single_int32", + R"pb( + single_int32: 42 + )pb", + IntValueIs(42)}, + {"single_uint64", "msg.single_uint64", + R"pb( + single_uint64: 42 + )pb", + UintValueIs(42)}, + {"single_uint32", "msg.single_uint32", + R"pb( + single_uint32: 42 + )pb", + UintValueIs(42)}, + {"single_sint64", "msg.single_sint64", + R"pb( + single_sint64: 42 + )pb", + IntValueIs(42)}, + {"single_sint32", "msg.single_sint32", + R"pb( + single_sint32: 42 + )pb", + IntValueIs(42)}, + {"single_fixed64", "msg.single_fixed64", + R"pb( + single_fixed64: 42 + )pb", + UintValueIs(42)}, + {"single_fixed32", "msg.single_fixed32", + R"pb( + single_fixed32: 42 + )pb", + UintValueIs(42)}, + {"single_sfixed64", "msg.single_sfixed64", + R"pb( + single_sfixed64: 42 + )pb", + IntValueIs(42)}, + {"single_sfixed32", "msg.single_sfixed32", + R"pb( + single_sfixed32: 42 + )pb", + IntValueIs(42)}, + {"single_float", "msg.single_float", + R"pb( + single_float: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"single_double", "msg.single_double", + R"pb( + single_double: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"single_bool", "msg.single_bool", + R"pb( + single_bool: true + )pb", + BoolValueIs(true)}, + {"single_string", "msg.single_string", + R"pb( + single_string: "Hello 😀" + )pb", + StringValueIs("Hello 😀")}, + {"single_bytes", "msg.single_bytes", + R"pb( + single_bytes: "Hello" + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.single_duration", + R"pb( + single_duration { seconds: 10 } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.single_timestamp", + R"pb( + single_timestamp { seconds: 10 } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.single_int64_wrapper", + R"pb( + single_int64_wrapper { value: -20 } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.single_int32_wrapper", + R"pb( + single_int32_wrapper { value: -10 } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.single_uint64_wrapper", + R"pb( + single_uint64_wrapper { value: 10 } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.single_uint32_wrapper", + R"pb( + single_uint32_wrapper { value: 11 } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.single_float_wrapper", + R"pb( + single_float_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.single_double_wrapper", + R"pb( + single_double_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.single_bool_wrapper", + R"pb( + single_bool_wrapper { value: false } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.single_string_wrapper", + R"pb( + single_string_wrapper { value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.single_bytes_wrapper", + R"pb( + single_bytes_wrapper { value: "abcd" } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.null_value", + R"pb( + null_value: NULL_VALUE + )pb", + IsNullValue()}, + {"message_field", "msg.standalone_message", + R"pb( + standalone_message { bb: 2 } + )pb", + StructValueIs(_)}, + {"single_enum", "msg.standalone_enum", + R"pb( + standalone_enum: BAR + )pb", + // BAR + IntValueIs(1)}})), + ProtobufValueEndToEndTest::ToString); + +INSTANTIATE_TEST_SUITE_P( + Repeated, ProtobufValueEndToEndTest, + testing::Combine( + testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + testing::ValuesIn(std::vector{ + {"repeated_int64", "msg.repeated_int64[0]", + R"pb( + repeated_int64: 42 + )pb", + IntValueIs(42)}, + {"repeated_int32", "msg.repeated_int32[0]", + R"pb( + repeated_int32: 42 + )pb", + IntValueIs(42)}, + {"repeated_uint64", "msg.repeated_uint64[0]", + R"pb( + repeated_uint64: 42 + )pb", + UintValueIs(42)}, + {"repeated_uint32", "msg.repeated_uint32[0]", + R"pb( + repeated_uint32: 42 + )pb", + UintValueIs(42)}, + {"repeated_sint64", "msg.repeated_sint64[0]", + R"pb( + repeated_sint64: 42 + )pb", + IntValueIs(42)}, + {"repeated_sint32", "msg.repeated_sint32[0]", + R"pb( + repeated_sint32: 42 + )pb", + IntValueIs(42)}, + {"repeated_fixed64", "msg.repeated_fixed64[0]", + R"pb( + repeated_fixed64: 42 + )pb", + UintValueIs(42)}, + {"repeated_fixed32", "msg.repeated_fixed32[0]", + R"pb( + repeated_fixed32: 42 + )pb", + UintValueIs(42)}, + {"repeated_sfixed64", "msg.repeated_sfixed64[0]", + R"pb( + repeated_sfixed64: 42 + )pb", + IntValueIs(42)}, + {"repeated_sfixed32", "msg.repeated_sfixed32[0]", + R"pb( + repeated_sfixed32: 42 + )pb", + IntValueIs(42)}, + {"repeated_float", "msg.repeated_float[0]", + R"pb( + repeated_float: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"repeated_double", "msg.repeated_double[0]", + R"pb( + repeated_double: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"repeated_bool", "msg.repeated_bool[0]", + R"pb( + repeated_bool: true + )pb", + BoolValueIs(true)}, + {"repeated_string", "msg.repeated_string[0]", + R"pb( + repeated_string: "Hello 😀" + )pb", + StringValueIs("Hello 😀")}, + {"repeated_bytes", "msg.repeated_bytes[0]", + R"pb( + repeated_bytes: "Hello" + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.repeated_duration[0]", + R"pb( + repeated_duration { seconds: 10 } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.repeated_timestamp[0]", + R"pb( + repeated_timestamp { seconds: 10 } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.repeated_int64_wrapper[0]", + R"pb( + repeated_int64_wrapper { value: -20 } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.repeated_int32_wrapper[0]", + R"pb( + repeated_int32_wrapper { value: -10 } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.repeated_uint64_wrapper[0]", + R"pb( + repeated_uint64_wrapper { value: 10 } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.repeated_uint32_wrapper[0]", + R"pb( + repeated_uint32_wrapper { value: 11 } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.repeated_float_wrapper[0]", + R"pb( + repeated_float_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.repeated_double_wrapper[0]", + R"pb( + repeated_double_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.repeated_bool_wrapper[0]", + R"pb( + + repeated_bool_wrapper { value: false } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.repeated_string_wrapper[0]", + R"pb( + repeated_string_wrapper { value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.repeated_bytes_wrapper[0]", + R"pb( + repeated_bytes_wrapper { value: "abcd" } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.repeated_null_value[0]", + R"pb( + repeated_null_value: NULL_VALUE + )pb", + IsNullValue()}, + {"message_field", "msg.repeated_nested_message[0]", + R"pb( + repeated_nested_message { bb: 42 } + )pb", + StructValueIs(_)}, + {"repeated_enum", "msg.repeated_nested_enum[0]", + R"pb( + repeated_nested_enum: BAR + )pb", + // BAR + IntValueIs(1)}, + // Implements CEL list interface + {"repeated_size", "msg.repeated_int64.size()", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + IntValueIs(2)}, + {"in_repeated", "42 in msg.repeated_int64", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(true)}, + {"in_repeated_false", "44 in msg.repeated_int64", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(false)}, + {"repeated_compre_exists", "msg.repeated_int64.exists(x, x > 42)", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(true)}, + {"repeated_compre_map", "msg.repeated_int64.map(x, x * 2)[0]", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + IntValueIs(84)}, + })), + ProtobufValueEndToEndTest::ToString); + +INSTANTIATE_TEST_SUITE_P( + Maps, ProtobufValueEndToEndTest, + testing::Combine( + testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + testing::ValuesIn(std::vector{ + {"map_bool_int64", "msg.map_bool_int64[false]", + R"pb( + map_bool_int64 { key: false value: 42 } + )pb", + IntValueIs(42)}, + {"map_bool_int32", "msg.map_bool_int32[false]", + R"pb( + map_bool_int32 { key: false value: 42 } + )pb", + IntValueIs(42)}, + {"map_bool_uint64", "msg.map_bool_uint64[false]", + R"pb( + map_bool_uint64 { key: false value: 42 } + )pb", + UintValueIs(42)}, + {"map_bool_uint32", "msg.map_bool_uint32[false]", + R"pb( + map_bool_uint32 { key: false, value: 42 } + )pb", + UintValueIs(42)}, + {"map_bool_float", "msg.map_bool_float[false]", + R"pb( + map_bool_float { key: false value: 4.25 } + )pb", + DoubleValueIs(4.25)}, + {"map_bool_double", "msg.map_bool_double[false]", + R"pb( + map_bool_double { key: false value: 4.25 } + )pb", + DoubleValueIs(4.25)}, + {"map_bool_bool", "msg.map_bool_bool[false]", + R"pb( + map_bool_bool { key: false value: true } + )pb", + BoolValueIs(true)}, + {"map_bool_string", "msg.map_bool_string[false]", + R"pb( + map_bool_string { key: false value: "Hello 😀" } + )pb", + StringValueIs("Hello 😀")}, + {"map_bool_bytes", "msg.map_bool_bytes[false]", + R"pb( + map_bool_bytes { key: false value: "Hello" } + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.map_bool_duration[false]", + R"pb( + map_bool_duration { + key: false + value { seconds: 10 } + } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.map_bool_timestamp[false]", + R"pb( + map_bool_timestamp { + key: false + value { seconds: 10 } + } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.map_bool_int64_wrapper[false]", + R"pb( + map_bool_int64_wrapper { + key: false + value { value: -20 } + } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.map_bool_int32_wrapper[false]", + R"pb( + map_bool_int32_wrapper { + key: false + value { value: -10 } + } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.map_bool_uint64_wrapper[false]", + R"pb( + map_bool_uint64_wrapper { + key: false + value { value: 10 } + } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.map_bool_uint32_wrapper[false]", + R"pb( + map_bool_uint32_wrapper { + key: false + value { value: 11 } + } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.map_bool_float_wrapper[false]", + R"pb( + map_bool_float_wrapper { + key: false + value { value: 10.25 } + } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.map_bool_double_wrapper[false]", + R"pb( + map_bool_double_wrapper { + key: false + value { value: 10.25 } + } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.map_bool_bool_wrapper[false]", + R"pb( + map_bool_bool_wrapper { + key: false + value { value: false } + } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.map_bool_string_wrapper[false]", + R"pb( + map_bool_string_wrapper { + key: false + value { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.map_bool_bytes_wrapper[false]", + R"pb( + map_bool_bytes_wrapper { + key: false + value { value: "abcd" } + } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.map_bool_null_value[false]", + R"pb( + map_bool_null_value { key: false value: NULL_VALUE } + )pb", + IsNullValue()}, + {"message_field", "msg.map_bool_message[false]", + R"pb( + map_bool_message { + key: false + value { bb: 42 } + } + )pb", + StructValueIs(_)}, + {"map_bool_enum", "msg.map_bool_enum[false]", + R"pb( + map_bool_enum { key: false value: BAR } + )pb", + // BAR + IntValueIs(1)}, + // Simplified for remaining key types. + {"map_int32_int64", "msg.map_int32_int64[42]", + R"pb( + map_int32_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_int64_int64", "msg.map_int64_int64[42]", + R"pb( + map_int64_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_uint32_int64", "msg.map_uint32_int64[42u]", + R"pb( + map_uint32_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_uint64_int64", "msg.map_uint64_int64[42u]", + R"pb( + map_uint64_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_string_int64", "msg.map_string_int64['key1']", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + IntValueIs(-42)}, + // Implements CEL map + {"in_map_int64_true", "42 in msg.map_int64_int64", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(true)}, + {"in_map_int64_false", "44 in msg.map_int64_int64", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(false)}, + {"int_map_int64_compre_exists", + "msg.map_int64_int64.exists(key, key > 42)", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(true)}, + {"int_map_int64_compre_map", + "msg.map_int64_int64.map(key, key + 20)[0]", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + + IntValueIs(AnyOf(62, 63))}, + {"map_string_key_not_found", "msg.map_string_int64['key2']", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}, + // TODO(uncreated-issue/66): with heterogeneous lookups enabled, this + // should just be no such key. + // Add support for convertible double keys. + {"map_int32_out_of_range", "msg.map_int32_int64[0x1FFFFFFFF]", + R"pb( + map_int32_int64 { key: 10 value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange, + HasSubstr("int64 to int32_t overflow")))}, + {"map_uint32_out_of_range", "msg.map_uint32_int64[0x1FFFFFFFFu]", + R"pb( + map_uint32_int64 { key: 10 value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange, + HasSubstr("uint64 to uint32_t overflow")))}})), + ProtobufValueEndToEndTest::ToString); + +MATCHER_P(CelSizeIs, size, "") { return arg.Size() == size; } + +INSTANTIATE_TEST_SUITE_P( + JsonWrappers, ProtobufValueEndToEndTest, + testing::Combine( + testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + testing::ValuesIn(std::vector{ + {"single_struct", "msg.single_struct", + R"pb( + single_struct { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_struct_null_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + )pb", + IsNullValue()}, + {"single_struct_number_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { number_value: 10.25 } + } + } + )pb", + DoubleValueIs(10.25)}, + {"single_struct_bool_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_string_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { string_value: "abcd" } + } + } + )pb", + StringValueIs("abcd")}, + {"single_struct_struct_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { + struct_value { + fields { + key: "field2", + value: { null_value: NULL_VALUE } + } + } + } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_struct_list_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { list_value { values { null_value: NULL_VALUE } } } + } + } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_list_value", "msg.list_value", + R"pb( + list_value { values { null_value: NULL_VALUE } } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_list_value_index_null", "msg.list_value[0]", + R"pb( + list_value { values { null_value: NULL_VALUE } } + )pb", + IsNullValue()}, + {"single_list_value_index_number", "msg.list_value[0]", + R"pb( + list_value { values { number_value: 10.25 } } + )pb", + DoubleValueIs(10.25)}, + {"single_list_value_index_string", "msg.list_value[0]", + R"pb( + list_value { values { string_value: "abc" } } + )pb", + StringValueIs("abc")}, + {"single_list_value_index_bool", "msg.list_value[0]", + R"pb( + list_value { values { bool_value: false } } + )pb", + BoolValueIs(false)}, + {"single_list_value_index_struct", "msg.list_value[0]", + R"pb( + list_value { + values { + struct_value { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_list_value_index_list", "msg.list_value[0]", + R"pb( + list_value { + values { list_value { values { null_value: NULL_VALUE } } } + } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_json_value_null", "msg.single_value", + R"pb( + single_value { null_value: NULL_VALUE } + )pb", + IsNullValue()}, + {"single_json_value_number", "msg.single_value", + R"pb( + single_value { number_value: 13.25 } + )pb", + DoubleValueIs(13.25)}, + {"single_json_value_string", "msg.single_value", + R"pb( + single_value { string_value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"single_json_value_bool", "msg.single_value", + R"pb( + single_value { bool_value: false } + )pb", + BoolValueIs(false)}, + {"single_json_value_struct", "msg.single_value", + R"pb( + single_value { struct_value {} } + )pb", + MapValueIs(CelSizeIs(0))}, + {"single_json_value_list", "msg.single_value", + R"pb( + single_value { list_value {} } + )pb", + ListValueIs(CelSizeIs(0))}, + })), + ProtobufValueEndToEndTest::ToString); + +// TODO(uncreated-issue/66): any support needs the reflection impl for looking up the +// type name and corresponding deserializer (outside of the WKTs which are +// special cased). +INSTANTIATE_TEST_SUITE_P( + Any, ProtobufValueEndToEndTest, + testing::Combine( + testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + testing::ValuesIn(std::vector{ + {"single_any_wkt_int64", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } + } + )pb", + IntValueIs(42)}, + {"single_any_wkt_int32", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int32Value] { value: 42 } + } + )pb", + IntValueIs(42)}, + {"single_any_wkt_uint64", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.UInt64Value] { value: 42 } + } + )pb", + UintValueIs(42)}, + {"single_any_wkt_uint32", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.UInt32Value] { value: 42 } + } + )pb", + UintValueIs(42)}, + {"single_any_wkt_double", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.DoubleValue] { + value: 30.5 + } + } + )pb", + DoubleValueIs(30.5)}, + {"single_any_wkt_string", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.StringValue] { + value: "abcd" + } + } + )pb", + StringValueIs("abcd")}, + + {"repeated_any_wkt_string", "msg.repeated_any[0]", + R"pb( + repeated_any { + [type.googleapis.com/google.protobuf.StringValue] { + value: "abcd" + } + } + )pb", + StringValueIs("abcd")}, + {"map_int64_any_wkt_string", "msg.map_int64_any[0]", + R"pb( + map_int64_any { + key: 0 + value { + [type.googleapis.com/google.protobuf.StringValue] { + value: "abcd" + } + } + } + )pb", + StringValueIs("abcd")}, + })), + ProtobufValueEndToEndTest::ToString); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/protobuf/value_test.cc b/extensions/protobuf/value_test.cc index 50e7b1122..f73baf7a8 100644 --- a/extensions/protobuf/value_test.cc +++ b/extensions/protobuf/value_test.cc @@ -14,25 +14,72 @@ #include "extensions/protobuf/value.h" +#include +#include +#include +#include + +#include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "common/casting.h" +#include "common/memory.h" #include "common/value.h" +#include "common/value_kind.h" #include "common/value_testing.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/proto_matchers.h" #include "internal/testing.h" #include "proto/test/v1/proto2/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" namespace cel::extensions { namespace { +using ::cel::internal::test::EqualsProto; +using ::cel::test::BoolValueIs; +using ::cel::test::BytesValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::StringValueIs; +using ::cel::test::StructValueFieldHas; +using ::cel::test::StructValueFieldIs; +using ::cel::test::TimestampValueIs; +using ::cel::test::UintValueIs; +using ::cel::test::ValueKindIs; using ::google::api::expr::test::v1::proto2::TestAllTypes; using testing::Eq; +using testing::IsTrue; +using testing::Pointee; +using cel::internal::IsOk; using cel::internal::IsOkAndHolds; using cel::internal::StatusIs; -class ProtoValueTest : public common_internal::ThreadCompatibleValueTest<> {}; +template +T ParseTextOrDie(absl::string_view text) { + T proto; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text, &proto)); + return proto; +} + +class ProtoValueTest : public common_internal::ThreadCompatibleValueTest<> { + protected: + MemoryManager NewThreadCompatiblePoolingMemoryManager() override { + return ProtoMemoryManager(); + } +}; + +class ProtoValueWrapTest : public ProtoValueTest {}; -TEST_P(ProtoValueTest, ProtoEnumToValue) { +TEST_P(ProtoValueWrapTest, ProtoEnumToValue) { ASSERT_OK_AND_ASSIGN( auto enum_value, ProtoEnumToValue(value_factory(), @@ -44,7 +91,367 @@ TEST_P(ProtoValueTest, ProtoEnumToValue) { ASSERT_THAT(Cast(enum_value).NativeValue(), Eq(1)); } -TEST_P(ProtoValueTest, ProtoEnumFromValue) { +TEST_P(ProtoValueWrapTest, ProtoBoolValueToValue) { + google::protobuf::BoolValue message; + message.set_value(true); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(BoolValueIs(Eq(true)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(BoolValueIs(Eq(true)))); +} + +TEST_P(ProtoValueWrapTest, ProtoInt32ValueToValue) { + google::protobuf::Int32Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(IntValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(IntValueIs(Eq(1)))); +} + +TEST_P(ProtoValueWrapTest, ProtoInt64ValueToValue) { + google::protobuf::Int64Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(IntValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(IntValueIs(Eq(1)))); +} + +TEST_P(ProtoValueWrapTest, ProtoUInt32ValueToValue) { + google::protobuf::UInt32Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(UintValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(UintValueIs(Eq(1)))); +} + +TEST_P(ProtoValueWrapTest, ProtoUInt64ValueToValue) { + google::protobuf::UInt64Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(UintValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(UintValueIs(Eq(1)))); +} + +TEST_P(ProtoValueWrapTest, ProtoFloatValueToValue) { + google::protobuf::FloatValue message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(DoubleValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(DoubleValueIs(Eq(1)))); +} + +TEST_P(ProtoValueWrapTest, ProtoDoubleValueToValue) { + google::protobuf::DoubleValue message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(DoubleValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(DoubleValueIs(Eq(1)))); +} + +TEST_P(ProtoValueWrapTest, ProtoBytesValueToValue) { + google::protobuf::BytesValue message; + message.set_value("foo"); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(BytesValueIs(Eq("foo")))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(BytesValueIs(Eq("foo")))); +} + +TEST_P(ProtoValueWrapTest, ProtoStringValueToValue) { + google::protobuf::StringValue message; + message.set_value("foo"); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(StringValueIs(Eq("foo")))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(StringValueIs(Eq("foo")))); +} + +TEST_P(ProtoValueWrapTest, ProtoDurationToValue) { + google::protobuf::Duration message; + message.set_seconds(1); + message.set_nanos(1); + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(DurationValueIs( + Eq(absl::Seconds(1) + absl::Nanoseconds(1))))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(DurationValueIs( + Eq(absl::Seconds(1) + absl::Nanoseconds(1))))); +} + +TEST_P(ProtoValueWrapTest, ProtoTimestampToValue) { + google::protobuf::Timestamp message; + message.set_seconds(1); + message.set_nanos(1); + EXPECT_THAT( + ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(TimestampValueIs( + Eq(absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1))))); + EXPECT_THAT( + ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(TimestampValueIs( + Eq(absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1))))); +} + +TEST_P(ProtoValueWrapTest, ProtoMessageToValue) { + TestAllTypes message; + EXPECT_THAT(ProtoMessageToValue(value_manager(), message), + IsOkAndHolds(ValueKindIs(Eq(ValueKind::kStruct)))); + EXPECT_THAT(ProtoMessageToValue(value_manager(), std::move(message)), + IsOkAndHolds(ValueKindIs(Eq(ValueKind::kStruct)))); +} + +TEST_P(ProtoValueWrapTest, GetFieldByName) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(value_manager(), ParseTextOrDie( + R"pb(single_int32: 1, + single_int64: 1 + single_uint32: 1 + single_uint64: 1 + single_float: 1 + single_double: 1 + single_bool: true + single_string: "foo" + single_bytes: "foo")pb"))); + EXPECT_THAT(value, StructValueIs(StructValueFieldIs( + &value_manager(), "single_int32", IntValueIs(Eq(1))))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_int32", IsTrue()))); + EXPECT_THAT(value, StructValueIs(StructValueFieldIs( + &value_manager(), "single_int64", IntValueIs(Eq(1))))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_int64", IsTrue()))); + EXPECT_THAT( + value, StructValueIs(StructValueFieldIs(&value_manager(), "single_uint32", + UintValueIs(Eq(1))))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_uint32", IsTrue()))); + EXPECT_THAT( + value, StructValueIs(StructValueFieldIs(&value_manager(), "single_uint64", + UintValueIs(Eq(1))))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_uint64", IsTrue()))); +} + +INSTANTIATE_TEST_SUITE_P(ProtoValueTest, ProtoValueWrapTest, + testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + ProtoValueTest::ToString); + +struct DefaultArenaDeleter { + template + void operator()(T* message) const { + if (arena == nullptr) { + delete message; + } + } + + google::protobuf::Arena* arena = nullptr; +}; + +template +using ArenaUniquePtr = std::unique_ptr; + +template +ArenaUniquePtr WrapArenaUnique(T* message) { + return ArenaUniquePtr(message, DefaultArenaDeleter{message->GetArena()}); +} + +template +absl::StatusOr> WrapArenaUnique(absl::StatusOr message) { + if (!message.ok()) { + return message.status(); + } + return WrapArenaUnique(*message); +} + +class ProtoValueUnwrapTest : public ProtoValueTest {}; + +TEST_P(ProtoValueUnwrapTest, ProtoBoolValueFromValue) { + google::protobuf::BoolValue message; + EXPECT_THAT(ProtoMessageFromValue(BoolValueView{true}, message), IsOk()); + EXPECT_EQ(message.value(), true); + + EXPECT_THAT( + WrapArenaUnique(ProtoMessageFromValue( + BoolValueView{true}, ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoInt32ValueFromValue) { + google::protobuf::Int32Value message; + EXPECT_THAT(ProtoMessageFromValue(IntValueView{1}, message), IsOk()); + EXPECT_EQ(message.value(), 1); + EXPECT_THAT( + ProtoMessageFromValue( + IntValueView{ + static_cast(std::numeric_limits::max()) + 1}, + message), + StatusIs(absl::StatusCode::kOutOfRange)); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoInt64ValueFromValue) { + google::protobuf::Int64Value message; + EXPECT_THAT(ProtoMessageFromValue(IntValueView{1}, message), IsOk()); + EXPECT_EQ(message.value(), true); + + EXPECT_THAT(WrapArenaUnique(ProtoMessageFromValue( + IntValueView{1}, ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoUInt32ValueFromValue) { + google::protobuf::UInt32Value message; + EXPECT_THAT(ProtoMessageFromValue(UintValueView{1}, message), IsOk()); + EXPECT_EQ(message.value(), 1); + EXPECT_THAT( + ProtoMessageFromValue( + UintValueView{ + static_cast(std::numeric_limits::max()) + 1}, + message), + StatusIs(absl::StatusCode::kOutOfRange)); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoUInt64ValueFromValue) { + google::protobuf::UInt64Value message; + EXPECT_THAT(ProtoMessageFromValue(UintValueView{1}, message), IsOk()); + EXPECT_EQ(message.value(), 1); + + EXPECT_THAT(WrapArenaUnique(ProtoMessageFromValue( + UintValueView{1}, ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoFloatValueFromValue) { + google::protobuf::FloatValue message; + EXPECT_THAT(ProtoMessageFromValue(DoubleValueView{1}, message), IsOk()); + EXPECT_EQ(message.value(), 1); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoDoubleValueFromValue) { + google::protobuf::DoubleValue message; + EXPECT_THAT(ProtoMessageFromValue(DoubleValueView{1}, message), IsOk()); + EXPECT_EQ(message.value(), 1); + + EXPECT_THAT( + WrapArenaUnique(ProtoMessageFromValue( + DoubleValueView{1}, ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoBytesValueFromValue) { + google::protobuf::BytesValue message; + EXPECT_THAT(ProtoMessageFromValue(BytesValueView{"foo"}, message), IsOk()); + EXPECT_EQ(message.value(), "foo"); + + EXPECT_THAT( + WrapArenaUnique(ProtoMessageFromValue( + BytesValueView{"foo"}, ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoStringValueFromValue) { + google::protobuf::StringValue message; + EXPECT_THAT(ProtoMessageFromValue(StringValueView{"foo"}, message), IsOk()); + EXPECT_EQ(message.value(), "foo"); + + EXPECT_THAT( + WrapArenaUnique(ProtoMessageFromValue( + StringValueView{"foo"}, ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoDurationFromValue) { + google::protobuf::Duration message; + EXPECT_THAT( + ProtoMessageFromValue( + DurationValueView{absl::Seconds(1) + absl::Nanoseconds(1)}, message), + IsOk()); + EXPECT_EQ(message.seconds(), 1); + EXPECT_EQ(message.nanos(), 1); + + EXPECT_THAT(WrapArenaUnique(ProtoMessageFromValue( + DurationValueView{absl::Seconds(1) + absl::Nanoseconds(1)}, + ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoTimestampFromValue) { + google::protobuf::Timestamp message; + EXPECT_THAT(ProtoMessageFromValue( + TimestampValueView{absl::UnixEpoch() + absl::Seconds(1) + + absl::Nanoseconds(1)}, + message), + IsOk()); + EXPECT_EQ(message.seconds(), 1); + EXPECT_EQ(message.nanos(), 1); + + EXPECT_THAT(WrapArenaUnique(ProtoMessageFromValue( + TimestampValueView{absl::UnixEpoch() + absl::Seconds(1) + + absl::Nanoseconds(1)}, + ProtoMemoryManagerArena(memory_manager()))), + IsOkAndHolds(Pointee(EqualsProto(message)))); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoValueFromValue) { + google::protobuf::Value message; + EXPECT_THAT(ProtoMessageFromValue(NullValueView{}, message), IsOk()); + EXPECT_TRUE(message.has_null_value()); + EXPECT_THAT(ProtoMessageFromValue(BoolValueView{true}, message), IsOk()); + EXPECT_EQ(message.bool_value(), true); + EXPECT_THAT(ProtoMessageFromValue(DoubleValueView{1}, message), IsOk()); + EXPECT_EQ(message.number_value(), 1); + EXPECT_THAT(ProtoMessageFromValue(ListValueView{}, message), IsOk()); + EXPECT_TRUE(message.has_list_value()); + EXPECT_TRUE(message.list_value().values().empty()); + EXPECT_THAT(ProtoMessageFromValue(MapValueView{}, message), IsOk()); + EXPECT_TRUE(message.has_struct_value()); + EXPECT_TRUE(message.struct_value().fields().empty()); + + EXPECT_THAT(ProtoMessageFromValue(UnknownValueView{}, message), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_P(ProtoValueUnwrapTest, ProtoEnumFromValue) { EXPECT_THAT(ProtoEnumFromValue(NullValueView{}), IsOkAndHolds(Eq(google::protobuf::NULL_VALUE))); EXPECT_THAT( @@ -61,11 +468,10 @@ TEST_P(ProtoValueTest, ProtoEnumFromValue) { StatusIs(absl::StatusCode::kInvalidArgument)); } -INSTANTIATE_TEST_SUITE_P( - ProtoValueTest, ProtoValueTest, - ::testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - ProtoValueTest::ToString); +INSTANTIATE_TEST_SUITE_P(ProtoValueTest, ProtoValueUnwrapTest, + testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting), + ProtoValueTest::ToString); } // namespace } // namespace cel::extensions