diff --git a/conformance/service.cc b/conformance/service.cc index a6d90c0b0..a0593971d 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -481,8 +481,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { if (enable_optimizations_) { CEL_RETURN_IF_ERROR(cel::extensions::EnableConstantFolding( - builder, constant_memory_manager_, - google::protobuf::MessageFactory::generated_factory())); + builder, google::protobuf::MessageFactory::generated_factory())); } CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( builder, cel::ReferenceResolverEnabled::kAlways)); @@ -528,7 +527,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Check(const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) override { - auto status = DoCheck(&constant_arena_, request, response); + auto status = DoCheck(&arena_, request, response); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -614,10 +613,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { bool enable_optimizations) : options_(options), use_arena_(use_arena), - enable_optimizations_(enable_optimizations), - constant_memory_manager_( - use_arena_ ? ProtoMemoryManagerRef(&constant_arena_) - : cel::MemoryManagerRef::ReferenceCounting()) {} + enable_optimizations_(enable_optimizations) {} static absl::Status DoCheck( google::protobuf::Arena* arena, const conformance::v1alpha1::CheckRequest& request, @@ -733,8 +729,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { RuntimeOptions options_; bool use_arena_; bool enable_optimizations_; - Arena constant_arena_; - cel::MemoryManagerRef constant_memory_manager_; + Arena arena_; }; } // namespace diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 396cca677..d31e90451 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -36,7 +36,9 @@ cc_library( "//internal:casts", "//runtime:runtime_options", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", @@ -46,6 +48,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -56,7 +59,6 @@ cc_test( ":flat_expr_builder_extensions", ":resolver", "//base/ast_internal:expr", - "//common:casting", "//common:memory", "//common:native_type", "//common:value", @@ -71,8 +73,12 @@ cc_test( "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -115,7 +121,6 @@ cc_library( "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", "//eval/eval:trace_step", - "//eval/public:cel_type_registry", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_issue", @@ -123,13 +128,14 @@ cc_library( "//runtime:type_registry", "//runtime/internal:convert_constant", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -137,6 +143,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -155,6 +162,7 @@ cc_test( ":qualified_reference_resolver", "//base:function", "//base:function_descriptor", + "//common:value", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", @@ -174,13 +182,13 @@ cc_test( "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", - "//extensions/protobuf:memory_manager", "//internal:proto_file_util", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//parser", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -213,6 +221,7 @@ cc_test( "//internal:testing", "//parser", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -236,14 +245,20 @@ cc_library( "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/public:cel_expression", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", "//extensions/protobuf:ast_converters", "//internal:status_macros", "//runtime:runtime_issue", "//runtime:runtime_options", + "//runtime/internal:runtime_env", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], @@ -270,12 +285,12 @@ cc_test( "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", "//extensions:bindings_ext", - "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", "//parser", "//parser:macro", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -303,7 +318,7 @@ cc_library( "//base:kind", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", - "//common:allocator", + "//common:memory", "//common:value", "//eval/eval:const_value_step", "//eval/eval:evaluator_core", @@ -331,8 +346,6 @@ cc_test( "//base:ast", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", - "//common:memory", - "//common:type", "//common:value", "//eval/eval:const_value_step", "//eval/eval:create_list_step", @@ -342,12 +355,16 @@ cc_test( "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", + "//internal:testing_descriptor_pool", "//parser", "//runtime:function_registry", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -390,14 +407,12 @@ cc_library( hdrs = ["resolver.h"], deps = [ "//base:kind", - "//common:memory", "//common:type", "//common:value", "//internal:status_macros", "//runtime:function_overload_reference", "//runtime:function_registry", "//runtime:type_registry", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -445,17 +460,16 @@ cc_test( ], deps = [ ":cel_expression_builder_flat_impl", - ":flat_expr_builder", + "//base:builtins", "//eval/public:activation", "//eval/public:cel_attribute", - "//eval/public:cel_builtins", "//eval/public:cel_expression", - "//eval/public:cel_options", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", - "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -533,6 +547,9 @@ cc_test( "//parser", "//runtime:runtime_issue", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -581,7 +598,6 @@ cc_test( ":instrumentation", ":regex_precompilation_optimization", "//base/ast_internal:ast_impl", - "//common:type", "//common:value", "//eval/eval:evaluator_core", "//extensions/protobuf:ast_converters", @@ -594,6 +610,9 @@ cc_test( "//runtime:runtime_options", "//runtime:standard_functions", "//runtime:type_registry", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/eval/compiler/cel_expression_builder_flat_impl.h b/eval/compiler/cel_expression_builder_flat_impl.h index 98efc4b74..7b09b7879 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.h +++ b/eval/compiler/cel_expression_builder_flat_impl.h @@ -24,11 +24,17 @@ #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "base/ast.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -37,13 +43,16 @@ namespace google::api::expr::runtime { // Builds instances of CelExpressionFlatImpl. class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { public: - explicit CelExpressionBuilderFlatImpl(const cel::RuntimeOptions& options) - : flat_expr_builder_(GetRegistry()->InternalGetRegistry(), - *GetTypeRegistry(), options) {} + CelExpressionBuilderFlatImpl( + absl::Nonnull> env, + const cel::RuntimeOptions& options) + : env_(std::move(env)), flat_expr_builder_(env_, options) { + ABSL_DCHECK(env_->IsInitialized()); + } - CelExpressionBuilderFlatImpl() - : flat_expr_builder_(GetRegistry()->InternalGetRegistry(), - *GetTypeRegistry()) {} + explicit CelExpressionBuilderFlatImpl( + absl::Nonnull> env) + : CelExpressionBuilderFlatImpl(std::move(env), cel::RuntimeOptions()) {} absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, @@ -64,15 +73,32 @@ class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { FlatExprBuilder& flat_expr_builder() { return flat_expr_builder_; } void set_container(std::string container) override { - CelExpressionBuilder::set_container(container); flat_expr_builder_.set_container(std::move(container)); } + // CelFunction registry. Extension function should be registered with it + // prior to expression creation. + CelFunctionRegistry* GetRegistry() const override { + return &env_->legacy_function_registry; + } + + // CEL Type registry. Provides a means to resolve the CEL built-in types to + // CelValue instances, and to extend the set of types and enums known to + // expressions by registering them ahead of time. + CelTypeRegistry* GetTypeRegistry() const override { + return &env_->legacy_type_registry; + } + + absl::string_view container() const override { + return flat_expr_builder_.container(); + } + private: absl::StatusOr> CreateExpressionImpl( std::unique_ptr converted_ast, std::vector* warnings) const; + absl::Nonnull> env_; FlatExprBuilder flat_expr_builder_; }; diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc index c70a04396..46212128b 100644 --- a/eval/compiler/cel_expression_builder_flat_impl_test.cc +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -44,11 +44,11 @@ #include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/public/testing/matchers.h" #include "extensions/bindings_ext.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -62,6 +62,7 @@ namespace { using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto3::NestedTestAllTypes; using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; @@ -78,7 +79,7 @@ using ::testing::NotNull; TEST(CelExpressionBuilderFlatImplTest, Error) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid empty expression"))); @@ -87,7 +88,7 @@ TEST(CelExpressionBuilderFlatImplTest, Error) { TEST(CelExpressionBuilderFlatImplTest, ParsedExpr) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, @@ -167,7 +168,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) { google::protobuf::Arena arena; // Unbounded. options.max_recursion_depth = -1; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); @@ -195,13 +196,12 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) { // Unbounded. options.max_recursion_depth = -1; options.enable_comprehension_list_append = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); builder.flat_expr_builder().AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer( - cel::extensions::ProtoMemoryManagerRef(&arena))); + cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options.regex_max_program_size)); @@ -232,7 +232,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) { // Unbounded. options.max_recursion_depth = -1; options.enable_recursive_tracing = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); @@ -261,7 +261,7 @@ TEST_P(RecursivePlanTest, Disabled) { google::protobuf::Arena arena; // disabled. options.max_recursion_depth = 0; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); @@ -343,7 +343,7 @@ TEST(CelExpressionBuilderFlatImplTest, ParsedExprWithWarnings) { cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; ASSERT_OK_AND_ASSIGN( @@ -367,7 +367,7 @@ TEST(CelExpressionBuilderFlatImplTest, CheckedExpr) { checked_expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); checked_expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, @@ -387,7 +387,7 @@ TEST(CelExpressionBuilderFlatImplTest, CheckedExprWithWarnings) { cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index faf0b0387..73eccad0e 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -29,7 +29,7 @@ #include "base/builtins.h" #include "base/kind.h" #include "base/type_provider.h" -#include "common/allocator.h" +#include "common/memory.h" #include "common/value.h" #include "common/value_manager.h" #include "eval/compiler/flat_expr_builder_extensions.h" @@ -39,6 +39,7 @@ #include "internal/status_macros.h" #include "runtime/activation.h" #include "runtime/internal/convert_constant.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { @@ -73,13 +74,18 @@ using ::google::api::expr::runtime::Resolver; class ConstantFoldingExtension : public ProgramOptimizer { public: ConstantFoldingExtension( - Allocator<> allocator, - absl::Nullable message_factory, + absl::Nullable> shared_arena, + absl::Nonnull arena, + absl::Nullable> + shared_message_factory, + absl::Nonnull message_factory, const TypeProvider& type_provider) - : memory_manager_(allocator), + : shared_arena_(std::move(shared_arena)), + arena_(arena), + shared_message_factory_(std::move(shared_message_factory)), + message_factory_(message_factory), state_(kDefaultStackLimit, kComprehensionSlotCount, type_provider, - MemoryManager(allocator)), - message_factory_(message_factory) {} + MemoryManager::Pooling(arena)) {} absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, const Expr& node) override; @@ -99,12 +105,15 @@ class ConstantFoldingExtension : public ProgramOptimizer { // if the comprehension variables are only used in a const way. static constexpr size_t kComprehensionSlotCount = 0; - MemoryManager memory_manager_; + absl::Nullable> shared_arena_; + ABSL_ATTRIBUTE_UNUSED + absl::Nonnull arena_; + absl::Nullable> + shared_message_factory_; + ABSL_ATTRIBUTE_UNUSED + absl::Nonnull message_factory_; Activation empty_; FlatExpressionEvaluatorState state_; - // Not yet used, will be in future. - ABSL_ATTRIBUTE_UNUSED - absl::Nullable message_factory_; std::vector is_const_; }; @@ -254,13 +263,29 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, } // namespace ProgramOptimizerFactory CreateConstantFoldingOptimizer( - Allocator<> allocator, - absl::Nullable message_factory) { - return [allocator, message_factory](PlannerContext& ctx, const AstImpl&) - -> absl::StatusOr> { - return std::make_unique( - allocator, message_factory, ctx.value_factory().type_provider()); - }; + absl::Nullable> arena, + absl::Nullable> message_factory) { + return + [shared_arena = std::move(arena), + shared_message_factory = std::move(message_factory)]( + PlannerContext& context, + const AstImpl&) -> absl::StatusOr> { + // If one was explicitly provided during planning or none was explicitly + // provided during configuration, request one from the planning context. + // Otherwise use the one provided during configuration. + absl::Nonnull arena = + context.HasExplicitArena() || shared_arena == nullptr + ? context.MutableArena() + : shared_arena.get(); + absl::Nonnull message_factory = + context.HasExplicitMessageFactory() || + shared_message_factory == nullptr + ? context.MutableMessageFactory() + : shared_message_factory.get(); + return std::make_unique( + shared_arena, arena, shared_message_factory, message_factory, + context.type_reflector()); + }; } } // namespace cel::runtime_internal diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index a69df01a3..532ba2b4b 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -15,9 +15,11 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ +#include + #include "absl/base/nullability.h" -#include "common/allocator.h" #include "eval/compiler/flat_expr_builder_extensions.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { @@ -31,8 +33,9 @@ namespace cel::runtime_internal { // extension. google::api::expr::runtime::ProgramOptimizerFactory CreateConstantFoldingOptimizer( - Allocator<> allocator, - absl::Nullable message_factory = nullptr); + absl::Nullable> arena = nullptr, + absl::Nullable> message_factory = + nullptr); } // namespace cel::runtime_internal diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index 7aafa7442..fcecf1297 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -18,17 +18,14 @@ #include #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" -#include "common/memory.h" -#include "common/type_factory.h" -#include "common/type_manager.h" #include "common/value.h" -#include "common/value_manager.h" #include "common/values/legacy_value_manager.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" @@ -40,9 +37,12 @@ #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" @@ -58,6 +58,7 @@ using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Expr; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::CreateConstValueStep; @@ -74,16 +75,20 @@ using ::testing::SizeIs; class UpdatedConstantFoldingTest : public testing::Test { public: UpdatedConstantFoldingTest() - : value_factory_(ProtoMemoryManagerRef(&arena_), + : env_(NewTestingRuntimeEnv()), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry), + value_factory_(ProtoMemoryManagerRef(&arena_), type_registry_.GetComposedTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError), resolver_("", function_registry_, type_registry_, value_factory_, type_registry_.resolveable_enums()) {} protected: + absl::Nonnull> env_; google::protobuf::Arena arena_; - cel::FunctionRegistry function_registry_; - cel::TypeRegistry type_registry_; + cel::FunctionRegistry& function_registry_; + cel::TypeRegistry& type_registry_; cel::common_internal::LegacyValueManager value_factory_; cel::RuntimeOptions options_; IssueCollector issue_collector_; @@ -143,12 +148,12 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -205,12 +210,12 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -264,12 +269,12 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -320,12 +325,12 @@ TEST_F(UpdatedConstantFoldingTest, CreatesList) { program_builder.ExitSubexpression(&create_list); // Insert the list creation step - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -377,12 +382,12 @@ TEST_F(UpdatedConstantFoldingTest, CreatesMap) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -435,12 +440,12 @@ TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -494,12 +499,12 @@ TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act / Assert ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 1bd7c205b..fcb4d5c44 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -82,6 +82,7 @@ #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -2128,23 +2129,20 @@ std::vector FlattenExpressionTable( absl::StatusOr FlatExprBuilder::CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const { - // These objects are expected to remain scoped to one build call -- references - // to them shouldn't be persisted in any part of the result expression. - cel::common_internal::LegacyValueManager value_factory( - cel::MemoryManagerRef::ReferenceCounting(), - type_registry_.GetComposedTypeProvider()); - RuntimeIssue::Severity max_severity = options_.fail_on_warnings ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); Resolver resolver(container_, function_registry_, type_registry_, - value_factory, type_registry_.resolveable_enums(), + type_registry_.GetComposedTypeProvider(), + type_registry_.resolveable_enums(), options_.enable_qualified_type_identifiers); + std::shared_ptr arena; ProgramBuilder program_builder; - PlannerContext extension_context(resolver, options_, value_factory, - issue_collector, program_builder); + PlannerContext extension_context(env_, resolver, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector, program_builder, arena); auto& ast_impl = AstImpl::CastFromPublicAst(*ast); @@ -2166,6 +2164,11 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( } } + // These objects are expected to remain scoped to one build call -- references + // to them shouldn't be persisted in any part of the result expression. + cel::common_internal::LegacyValueManager value_factory( + cel::MemoryManagerRef::ReferenceCounting(), + type_registry_.GetComposedTypeProvider()); FlatExprVisitor visitor(resolver, options_, std::move(optimizers), ast_impl.reference_map(), value_factory, issue_collector, program_builder, extension_context, @@ -2187,9 +2190,15 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( std::vector subexpressions = FlattenExpressionTable(program_builder, execution_path); + if (arena != nullptr && arena->SpaceUsed() == 0) { + // Arena was requested but no memory was used. Destroy it. + arena.reset(); + } + return FlatExpression(std::move(execution_path), std::move(subexpressions), visitor.slot_count(), - type_registry_.GetComposedTypeProvider(), options_); + type_registry_.GetComposedTypeProvider(), options_, + std::move(arena)); } } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index f1081d5c4..eafb58781 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -22,12 +22,14 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "base/ast.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/cel_type_registry.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" @@ -38,29 +40,28 @@ namespace google::api::expr::runtime { // Builds instances of CelExpressionFlatImpl. class FlatExprBuilder { public: - FlatExprBuilder(const cel::FunctionRegistry& function_registry, - const CelTypeRegistry& type_registry, - const cel::RuntimeOptions& options) - : options_(options), + FlatExprBuilder( + absl::Nonnull> + env, + const cel::RuntimeOptions& options) + : env_(std::move(env)), + options_(options), container_(options.container), - function_registry_(function_registry), - type_registry_(type_registry.InternalGetModernRegistry()) {} - - FlatExprBuilder(const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry& type_registry, - const cel::RuntimeOptions& options) - : options_(options), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry) {} + + FlatExprBuilder( + absl::Nonnull> + env, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::RuntimeOptions& options) + : env_(std::move(env)), + options_(options), container_(options.container), function_registry_(function_registry), type_registry_(type_registry) {} - // Create a flat expr builder with defaulted options. - FlatExprBuilder(const cel::FunctionRegistry& function_registry, - const CelTypeRegistry& type_registry) - : options_(cel::RuntimeOptions()), - function_registry_(function_registry), - type_registry_(type_registry.InternalGetModernRegistry()) {} - void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); } @@ -73,12 +74,16 @@ class FlatExprBuilder { container_ = std::move(container); } + absl::string_view container() const { return container_; } + // TODO: Add overload for cref AST. At the moment, all the users // can pass ownership of a freshly converted AST. absl::StatusOr CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const; + const cel::runtime_internal::RuntimeEnv& env() const { return *env_; } + const cel::RuntimeOptions& options() const { return options_; } // Called by `cel::extensions::EnableOptionalTypes` to indicate that special @@ -86,6 +91,8 @@ class FlatExprBuilder { void enable_optional_types() { enable_optional_types_ = true; } private: + const absl::Nonnull> + env_; cel::RuntimeOptions options_; std::string container_; bool enable_optional_types_ = false; diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 4b9ff2b8c..9d46d8dd8 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -34,6 +34,7 @@ #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" @@ -43,6 +44,7 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::ParsedExpr; using ::testing::HasSubstr; @@ -66,7 +68,7 @@ class CelExpressionBuilderFlatImplComprehensionsTest TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].filter(x, [3, 4].all(y, x < y))")); @@ -84,7 +86,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].map(x, x * 2)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -105,7 +107,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[7].exists_one(a, a == 7)")); @@ -122,7 +124,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[7, 7].exists_one(a, a == 7)")); @@ -140,7 +142,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { cel::RuntimeOptions options = GetRuntimeOptions(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("items.exists(i, i < 0)")); @@ -203,7 +205,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, })pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -256,7 +258,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -300,7 +302,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -357,7 +359,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -425,7 +427,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -472,7 +474,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -524,7 +526,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -571,7 +573,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -614,7 +616,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr)); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT( diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 10f5513ce..f7d46de0e 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -27,6 +27,7 @@ #include #include +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" @@ -39,6 +40,7 @@ #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" #include "common/native_type.h" +#include "common/type_reflector.h" #include "common/value.h" #include "common/value_manager.h" #include "eval/compiler/resolver.h" @@ -47,7 +49,10 @@ #include "eval/eval/trace_step.h" #include "internal/casts.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -321,16 +326,35 @@ const Subclass* TryDowncastDirectStep(const DirectExpressionStep* step) { // Class representing FlatExpr internals exposed to extensions. class PlannerContext { public: - explicit PlannerContext( + PlannerContext( + std::shared_ptr environment, const Resolver& resolver, const cel::RuntimeOptions& options, - cel::ValueManager& value_factory, + cel::ValueManager& value_manager, cel::runtime_internal::IssueCollector& issue_collector, - ProgramBuilder& program_builder) - : resolver_(resolver), - value_factory_(value_factory), + ProgramBuilder& program_builder, + std::shared_ptr& arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::shared_ptr message_factory = nullptr) + : PlannerContext(std::move(environment), resolver, options, + value_manager.type_provider(), issue_collector, + program_builder, arena, std::move(message_factory)) {} + + PlannerContext( + std::shared_ptr environment, + const Resolver& resolver, const cel::RuntimeOptions& options, + const cel::TypeReflector& type_reflector, + cel::runtime_internal::IssueCollector& issue_collector, + ProgramBuilder& program_builder, + std::shared_ptr& arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::shared_ptr message_factory = nullptr) + : environment_(std::move(environment)), + resolver_(resolver), + type_reflector_(type_reflector), options_(options), issue_collector_(issue_collector), - program_builder_(program_builder) {} + program_builder_(program_builder), + arena_(arena), + explicit_arena_(arena_ != nullptr), + message_factory_(std::move(message_factory)) {} ProgramBuilder& program_builder() { return program_builder_; } @@ -374,18 +398,42 @@ class PlannerContext { std::unique_ptr step); const Resolver& resolver() const { return resolver_; } - cel::ValueManager& value_factory() const { return value_factory_; } + const cel::TypeReflector& type_reflector() const { return type_reflector_; } const cel::RuntimeOptions& options() const { return options_; } cel::runtime_internal::IssueCollector& issue_collector() { return issue_collector_; } + // Returns `true` if an arena was explicitly provided during planning. + bool HasExplicitArena() const { return explicit_arena_; } + + absl::Nonnull MutableArena() { + if (!explicit_arena_ && arena_ == nullptr) { + arena_ = std::make_shared(); + } + ABSL_DCHECK(arena_ != nullptr); + return arena_.get(); + } + + // Returns `true` if a message factory was explicitly provided during + // planning. + bool HasExplicitMessageFactory() const { return message_factory_ != nullptr; } + + absl::Nonnull MutableMessageFactory() { + return HasExplicitMessageFactory() ? message_factory_.get() + : environment_->MutableMessageFactory(); + } + private: + const std::shared_ptr environment_; const Resolver& resolver_; - cel::ValueManager& value_factory_; + const cel::TypeReflector& type_reflector_; const cel::RuntimeOptions& options_; cel::runtime_internal::IssueCollector& issue_collector_; ProgramBuilder& program_builder_; + std::shared_ptr& arena_; + const bool explicit_arena_; + const std::shared_ptr message_factory_; }; // Interface for Ast Transforms. diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc index 1374cdfbf..c3b22c5ca 100644 --- a/eval/compiler/flat_expr_builder_extensions_test.cc +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "eval/compiler/flat_expr_builder_extensions.h" +#include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" @@ -31,9 +33,12 @@ #include "internal/testing.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { @@ -42,6 +47,8 @@ using ::absl_testing::StatusIs; using ::cel::RuntimeIssue; using ::cel::ast_internal::Expr; using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Optional; @@ -51,8 +58,9 @@ using Subexpression = ProgramBuilder::Subexpression; class PlannerContextTest : public testing::Test { public: PlannerContextTest() - : type_registry_(), - function_registry_(), + : env_(NewTestingRuntimeEnv()), + type_registry_(env_->type_registry), + function_registry_(env_->function_registry), value_factory_(cel::MemoryManagerRef::ReferenceCounting(), type_registry_.GetComposedTypeProvider()), resolver_("", function_registry_, type_registry_, value_factory_, @@ -60,8 +68,9 @@ class PlannerContextTest : public testing::Test { issue_collector_(RuntimeIssue::Severity::kError) {} protected: - cel::TypeRegistry type_registry_; - cel::FunctionRegistry function_registry_; + absl::Nonnull> env_; + cel::TypeRegistry& type_registry_; + cel::FunctionRegistry& function_registry_; cel::RuntimeOptions options_; cel::common_internal::LegacyValueManager value_factory_; Resolver resolver_; @@ -117,8 +126,9 @@ TEST_F(PlannerContextTest, GetPlan) { ASSERT_OK_AND_ASSIGN( auto step_ptrs, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(step_ptrs.b))); @@ -142,8 +152,9 @@ TEST_F(PlannerContextTest, ReplacePlan) { ASSERT_OK_AND_ASSIGN( auto step_ptrs, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(step_ptrs.b), UniquePtrHolds(step_ptrs.c), @@ -172,8 +183,9 @@ TEST_F(PlannerContextTest, ExtractPlan) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_TRUE(context.IsSubplanInspectable(a)); EXPECT_TRUE(context.IsSubplanInspectable(b)); @@ -191,8 +203,9 @@ TEST_F(PlannerContextTest, ExtractFailsOnReplacedNode) { ASSERT_OK(InitSimpleTree(a, b, c, value_factory_, program_builder).status()); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); ASSERT_OK(context.ReplaceSubplan(a, {})); @@ -208,8 +221,9 @@ TEST_F(PlannerContextTest, ReplacePlanUpdatesParent) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_TRUE(context.IsSubplanInspectable(a)); @@ -229,8 +243,9 @@ TEST_F(PlannerContextTest, ReplacePlanUpdatesSibling) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); ExecutionPath new_b; @@ -263,8 +278,9 @@ TEST_F(PlannerContextTest, ReplacePlanFailsOnUpdatedNode) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(plan_steps.c), @@ -289,8 +305,9 @@ TEST_F(PlannerContextTest, AddSubplanStep) { const ExpressionStep* b2_step_ptr = b2_step.get(); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); ASSERT_OK(context.AddSubplanStep(b, std::move(b2_step))); @@ -315,8 +332,9 @@ TEST_F(PlannerContextTest, AddSubplanStepFailsOnUnknownNode) { ASSERT_OK_AND_ASSIGN(auto b2_step, CreateConstValueStep(value_factory_.GetNullValue(), -1)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(d), IsEmpty()); diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index b7bed3655..afe7c5f9f 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -2,27 +2,28 @@ // produce expressions with the same outputs. #include -#include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "base/builtins.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" -#include "eval/compiler/flat_expr_builder.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_expression.h" -#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" -#include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::Expr; using ::testing::Eq; using ::testing::SizeIs; @@ -104,7 +105,8 @@ class ShortCircuitingTest : public testing::TestWithParam { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; } - auto result = std::make_unique(options); + auto result = std::make_unique( + NewTestingRuntimeEnv(), options); return result; } }; @@ -114,7 +116,7 @@ TEST_P(ShortCircuitingTest, BasicAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(true)); @@ -142,7 +144,7 @@ TEST_P(ShortCircuitingTest, BasicOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(false)); @@ -170,7 +172,7 @@ TEST_P(ShortCircuitingTest, ErrorAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); @@ -200,7 +202,7 @@ TEST_P(ShortCircuitingTest, ErrorOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); @@ -230,7 +232,7 @@ TEST_P(ShortCircuitingTest, UnknownAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); @@ -262,7 +264,7 @@ TEST_P(ShortCircuitingTest, UnknownOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 488f81a8d..ad0664777 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -33,6 +33,7 @@ #include "absl/types/span.h" #include "base/function.h" #include "base/function_descriptor.h" +#include "common/value.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/qualified_reference_resolver.h" @@ -55,12 +56,12 @@ #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/proto_file_util.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" @@ -75,9 +76,9 @@ namespace { using ::absl_testing::StatusIs; using ::cel::Value; using ::cel::expr::conformance::proto3::TestAllTypes; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::internal::test::EqualsProto; using ::cel::internal::test::ReadBinaryProtoFromFile; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; @@ -150,7 +151,7 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK( builder.GetRegistry()->Register(std::make_unique())); @@ -172,7 +173,7 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { TEST(FlatExprBuilderTest, ExprUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid empty expression"))); @@ -181,7 +182,7 @@ TEST(FlatExprBuilderTest, ExprUnset) { TEST(FlatExprBuilderTest, ConstValueUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Create an empty constant expression to ensure that it triggers an error. expr.mutable_const_expr(); @@ -193,7 +194,7 @@ TEST(FlatExprBuilderTest, ConstValueUnset) { TEST(FlatExprBuilderTest, MapKeyValueUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Don't set either the key or the value for the map creation step. auto* entry = expr.mutable_struct_expr()->add_entries(); @@ -211,7 +212,7 @@ TEST(FlatExprBuilderTest, MapKeyValueUnset) { TEST(FlatExprBuilderTest, MessageFieldValueUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -235,7 +236,7 @@ TEST(FlatExprBuilderTest, MessageFieldValueUnset) { TEST(FlatExprBuilderTest, BinaryCallTooManyArguments) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); auto* call = expr.mutable_call_expr(); call->set_function(builtin::kAnd); @@ -261,7 +262,7 @@ TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { { cel::RuntimeOptions options; options.short_circuiting = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -272,7 +273,7 @@ TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -294,7 +295,7 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; // Concat function not registered. @@ -338,7 +339,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { { cel::RuntimeOptions options; options.short_circuiting = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count1 = 0; @@ -361,7 +362,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count1 = 0; @@ -409,7 +410,7 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { { cel::RuntimeOptions options; options.short_circuiting = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; @@ -427,7 +428,7 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; @@ -446,7 +447,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -462,7 +463,7 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -474,7 +475,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -489,7 +490,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { comprehension_expr{accu_var: "a"} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -506,7 +507,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { iter_var: "b"} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -526,7 +527,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { }} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -549,7 +550,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { }} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -575,7 +576,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { }} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -625,7 +626,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -657,7 +658,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); builder.set_container(".bad"); @@ -673,7 +674,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); using FunctionAdapterT = FunctionAdapter; @@ -703,7 +704,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("XOr(a, b)")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("ext"); @@ -733,7 +734,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -760,7 +761,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderParentContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -787,7 +788,7 @@ TEST(FlatExprBuilderTest, TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderExplicitGlobal) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".c.d.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -813,7 +814,7 @@ TEST(FlatExprBuilderTest, TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("e.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -842,7 +843,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector build_warnings; builder.set_container("ext"); using FunctionAdapterT = FunctionAdapter; @@ -888,7 +889,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); @@ -948,7 +949,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -1017,7 +1018,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); builder.set_container("com.foo"); @@ -1085,7 +1086,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -1150,13 +1151,12 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); google::protobuf::Arena arena; - auto memory_manager = ProtoMemoryManagerRef(&arena); builder.flat_expr_builder().AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer(memory_manager)); + cel::runtime_internal::CreateConstantFoldingOptimizer()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); @@ -1239,7 +1239,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1310,7 +1310,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1362,7 +1362,7 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { cel::RuntimeOptions options; options.comprehension_max_iterations = 1; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1392,7 +1392,7 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1414,7 +1414,7 @@ TEST(FlatExprBuilderTest, SimpleEnumIdentTest) { Expr* cur_expr = &expr; cur_expr->mutable_ident_expr()->set_name(enum_name); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1431,7 +1431,7 @@ TEST(FlatExprBuilderTest, ContainerStringFormat) { SourceInfo source_info; expr.mutable_ident_expr()->set_name("ident"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.set_container(""); ASSERT_OK(builder.CreateExpression(&expr, &source_info)); @@ -1469,7 +1469,7 @@ void EvalExpressionWithEnum(absl::string_view enum_name, cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); builder.GetTypeRegistry()->Register(TestEnum_descriptor()); builder.set_container(std::string(container)); @@ -1552,7 +1552,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1596,7 +1596,7 @@ TEST(FlatExprBuilderTest, RepeatedFieldPresence) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1639,7 +1639,7 @@ absl::Status RunTernaryExpression(CelValue selector, CelValue value1, auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value2"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); CEL_ASSIGN_OR_RETURN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1668,7 +1668,7 @@ TEST(FlatExprBuilderTest, Ternary) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value1"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1768,7 +1768,7 @@ TEST(FlatExprBuilderTest, EmptyCallList) { SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function(op); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); auto build = builder.CreateExpression(&expr, &source_info); ASSERT_FALSE(build.ok()); @@ -1782,7 +1782,7 @@ TEST(FlatExprBuilderTest, HeterogeneousListsAllowed) { parser::Parse("[17, 'seventeen']")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), @@ -1812,7 +1812,7 @@ TEST(FlatExprBuilderTest, NullUnboxingEnabled) { parser::Parse("message.int32_wrapper_value")); cel::RuntimeOptions options; options.enable_empty_wrapper_null_unboxing = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1833,7 +1833,7 @@ TEST(FlatExprBuilderTest, TypeResolve) { parser::Parse("type(message) == runtime.TestMessage")); cel::RuntimeOptions options; options.enable_qualified_type_identifiers = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1861,7 +1861,7 @@ TEST(FlatExprBuilderTest, AnyPackingList) { parser::Parse("TestAllTypes{single_any: [1, 2, 3]}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1896,7 +1896,7 @@ TEST(FlatExprBuilderTest, AnyPackingNestedNumbers) { parser::Parse("TestAllTypes{single_any: [1, 2.3]}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1929,7 +1929,7 @@ TEST(FlatExprBuilderTest, AnyPackingInt) { parser::Parse("TestAllTypes{single_any: 1}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1961,7 +1961,7 @@ TEST(FlatExprBuilderTest, AnyPackingMap) { parser::Parse("TestAllTypes{single_any: {'key': 'value'}}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1996,7 +1996,7 @@ TEST(FlatExprBuilderTest, NullUnboxingDisabled) { parser::Parse("message.int32_wrapper_value")); cel::RuntimeOptions options; options.enable_empty_wrapper_null_unboxing = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2016,7 +2016,7 @@ TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { parser::Parse("{1: 2, 2u: 3}[1.0]")); cel::RuntimeOptions options; options.enable_heterogeneous_equality = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2034,7 +2034,7 @@ TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { parser::Parse("{1: 2, 2u: 3}[1.0]")); cel::RuntimeOptions options; options.enable_heterogeneous_equality = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2056,7 +2056,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is unknown. We only have the proto as data, we did // not link the generated message, so it's not included in the generated pool. - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -2079,7 +2079,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is *known*. We are using a custom descriptor pool // that has been primed with the relevant message. - CelExpressionBuilderFlatImpl builder2; + CelExpressionBuilderFlatImpl builder2(NewTestingRuntimeEnv()); builder2.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&desc_pool, &message_factory)); @@ -2121,7 +2121,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { // The since this is access only, the evaluator will work with message duck // typing. - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2170,7 +2170,7 @@ TEST_P(CustomDescriptorPoolTest, TestType) { ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&descriptor_pool, &message_factory)); @@ -2408,7 +2408,7 @@ TEST(FlatExprBuilderTest, BlockBadIndex) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("bad @index"))); @@ -2430,7 +2430,7 @@ TEST(FlatExprBuilderTest, OutOfRangeBlockIndex) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2451,7 +2451,7 @@ TEST(FlatExprBuilderTest, EarlyBlockIndex) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2466,7 +2466,7 @@ TEST(FlatExprBuilderTest, OutOfScopeCSE) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2482,7 +2482,7 @@ TEST(FlatExprBuilderTest, BlockMissingBindings) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2503,7 +2503,7 @@ TEST(FlatExprBuilderTest, BlockMissingExpression) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2524,7 +2524,7 @@ TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2546,7 +2546,7 @@ TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs( @@ -2574,7 +2574,7 @@ TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2608,7 +2608,7 @@ TEST(FlatExprBuilderTest, BlockNested) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, diff --git a/eval/compiler/instrumentation_test.cc b/eval/compiler/instrumentation_test.cc index beb94fe2c..78b2ba59b 100644 --- a/eval/compiler/instrumentation_test.cc +++ b/eval/compiler/instrumentation_test.cc @@ -15,14 +15,15 @@ #include "eval/compiler/instrumentation.h" #include +#include #include #include #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" -#include "common/type.h" #include "common/value.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" @@ -34,6 +35,8 @@ #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" @@ -45,6 +48,8 @@ namespace { using ::cel::IntValue; using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; @@ -54,7 +59,10 @@ using ::testing::UnorderedElementsAre; class InstrumentationTest : public ::testing::Test { public: InstrumentationTest() - : managed_value_factory_( + : env_(NewTestingRuntimeEnv()), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry), + managed_value_factory_( type_registry_.GetComposedTypeProvider(), cel::extensions::ProtoMemoryManagerRef(&arena_)) {} void SetUp() override { @@ -62,9 +70,10 @@ class InstrumentationTest : public ::testing::Test { } protected: + absl::Nonnull> env_; cel::RuntimeOptions options_; - cel::FunctionRegistry function_registry_; - cel::TypeRegistry type_registry_; + cel::FunctionRegistry& function_registry_; + cel::TypeRegistry& type_registry_; google::protobuf::Arena arena_; cel::ManagedValueFactory managed_value_factory_; }; @@ -76,7 +85,7 @@ MATCHER_P(IsIntValue, expected, "") { } TEST_F(InstrumentationTest, Basic) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -114,7 +123,7 @@ TEST_F(InstrumentationTest, Basic) { } TEST_F(InstrumentationTest, BasicWithConstFolding) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); absl::flat_hash_map expr_id_to_value; Instrumentation expr_id_recorder = [&expr_id_to_value]( @@ -124,8 +133,7 @@ TEST_F(InstrumentationTest, BasicWithConstFolding) { return absl::OkStatus(); }; builder.AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer( - managed_value_factory_.get().GetMemoryManager())); + cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::ast_internal::AstImpl&) -> Instrumentation { return expr_id_recorder; @@ -161,7 +169,7 @@ TEST_F(InstrumentationTest, BasicWithConstFolding) { } TEST_F(InstrumentationTest, AndShortCircuit) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -206,7 +214,7 @@ TEST_F(InstrumentationTest, AndShortCircuit) { } TEST_F(InstrumentationTest, OrShortCircuit) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -251,7 +259,7 @@ TEST_F(InstrumentationTest, OrShortCircuit) { } TEST_F(InstrumentationTest, Ternary) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -304,7 +312,7 @@ TEST_F(InstrumentationTest, Ternary) { } TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateRegexPrecompilationExtension(0)); @@ -340,7 +348,7 @@ TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { } TEST_F(InstrumentationTest, NoopSkipped) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::ast_internal::AstImpl&) -> Instrumentation { diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index 65d2d9058..2a6341a44 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -21,6 +21,7 @@ #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" #include "common/memory.h" @@ -38,6 +39,8 @@ #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "google/protobuf/arena.h" @@ -46,6 +49,8 @@ namespace { using ::cel::RuntimeIssue; using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; @@ -54,7 +59,9 @@ namespace exprpb = cel::expr; class RegexPrecompilationExtensionTest : public testing::TestWithParam { public: RegexPrecompilationExtensionTest() - : type_registry_(*builder_.GetTypeRegistry()), + : env_(NewTestingRuntimeEnv()), + builder_(env_), + type_registry_(*builder_.GetTypeRegistry()), function_registry_(*builder_.GetRegistry()), value_factory_(cel::MemoryManagerRef::ReferenceCounting(), type_registry_.GetTypeProvider()), @@ -88,6 +95,7 @@ class RegexPrecompilationExtensionTest : public testing::TestWithParam { }; } + absl::Nonnull> env_; CelExpressionBuilderFlatImpl builder_; CelTypeRegistry& type_registry_; CelFunctionRegistry& function_registry_; @@ -106,8 +114,9 @@ TEST_P(RegexPrecompilationExtensionTest, SmokeTest) { ProgramBuilder program_builder; cel::ast_internal::AstImpl ast_impl; ast_impl.set_is_checked(true); - PlannerContext context(resolver_, runtime_options_, value_factory_, - issue_collector_, program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, runtime_options_, value_factory_, + issue_collector_, program_builder, arena); ASSERT_OK_AND_ASSIGN(std::unique_ptr optimizer, factory(context, ast_impl)); @@ -209,8 +218,7 @@ class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { public: RegexConstFoldInteropTest() : RegexPrecompilationExtensionTest() { builder_.flat_expr_builder().AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer( - cel::MemoryManagerRef::ReferenceCounting())); + cel::runtime_internal::CreateConstantFoldingOptimizer()); } protected: diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index d2f0ae184..ddbdf1be7 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -15,12 +15,10 @@ #include "eval/compiler/resolver.h" #include -#include #include #include #include -#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -30,10 +28,9 @@ #include "absl/strings/strip.h" #include "absl/types/optional.h" #include "base/kind.h" -#include "common/memory.h" #include "common/type.h" +#include "common/type_reflector.h" #include "common/value.h" -#include "common/value_manager.h" #include "internal/status_macros.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" @@ -41,18 +38,20 @@ namespace google::api::expr::runtime { +using ::cel::IntValue; +using ::cel::TypeValue; using ::cel::Value; Resolver::Resolver( absl::string_view container, const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry&, cel::ValueManager& value_factory, + const cel::TypeRegistry&, const cel::TypeReflector& type_reflector, const absl::flat_hash_map& resolveable_enums, bool resolve_qualified_type_identifiers) : namespace_prefixes_(), enum_value_map_(), function_registry_(function_registry), - value_factory_(value_factory), + type_reflector_(type_reflector), resolveable_enums_(resolveable_enums), resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) { // The constructor for the registry determines the set of possible namespace @@ -85,7 +84,7 @@ Resolver::Resolver( for (const auto& enumerator : enum_type.enumerators) { auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", enumerator.name); - enum_value_map_[key] = value_factory.CreateIntValue(enumerator.number); + enum_value_map_[key] = IntValue(enumerator.number); } } } @@ -127,9 +126,9 @@ absl::optional Resolver::FindConstant(absl::string_view name, // to do so is configured in the expression builder. If the type name is // not qualified, then it too may be returned as a constant value. if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) { - auto type_value = value_factory_.FindType(name); + auto type_value = type_reflector_.FindType(name); if (type_value.ok() && type_value->has_value()) { - return value_factory_.CreateTypeValue(**type_value); + return TypeValue(**type_value); } } } @@ -179,7 +178,7 @@ Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto qualified_names = FullyQualifiedNames(name, expr_id); for (auto& qualified_name : qualified_names) { CEL_ASSIGN_OR_RETURN(auto maybe_type, - value_factory_.FindType(qualified_name)); + type_reflector_.FindType(qualified_name)); if (maybe_type.has_value()) { return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); } diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h index 2d164cb14..ee0e55ce1 100644 --- a/eval/compiler/resolver.h +++ b/eval/compiler/resolver.h @@ -25,6 +25,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/kind.h" +#include "common/type_reflector.h" #include "common/value.h" #include "common/value_manager.h" #include "runtime/function_overload_reference.h" @@ -47,6 +48,18 @@ class Resolver { absl::string_view container, const cel::FunctionRegistry& function_registry, const cel::TypeRegistry& type_registry, cel::ValueManager& value_factory, + const absl::flat_hash_map& + resolveable_enums, + bool resolve_qualified_type_identifiers = true) + : Resolver(container, function_registry, type_registry, + value_factory.type_provider(), resolveable_enums, + resolve_qualified_type_identifiers) {} + + Resolver( + absl::string_view container, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::TypeReflector& type_reflector, const absl::flat_hash_map& resolveable_enums, bool resolve_qualified_type_identifiers = true); @@ -89,7 +102,7 @@ class Resolver { std::vector namespace_prefixes_; absl::flat_hash_map enum_value_map_; const cel::FunctionRegistry& function_registry_; - cel::ValueManager& value_factory_; + const cel::TypeReflector& type_reflector_; const absl::flat_hash_map& resolveable_enums_; diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 62c67c0e9..d7769f22f 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -60,6 +60,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/utility", + "@com_google_protobuf//:protobuf", ], ) @@ -541,6 +542,7 @@ cc_test( "//internal:testing", "//runtime:activation", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index b654d92b7..468a06634 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -17,7 +17,6 @@ #include #include -#include #include #include #include @@ -44,6 +43,7 @@ #include "runtime/managed_value_factory.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -369,23 +369,27 @@ class FlatExpression { // value creation in evaluation FlatExpression(ExecutionPath path, size_t comprehension_slots_size, const cel::TypeProvider& type_provider, - const cel::RuntimeOptions& options) + const cel::RuntimeOptions& options, + absl::Nullable> arena = nullptr) : path_(std::move(path)), subexpressions_({path_}), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), - options_(options) {} + options_(options), + arena_(std::move(arena)) {} FlatExpression(ExecutionPath path, std::vector subexpressions, size_t comprehension_slots_size, const cel::TypeProvider& type_provider, - const cel::RuntimeOptions& options) + const cel::RuntimeOptions& options, + absl::Nullable> arena = nullptr) : path_(std::move(path)), subexpressions_(std::move(subexpressions)), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), - options_(options) {} + options_(options), + arena_(std::move(arena)) {} // Move-only FlatExpression(FlatExpression&&) = default; @@ -429,6 +433,9 @@ class FlatExpression { size_t comprehension_slots_size_; const cel::TypeProvider& type_provider_; cel::RuntimeOptions options_; + // Arena used during planning phase, may hold constant values so should be + // kept alive. + absl::Nullable> arena_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 7b4404af1..da15f4b4e 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -15,6 +15,7 @@ #include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -23,6 +24,7 @@ using ::cel::IntValue; using ::cel::TypeProvider; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::interop_internal::CreateIntValue; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::Expr; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::testing::_; @@ -173,7 +175,7 @@ TEST(EvaluatorCoreTest, TraceTest) { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); diff --git a/eval/public/BUILD b/eval/public/BUILD index be7b3a1c8..6142f3fa4 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -573,23 +573,21 @@ cc_library( ":cel_function", ":cel_options", "//base:kind", - "//base/ast_internal:ast_impl", "//common:memory", "//eval/compiler:cel_expression_builder_flat_impl", "//eval/compiler:comprehension_vulnerability_check", "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", - "//eval/compiler:flat_expr_builder_extensions", "//eval/compiler:qualified_reference_resolver", "//eval/compiler:regex_precompilation_optimization", "//eval/public/structs:protobuf_descriptor_type_provider", "//extensions:select_optimization", - "//extensions/protobuf:memory_manager", - "//internal:proto_util", + "//internal:noop_delete", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index cc061a7ea..436a85752 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -17,28 +17,28 @@ #include "eval/public/cel_expr_builder_factory.h" #include +#include +#include "absl/base/nullability.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "base/ast_internal/ast_impl.h" #include "base/kind.h" #include "common/memory.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/comprehension_vulnerability_check.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" -#include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/qualified_reference_resolver.h" #include "eval/compiler/regex_precompilation_optimization.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/structs/protobuf_descriptor_type_provider.h" -#include "extensions/protobuf/memory_manager.h" #include "extensions/select_optimization.h" -#include "internal/proto_util.h" +#include "internal/noop_delete.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -47,25 +47,12 @@ namespace google::api::expr::runtime { namespace { using ::cel::MemoryManagerRef; -using ::cel::ast_internal::AstImpl; using ::cel::extensions::CreateSelectOptimizationProgramOptimizer; using ::cel::extensions::kCelAttribute; using ::cel::extensions::kCelHasField; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::extensions::SelectOptimizationAstUpdater; using ::cel::runtime_internal::CreateConstantFoldingOptimizer; -using ::google::api::expr::internal::ValidateStandardMessageTypes; - -// Adapter for a raw arena* pointer. Manages a MemoryManager object for the -// constant folding extension. -struct ArenaBackedConstfoldingFactory { - MemoryManagerRef memory_manager; - - absl::StatusOr> operator()( - PlannerContext& ctx, const AstImpl& ast) const { - return CreateConstantFoldingOptimizer(memory_manager)(ctx, ast); - } -}; +using ::cel::runtime_internal::RuntimeEnv; } // namespace @@ -78,15 +65,27 @@ std::unique_ptr CreateCelExpressionBuilder( "CreateCelExpressionBuilder"; return nullptr; } - if (auto s = ValidateStandardMessageTypes(*descriptor_pool); !s.ok()) { - ABSL_LOG(WARNING) << "Failed to validate standard message types: " - << s.ToString(); // NOLINT: OSS compatibility - return nullptr; - } cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); - auto builder = - std::make_unique(runtime_options); + absl::Nullable> + shared_message_factory; + if (message_factory != nullptr) { + shared_message_factory = std::shared_ptr( + message_factory, + cel::internal::NoopDeleteFor()); + } + auto env = std::make_shared( + std::shared_ptr( + descriptor_pool, + cel::internal::NoopDeleteFor()), + shared_message_factory); + if (auto status = env->Initialize(); !status.ok()) { + ABSL_LOG(ERROR) << "Failed to validate standard message types: " + << status.ToString(); // NOLINT: OSS compatibility + return nullptr; + } + auto builder = std::make_unique( + std::move(env), runtime_options); builder->GetTypeRegistry() ->InternalGetModernRegistry() @@ -109,9 +108,15 @@ std::unique_ptr CreateCelExpressionBuilder( } if (options.constant_folding) { + std::shared_ptr shared_arena; + if (options.constant_arena != nullptr) { + shared_arena = std::shared_ptr( + options.constant_arena, + cel::internal::NoopDeleteFor()); + } builder->flat_expr_builder().AddProgramOptimizer( - ArenaBackedConstfoldingFactory{ - ProtoMemoryManagerRef(options.constant_arena)}); + CreateConstantFoldingOptimizer(std::move(shared_arena), + std::move(shared_message_factory))); } if (options.enable_regex_precompilation) { diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 98b58aa98..3f52ad60d 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -5,7 +5,6 @@ #include #include #include -#include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" @@ -76,10 +75,7 @@ class CelExpression { // it built. class CelExpressionBuilder { public: - CelExpressionBuilder() - : func_registry_(std::make_unique()), - type_registry_(std::make_unique()), - container_("") {} + CelExpressionBuilder() = default; virtual ~CelExpressionBuilder() = default; @@ -129,23 +125,16 @@ class CelExpressionBuilder { // CelFunction registry. Extension function should be registered with it // prior to expression creation. - CelFunctionRegistry* GetRegistry() const { return func_registry_.get(); } + virtual CelFunctionRegistry* GetRegistry() const = 0; // CEL Type registry. Provides a means to resolve the CEL built-in types to // CelValue instances, and to extend the set of types and enums known to // expressions by registering them ahead of time. - CelTypeRegistry* GetTypeRegistry() const { return type_registry_.get(); } + virtual CelTypeRegistry* GetTypeRegistry() const = 0; - virtual void set_container(std::string container) { - container_ = std::move(container); - } - - absl::string_view container() const { return container_; } + virtual void set_container(std::string container) = 0; - private: - std::unique_ptr func_registry_; - std::unique_ptr type_registry_; - std::string container_; + virtual absl::string_view container() const = 0; }; } // namespace google::api::expr::runtime diff --git a/eval/tests/modern_benchmark_test.cc b/eval/tests/modern_benchmark_test.cc index 81cf91ef0..22233210a 100644 --- a/eval/tests/modern_benchmark_test.cc +++ b/eval/tests/modern_benchmark_test.cc @@ -102,8 +102,7 @@ std::unique_ptr StandardRuntimeOrDie( break; case ConstFoldingEnabled::kYes: ABSL_CHECK(arena != nullptr); - ABSL_CHECK_OK(extensions::EnableConstantFolding( - *builder, ProtoMemoryManagerRef(arena))); + ABSL_CHECK_OK(extensions::EnableConstantFolding(*builder)); break; } diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 2e34096e0..2083d3b82 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -152,7 +152,7 @@ Expr MakeSelectPathExpr( absl::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { - auto field_or = planner_context.value_factory() + auto field_or = planner_context.type_reflector() .FindStructTypeFieldByName(runtime_type, field_name) .value_or(absl::nullopt); if (field_or.has_value()) { @@ -515,7 +515,7 @@ class RewriterImpl : public AstRewriterBase { } absl::optional GetRuntimeType(absl::string_view type_name) { - return planner_context_.value_factory().FindType(type_name).value_or( + return planner_context_.type_reflector().FindType(type_name).value_or( absl::nullopt); } diff --git a/runtime/BUILD b/runtime/BUILD index c453afb89..1d8c3dbfc 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -213,8 +213,10 @@ cc_library( deps = [ ":runtime_builder", ":runtime_options", + ":type_registry", "//internal:noop_delete", "//internal:status_macros", + "//runtime/internal:runtime_env", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", @@ -304,11 +306,10 @@ cc_library( deps = [ ":runtime", ":runtime_builder", - "//common:allocator", - "//common:memory", "//common:native_type", "//eval/compiler:constant_folding", "//internal:casts", + "//internal:noop_delete", "//internal:status_macros", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", @@ -339,6 +340,7 @@ cc_test( "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -384,6 +386,7 @@ cc_test( "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/runtime/constant_folding.cc b/runtime/constant_folding.cc index 57ead8096..0174ef267 100644 --- a/runtime/constant_folding.cc +++ b/runtime/constant_folding.cc @@ -14,20 +14,24 @@ #include "runtime/constant_folding.h" -#include "absl/base/macros.h" +#include +#include + +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "common/allocator.h" #include "common/native_type.h" #include "eval/compiler/constant_folding.h" #include "internal/casts.h" +#include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::extensions { @@ -37,44 +41,122 @@ using ::cel::internal::down_cast; using ::cel::runtime_internal::RuntimeFriendAccess; using ::cel::runtime_internal::RuntimeImpl; -absl::StatusOr RuntimeImplFromBuilder(RuntimeBuilder& builder) { +absl::StatusOr> RuntimeImplFromBuilder( + RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); - if (RuntimeFriendAccess::RuntimeTypeId(runtime) != NativeTypeId::For()) { return absl::UnimplementedError( "constant folding only supported on the default cel::Runtime " "implementation."); } + return down_cast(&runtime); +} - RuntimeImpl& runtime_impl = down_cast(runtime); - - return &runtime_impl; +absl::Status EnableConstantFoldingImpl( + RuntimeBuilder& builder, + absl::Nullable> arena, + absl::Nullable> message_factory) { + CEL_ASSIGN_OR_RETURN(absl::Nonnull runtime_impl, + RuntimeImplFromBuilder(builder)); + if (arena != nullptr) { + runtime_impl->environment().KeepAlive(arena); + } + if (message_factory != nullptr) { + runtime_impl->environment().KeepAlive(message_factory); + } + runtime_impl->expr_builder().AddProgramOptimizer( + runtime_internal::CreateConstantFoldingOptimizer( + std::move(arena), std::move(message_factory))); + return absl::OkStatus(); } } // namespace +absl::Status EnableConstantFolding(RuntimeBuilder& builder) { + return EnableConstantFoldingImpl(builder, nullptr, nullptr); +} + absl::Status EnableConstantFolding(RuntimeBuilder& builder, - Allocator<> allocator) { - CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, - RuntimeImplFromBuilder(builder)); - ABSL_ASSERT(runtime_impl != nullptr); - runtime_impl->expr_builder().AddProgramOptimizer( - runtime_internal::CreateConstantFoldingOptimizer(allocator, nullptr)); - return absl::OkStatus(); + absl::Nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + nullptr); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena) { + ABSL_DCHECK(arena != nullptr); + return EnableConstantFoldingImpl(builder, std::move(arena), nullptr); } absl::Status EnableConstantFolding( - RuntimeBuilder& builder, Allocator<> allocator, + RuntimeBuilder& builder, absl::Nonnull message_factory) { ABSL_DCHECK(message_factory != nullptr); - CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, - RuntimeImplFromBuilder(builder)); - ABSL_ASSERT(runtime_impl != nullptr); - runtime_impl->expr_builder().AddProgramOptimizer( - runtime_internal::CreateConstantFoldingOptimizer(allocator, - message_factory)); - return absl::OkStatus(); + return EnableConstantFoldingImpl( + builder, nullptr, + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> message_factory) { + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl(builder, nullptr, + std::move(message_factory)); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl::Nonnull arena, + absl::Nonnull message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl::Nonnull arena, + absl::Nonnull> message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + std::move(message_factory)); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, std::move(arena), + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull> message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl(builder, std::move(arena), + std::move(message_factory)); } } // namespace cel::extensions diff --git a/runtime/constant_folding.h b/runtime/constant_folding.h index be5cf6044..58cd4cfd0 100644 --- a/runtime/constant_folding.h +++ b/runtime/constant_folding.h @@ -15,10 +15,12 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ +#include + #include "absl/base/nullability.h" #include "absl/status/status.h" -#include "common/allocator.h" #include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::extensions { @@ -26,20 +28,44 @@ namespace cel::extensions { // Enable constant folding in the runtime being built. // // Constant folding eagerly evaluates sub-expressions with all constant inputs -// at plan time to simplify the resulting program. User extensions functions are -// executed if they are eagerly bound. +// at plan time to simplify the resulting program. User functions are executed +// if they are eagerly bound. // -// The underlying implementation of `allocator` must outlive the resulting -// runtime and any programs it creates. +// The provided, the `google::protobuf::Arena` must outlive the resulting runtime +// and any program it creates. Otherwise the runtime will create one as needed +// during planning for each program, unless one is explicitly provided during +// planning. // -// The provided `google::protobuf::MessageFactory` must outlive the resulting runtime and -// any program it creates. Failure to pass a message factory may result in -// certain optimizations being disabled. +// The provided, the `google::protobuf::MessageFactory` must outlive the resulting runtime +// and any program it creates. Otherwise the runtime will create one as needed +// and use it for all planning and the resulting programs created from the +// runtime, unless one is explicitly provided during planning or evaluation. +absl::Status EnableConstantFolding(RuntimeBuilder& builder); absl::Status EnableConstantFolding(RuntimeBuilder& builder, - Allocator<> allocator); + absl::Nonnull arena); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> message_factory); absl::Status EnableConstantFolding( - RuntimeBuilder& builder, Allocator<> allocator, + RuntimeBuilder& builder, absl::Nonnull arena, absl::Nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl::Nonnull arena, + absl::Nonnull> message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull> message_factory); } // namespace cel::extensions diff --git a/runtime/constant_folding_test.cc b/runtime/constant_folding_test.cc index af3010b62..f579cb400 100644 --- a/runtime/constant_folding_test.cc +++ b/runtime/constant_folding_test.cc @@ -20,6 +20,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "base/function_adapter.h" @@ -38,6 +39,7 @@ namespace cel::extensions { namespace { +using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; @@ -87,10 +89,9 @@ TEST_P(ConstantFoldingExtTest, Runner) { return StringValue::Concat(f, prefix, value); }, builder.function_registry()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); - ASSERT_OK( - EnableConstantFolding(builder, MemoryManagerRef::ReferenceCounting())); + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 69b8a8e3e..9e5078ccd 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -47,11 +47,31 @@ cc_library( ], ) +cc_library( + name = "runtime_env", + srcs = ["runtime_env.cc"], + hdrs = ["runtime_env.h"], + deps = [ + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//internal:noop_delete", + "//internal:well_known_types", + "//runtime:function_registry", + "//runtime:type_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "runtime_impl", srcs = ["runtime_impl.cc"], hdrs = ["runtime_impl.h"], deps = [ + ":runtime_env", "//base:ast", "//base:data", "//common:native_type", @@ -73,7 +93,6 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", - "@com_google_protobuf//:protobuf", ], ) @@ -156,3 +175,16 @@ cc_test( "@com_google_absl//absl/time", ], ) + +cc_library( + name = "runtime_env_testing", + testonly = True, + srcs = ["runtime_env_testing.cc"], + hdrs = ["runtime_env_testing.h"], + deps = [ + ":runtime_env", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + ], +) diff --git a/runtime/internal/runtime_env.cc b/runtime/internal/runtime_env.cc new file mode 100644 index 000000000..dbe78d538 --- /dev/null +++ b/runtime/internal/runtime_env.cc @@ -0,0 +1,74 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_env.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/synchronization/mutex.h" +#include "internal/noop_delete.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +RuntimeEnv::KeepAlives::~KeepAlives() { + while (!deque.empty()) { + deque.pop_back(); + } +} + +absl::Nonnull RuntimeEnv::MutableMessageFactory() + const { + absl::Nullable shared_message_factory = + message_factory_ptr.load(std::memory_order_relaxed); + if (shared_message_factory != nullptr) { + return shared_message_factory; + } + absl::MutexLock lock(&message_factory_mutex); + shared_message_factory = message_factory_ptr.load(std::memory_order_relaxed); + if (shared_message_factory == nullptr) { + if (descriptor_pool.get() == google::protobuf::DescriptorPool::generated_pool()) { + // Using the generated descriptor pool, just use the generated message + // factory. + message_factory = std::shared_ptr( + google::protobuf::MessageFactory::generated_factory(), + internal::NoopDeleteFor()); + } else { + auto dynamic_message_factory = + std::make_shared(); + // Ensure we do not delegate to the generated factory, if the default + // every changes. We prefer being hermetic. + dynamic_message_factory->SetDelegateToGeneratedFactory(false); + message_factory = std::move(dynamic_message_factory); + } + shared_message_factory = message_factory.get(); + message_factory_ptr.store(shared_message_factory, + std::memory_order_seq_cst); + } + return shared_message_factory; +} + +void RuntimeEnv::KeepAlive(std::shared_ptr keep_alive) { + if (keep_alive == nullptr) { + return; + } + keep_alives.deque.push_back(std::move(keep_alive)); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_env.h b/runtime/internal/runtime_env.h new file mode 100644 index 000000000..e0ab566b1 --- /dev/null +++ b/runtime/internal/runtime_env.h @@ -0,0 +1,133 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "internal/well_known_types.h" +#include "runtime/function_registry.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +// Shared state used by the runtime during creation, configuration, planning, +// and evaluation. Passed around via `std::shared_ptr`. +// +// TODO: Make this a class. +struct RuntimeEnv final { + explicit RuntimeEnv( + absl::Nonnull> + descriptor_pool, + absl::Nullable> message_factory = + nullptr) + : descriptor_pool(std::move(descriptor_pool)), + message_factory(std::move(message_factory)), + type_registry(legacy_type_registry.InternalGetModernRegistry()), + function_registry(legacy_function_registry.InternalGetRegistry()) { + if (this->message_factory != nullptr) { + message_factory_ptr.store(this->message_factory.get(), + std::memory_order_seq_cst); + } + } + + // Not copyable or moveable. + RuntimeEnv(const RuntimeEnv&) = delete; + RuntimeEnv(RuntimeEnv&&) = delete; + RuntimeEnv& operator=(const RuntimeEnv&) = delete; + RuntimeEnv& operator=(RuntimeEnv&&) = delete; + + // Ideally the environment would already be initialized, but things are a bit + // awkward. This should only be called once immediately after construction. + absl::Status Initialize() { + return well_known_types.Initialize(descriptor_pool.get()); + } + + bool IsInitialized() const { return well_known_types.IsInitialized(); } + + ABSL_ATTRIBUTE_UNUSED + const absl::Nonnull> + descriptor_pool; + + private: + // These fields deal with a message factory that is lazily initialized as + // needed. This might be called during the planning phase of an expression or + // during evaluation. We want the ability to get the message factory when it + // is already created to be cheap, so we use an atomic and a mutex for the + // slow path. + // + // Do not access any of these fields directly, use member functions. + mutable absl::Mutex message_factory_mutex; + mutable absl::Nullable> + message_factory ABSL_GUARDED_BY(message_factory_mutex); + // std::atomic> is not really a simple atomic, so we + // avoid it. + mutable std::atomic> + message_factory_ptr = nullptr; + + struct KeepAlives final { + KeepAlives() = default; + + ~KeepAlives(); + + // Not copyable or moveable. + KeepAlives(const KeepAlives&) = delete; + KeepAlives(KeepAlives&&) = delete; + KeepAlives& operator=(const KeepAlives&) = delete; + KeepAlives& operator=(KeepAlives&&) = delete; + + std::deque> deque; + }; + + KeepAlives keep_alives; + + public: + // Because of legacy shenanigans, we use shared_ptr here. For legacy, this is + // an unowned shared_ptr (a noop deleter) pointing to the modern equivalent + // which is a member of the legacy variant. + google::api::expr::runtime::CelTypeRegistry legacy_type_registry; + google::api::expr::runtime::CelFunctionRegistry legacy_function_registry; + TypeRegistry& type_registry; + FunctionRegistry& function_registry; + + well_known_types::Reflection well_known_types; + + absl::Nonnull MutableMessageFactory() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + // Not thread safe. Adds `keep_alive` to a list owned by this environment + // and ensures it survives at least as long as this environment. Keep alives + // are released in reverse order of their registration. This mimics normal + // destructor rules of members. + // + // IMPORTANT: This should only be when building the runtime, and not after. + void KeepAlive(std::shared_ptr keep_alive); +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ diff --git a/runtime/internal/runtime_env_testing.cc b/runtime/internal/runtime_env_testing.cc new file mode 100644 index 000000000..ae7dd0ab9 --- /dev/null +++ b/runtime/internal/runtime_env_testing.cc @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_env_testing.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/internal/runtime_env.h" + +namespace cel::runtime_internal { + +absl::Nonnull> NewTestingRuntimeEnv() { + auto env = + std::make_shared(internal::GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(env->Initialize()); // Crash OK + return env; +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_env_testing.h b/runtime/internal/runtime_env_testing.h new file mode 100644 index 000000000..1645ce4dd --- /dev/null +++ b/runtime/internal/runtime_env_testing.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ + +#include + +#include "absl/base/nullability.h" +#include "runtime/internal/runtime_env.h" + +namespace cel::runtime_internal { + +absl::Nonnull> NewTestingRuntimeEnv(); + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ diff --git a/runtime/internal/runtime_impl.h b/runtime/internal/runtime_impl.h index 4dc2fe929..0c4972fcf 100644 --- a/runtime/internal/runtime_impl.h +++ b/runtime/internal/runtime_impl.h @@ -28,48 +28,51 @@ #include "eval/compiler/flat_expr_builder.h" #include "internal/well_known_types.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" -#include "google/protobuf/descriptor.h" namespace cel::runtime_internal { class RuntimeImpl : public Runtime { public: - struct Environment { - ABSL_ATTRIBUTE_UNUSED - absl::Nonnull> - descriptor_pool; - TypeRegistry type_registry; - FunctionRegistry function_registry; - well_known_types::Reflection well_known_types; - }; + using Environment = RuntimeEnv; RuntimeImpl(absl::Nonnull> environment, const RuntimeOptions& options) : environment_(std::move(environment)), - expr_builder_(environment_->function_registry, - environment_->type_registry, options) { + expr_builder_(environment_, options) { ABSL_DCHECK(environment_->well_known_types.IsInitialized()); } - TypeRegistry& type_registry() { return environment_->type_registry; } - const TypeRegistry& type_registry() const { + TypeRegistry& type_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->type_registry; + } + const TypeRegistry& type_registry() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->type_registry; } - FunctionRegistry& function_registry() { + FunctionRegistry& function_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->function_registry; } - const FunctionRegistry& function_registry() const { + const FunctionRegistry& function_registry() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->function_registry; } - const well_known_types::Reflection& well_known_types() const { + const well_known_types::Reflection& well_known_types() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->well_known_types; } + Environment& environment() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *environment_; + } + const Environment& environment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *environment_; + } + // implement Runtime absl::StatusOr> CreateProgram( std::unique_ptr ast, @@ -84,7 +87,8 @@ class RuntimeImpl : public Runtime { } // exposed for extensions access - google::api::expr::runtime::FlatExprBuilder& expr_builder() { + google::api::expr::runtime::FlatExprBuilder& expr_builder() + ABSL_ATTRIBUTE_LIFETIME_BOUND { return expr_builder_; } diff --git a/runtime/regex_precompilation_test.cc b/runtime/regex_precompilation_test.cc index b5da4aa4e..5cbdb291c 100644 --- a/runtime/regex_precompilation_test.cc +++ b/runtime/regex_precompilation_test.cc @@ -20,6 +20,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "base/function_adapter.h" @@ -39,6 +40,7 @@ namespace cel::extensions { namespace { +using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; @@ -89,9 +91,9 @@ TEST_P(RegexPrecompilationTest, Basic) { return StringValue::Concat(f, prefix, value); }, builder.function_registry()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); - ASSERT_OK(EnableRegexPrecompilation(builder)); + ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); @@ -136,11 +138,10 @@ TEST_P(RegexPrecompilationTest, WithConstantFolding) { return StringValue::Concat(f, prefix, value); }, builder.function_registry()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); - ASSERT_OK( - EnableConstantFolding(builder, MemoryManagerRef::ReferenceCounting())); - ASSERT_OK(EnableRegexPrecompilation(builder)); + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); diff --git a/runtime/runtime_builder_factory.cc b/runtime/runtime_builder_factory.cc index 34e16b03a..cdfb0058f 100644 --- a/runtime/runtime_builder_factory.cc +++ b/runtime/runtime_builder_factory.cc @@ -22,13 +22,16 @@ #include "absl/status/statusor.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" +#include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" +#include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" namespace cel { +using ::cel::runtime_internal::RuntimeEnv; using ::cel::runtime_internal::RuntimeImpl; absl::StatusOr CreateRuntimeBuilder( @@ -51,10 +54,8 @@ absl::StatusOr CreateRuntimeBuilder( // TODO: add API for attaching an issue listener (replacing the // vector overloads). ABSL_DCHECK(descriptor_pool != nullptr); - auto environment = std::make_shared(); - environment->descriptor_pool = std::move(descriptor_pool); - CEL_RETURN_IF_ERROR(environment->well_known_types.Initialize( - environment->descriptor_pool.get())); + auto environment = std::make_shared(std::move(descriptor_pool)); + CEL_RETURN_IF_ERROR(environment->Initialize()); auto runtime_impl = std::make_unique(std::move(environment), options); runtime_impl->expr_builder().set_container(options.container);