Skip to content

Commit

Permalink
Update assignability checks for lists and maps to consider all elemen…
Browse files Browse the repository at this point in the history
…ts before accepting

new inferred types.

PiperOrigin-RevId: 693403660
  • Loading branch information
jnthntatum authored and copybara-github committed Nov 5, 2024
1 parent 834c7fd commit e8fdff4
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 17 deletions.
1 change: 0 additions & 1 deletion checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ cc_test(
"//common:expr",
"//common:source",
"//common:type",
"//extensions/protobuf:value",
"//internal:status_macros",
"//internal:testing",
"//internal:testing_descriptor_pool",
Expand Down
21 changes: 19 additions & 2 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
Type overall_value_type =
inference_context_->InstantiateTypeParams(TypeParamType("V"));

auto assignability_context = inference_context_->CreateAssignabilityContext();
for (const auto& entry : map.entries()) {
const Expr* key = &entry.key();
Type key_type = GetTypeOrDyn(key);
Expand All @@ -593,10 +594,17 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
inference_context_->FinalizeType(key_type).DebugString())));
}

if (!inference_context_->IsAssignable(key_type, overall_key_type)) {
if (!assignability_context.IsAssignable(key_type, overall_key_type)) {
overall_key_type = DynType();
}
}

if (!overall_key_type.IsDyn()) {
assignability_context.UpdateInferredTypeAssignments();
}

assignability_context.Reset();
for (const auto& entry : map.entries()) {
const Expr* value = &entry.value();
Type value_type = GetTypeOrDyn(value);
if (entry.optional()) {
Expand All @@ -613,6 +621,10 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
}
}

if (!overall_value_type.IsDyn()) {
assignability_context.UpdateInferredTypeAssignments();
}

types_[&expr] = inference_context_->FullySubstitute(
MapType(arena_, overall_key_type, overall_value_type));
}
Expand All @@ -622,6 +634,7 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) {

Type overall_elem_type =
inference_context_->InstantiateTypeParams(TypeParamType("E"));
auto assignability_context = inference_context_->CreateAssignabilityContext();
for (const auto& element : list.elements()) {
const Expr* value = &element.expr();
Type value_type = GetTypeOrDyn(value);
Expand All @@ -635,11 +648,15 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) {
}
}

if (!inference_context_->IsAssignable(value_type, overall_elem_type)) {
if (!assignability_context.IsAssignable(value_type, overall_elem_type)) {
overall_elem_type = DynType();
}
}

if (!overall_elem_type.IsDyn()) {
assignability_context.UpdateInferredTypeAssignments();
}

types_[&expr] =
inference_context_->FullySubstitute(ListType(arena_, overall_elem_type));
}
Expand Down
29 changes: 29 additions & 0 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,35 @@ INSTANTIATE_TEST_SUITE_P(
.expected_result_type = AstType(ast_internal::DynamicType()),
}));

INSTANTIATE_TEST_SUITE_P(
TypeInferences, GenericMessagesTest,
::testing::Values(
CheckedExprTestCase{
.expr = "[1, test_msg.single_int64_wrapper]",
.expected_result_type = AstType(ast_internal::ListType(
std::make_unique<AstType>(ast_internal::PrimitiveTypeWrapper(
ast_internal::PrimitiveType::kInt64))))},
CheckedExprTestCase{
.expr = "[1, 2, test_msg.single_int64_wrapper]",
.expected_result_type = AstType(ast_internal::ListType(
std::make_unique<AstType>(ast_internal::PrimitiveTypeWrapper(
ast_internal::PrimitiveType::kInt64))))},
CheckedExprTestCase{
.expr = "[test_msg.single_int64_wrapper, 1]",
.expected_result_type = AstType(ast_internal::ListType(
std::make_unique<AstType>(ast_internal::PrimitiveTypeWrapper(
ast_internal::PrimitiveType::kInt64))))},
CheckedExprTestCase{
.expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]",
.expected_result_type = AstType(ast_internal::ListType(
std::make_unique<AstType>(ast_internal::DynamicType())))},

CheckedExprTestCase{
.expr = "test_msg.single_int64",
.expected_result_type =
AstType(ast_internal::PrimitiveType::kInt64),
}));

