diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 926399a92..58168ab19 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -141,10 +141,10 @@ def cel_spec_deps(): url = "https://github.com/bazelbuild/rules_python/releases/download/0.33.2/rules_python-0.33.2.tar.gz", ) - CEL_SPEC_GIT_SHA = "5299974f1c69103e4bb4eec48f7d9b24413ca3c7" # Jul 19, 2024 + CEL_SPEC_GIT_SHA = "5ed294fa64206016a37db2986dab942c80a65e4b" # Aug 16, 2024 http_archive( name = "com_google_cel_spec", - sha256 = "2beb97d2d8fff4db659f0633d0e432fdb7d328fe9c39061eb142af5dbb6eaea0", + sha256 = "926abf84cde8c05ce99700caee5786bc7d8aeec77185fd669bce27df455a1215", strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) diff --git a/conformance/BUILD b/conformance/BUILD index 6f8d2dd1a..d1683997b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -45,6 +45,7 @@ cc_library( hdrs = ["service.h"], deps = [ ":value_conversion", + "//common:expr", "//common:memory", "//common:source", "//common:value", @@ -67,6 +68,8 @@ cc_library( "//extensions/protobuf:value", "//internal:status_macros", "//parser", + "//parser:macro", + "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", @@ -82,6 +85,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", @@ -141,6 +146,7 @@ _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", "@com_google_cel_spec//tests/simple:testdata/wrappers.textproto", + "@com_google_cel_spec//tests/simple:testdata/block_ext.textproto", ] _TESTS_TO_SKIP_MODERN = [ @@ -209,6 +215,13 @@ _TESTS_TO_SKIP_MODERN = [ # TODO: Add missing conversion function "conversions/bool", + + # cel.@block + "block_ext/basic/multiple_macros_1", + "block_ext/basic/multiple_macros_2", + "block_ext/basic/multiple_macros_3", + "block_ext/basic/nested_macros", + "block_ext/basic/nested_macros_2", ] _TESTS_TO_SKIP_MODERN_DASHBOARD = [ @@ -276,6 +289,17 @@ _TESTS_TO_SKIP_LEGACY = [ # TODO: Add missing conversion function "conversions/bool", + + # cel.@block + "block_ext/basic/optional_list", + "block_ext/basic/optional_map", + "block_ext/basic/optional_map_chained", + "block_ext/basic/optional_message", + "block_ext/basic/multiple_macros_1", + "block_ext/basic/multiple_macros_2", + "block_ext/basic/multiple_macros_3", + "block_ext/basic/nested_macros", + "block_ext/basic/nested_macros_2", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ diff --git a/conformance/service.cc b/conformance/service.cc index 0b994d9f6..3c9e5aa63 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -14,6 +14,7 @@ #include "conformance/service.h" +#include #include #include #include @@ -33,7 +34,11 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" #include "common/memory.h" #include "common/source.h" #include "common/value.h" @@ -56,6 +61,8 @@ #include "extensions/protobuf/type_reflector.h" #include "extensions/strings.h" #include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" @@ -90,6 +97,103 @@ namespace google::api::expr::runtime { namespace { +bool IsCelNamespace(const cel::Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, + cel::Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + cel::Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, + cel::Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + cel::Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +absl::optional CelIterVarMacroExpander( + cel::MacroExprFactory& factory, cel::Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + cel::Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, + "cel.iterVar requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, + "cel.iterVar requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@c:", index)); +} + +absl::optional CelAccuVarMacroExpander( + cel::MacroExprFactory& factory, cel::Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + cel::Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, + "cel.accuVar requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, + "cel.accuVar requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@x:", index)); +} + +absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN(auto block_macro, + cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); + CEL_ASSIGN_OR_RETURN(auto index_macro, + cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); + CEL_ASSIGN_OR_RETURN( + auto iter_var_macro, + cel::Macro::Receiver("iterVar", 1, CelIterVarMacroExpander)); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); + CEL_ASSIGN_OR_RETURN( + auto accu_var_macro, + cel::Macro::Receiver("accuVar", 1, CelAccuVarMacroExpander)); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); + return absl::OkStatus(); +} + google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } @@ -126,6 +230,7 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); + CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 74d03c57a..15d8c7f8e 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -129,6 +129,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 6bd17ce74..e7ad07800 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -38,8 +38,10 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" @@ -301,6 +303,10 @@ bool IsBind(const cel::ast_internal::Comprehension* comprehension) { comprehension->iter_range().list_expr().elements().empty(); } +bool IsBlock(const cel::ast_internal::Call* call) { + return call->function() == "cel.@block"; +} + // Visitor for Comprehension expressions. class ComprehensionVisitor { public: @@ -397,7 +403,6 @@ class FlatExprVisitor : public cel::AstVisitor { value_factory_(value_factory), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), - parent_expr_(nullptr), options_(options), program_optimizers_(std::move(program_optimizers)), issue_collector_(issue_collector), @@ -416,9 +421,14 @@ class FlatExprVisitor : public cel::AstVisitor { resume_from_suppressed_branch_ = &expr; } - program_builder_.EnterSubexpression(&expr); + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.in && block.bindings_set.contains(&expr)) { + block.current_binding = &expr; + } + } - parent_expr_ = &expr; + program_builder_.EnterSubexpression(&expr); for (const std::unique_ptr& optimizer : program_optimizers_) { @@ -462,6 +472,20 @@ class FlatExprVisitor : public cel::AstVisitor { SetProgressStatusError( MaybeExtractSubexpression(&expr, comprehension_stack_.back())); } + + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.current_binding == &expr) { + int index = program_builder_.ExtractSubexpression(&expr); + if (index == -1) { + SetProgressStatusError( + absl::InvalidArgumentError("failed to extract subexpression")); + return; + } + block.subexpressions[block.current_index++] = index; + block.current_binding = nullptr; + } + } } void PostVisitConst(const cel::ast_internal::Expr& expr, @@ -499,6 +523,47 @@ class FlatExprVisitor : public cel::AstVisitor { // If lazy evaluation enabled and ided as a lazy expression, // subexpression and slot will be set. SlotLookupResult LookupSlot(absl::string_view path) { + if (block_.has_value()) { + const BlockInfo& block = *block_; + if (block.in) { + absl::string_view index_suffix = path; + if (absl::ConsumePrefix(&index_suffix, "@index")) { + size_t index; + if (!absl::SimpleAtoi(index_suffix, &index)) { + SetProgressStatusError( + issue_collector_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError("bad @index")))); + return {-1, -1}; + } + if (index >= block.size) { + SetProgressStatusError( + issue_collector_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError(absl::StrCat( + "invalid @index greater than number of bindings: ", + index, " >= ", block.size))))); + return {-1, -1}; + } + if (index >= block.current_index) { + SetProgressStatusError( + issue_collector_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError(absl::StrCat( + "@index references current or future binding: ", index, + " >= ", block.current_index))))); + return {-1, -1}; + } + return {static_cast(block.index + index), + block.subexpressions[index]}; + } + if (absl::ConsumePrefix(&index_suffix, "@c:") || + absl::ConsumePrefix(&index_suffix, "@x:")) { + SetProgressStatusError(issue_collector_.AddIssue( + RuntimeIssue::CreateWarning(absl::InvalidArgumentError( + "support is not yet implemented for CSE generated @c: or @x: " + "comprehension variables")))); + return {-1, -1}; + } + } + } if (!comprehension_stack_.empty()) { for (int i = comprehension_stack_.size() - 1; i >= 0; i--) { const ComprehensionStackRecord& record = comprehension_stack_[i]; @@ -740,6 +805,55 @@ class FlatExprVisitor : public cel::AstVisitor { call_expr.has_target() && call_expr.args().size() == 1) { cond_visitor = std::make_unique( this, BinaryCond::kOptionalOrValue, options_.short_circuiting); + } else if (IsBlock(&call_expr)) { + // cel.@block + if (block_.has_value()) { + // There can only be one for now. + SetProgressStatusError( + absl::InvalidArgumentError("multiple cel.@block are not allowed")); + return; + } + block_ = BlockInfo(); + BlockInfo& block = *block_; + block.in = true; + if (call_expr.args().empty()) { + SetProgressStatusError(absl::InvalidArgumentError( + "malformed cel.@block: missing list of bound expressions")); + return; + } + if (call_expr.args().size() != 2) { + SetProgressStatusError(absl::InvalidArgumentError( + "malformed cel.@block: missing bound expression")); + return; + } + if (!call_expr.args()[0].has_list_expr()) { + SetProgressStatusError( + absl::InvalidArgumentError("malformed cel.@block: first argument " + "is not a list of bound expressions")); + return; + } + const auto& list_expr = call_expr.args().front().list_expr(); + block.size = list_expr.elements().size(); + if (block.size == 0) { + SetProgressStatusError(absl::InvalidArgumentError( + "malformed cel.@block: list of bound expressions is empty")); + return; + } + block.bindings_set.reserve(block.size); + for (const auto& list_expr_element : list_expr.elements()) { + if (list_expr_element.optional()) { + SetProgressStatusError( + absl::InvalidArgumentError("malformed cel.@block: list of bound " + "expressions contains an optional")); + return; + } + block.bindings_set.insert(&list_expr_element.expr()); + } + block.index = index_manager().ReserveSlots(block.size); + block.expr = &expr; + block.bindings = &call_expr.args()[0]; + block.bound = &call_expr.args()[1]; + block.subexpressions.resize(block.size, -1); } else { return; } @@ -1049,6 +1163,16 @@ class FlatExprVisitor : public cel::AstVisitor { return; } + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.expr == &expr) { + block.in = false; + index_manager().ReleaseSlots(block.size); + AddStep(CreateClearSlotsStep(block.index, block.size, -1)); + return; + } + } + // Establish the search criteria for a given function. absl::string_view function = call_expr.function(); @@ -1260,6 +1384,15 @@ class FlatExprVisitor : public cel::AstVisitor { if (!progress_status_.ok()) { return; } + + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.bindings == &expr) { + // Do nothing, this is the cel.@block bindings list. + return; + } + } + if (!comprehension_stack_.empty()) { const ComprehensionStackRecord& comprehension = comprehension_stack_.back(); @@ -1514,6 +1647,36 @@ class FlatExprVisitor : public cel::AstVisitor { std::unique_ptr visitor; }; + struct BlockInfo { + // True if we are currently visiting the `cel.@block` node or any of its + // children. + bool in = false; + // Pointer to the `cel.@block` node. + const cel::ast_internal::Expr* expr = nullptr; + // Pointer to the `cel.@block` bindings, that is the first argument to the + // function. + const cel::ast_internal::Expr* bindings = nullptr; + // Set of pointers to the elements of `bindings` above. + absl::flat_hash_set bindings_set; + // Pointer to the `cel.@block` bound expression, that is the second argument + // to the function. + const cel::ast_internal::Expr* bound = nullptr; + // The number of entries in the `cel.@block`. + size_t size = 0; + // Starting slot index for `cel.@block`. We occupy he slot indices `index` + // through `index + size + (var_size * 2)`. + size_t index = 0; + // The current slot index we are processing, any index references must be + // less than this to be valid. + size_t current_index = 0; + // Pointer to the current `cel.@block` being processed, that is one of the + // elements within the first argument. + const cel::ast_internal::Expr* current_binding = nullptr; + // Mapping between block indices and their subexpressions, fixed size with + // exactly `size` elements. Unprocessed indices are set to `-1`. + std::vector subexpressions; + }; + bool PlanningSuppressed() const { return resume_from_suppressed_branch_ != nullptr; } @@ -1593,10 +1756,6 @@ class FlatExprVisitor : public cel::AstVisitor { // field is used as marker suppressing CelExpression creation for SELECTs. const cel::ast_internal::Expr* resolved_select_expr_; - // Used for assembling a temporary tree mapping program segments - // to source expr nodes. - const cel::ast_internal::Expr* parent_expr_; - const cel::RuntimeOptions& options_; std::vector comprehension_stack_; @@ -1610,6 +1769,7 @@ class FlatExprVisitor : public cel::AstVisitor { IndexManager index_manager_; bool enable_optional_types_; + absl::optional block_; }; void BinaryCondVisitor::PreVisit(const cel::ast_internal::Expr* expr) { diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 24753829c..cc56ccfe7 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -58,7 +58,8 @@ bool IsSpecialFunction(absl::string_view function_name) { function_name == cel::builtin::kOr || function_name == cel::builtin::kIndex || function_name == cel::builtin::kTernary || - function_name == kOptionalOr || function_name == kOptionalOrValue; + function_name == kOptionalOr || function_name == kOptionalOrValue || + function_name == "cel.@block"; } bool OverloadExists(const Resolver& resolver, absl::string_view name,