diff --git a/extensions/BUILD b/extensions/BUILD index 5439d1a2d..e4502033e 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -347,8 +347,12 @@ cc_library( srcs = ["lists_functions.cc"], hdrs = ["lists_functions.h"], deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", "//common:expr", "//common:operators", + "//common:type", "//common:value", "//common:value_kind", "//internal:status_macros", @@ -360,6 +364,7 @@ cc_library( "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -377,9 +382,13 @@ cc_test( srcs = ["lists_functions_test.cc"], deps = [ ":lists_functions", + "//checker:standard_library", + "//checker:validation_result", "//common:source", "//common:value", "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc index 04fe553ec..0c5f64850 100644 --- a/extensions/lists_functions.cc +++ b/extensions/lists_functions.cc @@ -21,6 +21,7 @@ #include #include "absl/base/macros.h" +#include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -29,8 +30,12 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" #include "common/expr.h" #include "common/operators.h" +#include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "internal/status_macros.h" @@ -48,6 +53,8 @@ namespace cel::extensions { namespace { +using ::cel::checker_internal::BuiltinsArena; + // Slow distinct() implementation that uses Equal() to compare values in O(n^2). absl::Status ListDistinctHeterogeneousImpl( const ListValue& list, @@ -525,6 +532,68 @@ absl::Status RegisterListSortFunction(FunctionRegistry& registry) { return absl::OkStatus(); } +const Type& ListIntType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), IntType())); + return *kInstance; +} + +const Type& ListTypeParamType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), TypeParamType("T"))); + return *kInstance; +} + +absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + FunctionDecl distinct_decl, + MakeFunctionDecl("distinct", MakeMemberOverloadDecl( + "list_distinct", ListTypeParamType(), + ListTypeParamType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl flatten_decl, + MakeFunctionDecl( + "flatten", + MakeMemberOverloadDecl("list_flatten_int", ListType(), ListType(), + IntType()), + MakeMemberOverloadDecl("list_flatten", ListType(), ListType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl range_decl, + MakeFunctionDecl( + "lists.range", + MakeOverloadDecl("list_range", ListIntType(), IntType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl reverse_decl, + MakeFunctionDecl( + "reverse", MakeMemberOverloadDecl("list_reverse", ListTypeParamType(), + ListTypeParamType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl slice_decl, + MakeFunctionDecl( + "slice", + MakeMemberOverloadDecl("list_slice", ListTypeParamType(), + ListTypeParamType(), IntType(), IntType()))); + // TODO(uncreated-issue/83): Update to specific decls for sortable types. + CEL_ASSIGN_OR_RETURN( + FunctionDecl sort_decl, + MakeFunctionDecl("sort", + MakeMemberOverloadDecl("list_sort", ListTypeParamType(), + ListTypeParamType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl))); + // MergeFunction is used to combine with the reverse function + // defined in strings extension. + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl))); + return builder.AddFunction(std::move(sort_decl)); +} + } // namespace absl::Status RegisterListsFunctions(FunctionRegistry& registry, @@ -545,4 +614,8 @@ absl::Status RegisterListsMacros(MacroRegistry& registry, return registry.RegisterMacros(lists_macros()); } +CheckerLibrary ListsCheckerLibrary() { + return {.id = "cel.lib.ext.lists", .configure = RegisterListsCheckerDecls}; +} + } // namespace cel::extensions diff --git a/extensions/lists_functions.h b/extensions/lists_functions.h index d10f63a42..979360762 100644 --- a/extensions/lists_functions.h +++ b/extensions/lists_functions.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_LISTS_FUNCTIONS_H_ #include "absl/status/status.h" +#include "checker/type_checker_builder.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "runtime/function_registry.h" @@ -46,6 +47,23 @@ absl::Status RegisterListsFunctions(FunctionRegistry& registry, absl::Status RegisterListsMacros(MacroRegistry& registry, const ParserOptions& options); +// Type check declarations for the lists extension library. +// Provides decls for the following functions: +// +// lists.range(n: int) -> list(int) +// +// .distinct() -> list(T) +// +// .flatten() -> list(dyn) +// .flatten(limit: int) -> list(dyn) +// +// .reverse() -> list(T) +// +// .sort() -> list(T) +// +// .slice(start: int, end: int) -> list(T) +CheckerLibrary ListsCheckerLibrary(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ diff --git a/extensions/lists_functions_test.cc b/extensions/lists_functions_test.cc index 00cb11a63..7255e9071 100644 --- a/extensions/lists_functions_test.cc +++ b/extensions/lists_functions_test.cc @@ -17,13 +17,18 @@ #include #include #include +#include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" #include "common/source.h" #include "common/value.h" #include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -38,17 +43,19 @@ #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::extensions { namespace { -using ::cel::expr::Expr; -using ::cel::expr::ParsedExpr; -using ::cel::expr::SourceInfo; using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::testing::HasSubstr; +using ::testing::ValuesIn; struct TestInfo { std::string expr; @@ -273,5 +280,80 @@ TEST(ListsFunctionsTest, ListSortByMacroParseError) { HasSubstr("sortBy can only be applied to"))); } +struct ListCheckerTestCase { + const std::string expr; + bool is_valid; +}; + +class ListsCheckerLibraryTest + : public ::testing::TestWithParam { + public: + void SetUp() override { + // Arrange: Configure the compiler. + // Add the lists checker library to the compiler builder. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, + NewCompilerBuilder(descriptor_pool_)); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(ListsCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); + } + + const google::protobuf::DescriptorPool* descriptor_pool_ = + internal::GetTestingDescriptorPool(); + std::unique_ptr compiler_; +}; + +TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) { + // Act & Assert: Compile the expression and validate the result. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler_->Compile(GetParam().expr)); + EXPECT_EQ(result.IsValid(), GetParam().is_valid); +} + +// Returns a vector of test cases for the ListsCheckerLibraryTest. +// Returns both positive and negative test cases for the lists functions. +std::vector createListsCheckerParams() { + return { + // lists.distinct() + {R"([1,2,3,4,4].distinct() == [1,2,3,4])", true}, + {R"('abc'.distinct() == [1,2,3,4])", false}, + {R"([1,2,3,4,4].distinct() == 'abc')", false}, + {R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", false}, + // lists.flatten() + {R"([1,2,3,4].flatten() == [1,2,3,4])", true}, + {R"([1,2,3,4].flatten(1) == [1,2,3,4])", true}, + {R"('abc'.flatten() == [1,2,3,4])", false}, + {R"([1,2,3,4].flatten() == 'abc')", false}, + {R"('abc'.flatten(1) == [1,2,3,4])", false}, + {R"([1,2,3,4].flatten('abc') == [1,2,3,4])", false}, + {R"([1,2,3,4].flatten(1) == 'abc')", false}, + // lists.range() + {R"(lists.range(4) == [0,1,2,3])", true}, + {R"(lists.range('abc') == [])", false}, + {R"(lists.range(4) == 'abc')", false}, + {R"(lists.range(4, 4) == [0,1,2,3])", false}, + // lists.reverse() + {R"([1,2,3,4].reverse() == [4,3,2,1])", true}, + {R"('abc'.reverse() == [])", false}, + {R"([1,2,3,4].reverse() == 'abc')", false}, + {R"([1,2,3,4].reverse(1) == [4,3,2,1])", false}, + // lists.slice() + {R"([1,2,3,4].slice(0, 4) == [1,2,3,4])", true}, + {R"('abc'.slice(0, 4) == [1,2,3,4])", false}, + {R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", false}, + {R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", false}, + {R"([1,2,3,4].slice(0, 4) == 'abc')", false}, + {R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", false}, + // lists.sort() + {R"([1,2,3,4].sort() == [1,2,3,4])", true}, + {R"('abc'.sort() == [])", false}, + {R"([1,2,3,4].sort() == 'abc')", false}, + {R"([1,2,3,4].sort(2) == [1,2,3,4])", false}, + }; +} + +INSTANTIATE_TEST_SUITE_P(ListsCheckerLibraryTest, ListsCheckerLibraryTest, + ValuesIn(createListsCheckerParams())); + } // namespace } // namespace cel::extensions diff --git a/extensions/strings.cc b/extensions/strings.cc index a6792cbaa..c30985080 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -398,7 +398,9 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder) { CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(upper_ascii_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(format_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(quote_decl))); - CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(reverse_decl))); + // MergeFunction is used to combine with the reverse function + // defined in cel.lib.ext.lists extension. + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); return absl::OkStatus(); }