diff --git a/common/expr_factory.h b/common/expr_factory.h index fd483bc5e..14ce271be 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -330,6 +330,35 @@ class ExprFactory { return expr; } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewComprehension(ExprId id, IterVar iter_var, IterVar2 iter_var2, + IterRange iter_range, AccuVar accu_var, + AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + Expr expr; + expr.set_id(id); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var(std::move(iter_var)); + comprehension_expr.set_iter_var2(std::move(iter_var2)); + comprehension_expr.set_iter_range(std::move(iter_range)); + comprehension_expr.set_accu_var(std::move(accu_var)); + comprehension_expr.set_accu_init(std::move(accu_init)); + comprehension_expr.set_loop_condition(std::move(loop_condition)); + comprehension_expr.set_loop_step(std::move(loop_step)); + comprehension_expr.set_result(std::move(result)); + return expr; + } + private: friend class MacroExprFactory; friend class ParserMacroExprFactory; diff --git a/common/values/error_value.h b/common/values/error_value.h index 577675776..02380d575 100644 --- a/common/values/error_value.h +++ b/common/values/error_value.h @@ -156,6 +156,15 @@ bool IsNoSuchField(const ErrorValue& value); bool IsNoSuchKey(const ErrorValue& value); +class ErrorValueReturn final { + public: + ErrorValueReturn() = default; + + ErrorValue operator()(absl::Status status) const { + return ErrorValue(std::move(status)); + } +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ diff --git a/conformance/BUILD b/conformance/BUILD index aca4c2795..fe46e8c0a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -72,6 +72,8 @@ cc_library( "//eval/public:cel_value", "//eval/public:transform_utility", "//extensions:bindings_ext", + "//extensions:comprehensions_v2_functions", + "//extensions:comprehensions_v2_macros", "//extensions:encoders", "//extensions:math_ext", "//extensions:math_ext_macros", diff --git a/conformance/service.cc b/conformance/service.cc index d2a8bffed..278462bc8 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -60,6 +60,8 @@ #include "eval/public/cel_value.h" #include "eval/public/transform_utility.h" #include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2_functions.h" +#include "extensions/comprehensions_v2_macros.h" #include "extensions/encoders.h" #include "extensions/math_ext.h" #include "extensions/math_ext_macros.h" @@ -255,6 +257,8 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, options.enable_optional_syntax = enable_optional_syntax; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); + CEL_RETURN_IF_ERROR( + cel::extensions::RegisterComprehensionsV2Macros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); @@ -333,6 +337,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( + builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions( @@ -503,6 +509,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { type_registry, cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor())); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( + builder.function_registry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::EnableOptionalTypes(builder)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( builder.function_registry(), options)); diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index fcb4d5c44..435660c5a 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -323,12 +323,45 @@ const cel::ast_internal::Expr* GetOptimizableListAppendOperand( return &GetOptimizableListAppendCall(comprehension)->args()[1]; } +bool IsOptimizableMapInsert( + const cel::ast_internal::Comprehension* comprehension) { + if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) { + return false; + } + absl::string_view accu_var = comprehension->accu_var(); + if (accu_var.empty() || !comprehension->has_result() || + !comprehension->result().has_ident_expr() || + comprehension->result().ident_expr().name() != accu_var) { + return false; + } + if (!comprehension->accu_init().has_map_expr()) { + return false; + } + if (!comprehension->loop_step().has_call_expr()) { + return false; + } + const auto* call_expr = &comprehension->loop_step().call_expr(); + + if (call_expr->function() == cel::builtin::kTernary && + call_expr->args().size() == 3) { + if (!call_expr->args()[1].has_call_expr()) { + return false; + } + call_expr = &(call_expr->args()[1].call_expr()); + } + return call_expr->function() == "cel.@mapInsert" && + call_expr->args().size() == 3 && + call_expr->args()[0].has_ident_expr() && + call_expr->args()[0].ident_expr().name() == accu_var; +} + bool IsBind(const cel::ast_internal::Comprehension* comprehension) { static constexpr absl::string_view kUnusedIterVar = "#unused"; return comprehension->loop_condition().const_expr().has_bool_value() && comprehension->loop_condition().const_expr().bool_value() == false && comprehension->iter_var() == kUnusedIterVar && + comprehension->iter_var2().empty() && comprehension->iter_range().has_list_expr() && comprehension->iter_range().list_expr().elements().empty(); } @@ -342,7 +375,7 @@ class ComprehensionVisitor { public: explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting, bool is_trivial, size_t iter_slot, - size_t accu_slot) + size_t iter2_slot, size_t accu_slot) : visitor_(visitor), next_step_(nullptr), cond_step_(nullptr), @@ -350,6 +383,7 @@ class ComprehensionVisitor { is_trivial_(is_trivial), accu_init_extracted_(false), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot) {} void PreVisit(const cel::ast_internal::Expr* expr); @@ -383,6 +417,7 @@ class ComprehensionVisitor { bool is_trivial_; bool accu_init_extracted_; size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; }; @@ -599,6 +634,10 @@ class FlatExprVisitor : public cel::AstVisitor { } return {static_cast(record.iter_slot), -1}; } + if (record.iter_var2_in_scope && + record.comprehension->iter_var2() == path) { + return {static_cast(record.iter2_slot), -1}; + } if (record.accu_var_in_scope && record.comprehension->accu_var() == path) { int slot = record.accu_slot; @@ -1083,7 +1122,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::ast_internal::Expr* expr, const cel::ast_internal::Comprehension* comprehension, size_t iter_slot, - size_t accu_slot) { + size_t iter2_slot, size_t accu_slot) { if (options_.max_recursion_depth == 0) { return; } @@ -1136,7 +1175,8 @@ class FlatExprVisitor : public cel::AstVisitor { } auto step = CreateDirectComprehensionStep( - iter_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, + iter_slot, iter2_slot, accu_slot, + range_plan->ExtractRecursiveProgram().step, accu_plan->ExtractRecursiveProgram().step, loop_plan->ExtractRecursiveProgram().step, condition_plan->ExtractRecursiveProgram().step, @@ -1247,6 +1287,7 @@ class FlatExprVisitor : public cel::AstVisitor { } const auto& accu_var = comprehension.accu_var(); const auto& iter_var = comprehension.iter_var(); + const auto& iter_var2 = comprehension.iter_var2(); ValidateOrError(!accu_var.empty(), "Invalid comprehension: 'accu_var' must not be empty"); ValidateOrError(!iter_var.empty(), @@ -1254,6 +1295,12 @@ class FlatExprVisitor : public cel::AstVisitor { ValidateOrError( accu_var != iter_var, "Invalid comprehension: 'accu_var' must not be the same as 'iter_var'"); + ValidateOrError(accu_var != iter_var2, + "Invalid comprehension: 'accu_var' must not be the same as " + "'iter_var2'"); + ValidateOrError(iter_var2 != iter_var, + "Invalid comprehension: 'iter_var2' must not be the same " + "as 'iter_var'"); ValidateOrError(comprehension.has_accu_init(), "Invalid comprehension: 'accu_init' must be set"); ValidateOrError(comprehension.has_loop_condition(), @@ -1263,15 +1310,20 @@ class FlatExprVisitor : public cel::AstVisitor { ValidateOrError(comprehension.has_result(), "Invalid comprehension: 'result' must be set"); - size_t iter_slot, accu_slot, slot_count; + size_t iter_slot, iter2_slot, accu_slot, slot_count; bool is_bind = IsBind(&comprehension); if (is_bind) { - accu_slot = iter_slot = index_manager_.ReserveSlots(1); + accu_slot = iter_slot = iter2_slot = index_manager_.ReserveSlots(1); slot_count = 1; - } else { - iter_slot = index_manager_.ReserveSlots(2); + } else if (comprehension.iter_var2().empty()) { + iter_slot = iter2_slot = index_manager_.ReserveSlots(2); accu_slot = iter_slot + 1; slot_count = 2; + } else { + iter_slot = index_manager_.ReserveSlots(3); + iter2_slot = iter_slot + 1; + accu_slot = iter2_slot + 1; + slot_count = 3; } // If this is in the scope of an optimized bind accu-init, account the slots // to the outermost bind-init scope. @@ -1289,16 +1341,20 @@ class FlatExprVisitor : public cel::AstVisitor { } comprehension_stack_.push_back( - {&expr, &comprehension, iter_slot, accu_slot, slot_count, + {&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count, /*subexpression=*/-1, + /*.is_optimizable_list_append=*/ IsOptimizableListAppend(&comprehension, options_.enable_comprehension_list_append), - is_bind, + /*.is_optimizable_map_insert=*/IsOptimizableMapInsert(&comprehension), + /*.is_optimizable_bind=*/is_bind, /*.iter_var_in_scope=*/false, + /*.iter_var2_in_scope=*/false, /*.accu_var_in_scope=*/false, /*.in_accu_init=*/false, - std::make_unique( - this, options_.short_circuiting, is_bind, iter_slot, accu_slot)}); + std::make_unique(this, options_.short_circuiting, + is_bind, iter_slot, iter2_slot, + accu_slot)}); comprehension_stack_.back().visitor->PreVisit(&expr); } @@ -1341,30 +1397,35 @@ class FlatExprVisitor : public cel::AstVisitor { case cel::ITER_RANGE: { record.in_accu_init = false; record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; record.accu_var_in_scope = false; break; } case cel::ACCU_INIT: { record.in_accu_init = true; record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; record.accu_var_in_scope = false; break; } case cel::LOOP_CONDITION: { record.in_accu_init = false; record.iter_var_in_scope = true; + record.iter_var2_in_scope = true; record.accu_var_in_scope = true; break; } case cel::LOOP_STEP: { record.in_accu_init = false; record.iter_var_in_scope = true; + record.iter_var2_in_scope = true; record.accu_var_in_scope = true; break; } case cel::RESULT: { record.in_accu_init = false; record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; record.accu_var_in_scope = true; break; } @@ -1468,6 +1529,21 @@ class FlatExprVisitor : public cel::AstVisitor { return; } + if (!comprehension_stack_.empty()) { + const ComprehensionStackRecord& comprehension = + comprehension_stack_.back(); + if (comprehension.is_optimizable_map_insert) { + if (&(comprehension.comprehension->accu_init()) == &expr) { + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); + return; + } + AddStep(CreateMutableMapStep(expr.id())); + return; + } + } + } + auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { @@ -1672,13 +1748,16 @@ class FlatExprVisitor : public cel::AstVisitor { const cel::ast_internal::Expr* expr; const cel::ast_internal::Comprehension* comprehension; size_t iter_slot; + size_t iter2_slot; size_t accu_slot; size_t slot_count; // -1 indicates this shouldn't be used. int subexpression; bool is_optimizable_list_append; + bool is_optimizable_map_insert; bool is_optimizable_bind; bool iter_var_in_scope; + bool iter_var2_in_scope; bool accu_var_in_scope; bool in_accu_init; std::unique_ptr visitor; @@ -2011,20 +2090,24 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( case cel::ITER_RANGE: { // post process iter_range to list its keys if it's a map // and initialize the loop index. - visitor_->AddStep(CreateComprehensionInitStep(expr->id())); + if (iter_slot_ == iter2_slot_) { + visitor_->AddStep(CreateComprehensionInitStep(expr->id())); + } else { + visitor_->AddStep(CreateComprehensionInitStep2(expr->id())); + } break; } case cel::ACCU_INIT: { next_step_pos_ = visitor_->GetCurrentIndex(); - next_step_ = - new ComprehensionNextStep(iter_slot_, accu_slot_, expr->id()); + next_step_ = new ComprehensionNextStep(iter_slot_, iter2_slot_, + accu_slot_, expr->id()); visitor_->AddStep(std::unique_ptr(next_step_)); break; } case cel::LOOP_CONDITION: { cond_step_pos_ = visitor_->GetCurrentIndex(); - cond_step_ = new ComprehensionCondStep(iter_slot_, accu_slot_, - short_circuiting_, expr->id()); + cond_step_ = new ComprehensionCondStep( + iter_slot_, iter2_slot_, accu_slot_, short_circuiting_, expr->id()); visitor_->AddStep(std::unique_ptr(cond_step_)); break; } @@ -2049,7 +2132,13 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( break; } case cel::RESULT: { - visitor_->AddStep(CreateComprehensionFinishStep(accu_slot_, expr->id())); + if (iter_slot_ == iter2_slot_) { + visitor_->AddStep( + CreateComprehensionFinishStep(accu_slot_, expr->id())); + } else { + visitor_->AddStep( + CreateComprehensionFinishStep2(accu_slot_, expr->id())); + } CEL_ASSIGN_OR_RETURN( int jump_from_next, @@ -2097,8 +2186,8 @@ void ComprehensionVisitor::PostVisit(const cel::ast_internal::Expr* expr) { accu_slot_); return; } - visitor_->MaybeMakeComprehensionRecursive(expr, &expr->comprehension_expr(), - iter_slot_, accu_slot_); + visitor_->MaybeMakeComprehensionRecursive( + expr, &expr->comprehension_expr(), iter_slot_, iter2_slot_, accu_slot_); } // Flattens the expression table into the end of the mainline expression vector diff --git a/eval/eval/BUILD b/eval/eval/BUILD index d7769f22f..1e39efd9b 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -476,10 +476,12 @@ cc_library( "//eval/internal:errors", "//eval/public:cel_attribute", "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 75e723e17..c34121b09 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -5,10 +5,12 @@ #include #include +#include "absl/base/attributes.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "base/kind.h" @@ -27,6 +29,7 @@ namespace google::api::expr::runtime { namespace { +using ::cel::AttributeQualifier; using ::cel::BoolValue; using ::cel::Cast; using ::cel::InstanceOf; @@ -35,8 +38,25 @@ using ::cel::ListValue; using ::cel::MapValue; using ::cel::UnknownValue; using ::cel::Value; +using ::cel::ValueKind; using ::cel::runtime_internal::CreateNoMatchingOverloadError; +AttributeQualifier AttributeQualifierFromValue(const Value& v) { + switch (v->kind()) { + case ValueKind::kString: + return AttributeQualifier::OfString(v.GetString().ToString()); + case ValueKind::kInt64: + return AttributeQualifier::OfInt(v.GetInt().NativeValue()); + case ValueKind::kUint64: + return AttributeQualifier::OfUint(v.GetUint().NativeValue()); + case ValueKind::kBool: + return AttributeQualifier::OfBool(v.GetBool().NativeValue()); + default: + // Non-matching qualifier. + return AttributeQualifier(); + } +} + class ComprehensionFinish : public ExpressionStepBase { public: ComprehensionFinish(size_t accu_slot, int64_t expr_id); @@ -65,6 +85,30 @@ absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class ComprehensionFinish2 final : public ExpressionStepBase { + public: + ComprehensionFinish2(size_t accu_slot, int64_t expr_id) + : ExpressionStepBase(expr_id), accu_slot_(accu_slot) {} + + // Stack changes of ComprehensionFinish. + // + // Stack size before: 4. + // Stack size after: 1. + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(4)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + Value result = frame->value_stack().Peek(); + frame->value_stack().Pop(4); + frame->value_stack().Push(std::move(result)); + frame->comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); + } + + private: + size_t accu_slot_; +}; + class ComprehensionInitStep : public ExpressionStepBase { public: explicit ComprehensionInitStep(int64_t expr_id) @@ -124,10 +168,49 @@ absl::Status ComprehensionInitStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class ComprehensionInitStep2 final : public ExpressionStepBase { + public: + explicit ComprehensionInitStep2(int64_t expr_id) + : ExpressionStepBase(expr_id, false) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(1)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + + const auto& range = frame->value_stack().Peek(); + switch (range.kind()) { + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN( + Value keys, ProjectKeysImpl(*frame, range.GetMap(), + frame->value_stack().PeekAttribute())); + frame->value_stack().Push(std::move(keys)); + } break; + case ValueKind::kList: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + frame->value_stack().Push(range); + break; + default: + frame->value_stack().PopAndPush(frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + break; + } + + // Initialize current index. + // Error handling for wrong range type is deferred until the 'Next' step + // to simplify the number of jumps. + frame->value_stack().Push(frame->value_factory().CreateIntValue(-1)); + return absl::OkStatus(); + } +}; + class ComprehensionDirectStep : public DirectExpressionStep { public: explicit ComprehensionDirectStep( - size_t iter_slot, size_t accu_slot, + size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, @@ -136,6 +219,7 @@ class ComprehensionDirectStep : public DirectExpressionStep { int64_t expr_id) : DirectExpressionStep(expr_id), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot), range_(std::move(range)), accu_init_(std::move(accu_init)), @@ -143,11 +227,22 @@ class ComprehensionDirectStep : public DirectExpressionStep { condition_(std::move(condition_step)), result_step_(std::move(result_step)), shortcircuiting_(shortcircuiting) {} + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, - AttributeTrail& trail) const override; + AttributeTrail& trail) const final { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame, result, trail) + : Evaluate2(frame, result, trail); + } private: + absl::Status Evaluate1(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const; + + absl::Status Evaluate2(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const; + size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; std::unique_ptr range_; std::unique_ptr accu_init_; @@ -158,9 +253,9 @@ class ComprehensionDirectStep : public DirectExpressionStep { bool shortcircuiting_; }; -absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, - Value& result, - AttributeTrail& trail) const { +absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { cel::Value range; AttributeTrail range_attr; CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); @@ -257,6 +352,180 @@ absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, return absl::OkStatus(); } +absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + cel::Value iter2_range; + AttributeTrail range_attr; + CEL_RETURN_IF_ERROR(range_->Evaluate(frame, iter2_range, range_attr)); + + absl::optional iter2_range_map; + cel::Value iter_range; + if (iter2_range.IsMap()) { + iter2_range_map = iter2_range.GetMap(); + CEL_ASSIGN_OR_RETURN(iter_range, + ProjectKeysImpl(frame, *iter2_range_map, range_attr)); + } else { + iter_range = iter2_range; + } + + switch (iter_range.kind()) { + case cel::ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case cel::ValueKind::kUnknown: + result = iter_range; + return absl::OkStatus(); + case cel::ValueKind::kList: + break; + default: + result = cel::ErrorValue(CreateNoMatchingOverloadError("")); + return absl::OkStatus(); + } + + const auto& iter_range_list = iter_range.GetList(); + + Value accu_init; + AttributeTrail accu_init_attr; + CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); + + frame.comprehension_slots().Set(accu_slot_, std::move(accu_init), + accu_init_attr); + ComprehensionSlots::Slot* accu_slot = + frame.comprehension_slots().Get(accu_slot_); + ABSL_DCHECK(accu_slot != nullptr); + + frame.comprehension_slots().Set(iter_slot_); + ComprehensionSlots::Slot* iter_slot = + frame.comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + + frame.comprehension_slots().Set(iter2_slot_); + ComprehensionSlots::Slot* iter2_slot = + frame.comprehension_slots().Get(iter2_slot_); + ABSL_DCHECK(iter2_slot != nullptr); + + Value condition; + AttributeTrail condition_attr; + bool should_skip_result = false; + if (iter2_range_map) { + CEL_RETURN_IF_ERROR(iter2_range_map->ForEach( + frame.value_manager(), + [&](const Value& k, const Value& v) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + // Evaluate loop condition first. + CEL_RETURN_IF_ERROR( + condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.kind() == cel::ValueKind::kError || + condition.kind() == cel::ValueKind::kUnknown) { + result = std::move(condition); + should_skip_result = true; + return false; + } + if (condition.kind() != cel::ValueKind::kBool) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + should_skip_result = true; + return false; + } + if (shortcircuiting_ && !Cast(condition).NativeValue()) { + return false; + } + + iter_slot->value = k; + if (frame.unknown_processing_enabled()) { + iter_slot->attribute = + range_attr.Step(AttributeQualifierFromValue(k)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute)) { + iter_slot->value = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute.attribute()); + } + } + + iter2_slot->value = v; + if (frame.unknown_processing_enabled()) { + iter2_slot->attribute = + range_attr.Step(AttributeQualifierFromValue(v)); + if (frame.attribute_utility().CheckForUnknownExact( + iter2_slot->attribute)) { + iter2_slot->value = frame.attribute_utility().CreateUnknownSet( + iter2_slot->attribute.attribute()); + } + } + + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, + accu_slot->attribute)); + + return true; + })); + } else { + CEL_RETURN_IF_ERROR(iter_range_list.ForEach( + frame.value_manager(), + [&](size_t index, const Value& v) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + // Evaluate loop condition first. + CEL_RETURN_IF_ERROR( + condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.kind() == cel::ValueKind::kError || + condition.kind() == cel::ValueKind::kUnknown) { + result = std::move(condition); + should_skip_result = true; + return false; + } + if (condition.kind() != cel::ValueKind::kBool) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + should_skip_result = true; + return false; + } + if (shortcircuiting_ && !Cast(condition).NativeValue()) { + return false; + } + + iter_slot->value = IntValue(index); + if (frame.unknown_processing_enabled()) { + iter_slot->attribute = + range_attr.Step(CelAttributeQualifier::OfInt(index)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute)) { + iter_slot->value = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute.attribute()); + } + } + + iter2_slot->value = v; + if (frame.unknown_processing_enabled()) { + iter2_slot->attribute = + range_attr.Step(AttributeQualifierFromValue(v)); + if (frame.attribute_utility().CheckForUnknownExact( + iter2_slot->attribute)) { + iter2_slot->value = frame.attribute_utility().CreateUnknownSet( + iter2_slot->attribute.attribute()); + } + } + + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, + accu_slot->attribute)); + + return true; + })); + } + + frame.comprehension_slots().ClearSlot(iter_slot_); + frame.comprehension_slots().ClearSlot(iter2_slot_); + // Error state is already set to the return value, just clean up. + if (should_skip_result) { + frame.comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); + frame.comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); +} + } // namespace // Stack variables during comprehension evaluation: @@ -276,10 +545,12 @@ absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, // 8. result (dep) 2 -> 3 // 9. ComprehensionFinish 3 -> 1 -ComprehensionNextStep::ComprehensionNextStep(size_t iter_slot, size_t accu_slot, - int64_t expr_id) +ComprehensionNextStep::ComprehensionNextStep(size_t iter_slot, + size_t iter2_slot, + size_t accu_slot, int64_t expr_id) : ExpressionStepBase(expr_id, false), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot) {} void ComprehensionNextStep::set_jump_offset(int offset) { @@ -307,7 +578,7 @@ void ComprehensionNextStep::set_error_jump_offset(int offset) { // // Stack on error: // 0. error -absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionNextStep::Evaluate1(ExecutionFrame* frame) const { enum { POS_ITER_RANGE, POS_CURRENT_INDEX, @@ -386,11 +657,148 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } -ComprehensionCondStep::ComprehensionCondStep(size_t iter_slot, size_t accu_slot, +absl::Status ComprehensionNextStep::Evaluate2(ExecutionFrame* frame) const { + enum { + POS_ITER2_RANGE, // Map or same as POS_ITER_RANGE. + POS_ITER_RANGE, + POS_CURRENT_INDEX, + POS_LOOP_STEP_ACCU, + }; + constexpr int kStackSize = 4; + if (!frame->value_stack().HasEnough(kStackSize)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + absl::Span state = frame->value_stack().GetSpan(kStackSize); + + const cel::Value& iter2_range = state[POS_ITER2_RANGE]; + absl::optional iter2_range_map; + switch (iter2_range.kind()) { + case ValueKind::kMap: + iter2_range_map = iter2_range.GetMap(); + break; + case ValueKind::kList: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + // Leave it on the stack. + frame->value_stack().PopAndPush(kStackSize, std::move(iter2_range)); + return frame->JumpTo(error_jump_offset_); + default: + frame->value_stack().PopAndPush( + kStackSize, frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + return frame->JumpTo(error_jump_offset_); + } + + // Get range from the stack. + const cel::Value& iter_range = state[POS_ITER_RANGE]; + switch (iter_range.kind()) { + case ValueKind::kList: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + frame->value_stack().PopAndPush(kStackSize, std::move(iter_range)); + return frame->JumpTo(error_jump_offset_); + default: + frame->value_stack().PopAndPush( + kStackSize, frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + return frame->JumpTo(error_jump_offset_); + } + ListValue iter_range_list = iter_range.GetList(); + + // Get the current index off the stack. + const cel::Value& current_index_value = state[POS_CURRENT_INDEX]; + if (!current_index_value.IsInt()) { + return absl::InternalError(absl::StrCat( + "ComprehensionNextStep: want int, got ", + cel::KindToString(ValueKindToKind(current_index_value.kind())))); + } + CEL_RETURN_IF_ERROR(frame->IncrementIterations()); + + int64_t next_index = current_index_value.GetInt().NativeValue() + 1; + + frame->comprehension_slots().Set(accu_slot_, state[POS_LOOP_STEP_ACCU]); + + CEL_ASSIGN_OR_RETURN(auto iter_range_list_size, iter_range_list.Size()); + + if (next_index >= static_cast(iter_range_list_size)) { + // Make sure the iter var is out of scope. + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(iter2_slot_); + // pop loop step + frame->value_stack().Pop(1); + // jump to result production step + return frame->JumpTo(jump_offset_); + } + + AttributeTrail iter_range_trail; + if (frame->enable_unknowns()) { + iter_range_trail = + frame->value_stack().GetAttributeSpan(kStackSize)[POS_ITER_RANGE].Step( + cel::AttributeQualifier::OfInt(next_index)); + } + + Value current_iter_var; + if (frame->enable_unknowns() && + frame->attribute_utility().CheckForUnknown(iter_range_trail, + /*use_partial=*/false)) { + current_iter_var = frame->attribute_utility().CreateUnknownSet( + iter_range_trail.attribute()); + } else { + CEL_ASSIGN_OR_RETURN(current_iter_var, + iter_range_list.Get(frame->value_factory(), + static_cast(next_index))); + } + + AttributeTrail iter2_range_trail; + Value current_iter_var2; + if (iter2_range_map) { + AttributeTrail iter2_range_trail; + if (frame->enable_unknowns()) { + iter2_range_trail = + frame->value_stack() + .GetAttributeSpan(kStackSize)[POS_ITER2_RANGE] + .Step(AttributeQualifierFromValue(current_iter_var)); + } + if (frame->enable_unknowns() && + frame->attribute_utility().CheckForUnknown(iter2_range_trail, + /*use_partial=*/false)) { + current_iter_var2 = frame->attribute_utility().CreateUnknownSet( + iter2_range_trail.attribute()); + } else { + CEL_ASSIGN_OR_RETURN( + current_iter_var2, + iter2_range_map->Get(frame->value_manager(), current_iter_var)); + } + } else { + iter2_range_trail = iter_range_trail; + current_iter_var2 = current_iter_var; + current_iter_var = IntValue(next_index); + } + + // pop loop step + // pop old current_index + // push new current_index + frame->value_stack().PopAndPush( + 2, frame->value_factory().CreateIntValue(next_index)); + frame->comprehension_slots().Set(iter_slot_, std::move(current_iter_var), + std::move(iter_range_trail)); + frame->comprehension_slots().Set(iter2_slot_, std::move(current_iter_var2), + std::move(iter2_range_trail)); + return absl::OkStatus(); +} + +ComprehensionCondStep::ComprehensionCondStep(size_t iter_slot, + size_t iter2_slot, + size_t accu_slot, bool shortcircuiting, int64_t expr_id) : ExpressionStepBase(expr_id, false), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot), shortcircuiting_(shortcircuiting) {} @@ -412,7 +820,7 @@ void ComprehensionCondStep::set_error_jump_offset(int offset) { // Stack size before: 3. // Stack size after: 2. // Stack size on error: 1. -absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionCondStep::Evaluate1(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(3)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } @@ -440,8 +848,47 @@ absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +// Check the break condition for the comprehension. +// +// If the condition is false jump to the `result` subexpression. +// If not a bool, clear stack and jump past the result expression. +// Otherwise, continue to the accumulate step. +// Stack changes by ComprehensionCondStep. +// +// Stack size before: 4. +// Stack size after: 3. +// Stack size on error: 1. +absl::Status ComprehensionCondStep::Evaluate2(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(4)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + auto& loop_condition_value = frame->value_stack().Peek(); + if (!loop_condition_value->Is()) { + if (loop_condition_value->Is() || + loop_condition_value->Is()) { + frame->value_stack().PopAndPush(4, std::move(loop_condition_value)); + } else { + frame->value_stack().PopAndPush( + 4, frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + } + // The error jump skips the ComprehensionFinish clean-up step, so we + // need to update the iteration variable stack here. + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(iter2_slot_); + frame->comprehension_slots().ClearSlot(accu_slot_); + return frame->JumpTo(error_jump_offset_); + } + bool loop_condition = loop_condition_value.GetBool().NativeValue(); + frame->value_stack().Pop(1); // loop_condition + if (!loop_condition && shortcircuiting_) { + return frame->JumpTo(jump_offset_); + } + return absl::OkStatus(); +} + std::unique_ptr CreateDirectComprehensionStep( - size_t iter_slot, size_t accu_slot, + size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, @@ -449,7 +896,7 @@ std::unique_ptr CreateDirectComprehensionStep( std::unique_ptr result_step, bool shortcircuiting, int64_t expr_id) { return std::make_unique( - iter_slot, accu_slot, std::move(range), std::move(accu_init), + iter_slot, iter2_slot, accu_slot, std::move(range), std::move(accu_init), std::move(loop_step), std::move(condition_step), std::move(result_step), shortcircuiting, expr_id); } @@ -463,4 +910,13 @@ std::unique_ptr CreateComprehensionInitStep(int64_t expr_id) { return std::make_unique(expr_id); } +std::unique_ptr CreateComprehensionFinishStep2( + size_t accu_slot, int64_t expr_id) { + return std::make_unique(accu_slot, expr_id); +} + +std::unique_ptr CreateComprehensionInitStep2(int64_t expr_id) { + return std::make_unique(expr_id); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index c0fc78aa0..b0b8397f2 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -14,15 +14,23 @@ namespace google::api::expr::runtime { class ComprehensionNextStep : public ExpressionStepBase { public: - ComprehensionNextStep(size_t iter_slot, size_t accu_slot, int64_t expr_id); + ComprehensionNextStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, + int64_t expr_id); void set_jump_offset(int offset); void set_error_jump_offset(int offset); - absl::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const final { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); + } private: + absl::Status Evaluate1(ExecutionFrame* frame) const; + + absl::Status Evaluate2(ExecutionFrame* frame) const; + size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; int jump_offset_; int error_jump_offset_; @@ -30,16 +38,23 @@ class ComprehensionNextStep : public ExpressionStepBase { class ComprehensionCondStep : public ExpressionStepBase { public: - ComprehensionCondStep(size_t iter_slot, size_t accu_slot, + ComprehensionCondStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, bool shortcircuiting, int64_t expr_id); void set_jump_offset(int offset); void set_error_jump_offset(int offset); - absl::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const final { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); + } private: + absl::Status Evaluate1(ExecutionFrame* frame) const; + + absl::Status Evaluate2(ExecutionFrame* frame) const; + size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; int jump_offset_; int error_jump_offset_; @@ -48,7 +63,7 @@ class ComprehensionCondStep : public ExpressionStepBase { // Creates a step for executing a comprehension. std::unique_ptr CreateDirectComprehensionStep( - size_t iter_slot, size_t accu_slot, + size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, @@ -66,6 +81,16 @@ std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, // context for the comprehension. std::unique_ptr CreateComprehensionInitStep(int64_t expr_id); +// Creates a cleanup step for the comprehension. +// Removes the comprehension context then pushes the 'result' sub expression to +// the top of the stack. +std::unique_ptr CreateComprehensionFinishStep2(size_t accu_slot, + int64_t expr_id); + +// Creates a step that checks that the input is iterable and sets up the loop +// context for the comprehension. +std::unique_ptr CreateComprehensionInitStep2(int64_t expr_id); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 2fd513ee7..08e1ff039 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -302,7 +302,7 @@ TEST_F(DirectComprehensionTest, PropagateRangeNonOkStatus) { .WillByDefault(Return(absl::InternalError("test range error"))); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/std::move(range_step), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -331,7 +331,7 @@ TEST_F(DirectComprehensionTest, PropagateAccuInitNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/std::move(accu_init), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -360,7 +360,7 @@ TEST_F(DirectComprehensionTest, PropagateLoopNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), @@ -389,7 +389,7 @@ TEST_F(DirectComprehensionTest, PropagateConditionNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -418,7 +418,7 @@ TEST_F(DirectComprehensionTest, PropagateResultNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -451,7 +451,7 @@ TEST_F(DirectComprehensionTest, Shortcircuit) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), @@ -484,7 +484,7 @@ TEST_F(DirectComprehensionTest, IterationLimit) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), @@ -517,7 +517,7 @@ TEST_F(DirectComprehensionTest, Exhaustive) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), diff --git a/eval/eval/create_map_step.cc b/eval/eval/create_map_step.cc index 3d8d86729..f205dd4b0 100644 --- a/eval/eval/create_map_step.cc +++ b/eval/eval/create_map_step.cc @@ -46,6 +46,7 @@ using ::cel::MapValueBuilderPtr; using ::cel::UnknownValue; using ::cel::Value; using ::cel::common_internal::NewMapValueBuilder; +using ::cel::common_internal::NewMutableMapValue; // `CreateStruct` implementation for map. class CreateStructStepForMap final : public ExpressionStepBase { @@ -231,6 +232,30 @@ absl::Status DirectCreateMapStep::Evaluate( return absl::OkStatus(); } +class MutableMapStep final : public ExpressionStep { + public: + explicit MutableMapStep(int64_t expr_id) : ExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + frame->value_stack().Push(cel::ParsedMapValue( + NewMutableMapValue(frame->memory_manager().arena()))); + return absl::OkStatus(); + } +}; + +class DirectMutableMapStep final : public DirectExpressionStep { + public: + explicit DirectMutableMapStep(int64_t expr_id) + : DirectExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + result = cel::ParsedMapValue( + NewMutableMapValue(frame.value_manager().GetMemoryManager().arena())); + return absl::OkStatus(); + } +}; + } // namespace std::unique_ptr CreateDirectCreateMapStep( @@ -248,4 +273,14 @@ absl::StatusOr> CreateCreateStructStepForMap( std::move(optional_indices)); } +absl::StatusOr> CreateMutableMapStep( + int64_t expr_id) { + return std::make_unique(expr_id); +} + +std::unique_ptr CreateDirectMutableMapStep( + int64_t expr_id) { + return std::make_unique(expr_id); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/create_map_step.h b/eval/eval/create_map_step.h index f9be4be0c..cf5e94644 100644 --- a/eval/eval/create_map_step.h +++ b/eval/eval/create_map_step.h @@ -40,6 +40,20 @@ absl::StatusOr> CreateCreateStructStepForMap( size_t entry_count, absl::flat_hash_set optional_indices, int64_t expr_id); +// Factory method for CreateMap which constructs a mutable map. +// +// This is intended for the map construction step is generated for a +// map-building comprehension (rather than a user authored expression). +absl::StatusOr> CreateMutableMapStep( + int64_t expr_id); + +// Factory method for CreateMap which constructs a mutable map. +// +// This is intended for the map construction step is generated for a +// map-building comprehension (rather than a user authored expression). +std::unique_ptr CreateDirectMutableMapStep( + int64_t expr_id); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ diff --git a/extensions/BUILD b/extensions/BUILD index ae34b1194..4e256991f 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -370,3 +370,92 @@ cc_test( "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) + +cc_library( + name = "comprehensions_v2_functions", + srcs = ["comprehensions_v2_functions.cc"], + hdrs = ["comprehensions_v2_functions.h"], + deps = [ + "//common:value", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "comprehensions_v2_functions_test", + srcs = ["comprehensions_v2_functions_test.cc"], + deps = [ + ":bindings_ext", + ":comprehensions_v2_functions", + ":comprehensions_v2_macros", + ":strings", + "//common:memory", + "//common:source", + "//common:value", + "//common:value_testing", + "//extensions/protobuf:runtime_adapter", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", + "//parser:options", + "//parser:standard_macros", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:reference_resolver", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "comprehensions_v2_macros", + srcs = ["comprehensions_v2_macros.cc"], + hdrs = ["comprehensions_v2_macros.h"], + deps = [ + "//common:expr", + "//common:operators", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "comprehensions_v2_macros_test", + srcs = ["comprehensions_v2_macros_test.cc"], + deps = [ + ":comprehensions_v2_macros", + "//common:source", + "//internal:testing", + "//parser", + "//parser:macro_registry", + "//parser:options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) diff --git a/extensions/comprehensions_v2_functions.cc b/extensions/comprehensions_v2_functions.cc new file mode 100644 index 000000000..4202eef8d --- /dev/null +++ b/extensions/comprehensions_v2_functions.cc @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_functions.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "common/values/map_value_builder.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +namespace { + +absl::StatusOr MapInsert(ValueManager& value_manager, + const MapValue& map, const Value& key, + const Value& value) { + if (auto mutable_map_value = common_internal::AsMutableMapValue(map); + mutable_map_value) { + // Fast path, runtime has given us a mutable map. We can mutate it directly + // and return it. + CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value)) + .With(ErrorValueReturn()); + return map; + } + // Slow path, we have to make a copy. + auto builder = common_internal::NewMapValueBuilder( + value_manager.GetMemoryManager().arena()); + if (auto size = map.Size(); size.ok()) { + builder->Reserve(*size + 1); + } else { + size.IgnoreError(); + } + CEL_RETURN_IF_ERROR( + map.ForEach(value_manager, + [&builder](const Value& key, + const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder->Put(key, value)); + return true; + })) + .With(ErrorValueReturn()); + CEL_RETURN_IF_ERROR(builder->Put(key, value)).With(ErrorValueReturn()); + return std::move(*builder).Build(); +} + +} // namespace + +absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.Register( + VariadicFunctionAdapter, MapValue, Value, Value>:: + CreateDescriptor("cel.@mapInsert", /*receiver_style=*/false), + VariadicFunctionAdapter, MapValue, Value, + Value>::WrapFunction(&MapInsert))); + return absl::OkStatus(); +} + +absl::Status RegisterComprehensionsV2Functions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options) { + return RegisterComprehensionsV2Functions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options)); +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_functions.h b/extensions/comprehensions_v2_functions.h new file mode 100644 index 000000000..8f99780a2 --- /dev/null +++ b/extensions/comprehensions_v2_functions.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register comprehension v2 functions. +absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, + const RuntimeOptions& options); +absl::Status RegisterComprehensionsV2Functions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ diff --git a/extensions/comprehensions_v2_functions_test.cc b/extensions/comprehensions_v2_functions_test.cc new file mode 100644 index 000000000..53e6761ab --- /dev/null +++ b/extensions/comprehensions_v2_functions_test.cc @@ -0,0 +1,240 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_functions.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/memory.h" +#include "common/source.h" +#include "common/type_reflector.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "common/values/legacy_value_manager.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2_macros.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/standard_macros.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::cel::test::BoolValueIs; +using ::google::api::expr::parser::EnrichedParse; +using ::testing::TestWithParam; + +struct ComprehensionsV2FunctionsTestCase { + std::string expression; +}; + +class ComprehensionsV2FunctionsTest + : public TestWithParam { + public: + void SetUp() override { + RuntimeOptions options; + options.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + ASSERT_THAT( + RegisterComprehensionsV2Functions(builder.function_registry(), options), + IsOk()); + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); + } + + absl::StatusOr Parse(absl::string_view text) { + CEL_ASSIGN_OR_RETURN(auto source, NewSource(text)); + + ParserOptions options; + options.enable_optional_syntax = true; + + MacroRegistry registry; + CEL_RETURN_IF_ERROR(RegisterStandardMacros(registry, options)); + CEL_RETURN_IF_ERROR(RegisterComprehensionsV2Macros(registry, options)); + CEL_RETURN_IF_ERROR(RegisterBindingsMacros(registry, options)); + + CEL_ASSIGN_OR_RETURN(auto result, + EnrichedParse(*source, registry, options)); + return result.parsed_expr(); + } + + protected: + std::unique_ptr runtime_; +}; + +TEST_P(ComprehensionsV2FunctionsTest, Basic) { + ASSERT_OK_AND_ASSIGN(auto ast, Parse(GetParam().expression)); + ASSERT_OK_AND_ASSIGN(auto program, + ProtobufRuntimeAdapter::CreateProgram(*runtime_, ast)); + google::protobuf::Arena arena; + Activation activation; + common_internal::LegacyValueManager value_manager( + MemoryManager::Pooling(&arena), TypeReflector::Builtin()); + std::cout << absl::StrCat(ast) << std::endl; + EXPECT_THAT( + (program->Trace( + activation, + [](int64_t id, const Value& value, ValueManager&) -> absl::Status { + std::cout << id << " " << value.DebugString() << std::endl; + return absl::OkStatus(); + }, + value_manager)), + IsOkAndHolds(BoolValueIs(true))) + << GetParam().expression; +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionsV2FunctionsTest, ComprehensionsV2FunctionsTest, + ::testing::ValuesIn({ + // list.all() + {.expression = "[1, 2, 3, 4].all(i, v, i < 5 && v > 0)"}, + {.expression = "[1, 2, 3, 4].all(i, v, i < v)"}, + {.expression = "[1, 2, 3, 4].all(i, v, i > v) == false"}, + { + .expression = + R"cel(cel.bind(listA, [1, 2, 3, 4], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))))cel", + }, + { + .expression = + R"cel(cel.bind(listA, [1, 2, 3, 4, 5, 6], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))) == false)cel", + }, + // list.exists() + { + .expression = + R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.exists(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", + }, + // list.existsOne() + { + .expression = + R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", + }, + { + .expression = + R"cel(cel.bind(l, ['hello', 'goodbye', 'hello!', 'goodbye'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next == "goodbye").orValue(false))) == false)cel", + }, + // list.transformList() + { + .expression = + R"cel(['Hello', 'world'].transformList(i, v, "[" + string(i) + "]" + v.lowerAscii()) == ["[0]hello", "[1]world"])cel", + }, + { + .expression = + R"cel(['hello', 'world'].transformList(i, v, v.startsWith('greeting'), "[" + string(i) + "]" + v) == [])cel", + }, + { + .expression = + R"cel([1, 2, 3].transformList(indexVar, valueVar, (indexVar * valueVar) + valueVar) == [1, 4, 9])cel", + }, + { + .expression = + R"cel([1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == [1, 9])cel", + }, + // map.transformMap() + { + .expression = + R"cel(['Hello', 'world'].transformMap(i, v, [v.lowerAscii()]) == {0: ['hello'], 1: ['world']})cel", + }, + { + .expression = + R"cel([1, 2, 3].transformMap(indexVar, valueVar, (indexVar * valueVar) + valueVar) == {0: 1, 1: 4, 2: 9})cel", + }, + { + .expression = + R"cel([1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == {0: 1, 2: 9})cel", + }, + // map.all() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'world'}.all(k, v, k.startsWith('hello') && v == 'world'))cel", + }, + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.all(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", + }, + // map.exists() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.exists(k, v, k.startsWith('hello') && v.endsWith('world')))cel", + }, + // map.existsOne() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')))cel", + }, + { + .expression = + R"cel({'hello': 'world', 'hello!': 'wow, world'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", + }, + // map.transformList() + { + .expression = + R"cel({'Hello': 'world'}.transformList(k, v, k.lowerAscii() + "=" + v) == ["hello=world"])cel", + }, + { + .expression = + R"cel({'hello': 'world'}.transformList(k, v, k.startsWith('greeting'), k + "=" + v) == [])cel", + }, + { + .expression = + R"cel({'farewell': 'goodbye', 'greeting': 'hello'}.transformList(k, _, k) == ['farewell', 'greeting'])cel", + }, + { + .expression = + R"cel({'greeting': 'hello', 'farewell': 'goodbye'}.transformList(_, v, v) == ['goodbye', 'hello'])cel", + }, + // map.transformMap() + { + .expression = + R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, k + ", " + v + "!") == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'})cel", + }, + { + .expression = + R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, v.startsWith('world'), k + ", " + v + "!") == {'hello': 'hello, world!'})cel", + }, + })); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc new file mode 100644 index 000000000..04793ad39 --- /dev/null +++ b/extensions/comprehensions_v2_macros.cc @@ -0,0 +1,488 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_macros.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +namespace { + +using ::google::api::expr::common::CelOperator; + +absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("all() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "all() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], "all() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "all() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("all() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(true); + auto condition = + factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); + auto step = factory.NewCall(CelOperator::LOGICAL_AND, factory.NewAccuIdent(), + std::move(args[2])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeAllMacro2() { + auto status_or_macro = Macro::Receiver(CelOperator::ALL, 3, ExpandAllMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("exists() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "exists() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], "exists() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "exists() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("exists() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(false); + auto condition = factory.NewCall( + CelOperator::NOT_STRICTLY_FALSE, + factory.NewCall(CelOperator::LOGICAL_NOT, factory.NewAccuIdent())); + auto step = factory.NewCall(CelOperator::LOGICAL_OR, factory.NewAccuIdent(), + std::move(args[2])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeExistsMacro2() { + auto status_or_macro = + Macro::Receiver(CelOperator::EXISTS, 3, ExpandExistsMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("existsOne() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "existsOne() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "existsOne() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "existsOne() second variable must be different " + "from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("existsOne() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("existsOne() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto init = factory.NewIntConst(0); + auto condition = factory.NewBoolConst(true); + auto step = + factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + factory.NewCall(CelOperator::ADD, factory.NewAccuIdent(), + factory.NewIntConst(1)), + factory.NewAccuIdent()); + auto result = factory.NewCall(CelOperator::EQUALS, factory.NewAccuIdent(), + factory.NewIntConst(1)); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeExistsOneMacro2() { + auto status_or_macro = Macro::Receiver("existsOne", 3, ExpandExistsOneMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandFilterMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("filter() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "filter() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], "filter() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "filter() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("filter() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("filter() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto name = args[0].ident_expr().name(); + auto name2 = args[1].ident_expr().name(); + auto init = factory.NewList(); + auto condition = factory.NewBoolConst(true); + auto step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[1])))); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension( + name, name2, std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), factory.NewAccuIdent()); +} + +Macro MakeFilterMacro2() { + auto status_or_macro = + Macro::Receiver(CelOperator::FILTER, 3, ExpandFilterMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformList() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformList() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformList() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformList() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformList() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformList() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[2])))); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewList(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformList3Macro() { + auto status_or_macro = + Macro::Receiver("transformList", 3, ExpandTransformList3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformList() requires 4 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformList() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformList() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformList() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformList() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformList() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[3])))); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewList(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformList4Macro() { + auto status_or_macro = + Macro::Receiver("transformList", 4, ExpandTransformList4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformMap() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformMap() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformMap() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMap() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transforMap() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformMap() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[0]), std::move(args[2])); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap3Macro() { + auto status_or_macro = + Macro::Receiver("transformMap", 3, ExpandTransformMap3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformMap() requires 4 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformMap() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformMap() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMap() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformMap() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformMap() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[0]), std::move(args[3])); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap4Macro() { + auto status_or_macro = + Macro::Receiver("transformMap", 4, ExpandTransformMap4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +const Macro& AllMacro2() { + static const absl::NoDestructor macro(MakeAllMacro2()); + return *macro; +} + +const Macro& ExistsMacro2() { + static const absl::NoDestructor macro(MakeExistsMacro2()); + return *macro; +} + +const Macro& ExistsOneMacro2() { + static const absl::NoDestructor macro(MakeExistsOneMacro2()); + return *macro; +} + +const Macro& FilterMacro2() { + static const absl::NoDestructor macro(MakeFilterMacro2()); + return *macro; +} + +const Macro& TransformList3Macro() { + static const absl::NoDestructor macro(MakeTransformList3Macro()); + return *macro; +} + +const Macro& TransformList4Macro() { + static const absl::NoDestructor macro(MakeTransformList4Macro()); + return *macro; +} + +const Macro& TransformMap3Macro() { + static const absl::NoDestructor macro(MakeTransformMap3Macro()); + return *macro; +} + +const Macro& TransformMap4Macro() { + static const absl::NoDestructor macro(MakeTransformMap4Macro()); + return *macro; +} + +} // namespace + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions&) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(AllMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsOneMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(FilterMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformList3Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformList4Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformMap3Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformMap4Macro())); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_macros.h b/extensions/comprehensions_v2_macros.h new file mode 100644 index 000000000..3b2bfd577 --- /dev/null +++ b/extensions/comprehensions_v2_macros.h @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ diff --git a/extensions/comprehensions_v2_macros_test.cc b/extensions/comprehensions_v2_macros_test.cc new file mode 100644 index 000000000..4fa07123f --- /dev/null +++ b/extensions/comprehensions_v2_macros_test.cc @@ -0,0 +1,230 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_macros.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/source.h" +#include "internal/testing.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::google::api::expr::parser::EnrichedParse; +using ::testing::HasSubstr; + +struct ComprehensionsV2MacrosTestCase { + std::string expression; + std::string error; +}; + +using ComprehensionsV2MacrosTest = + ::testing::TestWithParam; + +TEST_P(ComprehensionsV2MacrosTest, Basic) { + const auto& test_param = GetParam(); + ASSERT_OK_AND_ASSIGN(auto source, NewSource(test_param.expression)); + + MacroRegistry registry; + ASSERT_THAT(RegisterComprehensionsV2Macros(registry, ParserOptions()), + IsOk()); + + EXPECT_THAT(EnrichedParse(*source, registry), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_param.error))); +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionsV2MacrosTest, ComprehensionsV2MacrosTest, + ::testing::ValuesIn({ + { + .expression = "[].all(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].all(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].all(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].all(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].all(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].exists(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].exists(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].exists(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].existsOne(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].existsOne(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].existsOne(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].existsOne(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].existsOne(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].filter(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].filter(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].filter(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].filter(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].filter(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].transformList(__result__, v, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(i, __result__, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(e, e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].transformList(foo.bar, e, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].transformList(e, foo.bar, e)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].transformList(__result__, v, v == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(i, __result__, i == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(e, e, e == e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].transformList(foo.bar, e, true, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].transformList(e, foo.bar, true, e)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(__result__, v, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(k, __result__, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(e, e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMap(foo.bar, e, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(e, foo.bar, e)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(__result__, v, v == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(k, __result__, k == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(e, e, e == e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMap(foo.bar, e, true, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(e, foo.bar, true, e)", + .error = "second variable name must be a simple identifier", + }, + })); + +} // namespace +} // namespace cel::extensions diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index e84e8be7a..291bccdb0 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -255,6 +255,27 @@ class MacroExprFactory : protected ExprFactory { std::move(loop_step), std::move(result)); } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); + } + ABSL_MUST_USE_RESULT virtual Expr ReportError(absl::string_view message) = 0; ABSL_MUST_USE_RESULT virtual Expr ReportErrorAt(