Skip to content

Add CompilerLibraries for lists and strings extensions. #1579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,13 @@ cc_library(
"//common:type",
"//common:value",
"//common:value_kind",
"//compiler",
"//internal:status_macros",
"//parser:macro",
"//parser:macro_expr_factory",
"//parser:macro_registry",
"//parser:options",
"//parser:parser_interface",
"//runtime:function_adapter",
"//runtime:function_registry",
"//runtime:runtime_options",
Expand All @@ -369,6 +371,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:optional",
Expand All @@ -382,13 +385,13 @@ cc_test(
srcs = ["lists_functions_test.cc"],
deps = [
":lists_functions",
"//checker:standard_library",
"//checker:validation_result",
"//common:source",
"//common:value",
"//common:value_testing",
"//compiler",
"//compiler:compiler_factory",
"//compiler:standard_library",
"//extensions/protobuf:runtime_adapter",
"//internal:testing",
"//internal:testing_descriptor_pool",
Expand All @@ -404,6 +407,7 @@ cc_test(
"//runtime:standard_runtime_builder_factory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/strings:string_view",
"@com_google_cel_spec//proto/cel/expr:syntax_cc_proto",
"@com_google_protobuf//:protobuf",
],
Expand Down Expand Up @@ -494,6 +498,7 @@ cc_library(
"//common:decl",
"//common:type",
"//common:value",
"//compiler",
"//eval/public:cel_function_registry",
"//eval/public:cel_options",
"//internal:status_macros",
Expand Down Expand Up @@ -524,6 +529,7 @@ cc_test(
"//common:decl",
"//common:value",
"//compiler:compiler_factory",
"//compiler:standard_library",
"//extensions/protobuf:runtime_adapter",
"//internal:testing",
"//internal:testing_descriptor_pool",
Expand Down
64 changes: 55 additions & 9 deletions extensions/lists_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstddef>
#include <cstdint>
#include <numeric>
#include <string>
#include <utility>
#include <vector>

Expand All @@ -26,6 +27,7 @@
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
Expand All @@ -38,11 +40,13 @@
#include "common/type.h"
#include "common/value.h"
#include "common/value_kind.h"
#include "compiler/compiler.h"
#include "internal/status_macros.h"
#include "parser/macro.h"
#include "parser/macro_expr_factory.h"
#include "parser/macro_registry.h"
#include "parser/options.h"
#include "parser/parser_interface.h"
#include "runtime/function_adapter.h"
#include "runtime/function_registry.h"
#include "runtime/runtime_options.h"
Expand All @@ -55,6 +59,15 @@ namespace {

using ::cel::checker_internal::BuiltinsArena;

absl::Span<const cel::Type> SortableTypes() {
static const Type kTypes[]{cel::IntType(), cel::UintType(),
cel::DoubleType(), cel::BoolType(),
cel::DurationType(), cel::TimestampType(),
cel::StringType(), cel::BytesType()};

return kTypes;
}

// Slow distinct() implementation that uses Equal() to compare values in O(n^2).
absl::Status ListDistinctHeterogeneousImpl(
const ListValue& list,
Expand Down Expand Up @@ -577,21 +590,50 @@ absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) {
"slice",
MakeMemberOverloadDecl("list_slice", ListTypeParamType(),
ListTypeParamType(), IntType(), IntType())));
// TODO(uncreated-issue/83): Update to specific decls for sortable types.
CEL_ASSIGN_OR_RETURN(
FunctionDecl sort_decl,
MakeFunctionDecl("sort",
MakeMemberOverloadDecl("list_sort", ListTypeParamType(),
ListTypeParamType())));

static const absl::NoDestructor<std::vector<Type>> kSortableListTypes([] {
std::vector<Type> instance;
instance.reserve(SortableTypes().size());
for (const Type& type : SortableTypes()) {
instance.push_back(ListType(BuiltinsArena(), type));
}
return instance;
}());

FunctionDecl sort_decl;
sort_decl.set_name("sort");
FunctionDecl sort_by_key_decl;
sort_by_key_decl.set_name("@sortByAssociatedKeys");

for (const Type& list_type : *kSortableListTypes) {
std::string elem_type_name(list_type.AsList()->GetElement().name());

CEL_RETURN_IF_ERROR(sort_decl.AddOverload(MakeMemberOverloadDecl(
absl::StrCat("list_", elem_type_name, "_sort"), list_type, list_type)));
CEL_RETURN_IF_ERROR(sort_by_key_decl.AddOverload(MakeMemberOverloadDecl(
absl::StrCat("list_", elem_type_name, "_sortByAssociatedKeys"),
ListTypeParamType(), ListTypeParamType(), list_type)));
}

CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_by_key_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl)));
// MergeFunction is used to combine with the reverse function
// defined in strings extension.
CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl)));
return builder.AddFunction(std::move(sort_decl));
return absl::OkStatus();
}

std::vector<Macro> lists_macros() { return {ListSortByMacro()}; }