class StrictNullAssignmentTest
: public testing::TestWithParam<CheckedExprTestCase> {};

Expand Down
34 changes: 25 additions & 9 deletions checker/internal/type_inference_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,6 @@ bool TypeInferenceContext::IsAssignableInternal(
prospective_substitutions);
}

// Type is as concrete as it can be under current substitutions.
if (absl::optional<Type> wrapped_type = WrapperToPrimitive(to_subs);
wrapped_type.has_value()) {
return IsAssignableInternal(NullType(), from_subs,
prospective_substitutions) ||
IsAssignableInternal(*wrapped_type, from_subs,
prospective_substitutions);
}

// Maybe widen a prospective type binding if it is a member of a union type.
// This enables things like `true ? 1 : single_int64_wrapper` to promote
// the left hand side of the ternary to an int wrapper.
Expand All @@ -287,6 +278,15 @@ bool TypeInferenceContext::IsAssignableInternal(
return true;
}

// Type is as concrete as it can be under current substitutions.
if (absl::optional<Type> wrapped_type = WrapperToPrimitive(to_subs);
wrapped_type.has_value()) {
return IsAssignableInternal(NullType(), from_subs,
prospective_substitutions) ||
IsAssignableInternal(*wrapped_type, from_subs,
prospective_substitutions);
}

// Wrapper types are assignable to their corresponding primitive type (
// somewhat similar to auto unboxing). This is a bit odd with CEL's null_type,
// but there isn't a dedicated syntax for narrowing from the nullable.
Expand Down Expand Up @@ -538,4 +538,20 @@ Type TypeInferenceContext::FullySubstitute(const Type& type,
}
}

bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from,
const Type& to) {
return inference_context_.IsAssignableInternal(from, to,
prospective_substitutions_);
}

void TypeInferenceContext::AssignabilityContext::
UpdateInferredTypeAssignments() {
inference_context_.UpdateTypeParameterBindings(
std::move(prospective_substitutions_));
}

void TypeInferenceContext::AssignabilityContext::Reset() {
prospective_substitutions_.clear();
}

} // namespace cel::checker_internal
62 changes: 57 additions & 5 deletions checker/internal/type_inference_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,68 @@ class TypeInferenceContext {
std::vector<OverloadDecl> overloads;
};

private:
// Alias for a map from type var name to the type it is bound to.
//
// Used for prospective substitutions during type inference to make progress
// without affecting final assigned types.
using SubstitutionMap = absl::flat_hash_map<absl::string_view, Type>;

public:
// Helper class for managing several dependent type assignability checks.
//
// Note: while allowed, updating multiple AssignabilityContexts concurrently
// can lead to inconsistencies in the final type bindings.
class AssignabilityContext {
public:
// Checks if `from` is assignable to `to` with the current type
// substitutions and any additional prospective substitutions in the parent
// inference context.
bool IsAssignable(const Type& from, const Type& to);

// Applies any prospective type assignments to the parent inference context.
//
// This should only be called after all assignability checks have completed.
//
// Leaves the AssignabilityContext in the starting state (i.e. no
// prospective substitutions).
void UpdateInferredTypeAssignments();

// Return the AssignabilityContext to the starting state (i.e. no
// prospective substitutions).
void Reset();

private:
explicit AssignabilityContext(TypeInferenceContext& inference_context)
: inference_context_(inference_context) {}

AssignabilityContext(const AssignabilityContext&) = delete;
AssignabilityContext& operator=(const AssignabilityContext&) = delete;
AssignabilityContext(AssignabilityContext&&) = delete;
AssignabilityContext& operator=(AssignabilityContext&&) = delete;

friend class TypeInferenceContext;

TypeInferenceContext& inference_context_;
SubstitutionMap prospective_substitutions_;
};

