Skip to content

Commit

Permalink
Initial minimal implementation of Comprehensions V2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692194361
  • Loading branch information
jcking authored and copybara-github committed Nov 5, 2024
1 parent 9310c49 commit 6177277
Show file tree
Hide file tree
Showing 19 changed files with 1,931 additions and 44 deletions.
29 changes: 29 additions & 0 deletions common/expr_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,35 @@ class ExprFactory {
return expr;
}

template <typename IterVar, typename IterVar2, typename IterRange,
typename AccuVar, typename AccuInit, typename LoopCondition,
typename LoopStep, typename Result,
typename = std::enable_if_t<IsStringLike<IterVar>::value>,
typename = std::enable_if_t<IsStringLike<IterVar2>::value>,
typename = std::enable_if_t<IsExprLike<IterRange>::value>,
typename = std::enable_if_t<IsStringLike<AccuVar>::value>,
typename = std::enable_if_t<IsExprLike<AccuInit>::value>,
typename = std::enable_if_t<IsExprLike<LoopStep>::value>,
typename = std::enable_if_t<IsExprLike<LoopCondition>::value>,
typename = std::enable_if_t<IsExprLike<Result>::value>>
Expr NewComprehension(ExprId id, IterVar iter_var, IterVar2 iter_var2,
IterRange iter_range, AccuVar accu_var,
AccuInit accu_init, LoopCondition loop_condition,
LoopStep loop_step, Result result) {
Expr expr;
expr.set_id(id);
auto& comprehension_expr = expr.mutable_comprehension_expr();
comprehension_expr.set_iter_var(std::move(iter_var));
comprehension_expr.set_iter_var2(std::move(iter_var2));
comprehension_expr.set_iter_range(std::move(iter_range));
comprehension_expr.set_accu_var(std::move(accu_var));
comprehension_expr.set_accu_init(std::move(accu_init));
comprehension_expr.set_loop_condition(std::move(loop_condition));
comprehension_expr.set_loop_step(std::move(loop_step));
comprehension_expr.set_result(std::move(result));
return expr;
}

private:
friend class MacroExprFactory;
friend class ParserMacroExprFactory;
Expand Down
9 changes: 9 additions & 0 deletions common/values/error_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ bool IsNoSuchField(const ErrorValue& value);

bool IsNoSuchKey(const ErrorValue& value);

class ErrorValueReturn final {
public:
ErrorValueReturn() = default;

ErrorValue operator()(absl::Status status) const {
return ErrorValue(std::move(status));
}
};

} // namespace cel

#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_
2 changes: 2 additions & 0 deletions conformance/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ cc_library(
"//eval/public:cel_value",
"//eval/public:transform_utility",
"//extensions:bindings_ext",
"//extensions:comprehensions_v2_functions",
"//extensions:comprehensions_v2_macros",
"//extensions:encoders",
"//extensions:math_ext",
"//extensions:math_ext_macros",
Expand Down
8 changes: 8 additions & 0 deletions conformance/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
#include "eval/public/cel_value.h"
#include "eval/public/transform_utility.h"
#include "extensions/bindings_ext.h"
#include "extensions/comprehensions_v2_functions.h"
#include "extensions/comprehensions_v2_macros.h"
#include "extensions/encoders.h"
#include "extensions/math_ext.h"
#include "extensions/math_ext_macros.h"
Expand Down Expand Up @@ -255,6 +257,8 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request,
options.enable_optional_syntax = enable_optional_syntax;
cel::MacroRegistry macros;
CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options));
CEL_RETURN_IF_ERROR(
cel::extensions::RegisterComprehensionsV2Macros(macros, options));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options));
Expand Down Expand Up @@ -333,6 +337,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor());
CEL_RETURN_IF_ERROR(
RegisterBuiltinFunctions(builder->GetRegistry(), options));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions(
builder->GetRegistry(), options));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions(
builder->GetRegistry(), options));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions(
Expand Down Expand Up @@ -503,6 +509,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
type_registry,
cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()));

CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions(
builder.function_registry(), options));
CEL_RETURN_IF_ERROR(cel::extensions::EnableOptionalTypes(builder));
CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions(
builder.function_registry(), options));
Expand Down
127 changes: 108 additions & 19 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,45 @@ const cel::ast_internal::Expr* GetOptimizableListAppendOperand(
return &GetOptimizableListAppendCall(comprehension)->args()[1];
}

bool IsOptimizableMapInsert(
const cel::ast_internal::Comprehension* comprehension) {
if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) {
return false;
}
absl::string_view accu_var = comprehension->accu_var();
if (accu_var.empty() || !comprehension->has_result() ||
!comprehension->result().has_ident_expr() ||
comprehension->result().ident_expr().name() != accu_var) {
return false;
}
if (!comprehension->accu_init().has_map_expr()) {
return false;
}
if (!comprehension->loop_step().has_call_expr()) {
return false;
}
const auto* call_expr = &comprehension->loop_step().call_expr();

if (call_expr->function() == cel::builtin::kTernary &&
call_expr->args().size() == 3) {
if (!call_expr->args()[1].has_call_expr()) {
return false;
}
call_expr = &(call_expr->args()[1].call_expr());
}
return call_expr->function() == "cel.@mapInsert" &&
call_expr->args().size() == 3 &&
call_expr->args()[0].has_ident_expr() &&
call_expr->args()[0].ident_expr().name() == accu_var;
}

bool IsBind(const cel::ast_internal::Comprehension* comprehension) {
static constexpr absl::string_view kUnusedIterVar = "#unused";

return comprehension->loop_condition().const_expr().has_bool_value() &&
comprehension->loop_condition().const_expr().bool_value() == false &&
comprehension->iter_var() == kUnusedIterVar &&
comprehension->iter_var2().empty() &&
comprehension->iter_range().has_list_expr() &&
comprehension->iter_range().list_expr().elements().empty();
}
Expand All @@ -342,14 +375,15 @@ class ComprehensionVisitor {
public:
explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting,
bool is_trivial, size_t iter_slot,
size_t accu_slot)
size_t iter2_slot, size_t accu_slot)
: visitor_(visitor),
next_step_(nullptr),
cond_step_(nullptr),
short_circuiting_(short_circuiting),
is_trivial_(is_trivial),
accu_init_extracted_(false),
iter_slot_(iter_slot),
iter2_slot_(iter2_slot),
accu_slot_(accu_slot) {}

void PreVisit(const cel::ast_internal::Expr* expr);
Expand Down Expand Up @@ -383,6 +417,7 @@ class ComprehensionVisitor {
bool is_trivial_;
bool accu_init_extracted_;
size_t iter_slot_;
size_t iter2_slot_;
size_t accu_slot_;
};

