Skip to content

[NaaS] Add cel list extension function in the NaaS evaluator. #1582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,12 @@ cc_library(
srcs = ["lists_functions.cc"],
hdrs = ["lists_functions.h"],
deps = [
"//checker:type_checker_builder",
"//checker/internal:builtins_arena",
"//common:decl",
"//common:expr",
"//common:operators",
"//common:type",
"//common:value",
"//common:value_kind",
"//internal:status_macros",
Expand All @@ -360,6 +364,7 @@ cc_library(
"//runtime:function_registry",
"//runtime:runtime_options",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
Expand All @@ -377,9 +382,13 @@ cc_test(
srcs = ["lists_functions_test.cc"],
deps = [
":lists_functions",
"//checker:standard_library",
"//checker:validation_result",
"//common:source",
"//common:value",
"//common:value_testing",
"//compiler",
"//compiler:compiler_factory",
"//extensions/protobuf:runtime_adapter",
"//internal:testing",
"//internal:testing_descriptor_pool",
Expand Down
73 changes: 73 additions & 0 deletions extensions/lists_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <vector>

#include "absl/base/macros.h"
#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
Expand All @@ -29,8 +30,12 @@
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "checker/internal/builtins_arena.h"
#include "checker/type_checker_builder.h"
#include "common/decl.h"
#include "common/expr.h"
#include "common/operators.h"
#include "common/type.h"
#include "common/value.h"
#include "common/value_kind.h"
#include "internal/status_macros.h"
Expand All @@ -48,6 +53,8 @@
namespace cel::extensions {
namespace {

using ::cel::checker_internal::BuiltinsArena;

// Slow distinct() implementation that uses Equal() to compare values in O(n^2).
absl::Status ListDistinctHeterogeneousImpl(
const ListValue& list,
Expand Down Expand Up @@ -525,6 +532,68 @@ absl::Status RegisterListSortFunction(FunctionRegistry& registry) {
return absl::OkStatus();
}

const Type& ListIntType() {
static absl::NoDestructor<Type> kInstance(
ListType(BuiltinsArena(), IntType()));
return *kInstance;
}

const Type& ListTypeParamType() {
static absl::NoDestructor<Type> kInstance(
ListType(BuiltinsArena(), TypeParamType("T")));
return *kInstance;
}

absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) {
CEL_ASSIGN_OR_RETURN(
FunctionDecl distinct_decl,
MakeFunctionDecl("distinct", MakeMemberOverloadDecl(
"list_distinct", ListTypeParamType(),
ListTypeParamType())));

CEL_ASSIGN_OR_RETURN(
FunctionDecl flatten_decl,
MakeFunctionDecl(
"flatten",
MakeMemberOverloadDecl("list_flatten_int", ListType(), ListType(),
IntType()),
MakeMemberOverloadDecl("list_flatten", ListType(), ListType())));

CEL_ASSIGN_OR_RETURN(
FunctionDecl range_decl,
MakeFunctionDecl(
"lists.range",
MakeOverloadDecl("list_range", ListIntType(), IntType())));

CEL_ASSIGN_OR_RETURN(
FunctionDecl reverse_decl,
MakeFunctionDecl(
"reverse", MakeMemberOverloadDecl("list_reverse", ListTypeParamType(),
ListTypeParamType())));

CEL_ASSIGN_OR_RETURN(
FunctionDecl slice_decl,
MakeFunctionDecl(
"slice",
MakeMemberOverloadDecl("list_slice", ListTypeParamType(),
ListTypeParamType(), IntType(), IntType())));
// TODO(uncreated-issue/83): Update to specific decls for sortable types.
CEL_ASSIGN_OR_RETURN(
FunctionDecl sort_decl,
MakeFunctionDecl("sort",
MakeMemberOverloadDecl("list_sort", ListTypeParamType(),
ListTypeParamType())));

CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl)));
// MergeFunction is used to combine with the reverse function
// defined in strings extension.
CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl)));
return builder.AddFunction(std::move(sort_decl));
}

} // namespace

absl::Status RegisterListsFunctions(FunctionRegistry& registry,
Expand All @@ -545,4 +614,8 @@ absl::Status RegisterListsMacros(MacroRegistry& registry,
return registry.RegisterMacros(lists_macros());
}

CheckerLibrary ListsCheckerLibrary() {
return {.id = "cel.lib.ext.lists", .configure = RegisterListsCheckerDecls};
}

} // namespace cel::extensions
18 changes: 18 additions & 0 deletions extensions/lists_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define THIRD_PARTY_CEL_CPP_EXTENSIONS_LISTS_FUNCTIONS_H_

#include "absl/status/status.h"
#include "checker/type_checker_builder.h"
#include "parser/macro_registry.h"
#include "parser/options.h"
#include "runtime/function_registry.h"
Expand Down Expand Up @@ -46,6 +47,23 @@ absl::Status RegisterListsFunctions(FunctionRegistry& registry,
absl::Status RegisterListsMacros(MacroRegistry& registry,
const ParserOptions& options);

// Type check declarations for the lists extension library.
// Provides decls for the following functions:
//
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
//
// <list(dyn)>.flatten() -> list(dyn)
// <list(dyn)>.flatten(limit: int) -> list(dyn)
//
// <list(T)>.reverse() -> list(T)
//
// <list(T)>.sort() -> list(T)
//
// <list(T)>.slice(start: int, end: int) -> list(T)
CheckerLibrary ListsCheckerLibrary();

} // namespace cel::extensions

