Skip to content

Commit

Permalink
Update type assignment widening behavior to more closely follow the '…
Browse files Browse the repository at this point in the history
…MoreGeneral' check in the Go and Java implementations.

PiperOrigin-RevId: 693412150
  • Loading branch information
jnthntatum authored and copybara-github committed Nov 5, 2024
1 parent e8fdff4 commit 9310c49
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 16 deletions.
1 change: 1 addition & 0 deletions checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ cc_test(
"//internal:status_macros",
"//internal:testing",
"//internal:testing_descriptor_pool",
"//testutil:baseline_tests",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
64 changes: 61 additions & 3 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "internal/status_macros.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
#include "testutil/baseline_tests.h"
#include "cel/expr/conformance/proto2/test_all_types.pb.h"
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
#include "google/protobuf/arena.h"
Expand Down Expand Up @@ -221,11 +222,18 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull<google::protobuf::Arena*> are

FunctionDecl ternary_op;
ternary_op.set_name("_?_:_");
CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl(
CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl(
"conditional",
/*return_type=*/
TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A"))));

FunctionDecl index_op;
index_op.set_name("_[_]");
CEL_RETURN_IF_ERROR(index_op.AddOverload(MakeOverloadDecl(
"index",
/*return_type=*/
TypeParamType("A"), ListType(arena, TypeParamType("A")), IntType())));

FunctionDecl to_int;
to_int.set_name("int");
CEL_RETURN_IF_ERROR(to_int.AddOverload(
Expand Down Expand Up @@ -268,6 +276,7 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull<google::protobuf::Arena*> are
env.InsertFunctionIfAbsent(std::move(to_int));
env.InsertFunctionIfAbsent(std::move(eq_op));
env.InsertFunctionIfAbsent(std::move(ternary_op));
env.InsertFunctionIfAbsent(std::move(index_op));
env.InsertFunctionIfAbsent(std::move(to_dyn));
env.InsertFunctionIfAbsent(std::move(to_type));
env.InsertFunctionIfAbsent(std::move(to_duration));
Expand Down Expand Up @@ -1543,7 +1552,8 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) {
const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_THAT(ast_impl.type_map(),
Contains(Pair(ast_impl.root_expr().id(),
Eq(test_case.expected_result_type))));
Eq(test_case.expected_result_type))))
<< cel::test::FormatBaselineAst(*checked_ast);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down Expand Up @@ -2039,11 +2049,59 @@ INSTANTIATE_TEST_SUITE_P(
.expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]",
.expected_result_type = AstType(ast_internal::ListType(
std::make_unique<AstType>(ast_internal::DynamicType())))},

CheckedExprTestCase{
.expr = "[null, test_msg][0]",
.expected_result_type = AstType(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes"))},
CheckedExprTestCase{
.expr = "[{'k': dyn(1)}, {dyn('k'): 1}][0]",
// Ambiguous type resolution, but we prefer the first option.
.expected_result_type = AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
std::make_unique<AstType>(ast_internal::DynamicType())))},
CheckedExprTestCase{
.expr = "[{'k': 1}, {dyn('k'): 1}][0]",
.expected_result_type = AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::DynamicType()),
std::make_unique<AstType>(
ast_internal::PrimitiveType::kInt64)))},
CheckedExprTestCase{
.expr = "[{dyn('k'): 1}, {'k': 1}][0]",
.expected_result_type = AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::DynamicType()),
std::make_unique<AstType>(
ast_internal::PrimitiveType::kInt64)))},
CheckedExprTestCase{
.expr = "[{'k': 1}, {'k': dyn(1)}][0]",
.expected_result_type = AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
std::make_unique<AstType>(ast_internal::DynamicType())))},
CheckedExprTestCase{
.expr = "[{'k': 1}, {dyn('k'): dyn(1)}][0]",
.expected_result_type = AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::DynamicType()),
std::make_unique<AstType>(ast_internal::DynamicType())))},
CheckedExprTestCase{
.expr =
"[{'k': 1.0}, {dyn('k'): test_msg.single_int64_wrapper}][0]",
.expected_result_type = AstType(ast_internal::DynamicType())},
CheckedExprTestCase{
.expr = "test_msg.single_int64",
.expected_result_type =
AstType(ast_internal::PrimitiveType::kInt64),
},
CheckedExprTestCase{
.expr = "[[1], {1: 2u}][0]",
.expected_result_type = AstType(ast_internal::DynamicType()),
},
CheckedExprTestCase{
.expr = "[{1: 2u}, [1]][0]",
.expected_result_type = AstType(ast_internal::DynamicType()),
},
CheckedExprTestCase{
.expr = "[test_msg.single_int64_wrapper,"
" test_msg.single_string_wrapper][0]",
.expected_result_type = AstType(ast_internal::DynamicType()),
}));