absl::Status ConfigureParser(ParserBuilder& builder) {
for (const Macro& macro : lists_macros()) {
CEL_RETURN_IF_ERROR(builder.AddMacro(macro));
}
return absl::OkStatus();
}

} // namespace
Expand All @@ -607,8 +649,6 @@ absl::Status RegisterListsFunctions(FunctionRegistry& registry,
return absl::OkStatus();
}

std::vector<Macro> lists_macros() { return {ListSortByMacro()}; }

absl::Status RegisterListsMacros(MacroRegistry& registry,
const ParserOptions&) {
return registry.RegisterMacros(lists_macros());
Expand All @@ -618,4 +658,10 @@ CheckerLibrary ListsCheckerLibrary() {
return {.id = "cel.lib.ext.lists", .configure = RegisterListsCheckerDecls};
}

CompilerLibrary ListsCompilerLibrary() {
auto lib = CompilerLibrary::FromCheckerLibrary(ListsCheckerLibrary());
lib.configure_parser = ConfigureParser;
return lib;
}

} // namespace cel::extensions
23 changes: 22 additions & 1 deletion extensions/lists_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "absl/status/status.h"
#include "checker/type_checker_builder.h"
#include "compiler/compiler.h"
#include "parser/macro_registry.h"
#include "parser/options.h"
#include "runtime/function_registry.h"
Expand Down Expand Up @@ -59,11 +60,31 @@ absl::Status RegisterListsMacros(MacroRegistry& registry,
//
// <list(T)>.reverse() -> list(T)
//
// <list(T)>.sort() -> list(T)
// <list(T_)>.sort() -> list(T_) where T_ is partially orderable
//
// <list(T)>.slice(start: int, end: int) -> list(T)
CheckerLibrary ListsCheckerLibrary();

// Provides decls for the following functions:
//
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
//
// <list(dyn)>.flatten() -> list(dyn)
// <list(dyn)>.flatten(limit: int) -> list(dyn)
//
// <list(T)>.reverse() -> list(T)
//
// <list(T_)>.sort() -> list(T_) where T_ is partially orderable
//
// <list(T)>.slice(start: int, end: int) -> list(T)
//
// and the following macros:
//
// <list(T)>.sortBy(<element name>, <element key expression>)
CompilerLibrary ListsCompilerLibrary();

} // namespace cel::extensions

#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_
102 changes: 62 additions & 40 deletions extensions/lists_functions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
#include "cel/expr/syntax.pb.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "checker/standard_library.h"
#include "absl/strings/string_view.h"
#include "checker/validation_result.h"
#include "common/source.h"
#include "common/value.h"
#include "common/value_testing.h"
#include "compiler/compiler.h"
#include "compiler/compiler_factory.h"
#include "compiler/standard_library.h"
#include "extensions/protobuf/runtime_adapter.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
Expand All @@ -43,7 +44,6 @@
#include "runtime/runtime_options.h"
#include "runtime/standard_runtime_builder_factory.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"

namespace cel::extensions {
namespace {
Expand Down Expand Up @@ -281,8 +281,8 @@ TEST(ListsFunctionsTest, ListSortByMacroParseError) {
}

struct ListCheckerTestCase {
const std::string expr;
bool is_valid;
std::string expr;
std::string error_substr;
};

class ListsCheckerLibraryTest
Expand All @@ -291,64 +291,86 @@ class ListsCheckerLibraryTest
void SetUp() override {
// Arrange: Configure the compiler.
// Add the lists checker library to the compiler builder.
ASSERT_OK_AND_ASSIGN(std::unique_ptr<CompilerBuilder> compiler_builder,
NewCompilerBuilder(descriptor_pool_));
ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk());
ASSERT_THAT(compiler_builder->AddLibrary(ListsCheckerLibrary()), IsOk());
ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CompilerBuilder> compiler_builder,
NewCompilerBuilder(internal::GetTestingDescriptorPool()));
ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()),
IsOk());
ASSERT_THAT(compiler_builder->AddLibrary(ListsCompilerLibrary()), IsOk());
compiler_builder->GetCheckerBuilder().set_container(
"cel.expr.conformance.proto3");
ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build());
}

const google::protobuf::DescriptorPool* descriptor_pool_ =
internal::GetTestingDescriptorPool();
std::unique_ptr<Compiler> compiler_;
};

TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) {
// Act & Assert: Compile the expression and validate the result.
ASSERT_OK_AND_ASSIGN(ValidationResult result,
compiler_->Compile(GetParam().expr));
EXPECT_EQ(result.IsValid(), GetParam().is_valid);
absl::string_view error_substr = GetParam().error_substr;
EXPECT_EQ(result.IsValid(), error_substr.empty());

if (!error_substr.empty()) {
EXPECT_THAT(result.FormatError(), HasSubstr(error_substr));
}
}

