From 9310c4910e598362695930f0e11b7f278f714755 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 5 Nov 2024 11:05:34 -0800 Subject: [PATCH] Update type assignment widening behavior to more closely follow the 'MoreGeneral' check in the Go and Java implementations. PiperOrigin-RevId: 693412150 --- checker/internal/BUILD | 1 + checker/internal/type_checker_impl_test.cc | 64 ++++++++++- checker/internal/type_inference_context.cc | 103 +++++++++++++++--- checker/internal/type_inference_context.h | 19 ++++ .../internal/type_inference_context_test.cc | 60 ++++++++++ checker/optional_test.cc | 12 ++ 6 files changed, 243 insertions(+), 16 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 2fbbf47d2..68ea74f4f 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -165,6 +165,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//testutil:baseline_tests", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index d4eb2c1a3..d64e22cc3 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -45,6 +45,7 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "testutil/baseline_tests.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -221,11 +222,18 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull are FunctionDecl ternary_op; ternary_op.set_name("_?_:_"); - CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( + CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl( "conditional", /*return_type=*/ TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A")))); + FunctionDecl index_op; + index_op.set_name("_[_]"); + CEL_RETURN_IF_ERROR(index_op.AddOverload(MakeOverloadDecl( + "index", + /*return_type=*/ + TypeParamType("A"), ListType(arena, TypeParamType("A")), IntType()))); + FunctionDecl to_int; to_int.set_name("int"); CEL_RETURN_IF_ERROR(to_int.AddOverload( @@ -268,6 +276,7 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull are env.InsertFunctionIfAbsent(std::move(to_int)); env.InsertFunctionIfAbsent(std::move(eq_op)); env.InsertFunctionIfAbsent(std::move(ternary_op)); + env.InsertFunctionIfAbsent(std::move(index_op)); env.InsertFunctionIfAbsent(std::move(to_dyn)); env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); @@ -1543,7 +1552,8 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); EXPECT_THAT(ast_impl.type_map(), Contains(Pair(ast_impl.root_expr().id(), - Eq(test_case.expected_result_type)))); + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); } INSTANTIATE_TEST_SUITE_P( @@ -2039,11 +2049,59 @@ INSTANTIATE_TEST_SUITE_P( .expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]", .expected_result_type = AstType(ast_internal::ListType( std::make_unique(ast_internal::DynamicType())))}, - + CheckedExprTestCase{ + .expr = "[null, test_msg][0]", + .expected_result_type = AstType(ast_internal::MessageType( + "cel.expr.conformance.proto3.TestAllTypes"))}, + CheckedExprTestCase{ + .expr = "[{'k': dyn(1)}, {dyn('k'): 1}][0]", + // Ambiguous type resolution, but we prefer the first option. + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType())))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {dyn('k'): 1}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::DynamicType()), + std::make_unique( + ast_internal::PrimitiveType::kInt64)))}, + CheckedExprTestCase{ + .expr = "[{dyn('k'): 1}, {'k': 1}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::DynamicType()), + std::make_unique( + ast_internal::PrimitiveType::kInt64)))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {'k': dyn(1)}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType())))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {dyn('k'): dyn(1)}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::DynamicType()), + std::make_unique(ast_internal::DynamicType())))}, + CheckedExprTestCase{ + .expr = + "[{'k': 1.0}, {dyn('k'): test_msg.single_int64_wrapper}][0]", + .expected_result_type = AstType(ast_internal::DynamicType())}, CheckedExprTestCase{ .expr = "test_msg.single_int64", .expected_result_type = AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "[[1], {1: 2u}][0]", + .expected_result_type = AstType(ast_internal::DynamicType()), + }, + CheckedExprTestCase{ + .expr = "[{1: 2u}, [1]][0]", + .expected_result_type = AstType(ast_internal::DynamicType()), + }, + CheckedExprTestCase{ + .expr = "[test_msg.single_int64_wrapper," + " test_msg.single_string_wrapper][0]", + .expected_result_type = AstType(ast_internal::DynamicType()), })); class StrictNullAssignmentTest diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 4c4900058..19d59daec 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -261,28 +261,30 @@ bool TypeInferenceContext::IsAssignableInternal( prospective_substitutions); } - // Maybe widen a prospective type binding if it is a member of a union type. - // This enables things like `true ? 1 : single_int64_wrapper` to promote - // the left hand side of the ternary to an int wrapper. - // This is a bit restricted to encourage more specific type -> type var - // assignments. + // Maybe widen a prospective type binding if another potential binding is + // more general and admits the previous binding. if ( // Checking assignability to a specific type var // that has a prospective type assignment. to.kind() == TypeKind::kTypeParam && - prospective_substitutions.contains(to.AsTypeParam()->name()) && - // from is a more general type that to and accepts the current - // prospective binding for to. - IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) { - prospective_substitutions[to.AsTypeParam()->name()] = from_subs; - return true; + prospective_substitutions.contains(to.AsTypeParam()->name())) { + auto prospective_subs_cpy(prospective_substitutions); + if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) == + RelativeGenerality::kMoreGeneral) { + if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) && + !OccursWithin(to.name(), from_subs, prospective_subs_cpy)) { + prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs; + prospective_substitutions = prospective_subs_cpy; + return true; + // otherwise, continue with normal assignability check. + } + } } // Type is as concrete as it can be under current substitutions. if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { - return IsAssignableInternal(NullType(), from_subs, - prospective_substitutions) || + return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, prospective_substitutions); } @@ -364,6 +366,81 @@ Type TypeInferenceContext::Substitute( return subs; } +TypeInferenceContext::RelativeGenerality +TypeInferenceContext::CompareGenerality( + const Type& from, const Type& to, + const SubstitutionMap& prospective_substitutions) const { + Type from_subs = Substitute(from, prospective_substitutions); + Type to_subs = Substitute(to, prospective_substitutions); + + if (from_subs == to_subs) { + return RelativeGenerality::kEquivalent; + } + + if (IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) { + return RelativeGenerality::kMoreGeneral; + } + + if (IsUnionType(to_subs)) { + return RelativeGenerality::kLessGeneral; + } + + if (enable_legacy_null_assignment_ && IsLegacyNullable(from_subs) && + to_subs.IsNull()) { + return RelativeGenerality::kMoreGeneral; + } + + // Not a polytype. Check if it is a parameterized type and all parameters are + // equivalent and at least one is more general. + if (from_subs.IsList() && to_subs.IsList()) { + return CompareGenerality(from_subs.AsList()->GetElement(), + to_subs.AsList()->GetElement(), + prospective_substitutions); + } + + if (from_subs.IsMap() && to_subs.IsMap()) { + RelativeGenerality key_generality = + CompareGenerality(from_subs.AsMap()->GetKey(), + to_subs.AsMap()->GetKey(), prospective_substitutions); + RelativeGenerality value_generality = CompareGenerality( + from_subs.AsMap()->GetValue(), to_subs.AsMap()->GetValue(), + prospective_substitutions); + if (key_generality == RelativeGenerality::kLessGeneral || + value_generality == RelativeGenerality::kLessGeneral) { + return RelativeGenerality::kLessGeneral; + } + if (key_generality == RelativeGenerality::kMoreGeneral || + value_generality == RelativeGenerality::kMoreGeneral) { + return RelativeGenerality::kMoreGeneral; + } + return RelativeGenerality::kEquivalent; + } + + if (from_subs.IsOpaque() && to_subs.IsOpaque() && + from_subs.AsOpaque()->name() == to_subs.AsOpaque()->name() && + from_subs.AsOpaque()->GetParameters().size() == + to_subs.AsOpaque()->GetParameters().size()) { + RelativeGenerality max_generality = RelativeGenerality::kEquivalent; + for (int i = 0; i < from_subs.AsOpaque()->GetParameters().size(); ++i) { + RelativeGenerality generality = CompareGenerality( + from_subs.AsOpaque()->GetParameters()[i], + to_subs.AsOpaque()->GetParameters()[i], prospective_substitutions); + if (generality == RelativeGenerality::kLessGeneral) { + return RelativeGenerality::kLessGeneral; + } + if (generality == RelativeGenerality::kMoreGeneral) { + max_generality = RelativeGenerality::kMoreGeneral; + } + } + return max_generality; + } + + // Default not comparable. Since we ruled out polytypes, they should be + // equivalent for the purposes of deciding the most general eligible + // substitution. + return RelativeGenerality::kEquivalent; +} + bool TypeInferenceContext::OccursWithin( absl::string_view var_name, const Type& type, const SubstitutionMap& substitutions) const { diff --git a/checker/internal/type_inference_context.h b/checker/internal/type_inference_context.h index 3b1939d2b..898af657f 100644 --- a/checker/internal/type_inference_context.h +++ b/checker/internal/type_inference_context.h @@ -160,6 +160,15 @@ class TypeInferenceContext { absl::string_view name; }; + // Relative generality between two types. + enum class RelativeGenerality { + kMoreGeneral, + // Note: kLessGeneral does not imply it is definitely more specific, only + // that we cannot determine if equivalent or more general. + kLessGeneral, + kEquivalent, + }; + absl::string_view NewTypeVar(absl::string_view name = "") { next_type_parameter_id_++; auto inserted = type_parameter_bindings_.insert( @@ -190,6 +199,16 @@ class TypeInferenceContext { bool IsAssignableWithConstraints(const Type& from, const Type& to, SubstitutionMap& prospective_substitutions); + // Relative generality of `from` as compared to `to` with the current type + // substitutions and any additional prospective substitutions. + // + // Generality is only defined as a partial ordering. Some types are + // incomparable. However we only need to know if a type is definitely more + // general or not. + RelativeGenerality CompareGenerality( + const Type& from, const Type& to, + const SubstitutionMap& prospective_substitutions) const; + Type Substitute(const Type& type, const SubstitutionMap& substitutions) const; bool OccursWithin(absl::string_view var_name, const Type& type, diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index bc9513574..93543c82d 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -737,6 +737,66 @@ TEST(TypeInferenceContextTest, AssignabilityContext) { IsTypeKind(TypeKind::kIntWrapper)); } +TEST(TypeInferenceContextTest, AssignabilityContextAbstractType) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntType()), + list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, DynType()), + list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + ASSERT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kOpaque)); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), + "optional_type"); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kDyn))); +} + +TEST(TypeInferenceContextTest, AssignabilityContextAbstractTypeWrapper) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntType()), + list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntWrapperType()), + list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + ASSERT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kOpaque)); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), + "optional_type"); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kIntWrapper))); +} + TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); diff --git a/checker/optional_test.cc b/checker/optional_test.cc index ae4383883..126225668 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -227,6 +227,18 @@ INSTANTIATE_TEST_SUITE_P( new AstType(ast_internal::PrimitiveType::kString)))))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type' but found 'string'"}, + TestCase{"[optional.of(dyn('1')), optional.of('2')][0]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[optional.of('1'), optional.of(2)][0]", + Eq(AstType(ast_internal::DynamicType()))}, + TestCase{"['v1', ?'v2']", _, + "expected type 'optional_type' but found 'string'"}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: " "optional.of(1)}", Eq(AstType(ast_internal::MessageType(