#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_
88 changes: 85 additions & 3 deletions extensions/lists_functions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cel/expr/syntax.pb.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "checker/standard_library.h"
#include "checker/validation_result.h"
#include "common/source.h"
#include "common/value.h"
#include "common/value_testing.h"
#include "compiler/compiler.h"
#include "compiler/compiler_factory.h"
#include "extensions/protobuf/runtime_adapter.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
Expand All @@ -38,17 +43,19 @@
#include "runtime/runtime_options.h"
#include "runtime/standard_runtime_builder_factory.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"

namespace cel::extensions {
namespace {
using ::cel::expr::Expr;
using ::cel::expr::ParsedExpr;
using ::cel::expr::SourceInfo;

using ::absl_testing::IsOk;
using ::absl_testing::StatusIs;
using ::cel::test::ErrorValueIs;
using ::cel::expr::Expr;
using ::cel::expr::ParsedExpr;
using ::cel::expr::SourceInfo;
using ::testing::HasSubstr;
using ::testing::ValuesIn;

struct TestInfo {
std::string expr;
Expand Down Expand Up @@ -273,5 +280,80 @@ TEST(ListsFunctionsTest, ListSortByMacroParseError) {
HasSubstr("sortBy can only be applied to")));
}

struct ListCheckerTestCase {
const std::string expr;
bool is_valid;
};

class ListsCheckerLibraryTest
: public ::testing::TestWithParam<ListCheckerTestCase> {
public:
void SetUp() override {
// Arrange: Configure the compiler.
// Add the lists checker library to the compiler builder.
ASSERT_OK_AND_ASSIGN(std::unique_ptr<CompilerBuilder> compiler_builder,
NewCompilerBuilder(descriptor_pool_));
ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk());
ASSERT_THAT(compiler_builder->AddLibrary(ListsCheckerLibrary()), IsOk());
ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build());
}

const google::protobuf::DescriptorPool* descriptor_pool_ =
internal::GetTestingDescriptorPool();
std::unique_ptr<Compiler> compiler_;
};

TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) {
// Act & Assert: Compile the expression and validate the result.
ASSERT_OK_AND_ASSIGN(ValidationResult result,
compiler_->Compile(GetParam().expr));
EXPECT_EQ(result.IsValid(), GetParam().is_valid);
}

// Returns a vector of test cases for the ListsCheckerLibraryTest.
// Returns both positive and negative test cases for the lists functions.
std::vector<ListCheckerTestCase> createListsCheckerParams() {
return {
// lists.distinct()
{R"([1,2,3,4,4].distinct() == [1,2,3,4])", true},
{R"('abc'.distinct() == [1,2,3,4])", false},
{R"([1,2,3,4,4].distinct() == 'abc')", false},
{R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", false},
// lists.flatten()
{R"([1,2,3,4].flatten() == [1,2,3,4])", true},
{R"([1,2,3,4].flatten(1) == [1,2,3,4])", true},
{R"('abc'.flatten() == [1,2,3,4])", false},
{R"([1,2,3,4].flatten() == 'abc')", false},
{R"('abc'.flatten(1) == [1,2,3,4])", false},
{R"([1,2,3,4].flatten('abc') == [1,2,3,4])", false},
{R"([1,2,3,4].flatten(1) == 'abc')", false},
// lists.range()
{R"(lists.range(4) == [0,1,2,3])", true},
{R"(lists.range('abc') == [])", false},
{R"(lists.range(4) == 'abc')", false},
{R"(lists.range(4, 4) == [0,1,2,3])", false},
// lists.reverse()
{R"([1,2,3,4].reverse() == [4,3,2,1])", true},
{R"('abc'.reverse() == [])", false},
{R"([1,2,3,4].reverse() == 'abc')", false},
{R"([1,2,3,4].reverse(1) == [4,3,2,1])", false},
// lists.slice()
{R"([1,2,3,4].slice(0, 4) == [1,2,3,4])", true},
{R"('abc'.slice(0, 4) == [1,2,3,4])", false},
{R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", false},
{R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", false},
{R"([1,2,3,4].slice(0, 4) == 'abc')", false},
{R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", false},
// lists.sort()
{R"([1,2,3,4].sort() == [1,2,3,4])", true},
{R"('abc'.sort() == [])", false},
{R"([1,2,3,4].sort() == 'abc')", false},
{R"([1,2,3,4].sort(2) == [1,2,3,4])", false},
};
}

INSTANTIATE_TEST_SUITE_P(ListsCheckerLibraryTest, ListsCheckerLibraryTest,
ValuesIn(createListsCheckerParams()));

} // namespace
} // namespace cel::extensions
4 changes: 3 additions & 1 deletion extensions/strings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder) {
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(upper_ascii_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(format_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(quote_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(reverse_decl)));
// MergeFunction is used to combine with the reverse function
// defined in cel.lib.ext.lists extension.
CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl)));

return absl::OkStatus();
}
Expand Down