class StrictNullAssignmentTest
Expand Down
103 changes: 90 additions & 13 deletions checker/internal/type_inference_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,28 +261,30 @@ bool TypeInferenceContext::IsAssignableInternal(
prospective_substitutions);
}

// Maybe widen a prospective type binding if it is a member of a union type.
// This enables things like `true ? 1 : single_int64_wrapper` to promote
// the left hand side of the ternary to an int wrapper.
// This is a bit restricted to encourage more specific type -> type var
// assignments.
// Maybe widen a prospective type binding if another potential binding is
// more general and admits the previous binding.
if (
// Checking assignability to a specific type var
// that has a prospective type assignment.
to.kind() == TypeKind::kTypeParam &&
prospective_substitutions.contains(to.AsTypeParam()->name()) &&
// from is a more general type that to and accepts the current
// prospective binding for to.
IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) {
prospective_substitutions[to.AsTypeParam()->name()] = from_subs;
return true;
prospective_substitutions.contains(to.AsTypeParam()->name())) {
auto prospective_subs_cpy(prospective_substitutions);
if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) ==
RelativeGenerality::kMoreGeneral) {
if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) &&
!OccursWithin(to.name(), from_subs, prospective_subs_cpy)) {
prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs;
prospective_substitutions = prospective_subs_cpy;
return true;
// otherwise, continue with normal assignability check.
}
}
}

// Type is as concrete as it can be under current substitutions.
if (absl::optional<Type> wrapped_type = WrapperToPrimitive(to_subs);
wrapped_type.has_value()) {
return IsAssignableInternal(NullType(), from_subs,
prospective_substitutions) ||
return from_subs.IsNull() ||
IsAssignableInternal(*wrapped_type, from_subs,
prospective_substitutions);
}
Expand Down Expand Up @@ -364,6 +366,81 @@ Type TypeInferenceContext::Substitute(
return subs;
}

TypeInferenceContext::RelativeGenerality
TypeInferenceContext::CompareGenerality(
const Type& from, const Type& to,
const SubstitutionMap& prospective_substitutions) const {
Type from_subs = Substitute(from, prospective_substitutions);
Type to_subs = Substitute(to, prospective_substitutions);

if (from_subs == to_subs) {
return RelativeGenerality::kEquivalent;
}

if (IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) {
return RelativeGenerality::kMoreGeneral;
}

if (IsUnionType(to_subs)) {
return RelativeGenerality::kLessGeneral;
}

if (enable_legacy_null_assignment_ && IsLegacyNullable(from_subs) &&
to_subs.IsNull()) {
return RelativeGenerality::kMoreGeneral;
}

// Not a polytype. Check if it is a parameterized type and all parameters are
// equivalent and at least one is more general.
if (from_subs.IsList() && to_subs.IsList()) {
return CompareGenerality(from_subs.AsList()->GetElement(),
to_subs.AsList()->GetElement(),
prospective_substitutions);
}

if (from_subs.IsMap() && to_subs.IsMap()) {
RelativeGenerality key_generality =
CompareGenerality(from_subs.AsMap()->GetKey(),
to_subs.AsMap()->GetKey(), prospective_substitutions);
RelativeGenerality value_generality = CompareGenerality(
from_subs.AsMap()->GetValue(), to_subs.AsMap()->GetValue(),
prospective_substitutions);
if (key_generality == RelativeGenerality::kLessGeneral ||
value_generality == RelativeGenerality::kLessGeneral) {
return RelativeGenerality::kLessGeneral;
}
if (key_generality == RelativeGenerality::kMoreGeneral ||
value_generality == RelativeGenerality::kMoreGeneral) {
return RelativeGenerality::kMoreGeneral;
}
return RelativeGenerality::kEquivalent;
}

if (from_subs.IsOpaque() && to_subs.IsOpaque() &&
from_subs.AsOpaque()->name() == to_subs.AsOpaque()->name() &&
from_subs.AsOpaque()->GetParameters().size() ==
to_subs.AsOpaque()->GetParameters().size()) {
RelativeGenerality max_generality = RelativeGenerality::kEquivalent;
for (int i = 0; i < from_subs.AsOpaque()->GetParameters().size(); ++i) {
RelativeGenerality generality = CompareGenerality(
from_subs.AsOpaque()->GetParameters()[i],
to_subs.AsOpaque()->GetParameters()[i], prospective_substitutions);
if (generality == RelativeGenerality::kLessGeneral) {
return RelativeGenerality::kLessGeneral;
}
if (generality == RelativeGenerality::kMoreGeneral) {
max_generality = RelativeGenerality::kMoreGeneral;
}
}
return max_generality;
}

// Default not comparable. Since we ruled out polytypes, they should be
// equivalent for the purposes of deciding the most general eligible
// substitution.
return RelativeGenerality::kEquivalent;
}

