diff --git a/extensions/BUILD b/extensions/BUILD index e4502033e..16f2c0be4 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -355,11 +355,13 @@ cc_library( "//common:type", "//common:value", "//common:value_kind", + "//compiler", "//internal:status_macros", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", + "//parser:parser_interface", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", @@ -369,6 +371,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", @@ -382,13 +385,13 @@ cc_test( srcs = ["lists_functions_test.cc"], deps = [ ":lists_functions", - "//checker:standard_library", "//checker:validation_result", "//common:source", "//common:value", "//common:value_testing", "//compiler", "//compiler:compiler_factory", + "//compiler:standard_library", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", @@ -404,6 +407,7 @@ cc_test( "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -494,6 +498,7 @@ cc_library( "//common:decl", "//common:type", "//common:value", + "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", @@ -524,6 +529,7 @@ cc_test( "//common:decl", "//common:value", "//compiler:compiler_factory", + "//compiler:standard_library", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc index 0c5f64850..0d1b6e317 100644 --- a/extensions/lists_functions.cc +++ b/extensions/lists_functions.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -26,6 +27,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -38,11 +40,13 @@ #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" +#include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" +#include "parser/parser_interface.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" @@ -55,6 +59,15 @@ namespace { using ::cel::checker_internal::BuiltinsArena; +absl::Span SortableTypes() { + static const Type kTypes[]{cel::IntType(), cel::UintType(), + cel::DoubleType(), cel::BoolType(), + cel::DurationType(), cel::TimestampType(), + cel::StringType(), cel::BytesType()}; + + return kTypes; +} + // Slow distinct() implementation that uses Equal() to compare values in O(n^2). absl::Status ListDistinctHeterogeneousImpl( const ListValue& list, @@ -577,13 +590,33 @@ absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) { "slice", MakeMemberOverloadDecl("list_slice", ListTypeParamType(), ListTypeParamType(), IntType(), IntType()))); - // TODO(uncreated-issue/83): Update to specific decls for sortable types. - CEL_ASSIGN_OR_RETURN( - FunctionDecl sort_decl, - MakeFunctionDecl("sort", - MakeMemberOverloadDecl("list_sort", ListTypeParamType(), - ListTypeParamType()))); + static const absl::NoDestructor> kSortableListTypes([] { + std::vector instance; + instance.reserve(SortableTypes().size()); + for (const Type& type : SortableTypes()) { + instance.push_back(ListType(BuiltinsArena(), type)); + } + return instance; + }()); + + FunctionDecl sort_decl; + sort_decl.set_name("sort"); + FunctionDecl sort_by_key_decl; + sort_by_key_decl.set_name("@sortByAssociatedKeys"); + + for (const Type& list_type : *kSortableListTypes) { + std::string elem_type_name(list_type.AsList()->GetElement().name()); + + CEL_RETURN_IF_ERROR(sort_decl.AddOverload(MakeMemberOverloadDecl( + absl::StrCat("list_", elem_type_name, "_sort"), list_type, list_type))); + CEL_RETURN_IF_ERROR(sort_by_key_decl.AddOverload(MakeMemberOverloadDecl( + absl::StrCat("list_", elem_type_name, "_sortByAssociatedKeys"), + ListTypeParamType(), ListTypeParamType(), list_type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_by_key_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl))); @@ -591,7 +624,16 @@ absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) { // defined in strings extension. CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl))); - return builder.AddFunction(std::move(sort_decl)); + return absl::OkStatus(); +} + +std::vector lists_macros() { return {ListSortByMacro()}; } + +absl::Status ConfigureParser(ParserBuilder& builder) { + for (const Macro& macro : lists_macros()) { + CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); + } + return absl::OkStatus(); } } // namespace @@ -607,8 +649,6 @@ absl::Status RegisterListsFunctions(FunctionRegistry& registry, return absl::OkStatus(); } -std::vector lists_macros() { return {ListSortByMacro()}; } - absl::Status RegisterListsMacros(MacroRegistry& registry, const ParserOptions&) { return registry.RegisterMacros(lists_macros()); @@ -618,4 +658,10 @@ CheckerLibrary ListsCheckerLibrary() { return {.id = "cel.lib.ext.lists", .configure = RegisterListsCheckerDecls}; } +CompilerLibrary ListsCompilerLibrary() { + auto lib = CompilerLibrary::FromCheckerLibrary(ListsCheckerLibrary()); + lib.configure_parser = ConfigureParser; + return lib; +} + } // namespace cel::extensions diff --git a/extensions/lists_functions.h b/extensions/lists_functions.h index 979360762..a2931e438 100644 --- a/extensions/lists_functions.h +++ b/extensions/lists_functions.h @@ -17,6 +17,7 @@ #include "absl/status/status.h" #include "checker/type_checker_builder.h" +#include "compiler/compiler.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "runtime/function_registry.h" @@ -59,11 +60,31 @@ absl::Status RegisterListsMacros(MacroRegistry& registry, // // .reverse() -> list(T) // -// .sort() -> list(T) +// .sort() -> list(T_) where T_ is partially orderable // // .slice(start: int, end: int) -> list(T) CheckerLibrary ListsCheckerLibrary(); +// Provides decls for the following functions: +// +// lists.range(n: int) -> list(int) +// +// .distinct() -> list(T) +// +// .flatten() -> list(dyn) +// .flatten(limit: int) -> list(dyn) +// +// .reverse() -> list(T) +// +// .sort() -> list(T_) where T_ is partially orderable +// +// .slice(start: int, end: int) -> list(T) +// +// and the following macros: +// +// .sortBy(, ) +CompilerLibrary ListsCompilerLibrary(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ diff --git a/extensions/lists_functions_test.cc b/extensions/lists_functions_test.cc index 7255e9071..cd8a930e4 100644 --- a/extensions/lists_functions_test.cc +++ b/extensions/lists_functions_test.cc @@ -22,13 +22,14 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "checker/standard_library.h" +#include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "common/source.h" #include "common/value.h" #include "common/value_testing.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -43,7 +44,6 @@ #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" namespace cel::extensions { namespace { @@ -281,8 +281,8 @@ TEST(ListsFunctionsTest, ListSortByMacroParseError) { } struct ListCheckerTestCase { - const std::string expr; - bool is_valid; + std::string expr; + std::string error_substr; }; class ListsCheckerLibraryTest @@ -291,15 +291,17 @@ class ListsCheckerLibraryTest void SetUp() override { // Arrange: Configure the compiler. // Add the lists checker library to the compiler builder. - ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, - NewCompilerBuilder(descriptor_pool_)); - ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); - ASSERT_THAT(compiler_builder->AddLibrary(ListsCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), + IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(ListsCompilerLibrary()), IsOk()); + compiler_builder->GetCheckerBuilder().set_container( + "cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); } - const google::protobuf::DescriptorPool* descriptor_pool_ = - internal::GetTestingDescriptorPool(); std::unique_ptr compiler_; }; @@ -307,7 +309,12 @@ TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) { // Act & Assert: Compile the expression and validate the result. ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler_->Compile(GetParam().expr)); - EXPECT_EQ(result.IsValid(), GetParam().is_valid); + absl::string_view error_substr = GetParam().error_substr; + EXPECT_EQ(result.IsValid(), error_substr.empty()); + + if (!error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); + } } // Returns a vector of test cases for the ListsCheckerLibraryTest. @@ -315,40 +322,55 @@ TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) { std::vector createListsCheckerParams() { return { // lists.distinct() - {R"([1,2,3,4,4].distinct() == [1,2,3,4])", true}, - {R"('abc'.distinct() == [1,2,3,4])", false}, - {R"([1,2,3,4,4].distinct() == 'abc')", false}, - {R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", false}, + {R"([1,2,3,4,4].distinct() == [1,2,3,4])"}, + {R"('abc'.distinct() == [1,2,3,4])", + "no matching overload for 'distinct'"}, + {R"([1,2,3,4,4].distinct() == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", "undeclared reference"}, // lists.flatten() - {R"([1,2,3,4].flatten() == [1,2,3,4])", true}, - {R"([1,2,3,4].flatten(1) == [1,2,3,4])", true}, - {R"('abc'.flatten() == [1,2,3,4])", false}, - {R"([1,2,3,4].flatten() == 'abc')", false}, - {R"('abc'.flatten(1) == [1,2,3,4])", false}, - {R"([1,2,3,4].flatten('abc') == [1,2,3,4])", false}, - {R"([1,2,3,4].flatten(1) == 'abc')", false}, + {R"([1,2,3,4].flatten() == [1,2,3,4])"}, + {R"([1,2,3,4].flatten(1) == [1,2,3,4])"}, + {R"('abc'.flatten() == [1,2,3,4])", "no matching overload for 'flatten'"}, + {R"([1,2,3,4].flatten() == 'abc')", "no matching overload for '_==_'"}, + {R"('abc'.flatten(1) == [1,2,3,4])", + "no matching overload for 'flatten'"}, + {R"([1,2,3,4].flatten('abc') == [1,2,3,4])", + "no matching overload for 'flatten'"}, + {R"([1,2,3,4].flatten(1) == 'abc')", "no matching overload"}, // lists.range() - {R"(lists.range(4) == [0,1,2,3])", true}, - {R"(lists.range('abc') == [])", false}, - {R"(lists.range(4) == 'abc')", false}, - {R"(lists.range(4, 4) == [0,1,2,3])", false}, + {R"(lists.range(4) == [0,1,2,3])"}, + {R"(lists.range('abc') == [])", "no matching overload for 'lists.range'"}, + {R"(lists.range(4) == 'abc')", "no matching overload for '_==_'"}, + {R"(lists.range(4, 4) == [0,1,2,3])", "undeclared reference"}, // lists.reverse() - {R"([1,2,3,4].reverse() == [4,3,2,1])", true}, - {R"('abc'.reverse() == [])", false}, - {R"([1,2,3,4].reverse() == 'abc')", false}, - {R"([1,2,3,4].reverse(1) == [4,3,2,1])", false}, + {R"([1,2,3,4].reverse() == [4,3,2,1])"}, + {R"('abc'.reverse() == [])", "no matching overload for 'reverse'"}, + {R"([1,2,3,4].reverse() == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4].reverse(1) == [4,3,2,1])", "undeclared reference"}, // lists.slice() - {R"([1,2,3,4].slice(0, 4) == [1,2,3,4])", true}, - {R"('abc'.slice(0, 4) == [1,2,3,4])", false}, - {R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", false}, - {R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", false}, - {R"([1,2,3,4].slice(0, 4) == 'abc')", false}, - {R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", false}, + {R"([1,2,3,4].slice(0, 4) == [1,2,3,4])"}, + {R"('abc'.slice(0, 4) == [1,2,3,4])", "no matching overload for 'slice'"}, + {R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", + "no matching overload for 'slice'"}, + {R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", + "no matching overload for 'slice'"}, + {R"([1,2,3,4].slice(0, 4) == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", "undeclared reference"}, // lists.sort() - {R"([1,2,3,4].sort() == [1,2,3,4])", true}, - {R"('abc'.sort() == [])", false}, - {R"([1,2,3,4].sort() == 'abc')", false}, - {R"([1,2,3,4].sort(2) == [1,2,3,4])", false}, + {R"([1,2,3,4].sort() == [1,2,3,4])"}, + {R"([TestAllTypes{}, TestAllTypes{}].sort() == [])", + "no matching overload for 'sort'"}, + {R"('abc'.sort() == [])", "no matching overload for 'sort'"}, + {R"([1,2,3,4].sort() == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4].sort(2) == [1,2,3,4])", "undeclared reference"}, + // sortBy macro + {R"([1,2,3,4].sortBy(x, -x) == [4,3,2,1])"}, + {R"([TestAllTypes{}, TestAllTypes{}].sortBy(x, x) == [])", + "no matching overload for '@sortByAssociatedKeys'"}, + {R"( + [TestAllTypes{single_int64: 2}, TestAllTypes{single_int64: 1}] + .sortBy(x, x.single_int64) == + [TestAllTypes{single_int64: 1}, TestAllTypes{single_int64: 2}])"}, }; } diff --git a/extensions/strings.h b/extensions/strings.h index 44f4a997e..c5b7d1d63 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -17,6 +17,7 @@ #include "absl/status/status.h" #include "checker/type_checker_builder.h" +#include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" @@ -34,6 +35,10 @@ absl::Status RegisterStringsFunctions( CheckerLibrary StringsCheckerLibrary(); +inline CompilerLibrary StringsCompilerLibrary() { + return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary()); +} + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index 8a4ddfbb3..e2eb5e71f 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -26,6 +26,7 @@ #include "common/decl.h" #include "common/value.h" #include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -288,8 +289,8 @@ TEST_P(StringsCheckerLibraryTest, TypeChecks) { const std::string& expr = GetParam(); ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); - ASSERT_THAT(builder->AddLibrary(StringsCheckerLibrary()), IsOk()); - ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build());