diff --git a/base/BUILD b/base/BUILD index 0c23b859c..a4e635741 100644 --- a/base/BUILD +++ b/base/BUILD @@ -332,8 +332,10 @@ cc_library( ":function", ":function_descriptor", ":handle", + ":kind", "//base/internal:function_adapter", "//internal:status_macros", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -356,6 +358,7 @@ cc_test( "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], ) diff --git a/base/function_adapter.h b/base/function_adapter.h index 95ca93d84..2a13652a5 100644 --- a/base/function_adapter.h +++ b/base/function_adapter.h @@ -22,16 +22,20 @@ #include #include +#include +#include "absl/functional/bind_front.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/function.h" #include "base/function_descriptor.h" #include "base/handle.h" #include "base/internal/function_adapter.h" +#include "base/kind.h" #include "base/value.h" #include "internal/status_macros.h" @@ -51,7 +55,85 @@ template struct AdaptedTypeTraits { using AssignableType = const T*; - static const T& ToArg(AssignableType v) { return *ABSL_DIE_IF_NULL(v); } + static std::reference_wrapper ToArg(AssignableType v) { + return *ABSL_DIE_IF_NULL(v); // Crash OK + } +}; + +template +struct KindAdderImpl; + +template +struct KindAdderImpl { + static void AddTo(std::vector& args) { + args.push_back(AdaptedKind()); + KindAdderImpl::AddTo(args); + } +}; + +template <> +struct KindAdderImpl<> { + static void AddTo(std::vector& args) {} +}; + +template +struct KindAdder { + static std::vector Kinds() { + std::vector args; + KindAdderImpl::AddTo(args); + return args; + } +}; + +template +struct ApplyReturnType { + using type = absl::StatusOr; +}; + +template +struct ApplyReturnType> { + using type = absl::StatusOr; +}; + +template +struct IndexerImpl { + using type = typename IndexerImpl::type; +}; + +template +struct IndexerImpl<0, Arg, Args...> { + using type = Arg; +}; + +template +struct Indexer { + static_assert(N < sizeof...(Args) && N >= 0); + using type = typename IndexerImpl::type; +}; + +template +struct ApplyHelper { + template + static typename ApplyReturnType::type Apply( + Op&& op, absl::Span> input) { + constexpr int idx = sizeof...(Args) - N; + using Arg = typename Indexer::type; + using ArgTraits = internal::AdaptedTypeTraits; + typename ArgTraits::AssignableType arg_i; + CEL_RETURN_IF_ERROR(internal::HandleToAdaptedVisitor{input[idx]}(&arg_i)); + + return ApplyHelper::template Apply( + absl::bind_front(std::forward(op), ArgTraits::ToArg(arg_i)), input); + } +}; + +template +struct ApplyHelper<0, Args...> { + template + static typename ApplyReturnType::type Apply( + Op&& op, absl::Span> input) { + return op(); + } }; } // namespace internal @@ -223,6 +305,54 @@ class UnaryFunctionAdapter { }; }; +// Generic adapter class for generating CEL extension functions from an +// n-argument function. Prefer using the Binary and Unary versions. They are +// simpler and cover most use cases. +// +// See documentation for Binary Function adapter for general recommendations. +template +class VariadicFunctionAdapter { + public: + using FunctionType = std::function; + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict = true) { + return FunctionDescriptor(name, receiver_style, + internal::KindAdder::Kinds(), is_strict); + } + + private: + class VariadicFunctionImpl : public cel::Function { + public: + explicit VariadicFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + + absl::StatusOr> Invoke( + const FunctionEvaluationContext& context, + absl::Span> args) const override { + if (args.size() != sizeof...(Args)) { + return absl::InvalidArgumentError( + absl::StrCat("unexpected number of arguments for variadic(", + sizeof...(Args), ") function")); + } + + CEL_ASSIGN_OR_RETURN( + T result, + (internal::ApplyHelper::template Apply( + absl::bind_front(fn_, std::ref(context.value_factory())), args))); + return internal::AdaptedToHandleVisitor{context.value_factory()}( + std::move(result)); + } + + private: + FunctionType fn_; + }; +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ diff --git a/base/function_adapter_test.cc b/base/function_adapter_test.cc index 124e18999..3929533e1 100644 --- a/base/function_adapter_test.cc +++ b/base/function_adapter_test.cc @@ -20,6 +20,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "base/function.h" #include "base/function_descriptor.h" @@ -27,11 +28,14 @@ #include "base/kind.h" #include "base/memory.h" #include "base/type_factory.h" +#include "base/type_manager.h" #include "base/type_provider.h" +#include "base/value.h" #include "base/value_factory.h" #include "base/values/bool_value.h" #include "base/values/bytes_value.h" #include "base/values/double_value.h" +#include "base/values/duration_value.h" #include "base/values/int_value.h" #include "base/values/timestamp_value.h" #include "base/values/uint_value.h" @@ -42,6 +46,7 @@ namespace { using testing::ElementsAre; using testing::HasSubstr; +using testing::IsEmpty; using cel::internal::StatusIs; class FunctionAdapterTest : public ::testing::Test { @@ -719,5 +724,103 @@ TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorNonStrict) { EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny, Kind::kAny)); } +TEST_F(FunctionAdapterTest, VariadicFunctionAdapterCreateDescriptor0Args) { + FunctionDescriptor desc = + VariadicFunctionAdapter>>::CreateDescriptor( + "ZeroArgs", false); + + EXPECT_EQ(desc.name(), "ZeroArgs"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), IsEmpty()); +} + +TEST_F(FunctionAdapterTest, VariadicFunctionAdapterWrapFunction0Args) { + std::unique_ptr fn = + VariadicFunctionAdapter>>::WrapFunction( + [](ValueFactory& value_factory) { + return value_factory.CreateStringValue("abc"); + }); + + ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(test_context(), {})); + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.As()->ToString(), "abc"); +} + +TEST_F(FunctionAdapterTest, VariadicFunctionAdapterCreateDescriptor3Args) { + FunctionDescriptor desc = VariadicFunctionAdapter< + absl::StatusOr>, int64_t, bool, + const StringValue&>::CreateDescriptor("MyFormatter", false); + + EXPECT_EQ(desc.name(), "MyFormatter"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), + ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString)); +} + +TEST_F(FunctionAdapterTest, VariadicFunctionAdapterWrapFunction3Args) { + std::unique_ptr fn = VariadicFunctionAdapter< + absl::StatusOr>, int64_t, bool, + const StringValue&>::WrapFunction([](ValueFactory& value_factory, + int64_t int_val, bool bool_val, + const StringValue& string_val) + -> absl::StatusOr> { + return value_factory.CreateStringValue( + absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", + string_val.ToString())); + }); + + std::vector> args{value_factory().CreateIntValue(42), + value_factory().CreateBoolValue(false)}; + ASSERT_OK_AND_ASSIGN(args.emplace_back(), + value_factory().CreateStringValue("abcd")); + ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(test_context(), args)); + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.As()->ToString(), "42_false_abcd"); +} + +TEST_F(FunctionAdapterTest, + VariadicFunctionAdapterWrapFunction3ArgsBadArgType) { + std::unique_ptr fn = VariadicFunctionAdapter< + absl::StatusOr>, int64_t, bool, + const StringValue&>::WrapFunction([](ValueFactory& value_factory, + int64_t int_val, bool bool_val, + const StringValue& string_val) + -> absl::StatusOr> { + return value_factory.CreateStringValue( + absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", + string_val.ToString())); + }); + + std::vector> args{value_factory().CreateIntValue(42), + value_factory().CreateBoolValue(false)}; + ASSERT_OK_AND_ASSIGN(args.emplace_back(), + value_factory().CreateTimestampValue(absl::UnixEpoch())); + EXPECT_THAT(fn->Invoke(test_context(), args), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected string value"))); +} + +TEST_F(FunctionAdapterTest, + VariadicFunctionAdapterWrapFunction3ArgsBadArgCount) { + std::unique_ptr fn = VariadicFunctionAdapter< + absl::StatusOr>, int64_t, bool, + const StringValue&>::WrapFunction([](ValueFactory& value_factory, + int64_t int_val, bool bool_val, + const StringValue& string_val) + -> absl::StatusOr> { + return value_factory.CreateStringValue( + absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", + string_val.ToString())); + }); + + std::vector> args{value_factory().CreateIntValue(42), + value_factory().CreateBoolValue(false)}; + EXPECT_THAT(fn->Invoke(test_context(), args), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of arguments"))); +} + } // namespace } // namespace cel diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 212a52fc6..97a3b3626 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -268,3 +268,40 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "bind_proto_to_activation", + srcs = ["bind_proto_to_activation.cc"], + hdrs = ["bind_proto_to_activation.h"], + deps = [ + ":data", + "//base:data", + "//base:handle", + "//internal:status_macros", + "//runtime:activation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "bind_proto_to_activation_test", + srcs = ["bind_proto_to_activation_test.cc"], + deps = [ + ":bind_proto_to_activation", + ":data", + ":memory_manager", + "//base:data", + "//base:handle", + "//base:memory", + "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/extensions/protobuf/bind_proto_to_activation.cc b/extensions/protobuf/bind_proto_to_activation.cc new file mode 100644 index 000000000..a33d070b1 --- /dev/null +++ b/extensions/protobuf/bind_proto_to_activation.cc @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/bind_proto_to_activation.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "base/handle.h" +#include "base/types/struct_type.h" +#include "base/value.h" +#include "base/value_factory.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "runtime/activation.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +namespace { + +absl::StatusOr ShouldBindField( + const google::protobuf::FieldDescriptor* field_desc, const StructValue& struct_value, + BindProtoUnsetFieldBehavior unset_field_behavior, + ValueFactory& value_factory) { + if (unset_field_behavior == BindProtoUnsetFieldBehavior::kBindDefaultValue || + field_desc->is_repeated()) { + return true; + } + return struct_value.HasFieldByNumber(value_factory.type_manager(), + field_desc->number()); +} + +} // namespace + +absl::Status BindProtoToActivation( + const google::protobuf::Message& context, ValueFactory& value_factory, + Activation& activation, BindProtoUnsetFieldBehavior unset_field_behavior) { + CEL_ASSIGN_OR_RETURN(Handle parent, + ProtoValue::Create(value_factory, context)); + + if (!parent->Is()) { + return absl::InvalidArgumentError( + absl::StrCat("context is a well-known type: ", context.GetTypeName())); + } + + const google::protobuf::Descriptor* desc = context.GetDescriptor(); + + if (desc == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("context missing descriptor: ", context.GetTypeName())); + } + const StructValue& struct_value = parent->As(); + for (int i = 0; i < desc->field_count(); i++) { + const google::protobuf::FieldDescriptor* field_desc = desc->field(i); + CEL_ASSIGN_OR_RETURN(bool should_bind, + ShouldBindField(field_desc, struct_value, + unset_field_behavior, value_factory)); + if (!should_bind) { + continue; + } + + CEL_ASSIGN_OR_RETURN( + Handle field, + struct_value.GetFieldByNumber(value_factory, field_desc->number())); + + activation.InsertOrAssignValue(field_desc->name(), std::move(field)); + } + + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/protobuf/bind_proto_to_activation.h b/extensions/protobuf/bind_proto_to_activation.h new file mode 100644 index 000000000..a73d38899 --- /dev/null +++ b/extensions/protobuf/bind_proto_to_activation.h @@ -0,0 +1,77 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ + +#include "absl/status/status.h" +#include "base/value_factory.h" +#include "runtime/activation.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +// Option for handling unset fields on the context proto. +enum class BindProtoUnsetFieldBehavior { + // Bind the message defined default or zero value. + kBindDefaultValue, + // Skip binding unset fields, no value is bound for the corresponding + // variable. + kSkip +}; + +// Utility method, that takes a protobuf Message and interprets it as a +// namespace, binding its fields to Activation. This is often referred to as a +// context message. +// +// Field names and values become respective names and values of parameters +// bound to the Activation object. +// Example: +// Assume we have a protobuf message of type: +// message Person { +// int age = 1; +// string name = 2; +// } +// +// The sample code snippet will look as follows: +// +// Person person; +// person.set_name("John Doe"); +// person.age(42); +// +// CEL_RETURN_IF_ERROR(BindProtoToActivation(person, value_factory, +// activation)); +// +// After this snippet, activation will have two parameters bound: +// "name", with string value of "John Doe" +// "age", with int value of 42. +// +// The default behavior for unset fields is to skip them. E.g. if the name field +// is not set on the Person message, it will not be bound in to the activation. +// BindProtoUnsetFieldBehavior::kBindDefault, will bind the cc proto api default +// for the field (either an explicit default value or a type specific default). +// +// For repeated fields, an unset field is bound as an empty list. +// +// The input message is not copied, it must remain valid as long as the +// activation. +absl::Status BindProtoToActivation( + const google::protobuf::Message& context, ValueFactory& value_factory, + Activation& activation, + BindProtoUnsetFieldBehavior unset_field_behavior = + BindProtoUnsetFieldBehavior::kSkip); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ diff --git a/extensions/protobuf/bind_proto_to_activation_test.cc b/extensions/protobuf/bind_proto_to_activation_test.cc new file mode 100644 index 000000000..68b59db09 --- /dev/null +++ b/extensions/protobuf/bind_proto_to_activation_test.cc @@ -0,0 +1,248 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/bind_proto_to_activation.h" + +#include "google/protobuf/wrappers.pb.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "base/handle.h" +#include "base/memory.h" +#include "base/value.h" +#include "base/values/int_value.h" +#include "base/values/list_value.h" +#include "base/values/map_value.h" +#include "extensions/protobuf/memory_manager.h" +#include "extensions/protobuf/type_provider.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" +#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::google::api::expr::test::v1::proto2::TestAllTypes; +using testing::HasSubstr; +using testing::Optional; +using cel::internal::IsOkAndHolds; +using cel::internal::StatusIs; + +enum class MemoryManagerOption { kGlobal, kArena }; + +class BindProtoToActivationTest + : public ::testing::TestWithParam { + public: + BindProtoToActivationTest() : proto_memory_manager_(&arena_) {} + cel::MemoryManager& memory_manager() { + return GetParam() == MemoryManagerOption::kGlobal ? MemoryManager::Global() + : proto_memory_manager_; + } + + private: + google::protobuf::Arena arena_; + ProtoMemoryManager proto_memory_manager_; +}; + +MATCHER_P(IsIntValue, value, "") { + const Handle& handle = arg; + + return handle->Is() && handle->As().value() == value; +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivation) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "single_int64"), + + IsOkAndHolds(Optional(IsIntValue(123)))); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationWktUnsupported) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + google::protobuf::Int64Value int64_value; + int64_value.set_value(123); + Activation activation; + + EXPECT_THAT( + BindProtoToActivation(int64_value, value_factory.get(), activation), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("google.protobuf.Int64Value"))); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationSkip) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_OK(BindProtoToActivation(test_all_types, value_factory.get(), + activation, + BindProtoUnsetFieldBehavior::kSkip)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "single_int32"), + + IsOkAndHolds(absl::nullopt)); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationDefault) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation, + BindProtoUnsetFieldBehavior::kBindDefaultValue)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "single_int32"), + + IsOkAndHolds(Optional(IsIntValue(-32)))); +} + +MATCHER_P(IsListValueOfSize, size, "") { + const Handle& handle = arg; + + return handle->Is() && handle->As().size() == size; +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationRepeated) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + test_all_types.add_repeated_int64(123); + test_all_types.add_repeated_int64(456); + test_all_types.add_repeated_int64(789); + + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "repeated_int64"), + IsOkAndHolds(Optional(IsListValueOfSize(3)))); + ASSERT_OK_AND_ASSIGN( + auto variable, + activation.FindVariable(value_factory.get(), "repeated_int64")); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationRepeatedEmpty) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "repeated_int32"), + IsOkAndHolds(Optional(IsListValueOfSize(0)))); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationRepeatedComplex) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + auto* nested = test_all_types.add_repeated_nested_message(); + nested->set_bb(123); + nested = test_all_types.add_repeated_nested_message(); + nested->set_bb(456); + nested = test_all_types.add_repeated_nested_message(); + nested->set_bb(789); + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT( + activation.FindVariable(value_factory.get(), "repeated_nested_message"), + + IsOkAndHolds(Optional(IsListValueOfSize(3)))); +} + +MATCHER_P(IsMapValueOfSize, size, "") { + const Handle& handle = arg; + + return handle->Is() && handle->As().size() == size; +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationMap) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + (*test_all_types.mutable_map_int64_int64())[1] = 2; + (*test_all_types.mutable_map_int64_int64())[2] = 4; + + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "map_int64_int64"), + + IsOkAndHolds(Optional(IsMapValueOfSize(2)))); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationMapEmpty) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "map_int32_int32"), + + IsOkAndHolds(Optional(IsMapValueOfSize(0)))); +} + +TEST_P(BindProtoToActivationTest, BindProtoToActivationMapComplex) { + ProtoTypeProvider provider; + ManagedValueFactory value_factory(provider, memory_manager()); + TestAllTypes test_all_types; + TestAllTypes::NestedMessage value; + value.set_bb(42); + (*test_all_types.mutable_map_int64_message())[1] = value; + (*test_all_types.mutable_map_int64_message())[2] = value; + + Activation activation; + + ASSERT_OK( + BindProtoToActivation(test_all_types, value_factory.get(), activation)); + + EXPECT_THAT(activation.FindVariable(value_factory.get(), "map_int64_message"), + + IsOkAndHolds(Optional(IsMapValueOfSize(2)))); +} + +INSTANTIATE_TEST_SUITE_P(Runner, BindProtoToActivationTest, + ::testing::Values(MemoryManagerOption::kGlobal, + MemoryManagerOption::kArena)); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 9e85c1be7..95e661157 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -2291,7 +2291,8 @@ absl::StatusOr> ParsedProtoStructValue::GetFieldByName( auto field_type, type()->FindFieldByName(value_factory.type_manager(), name)); if (ABSL_PREDICT_FALSE(!field_type)) { - return runtime_internal::CreateNoSuchFieldError(name); + return value_factory.CreateErrorValue( + runtime_internal::CreateNoSuchFieldError(name)); } return GetField(value_factory, *field_type); } @@ -2302,7 +2303,8 @@ absl::StatusOr> ParsedProtoStructValue::GetFieldByNumber( auto field_type, type()->FindFieldByNumber(value_factory.type_manager(), number)); if (ABSL_PREDICT_FALSE(!field_type)) { - return runtime_internal::CreateNoSuchFieldError(absl::StrCat(number)); + return value_factory.CreateErrorValue( + runtime_internal::CreateNoSuchFieldError(absl::StrCat(number))); } return GetField(value_factory, *field_type); } diff --git a/extensions/protobuf/type_provider_end_to_end_test.cc b/extensions/protobuf/type_provider_end_to_end_test.cc index e63b984a1..a61c75d39 100644 --- a/extensions/protobuf/type_provider_end_to_end_test.cc +++ b/extensions/protobuf/type_provider_end_to_end_test.cc @@ -317,6 +317,13 @@ INSTANTIATE_TEST_SUITE_P( )cel", absl::InvalidArgumentError( "type conversion error from int to string")}, + {"no_such_field", + R"cel( + TestAllTypes{ + single_int64: 32 + }.unknown_field + )cel", + absl::NotFoundError("no_such_field : unknown_field")}, }), ::testing::Values(MemoryManagerKind::kGlobal, MemoryManagerKind::kProto)),