// Returns a vector of test cases for the ListsCheckerLibraryTest.
// Returns both positive and negative test cases for the lists functions.
std::vector<ListCheckerTestCase> createListsCheckerParams() {
return {
// lists.distinct()
{R"([1,2,3,4,4].distinct() == [1,2,3,4])", true},
{R"('abc'.distinct() == [1,2,3,4])", false},
{R"([1,2,3,4,4].distinct() == 'abc')", false},
{R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", false},
{R"([1,2,3,4,4].distinct() == [1,2,3,4])"},
{R"('abc'.distinct() == [1,2,3,4])",
"no matching overload for 'distinct'"},
{R"([1,2,3,4,4].distinct() == 'abc')", "no matching overload for '_==_'"},
{R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", "undeclared reference"},
// lists.flatten()
{R"([1,2,3,4].flatten() == [1,2,3,4])", true},
{R"([1,2,3,4].flatten(1) == [1,2,3,4])", true},
{R"('abc'.flatten() == [1,2,3,4])", false},
{R"([1,2,3,4].flatten() == 'abc')", false},
{R"('abc'.flatten(1) == [1,2,3,4])", false},
{R"([1,2,3,4].flatten('abc') == [1,2,3,4])", false},
{R"([1,2,3,4].flatten(1) == 'abc')", false},
{R"([1,2,3,4].flatten() == [1,2,3,4])"},
{R"([1,2,3,4].flatten(1) == [1,2,3,4])"},
{R"('abc'.flatten() == [1,2,3,4])", "no matching overload for 'flatten'"},
{R"([1,2,3,4].flatten() == 'abc')", "no matching overload for '_==_'"},
{R"('abc'.flatten(1) == [1,2,3,4])",
"no matching overload for 'flatten'"},
{R"([1,2,3,4].flatten('abc') == [1,2,3,4])",
"no matching overload for 'flatten'"},
{R"([1,2,3,4].flatten(1) == 'abc')", "no matching overload"},
// lists.range()
{R"(lists.range(4) == [0,1,2,3])", true},
{R"(lists.range('abc') == [])", false},
{R"(lists.range(4) == 'abc')", false},
{R"(lists.range(4, 4) == [0,1,2,3])", false},
{R"(lists.range(4) == [0,1,2,3])"},
{R"(lists.range('abc') == [])", "no matching overload for 'lists.range'"},
{R"(lists.range(4) == 'abc')", "no matching overload for '_==_'"},
{R"(lists.range(4, 4) == [0,1,2,3])", "undeclared reference"},
// lists.reverse()
{R"([1,2,3,4].reverse() == [4,3,2,1])", true},
{R"('abc'.reverse() == [])", false},
{R"([1,2,3,4].reverse() == 'abc')", false},
{R"([1,2,3,4].reverse(1) == [4,3,2,1])", false},
{R"([1,2,3,4].reverse() == [4,3,2,1])"},
{R"('abc'.reverse() == [])", "no matching overload for 'reverse'"},
{R"([1,2,3,4].reverse() == 'abc')", "no matching overload for '_==_'"},
{R"([1,2,3,4].reverse(1) == [4,3,2,1])", "undeclared reference"},
// lists.slice()
{R"([1,2,3,4].slice(0, 4) == [1,2,3,4])", true},
{R"('abc'.slice(0, 4) == [1,2,3,4])", false},
{R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", false},
{R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", false},
{R"([1,2,3,4].slice(0, 4) == 'abc')", false},
{R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", false},
{R"([1,2,3,4].slice(0, 4) == [1,2,3,4])"},
{R"('abc'.slice(0, 4) == [1,2,3,4])", "no matching overload for 'slice'"},
{R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])",
"no matching overload for 'slice'"},
{R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])",
"no matching overload for 'slice'"},
{R"([1,2,3,4].slice(0, 4) == 'abc')", "no matching overload for '_==_'"},
{R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", "undeclared reference"},
// lists.sort()
{R"([1,2,3,4].sort() == [1,2,3,4])", true},
{R"('abc'.sort() == [])", false},
{R"([1,2,3,4].sort() == 'abc')", false},
{R"([1,2,3,4].sort(2) == [1,2,3,4])", false},
{R"([1,2,3,4].sort() == [1,2,3,4])"},
{R"([TestAllTypes{}, TestAllTypes{}].sort() == [])",
"no matching overload for 'sort'"},
{R"('abc'.sort() == [])", "no matching overload for 'sort'"},
{R"([1,2,3,4].sort() == 'abc')", "no matching overload for '_==_'"},
{R"([1,2,3,4].sort(2) == [1,2,3,4])", "undeclared reference"},
// sortBy macro
{R"([1,2,3,4].sortBy(x, -x) == [4,3,2,1])"},
{R"([TestAllTypes{}, TestAllTypes{}].sortBy(x, x) == [])",
"no matching overload for '@sortByAssociatedKeys'"},
{R"(
[TestAllTypes{single_int64: 2}, TestAllTypes{single_int64: 1}]
.sortBy(x, x.single_int64) ==
[TestAllTypes{single_int64: 1}, TestAllTypes{single_int64: 2}])"},
};
}

Expand Down
Loading