Expand Down Expand Up @@ -599,6 +634,10 @@ class FlatExprVisitor : public cel::AstVisitor {
}
return {static_cast<int>(record.iter_slot), -1};
}
if (record.iter_var2_in_scope &&
record.comprehension->iter_var2() == path) {
return {static_cast<int>(record.iter2_slot), -1};
}
if (record.accu_var_in_scope &&
record.comprehension->accu_var() == path) {
int slot = record.accu_slot;
Expand Down Expand Up @@ -1083,7 +1122,7 @@ class FlatExprVisitor : public cel::AstVisitor {
void MaybeMakeComprehensionRecursive(
const cel::ast_internal::Expr* expr,
const cel::ast_internal::Comprehension* comprehension, size_t iter_slot,
size_t accu_slot) {
size_t iter2_slot, size_t accu_slot) {
if (options_.max_recursion_depth == 0) {
return;
}
Expand Down Expand Up @@ -1136,7 +1175,8 @@ class FlatExprVisitor : public cel::AstVisitor {
}

auto step = CreateDirectComprehensionStep(
iter_slot, accu_slot, range_plan->ExtractRecursiveProgram().step,
iter_slot, iter2_slot, accu_slot,
range_plan->ExtractRecursiveProgram().step,
accu_plan->ExtractRecursiveProgram().step,
loop_plan->ExtractRecursiveProgram().step,
condition_plan->ExtractRecursiveProgram().step,
Expand Down Expand Up @@ -1247,13 +1287,20 @@ class FlatExprVisitor : public cel::AstVisitor {
}
const auto& accu_var = comprehension.accu_var();
const auto& iter_var = comprehension.iter_var();
const auto& iter_var2 = comprehension.iter_var2();
ValidateOrError(!accu_var.empty(),
"Invalid comprehension: 'accu_var' must not be empty");
ValidateOrError(!iter_var.empty(),
"Invalid comprehension: 'iter_var' must not be empty");
ValidateOrError(
accu_var != iter_var,
"Invalid comprehension: 'accu_var' must not be the same as 'iter_var'");
ValidateOrError(accu_var != iter_var2,
"Invalid comprehension: 'accu_var' must not be the same as "
"'iter_var2'");
ValidateOrError(iter_var2 != iter_var,
"Invalid comprehension: 'iter_var2' must not be the same "
"as 'iter_var'");
ValidateOrError(comprehension.has_accu_init(),
"Invalid comprehension: 'accu_init' must be set");
ValidateOrError(comprehension.has_loop_condition(),
Expand All @@ -1263,15 +1310,20 @@ class FlatExprVisitor : public cel::AstVisitor {
ValidateOrError(comprehension.has_result(),
"Invalid comprehension: 'result' must be set");

size_t iter_slot, accu_slot, slot_count;
size_t iter_slot, iter2_slot, accu_slot, slot_count;
bool is_bind = IsBind(&comprehension);
if (is_bind) {
accu_slot = iter_slot = index_manager_.ReserveSlots(1);
accu_slot = iter_slot = iter2_slot = index_manager_.ReserveSlots(1);
slot_count = 1;
} else {
iter_slot = index_manager_.ReserveSlots(2);
} else if (comprehension.iter_var2().empty()) {
iter_slot = iter2_slot = index_manager_.ReserveSlots(2);
accu_slot = iter_slot + 1;
slot_count = 2;
} else {
iter_slot = index_manager_.ReserveSlots(3);
iter2_slot = iter_slot + 1;
accu_slot = iter2_slot + 1;
slot_count = 3;
}
// If this is in the scope of an optimized bind accu-init, account the slots
// to the outermost bind-init scope.
Expand All @@ -1289,16 +1341,20 @@ class FlatExprVisitor : public cel::AstVisitor {
}

comprehension_stack_.push_back(
{&expr, &comprehension, iter_slot, accu_slot, slot_count,
{&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count,
/*subexpression=*/-1,
/*.is_optimizable_list_append=*/
IsOptimizableListAppend(&comprehension,
options_.enable_comprehension_list_append),
is_bind,
/*.is_optimizable_map_insert=*/IsOptimizableMapInsert(&comprehension),
/*.is_optimizable_bind=*/is_bind,
/*.iter_var_in_scope=*/false,
/*.iter_var2_in_scope=*/false,
/*.accu_var_in_scope=*/false,
/*.in_accu_init=*/false,
std::make_unique<ComprehensionVisitor>(
this, options_.short_circuiting, is_bind, iter_slot, accu_slot)});
std::make_unique<ComprehensionVisitor>(this, options_.short_circuiting,
is_bind, iter_slot, iter2_slot,
accu_slot)});
comprehension_stack_.back().visitor->PreVisit(&expr);
}

Expand Down Expand Up @@ -1341,30 +1397,35 @@ class FlatExprVisitor : public cel::AstVisitor {
case cel::ITER_RANGE: {
record.in_accu_init = false;
record.iter_var_in_scope = false;
record.iter_var2_in_scope = false;
record.accu_var_in_scope = false;
break;
}
case cel::ACCU_INIT: {
record.in_accu_init = true;
record.iter_var_in_scope = false;
record.iter_var2_in_scope = false;
record.accu_var_in_scope = false;
break;
}
case cel::LOOP_CONDITION: {
record.in_accu_init = false;
record.iter_var_in_scope = true;
record.iter_var2_in_scope = true;
record.accu_var_in_scope = true;
break;
}
case cel::LOOP_STEP: {
record.in_accu_init = false;
record.iter_var_in_scope = true;
record.iter_var2_in_scope = true;
record.accu_var_in_scope = true;
break;
}
case cel::RESULT: {
record.in_accu_init = false;
record.iter_var_in_scope = false;
record.iter_var2_in_scope = false;
record.accu_var_in_scope = true;
break;
}
Expand Down Expand Up @@ -1468,6 +1529,21 @@ class FlatExprVisitor : public cel::AstVisitor {
return;
}

if (!comprehension_stack_.empty()) {
const ComprehensionStackRecord& comprehension =
comprehension_stack_.back();
if (comprehension.is_optimizable_map_insert) {
if (&(comprehension.comprehension->accu_init()) == &expr) {
if (options_.max_recursion_depth != 0) {
SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1);
return;
}
AddStep(CreateMutableMapStep(expr.id()));
return;
}
}
}

auto status_or_resolved_fields =
ResolveCreateStructFields(struct_expr, expr.id());
if (!status_or_resolved_fields.ok()) {
Expand Down Expand Up @@ -1672,13 +1748,16 @@ class FlatExprVisitor : public cel::AstVisitor {
const cel::ast_internal::Expr* expr;
const cel::ast_internal::Comprehension* comprehension;
size_t iter_slot;
size_t iter2_slot;
size_t accu_slot;
size_t slot_count;
// -1 indicates this shouldn't be used.
int subexpression;
bool is_optimizable_list_append;
bool is_optimizable_map_insert;
bool is_optimizable_bind;
bool iter_var_in_scope;
bool iter_var2_in_scope;
bool accu_var_in_scope;
bool in_accu_init;
std::unique_ptr<ComprehensionVisitor> visitor;
Expand Down Expand Up @@ -2011,20 +2090,24 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault(
case cel::ITER_RANGE: {
// post process iter_range to list its keys if it's a map
// and initialize the loop index.
visitor_->AddStep(CreateComprehensionInitStep(expr->id()));
if (iter_slot_ == iter2_slot_) {
visitor_->AddStep(CreateComprehensionInitStep(expr->id()));
} else {
visitor_->AddStep(CreateComprehensionInitStep2(expr->id()));
}
break;
}
case cel::ACCU_INIT: {
next_step_pos_ = visitor_->GetCurrentIndex();
next_step_ =
new ComprehensionNextStep(iter_slot_, accu_slot_, expr->id());
next_step_ = new ComprehensionNextStep(iter_slot_, iter2_slot_,
accu_slot_, expr->id());
visitor_->AddStep(std::unique_ptr<ExpressionStep>(next_step_));
break;
}
case cel::LOOP_CONDITION: {
cond_step_pos_ = visitor_->GetCurrentIndex();
cond_step_ = new ComprehensionCondStep(iter_slot_, accu_slot_,
short_circuiting_, expr->id());
cond_step_ = new ComprehensionCondStep(
iter_slot_, iter2_slot_, accu_slot_, short_circuiting_, expr->id());
visitor_->AddStep(std::unique_ptr<ExpressionStep>(cond_step_));
break;
}
Expand All @@ -2049,7 +2132,13 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault(
break;
}
case cel::RESULT: {
visitor_->AddStep(CreateComprehensionFinishStep(accu_slot_, expr->id()));
if (iter_slot_ == iter2_slot_) {
visitor_->AddStep(
CreateComprehensionFinishStep(accu_slot_, expr->id()));
} else {
visitor_->AddStep(
CreateComprehensionFinishStep2(accu_slot_, expr->id()));
}

CEL_ASSIGN_OR_RETURN(
int jump_from_next,
Expand Down Expand Up @@ -2097,8 +2186,8 @@ void ComprehensionVisitor::PostVisit(const cel::ast_internal::Expr* expr) {
accu_slot_);
return;
}
visitor_->MaybeMakeComprehensionRecursive(expr, &expr->comprehension_expr(),
iter_slot_, accu_slot_);
visitor_->MaybeMakeComprehensionRecursive(
expr, &expr->comprehension_expr(), iter_slot_, iter2_slot_, accu_slot_);
}

// Flattens the expression table into the end of the mainline expression vector
Expand Down
Loading

0 comments on commit 6177277

Please sign in to comment.