bool TypeInferenceContext::OccursWithin(
absl::string_view var_name, const Type& type,
const SubstitutionMap& substitutions) const {
Expand Down
19 changes: 19 additions & 0 deletions checker/internal/type_inference_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ class TypeInferenceContext {
absl::string_view name;
};

// Relative generality between two types.
enum class RelativeGenerality {
kMoreGeneral,
// Note: kLessGeneral does not imply it is definitely more specific, only
// that we cannot determine if equivalent or more general.
kLessGeneral,
kEquivalent,
};

absl::string_view NewTypeVar(absl::string_view name = "") {
next_type_parameter_id_++;
auto inserted = type_parameter_bindings_.insert(
Expand Down Expand Up @@ -190,6 +199,16 @@ class TypeInferenceContext {
bool IsAssignableWithConstraints(const Type& from, const Type& to,
SubstitutionMap& prospective_substitutions);

// Relative generality of `from` as compared to `to` with the current type
// substitutions and any additional prospective substitutions.
//
// Generality is only defined as a partial ordering. Some types are
// incomparable. However we only need to know if a type is definitely more
// general or not.
RelativeGenerality CompareGenerality(
const Type& from, const Type& to,
const SubstitutionMap& prospective_substitutions) const;

Type Substitute(const Type& type, const SubstitutionMap& substitutions) const;

bool OccursWithin(absl::string_view var_name, const Type& type,
Expand Down
60 changes: 60 additions & 0 deletions checker/internal/type_inference_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,66 @@ TEST(TypeInferenceContextTest, AssignabilityContext) {
IsTypeKind(TypeKind::kIntWrapper));
}

TEST(TypeInferenceContextTest, AssignabilityContextAbstractType) {
google::protobuf::Arena arena;
TypeInferenceContext context(&arena);

Type list_of_a = ListType(&arena, TypeParamType("A"));

Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);

{
auto assignability_context = context.CreateAssignabilityContext();
EXPECT_TRUE(assignability_context.IsAssignable(
OptionalType(&arena, IntType()),
list_of_a_instance.AsList()->GetElement()));
EXPECT_TRUE(assignability_context.IsAssignable(
OptionalType(&arena, DynType()),
list_of_a_instance.AsList()->GetElement()));

assignability_context.UpdateInferredTypeAssignments();
}
Type resolved_type = context.FinalizeType(list_of_a_instance);

ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
ASSERT_THAT(resolved_type.AsList()->GetElement(),
IsTypeKind(TypeKind::kOpaque));
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(),
"optional_type");
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(),
ElementsAre(IsTypeKind(TypeKind::kDyn)));
}

TEST(TypeInferenceContextTest, AssignabilityContextAbstractTypeWrapper) {
google::protobuf::Arena arena;
TypeInferenceContext context(&arena);

Type list_of_a = ListType(&arena, TypeParamType("A"));

Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);

{
auto assignability_context = context.CreateAssignabilityContext();
EXPECT_TRUE(assignability_context.IsAssignable(
OptionalType(&arena, IntType()),
list_of_a_instance.AsList()->GetElement()));
EXPECT_TRUE(assignability_context.IsAssignable(
OptionalType(&arena, IntWrapperType()),
list_of_a_instance.AsList()->GetElement()));

assignability_context.UpdateInferredTypeAssignments();
}
Type resolved_type = context.FinalizeType(list_of_a_instance);

ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
ASSERT_THAT(resolved_type.AsList()->GetElement(),
IsTypeKind(TypeKind::kOpaque));
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(),
"optional_type");
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(),
ElementsAre(IsTypeKind(TypeKind::kIntWrapper)));
}

TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) {
google::protobuf::Arena arena;
TypeInferenceContext context(&arena);
Expand Down
12 changes: 12 additions & 0 deletions checker/optional_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ INSTANTIATE_TEST_SUITE_P(
new AstType(ast_internal::PrimitiveType::kString)))))},
TestCase{"['v1', ?'v2']", _,
"expected type 'optional_type<string>' but found 'string'"},
TestCase{"[optional.of(dyn('1')), optional.of('2')][0]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
TestCase{"[optional.of('1'), optional.of(2)][0]",
Eq(AstType(ast_internal::DynamicType()))},
TestCase{"['v1', ?'v2']", _,
"expected type 'optional_type<string>' but found 'string'"},
TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: "
"optional.of(1)}",
Eq(AstType(ast_internal::MessageType(
Expand Down

0 comments on commit 9310c49

Please sign in to comment.