explicit TypeInferenceContext(google::protobuf::Arena* arena,
bool enable_legacy_null_assignment = true)
: arena_(arena),
enable_legacy_null_assignment_(enable_legacy_null_assignment) {}

// Creates a new AssignabilityContext for the current inference context.
//
// This is intended for managing several dependent type assignability checks
// that should only be added to the final type bindings if all checks succeed.
//
// Note: while allowed, updating multiple AssignabilityContexts concurrently
// can lead to inconsistencies in the final type bindings.
AssignabilityContext CreateAssignabilityContext()
ABSL_ATTRIBUTE_LIFETIME_BOUND {
return AssignabilityContext(*this);
}
// Resolves any remaining type parameters in the given type to a concrete
// type or dyn.
Type FinalizeType(const Type& type) const {
Expand Down Expand Up @@ -98,11 +155,6 @@ class TypeInferenceContext {
}

private:
// Alias for a map from type var name to the type it is bound to.
//
// Used for prospective substitutions during type inference.
using SubstitutionMap = absl::flat_hash_map<absl::string_view, Type>;

struct TypeVar {
absl::optional<Type> type;
absl::string_view name;
Expand Down
75 changes: 75 additions & 0 deletions checker/internal/type_inference_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,5 +711,80 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) {
ElementsAre(IsTypeKind(TypeKind::kInt)));
}

TEST(TypeInferenceContextTest, AssignabilityContext) {
google::protobuf::Arena arena;
TypeInferenceContext context(&arena);

Type list_of_a = ListType(&arena, TypeParamType("A"));

Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);

{
auto assignability_context = context.CreateAssignabilityContext();
EXPECT_TRUE(assignability_context.IsAssignable(
IntType(), list_of_a_instance.AsList()->GetElement()));
EXPECT_TRUE(assignability_context.IsAssignable(
IntType(), list_of_a_instance.AsList()->GetElement()));
EXPECT_TRUE(assignability_context.IsAssignable(
IntWrapperType(), list_of_a_instance.AsList()->GetElement()));

assignability_context.UpdateInferredTypeAssignments();
}
Type resolved_type = context.FinalizeType(list_of_a_instance);

ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
EXPECT_THAT(resolved_type.AsList()->GetElement(),
IsTypeKind(TypeKind::kIntWrapper));
}

TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) {
google::protobuf::Arena arena;
TypeInferenceContext context(&arena);

Type list_of_a = ListType(&arena, TypeParamType("A"));

Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);

{
auto assignability_context = context.CreateAssignabilityContext();
EXPECT_TRUE(assignability_context.IsAssignable(
IntType(), list_of_a_instance.AsList()->GetElement()));
EXPECT_TRUE(assignability_context.IsAssignable(
IntType(), list_of_a_instance.AsList()->GetElement()));
EXPECT_TRUE(assignability_context.IsAssignable(
IntWrapperType(), list_of_a_instance.AsList()->GetElement()));
}

Type resolved_type = context.FinalizeType(list_of_a_instance);

ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
EXPECT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kDyn));
}

TEST(TypeInferenceContextTest, AssignabilityContextReset) {
google::protobuf::Arena arena;
TypeInferenceContext context(&arena);

Type list_of_a = ListType(&arena, TypeParamType("A"));

Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);

{
auto assignability_context = context.CreateAssignabilityContext();
EXPECT_TRUE(assignability_context.IsAssignable(
IntType(), list_of_a_instance.AsList()->GetElement()));
assignability_context.Reset();
EXPECT_TRUE(assignability_context.IsAssignable(
DoubleType(), list_of_a_instance.AsList()->GetElement()));
assignability_context.UpdateInferredTypeAssignments();
}

Type resolved_type = context.FinalizeType(list_of_a_instance);

ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
EXPECT_THAT(resolved_type.AsList()->GetElement(),
IsTypeKind(TypeKind::kDouble));
}

} // namespace
} // namespace cel::checker_internal

0 comments on commit e8fdff4

Please sign in to comment.