Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: add support for cloning cel::Expr. #1237

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/functional:overload",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
Expand Down
126 changes: 126 additions & 0 deletions common/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,124 @@

#include "common/expr.h"

#include <vector>

#include "absl/base/no_destructor.h"
#include "absl/functional/overload.h"
#include "absl/types/variant.h"
#include "common/constant.h"

namespace cel {

namespace {

struct CopyStackRecord {
const Expr* src;
Expr* dst;
};

void CopyNode(CopyStackRecord element, std::vector<CopyStackRecord>& stack) {
const Expr* src = element.src;
Expr* dst = element.dst;
dst->set_id(src->id());
absl::visit(
absl::Overload(
[](const UnspecifiedExpr&) {},
[=](const IdentExpr& i) {
dst->mutable_ident_expr().set_name(i.name());
},
[=](const Constant& c) { dst->mutable_const_expr() = c; },
[&](const SelectExpr& s) {
dst->mutable_select_expr().set_field(s.field());
dst->mutable_select_expr().set_test_only(s.test_only());

if (s.has_operand()) {
stack.push_back({&s.operand(),
&dst->mutable_select_expr().mutable_operand()});
}
},
[&](const CallExpr& c) {
dst->mutable_call_expr().set_function(c.function());
if (c.has_target()) {
stack.push_back(
{&c.target(), &dst->mutable_call_expr().mutable_target()});
}
dst->mutable_call_expr().mutable_args().resize(c.args().size());
for (int i = 0; i < dst->mutable_call_expr().mutable_args().size();
++i) {
stack.push_back(
{&c.args()[i], &dst->mutable_call_expr().mutable_args()[i]});
}
},
[&](const ListExpr& c) {
auto& dst_list = dst->mutable_list_expr();
dst_list.mutable_elements().resize(c.elements().size());
for (int i = 0; i < src->list_expr().elements().size(); ++i) {
dst_list.mutable_elements()[i].set_optional(
c.elements()[i].optional());
stack.push_back({&c.elements()[i].expr(),
&dst_list.mutable_elements()[i].mutable_expr()});
}
},
[&](const StructExpr& s) {
auto& dst_struct = dst->mutable_struct_expr();
dst_struct.mutable_fields().resize(s.fields().size());
dst_struct.set_name(s.name());
for (int i = 0; i < s.fields().size(); ++i) {
dst_struct.mutable_fields()[i].set_optional(
s.fields()[i].optional());
dst_struct.mutable_fields()[i].set_name(s.fields()[i].name());
dst_struct.mutable_fields()[i].set_id(s.fields()[i].id());
stack.push_back(
{&s.fields()[i].value(),
&dst_struct.mutable_fields()[i].mutable_value()});
}
},
[&](const MapExpr& c) {
auto& dst_map = dst->mutable_map_expr();
dst_map.mutable_entries().resize(c.entries().size());
for (int i = 0; i < c.entries().size(); ++i) {
dst_map.mutable_entries()[i].set_optional(
c.entries()[i].optional());
dst_map.mutable_entries()[i].set_id(c.entries()[i].id());
stack.push_back({&c.entries()[i].key(),
&dst_map.mutable_entries()[i].mutable_key()});
stack.push_back({&c.entries()[i].value(),
&dst_map.mutable_entries()[i].mutable_value()});
}
},
[&](const ComprehensionExpr& c) {
auto& dst_comprehension = dst->mutable_comprehension_expr();
dst_comprehension.set_iter_var(c.iter_var());
dst_comprehension.set_iter_var2(c.iter_var2());
dst_comprehension.set_accu_var(c.accu_var());
if (c.has_accu_init()) {
stack.push_back(
{&c.accu_init(), &dst_comprehension.mutable_accu_init()});
}
if (c.has_iter_range()) {
stack.push_back(
{&c.iter_range(), &dst_comprehension.mutable_iter_range()});
}
if (c.has_loop_condition()) {
stack.push_back({&c.loop_condition(),
&dst_comprehension.mutable_loop_condition()});
}
if (c.has_loop_step()) {
stack.push_back(
{&c.loop_step(), &dst_comprehension.mutable_loop_step()});
}
if (c.has_result()) {
stack.push_back(
{&c.result(), &dst_comprehension.mutable_result()});
}
}

),
src->kind());
}
} // namespace

const UnspecifiedExpr& UnspecifiedExpr::default_instance() {
static const absl::NoDestructor<UnspecifiedExpr> instance;
return *instance;
Expand Down Expand Up @@ -63,4 +177,16 @@ const Expr& Expr::default_instance() {
return *instance;
}

Expr CloneExpr(const Expr& expr) {
Expr result;
std::vector<CopyStackRecord> stack;
stack.push_back({&expr, &result});
while (!stack.empty()) {
CopyStackRecord element = stack.back();
stack.pop_back();
CopyNode(element, stack);
}
return result;
}

} // namespace cel
3 changes: 3 additions & 0 deletions common/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class ComprehensionExpr;

inline constexpr absl::string_view kAccumulatorVariableName = "__result__";

// Returns a deep copy of the given expression node.
Expr CloneExpr(const Expr& expr);

bool operator==(const Expr& lhs, const Expr& rhs);

inline bool operator!=(const Expr& lhs, const Expr& rhs) {
Expand Down
1 change: 1 addition & 0 deletions extensions/protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cc_test(
":ast_converters",
"//base/ast_internal:ast_impl",
"//base/ast_internal:expr",
"//common:expr",
"//internal:proto_matchers",
"//internal:testing",
"//parser",
Expand Down
18 changes: 18 additions & 0 deletions extensions/protobuf/ast_converters_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "absl/types/variant.h"
#include "base/ast_internal/ast_impl.h"
#include "base/ast_internal/expr.h"
#include "common/expr.h"
#include "internal/proto_matchers.h"
#include "internal/testing.h"
#include "parser/options.h"
Expand Down Expand Up @@ -801,6 +802,23 @@ TEST_P(ConversionRoundTripTest, ParsedExprCopyable) {
IsOkAndHolds(EqualsProto(parsed_expr)));
}

TEST_P(ConversionRoundTripTest, ExprClonable) {
ASSERT_OK_AND_ASSIGN(ParsedExprPb parsed_expr,
Parse(GetParam().expr, "<input>", options_));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
CreateAstFromParsedExpr(parsed_expr));

auto& impl = ast_internal::AstImpl::CastFromPublicAst(*ast);
impl.root_expr() = CloneExpr(impl.root_expr());

EXPECT_THAT(CreateCheckedExprFromAst(impl),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("AST is not type-checked")));
EXPECT_THAT(CreateParsedExprFromAst(impl),
IsOkAndHolds(EqualsProto(parsed_expr)));
}

TEST_P(ConversionRoundTripTest, CheckedExprCopyable) {
ASSERT_OK_AND_ASSIGN(ParsedExprPb parsed_expr,
Parse(GetParam().expr, "<input>", options_));
Expand Down