Skip to content

Commit 9310c49

Browse files
jnthntatumcopybara-github
authored andcommitted
Update type assignment widening behavior to more closely follow the 'MoreGeneral' check in the Go and Java implementations.
PiperOrigin-RevId: 693412150
1 parent e8fdff4 commit 9310c49

File tree

6 files changed

+243
-16
lines changed

6 files changed

+243
-16
lines changed

checker/internal/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ cc_test(
165165
"//internal:status_macros",
166166
"//internal:testing",
167167
"//internal:testing_descriptor_pool",
168+
"//testutil:baseline_tests",
168169
"@com_google_absl//absl/base:no_destructor",
169170
"@com_google_absl//absl/base:nullability",
170171
"@com_google_absl//absl/container:flat_hash_set",

checker/internal/type_checker_impl_test.cc

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "internal/status_macros.h"
4646
#include "internal/testing.h"
4747
#include "internal/testing_descriptor_pool.h"
48+
#include "testutil/baseline_tests.h"
4849
#include "cel/expr/conformance/proto2/test_all_types.pb.h"
4950
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
5051
#include "google/protobuf/arena.h"
@@ -221,11 +222,18 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull<google::protobuf::Arena*> are
221222

222223
FunctionDecl ternary_op;
223224
ternary_op.set_name("_?_:_");
224-
CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl(
225+
CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl(
225226
"conditional",
226227
/*return_type=*/
227228
TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A"))));
228229

230+
FunctionDecl index_op;
231+
index_op.set_name("_[_]");
232+
CEL_RETURN_IF_ERROR(index_op.AddOverload(MakeOverloadDecl(
233+
"index",
234+
/*return_type=*/
235+
TypeParamType("A"), ListType(arena, TypeParamType("A")), IntType())));
236+
229237
FunctionDecl to_int;
230238
to_int.set_name("int");
231239
CEL_RETURN_IF_ERROR(to_int.AddOverload(
@@ -268,6 +276,7 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull<google::protobuf::Arena*> are
268276
env.InsertFunctionIfAbsent(std::move(to_int));
269277
env.InsertFunctionIfAbsent(std::move(eq_op));
270278
env.InsertFunctionIfAbsent(std::move(ternary_op));
279+
env.InsertFunctionIfAbsent(std::move(index_op));
271280
env.InsertFunctionIfAbsent(std::move(to_dyn));
272281
env.InsertFunctionIfAbsent(std::move(to_type));
273282
env.InsertFunctionIfAbsent(std::move(to_duration));
@@ -1543,7 +1552,8 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) {
15431552
const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
15441553
EXPECT_THAT(ast_impl.type_map(),
15451554
Contains(Pair(ast_impl.root_expr().id(),
1546-
Eq(test_case.expected_result_type))));
1555+
Eq(test_case.expected_result_type))))
1556+
<< cel::test::FormatBaselineAst(*checked_ast);
15471557
}
15481558

15491559
INSTANTIATE_TEST_SUITE_P(
@@ -2039,11 +2049,59 @@ INSTANTIATE_TEST_SUITE_P(
20392049
.expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]",
20402050
.expected_result_type = AstType(ast_internal::ListType(
20412051
std::make_unique<AstType>(ast_internal::DynamicType())))},
2042-
2052+
CheckedExprTestCase{
2053+
.expr = "[null, test_msg][0]",
2054+
.expected_result_type = AstType(ast_internal::MessageType(
2055+
"cel.expr.conformance.proto3.TestAllTypes"))},
2056+
CheckedExprTestCase{
2057+
.expr = "[{'k': dyn(1)}, {dyn('k'): 1}][0]",
2058+
// Ambiguous type resolution, but we prefer the first option.
2059+
.expected_result_type = AstType(ast_internal::MapType(
2060+
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
2061+
std::make_unique<AstType>(ast_internal::DynamicType())))},
2062+
CheckedExprTestCase{
2063+
.expr = "[{'k': 1}, {dyn('k'): 1}][0]",
2064+
.expected_result_type = AstType(ast_internal::MapType(
2065+
std::make_unique<AstType>(ast_internal::DynamicType()),
2066+
std::make_unique<AstType>(
2067+
ast_internal::PrimitiveType::kInt64)))},
2068+
CheckedExprTestCase{
2069+
.expr = "[{dyn('k'): 1}, {'k': 1}][0]",
2070+
.expected_result_type = AstType(ast_internal::MapType(
2071+
std::make_unique<AstType>(ast_internal::DynamicType()),
2072+
std::make_unique<AstType>(
2073+
ast_internal::PrimitiveType::kInt64)))},
2074+
CheckedExprTestCase{
2075+
.expr = "[{'k': 1}, {'k': dyn(1)}][0]",
2076+
.expected_result_type = AstType(ast_internal::MapType(
2077+
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
2078+
std::make_unique<AstType>(ast_internal::DynamicType())))},
2079+
CheckedExprTestCase{
2080+
.expr = "[{'k': 1}, {dyn('k'): dyn(1)}][0]",
2081+
.expected_result_type = AstType(ast_internal::MapType(
2082+
std::make_unique<AstType>(ast_internal::DynamicType()),
2083+
std::make_unique<AstType>(ast_internal::DynamicType())))},
2084+
CheckedExprTestCase{
2085+
.expr =
2086+
"[{'k': 1.0}, {dyn('k'): test_msg.single_int64_wrapper}][0]",
2087+
.expected_result_type = AstType(ast_internal::DynamicType())},
20432088
CheckedExprTestCase{
20442089
.expr = "test_msg.single_int64",
20452090
.expected_result_type =
20462091
AstType(ast_internal::PrimitiveType::kInt64),
2092+
},
2093+
CheckedExprTestCase{
2094+
.expr = "[[1], {1: 2u}][0]",
2095+
.expected_result_type = AstType(ast_internal::DynamicType()),
2096+
},
2097+
CheckedExprTestCase{
2098+
.expr = "[{1: 2u}, [1]][0]",
2099+
.expected_result_type = AstType(ast_internal::DynamicType()),
2100+
},
2101+
CheckedExprTestCase{
2102+
.expr = "[test_msg.single_int64_wrapper,"
2103+
" test_msg.single_string_wrapper][0]",
2104+
.expected_result_type = AstType(ast_internal::DynamicType()),
20472105
}));
20482106

20492107
class StrictNullAssignmentTest

checker/internal/type_inference_context.cc

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -261,28 +261,30 @@ bool TypeInferenceContext::IsAssignableInternal(
261261
prospective_substitutions);
262262
}
263263

264-
// Maybe widen a prospective type binding if it is a member of a union type.
265-
// This enables things like `true ? 1 : single_int64_wrapper` to promote
266-
// the left hand side of the ternary to an int wrapper.
267-
// This is a bit restricted to encourage more specific type -> type var
268-
// assignments.
264+
// Maybe widen a prospective type binding if another potential binding is
265+
// more general and admits the previous binding.
269266
if (
270267
// Checking assignability to a specific type var
271268
// that has a prospective type assignment.
272269
to.kind() == TypeKind::kTypeParam &&
273-
prospective_substitutions.contains(to.AsTypeParam()->name()) &&
274-
// from is a more general type that to and accepts the current
275-
// prospective binding for to.
276-
IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) {
277-
prospective_substitutions[to.AsTypeParam()->name()] = from_subs;
278-
return true;
270+
prospective_substitutions.contains(to.AsTypeParam()->name())) {
271+
auto prospective_subs_cpy(prospective_substitutions);
272+
if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) ==
273+
RelativeGenerality::kMoreGeneral) {
274+
if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) &&
275+
!OccursWithin(to.name(), from_subs, prospective_subs_cpy)) {
276+
prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs;
277+
prospective_substitutions = prospective_subs_cpy;
278+
return true;
279+
// otherwise, continue with normal assignability check.
280+
}
281+
}
279282
}
280283

281284
// Type is as concrete as it can be under current substitutions.
282285
if (absl::optional<Type> wrapped_type = WrapperToPrimitive(to_subs);
283286
wrapped_type.has_value()) {
284-
return IsAssignableInternal(NullType(), from_subs,
285-
prospective_substitutions) ||
287+
return from_subs.IsNull() ||
286288
IsAssignableInternal(*wrapped_type, from_subs,
287289
prospective_substitutions);
288290
}
@@ -364,6 +366,81 @@ Type TypeInferenceContext::Substitute(
364366
return subs;
365367
}
366368

369+
TypeInferenceContext::RelativeGenerality
370+
TypeInferenceContext::CompareGenerality(
371+
const Type& from, const Type& to,
372+
const SubstitutionMap& prospective_substitutions) const {
373+
Type from_subs = Substitute(from, prospective_substitutions);
374+
Type to_subs = Substitute(to, prospective_substitutions);
375+
376+
if (from_subs == to_subs) {
377+
return RelativeGenerality::kEquivalent;
378+
}
379+
380+
if (IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) {
381+
return RelativeGenerality::kMoreGeneral;
382+
}
383+
384+
if (IsUnionType(to_subs)) {
385+
return RelativeGenerality::kLessGeneral;
386+
}
387+
388+
if (enable_legacy_null_assignment_ && IsLegacyNullable(from_subs) &&
389+
to_subs.IsNull()) {
390+
return RelativeGenerality::kMoreGeneral;
391+
}
392+
393+
// Not a polytype. Check if it is a parameterized type and all parameters are
394+
// equivalent and at least one is more general.
395+
if (from_subs.IsList() && to_subs.IsList()) {
396+
return CompareGenerality(from_subs.AsList()->GetElement(),
397+
to_subs.AsList()->GetElement(),
398+
prospective_substitutions);
399+
}
400+
401+
if (from_subs.IsMap() && to_subs.IsMap()) {
402+
RelativeGenerality key_generality =
403+
CompareGenerality(from_subs.AsMap()->GetKey(),
404+
to_subs.AsMap()->GetKey(), prospective_substitutions);
405+
RelativeGenerality value_generality = CompareGenerality(
406+
from_subs.AsMap()->GetValue(), to_subs.AsMap()->GetValue(),
407+
prospective_substitutions);
408+
if (key_generality == RelativeGenerality::kLessGeneral ||
409+
value_generality == RelativeGenerality::kLessGeneral) {
410+
return RelativeGenerality::kLessGeneral;
411+
}
412+
if (key_generality == RelativeGenerality::kMoreGeneral ||
413+
value_generality == RelativeGenerality::kMoreGeneral) {
414+
return RelativeGenerality::kMoreGeneral;
415+
}
416+
return RelativeGenerality::kEquivalent;
417+
}
418+
419+
if (from_subs.IsOpaque() && to_subs.IsOpaque() &&
420+
from_subs.AsOpaque()->name() == to_subs.AsOpaque()->name() &&
421+
from_subs.AsOpaque()->GetParameters().size() ==
422+
to_subs.AsOpaque()->GetParameters().size()) {
423+
RelativeGenerality max_generality = RelativeGenerality::kEquivalent;
424+
for (int i = 0; i < from_subs.AsOpaque()->GetParameters().size(); ++i) {
425+
RelativeGenerality generality = CompareGenerality(
426+
from_subs.AsOpaque()->GetParameters()[i],
427+
to_subs.AsOpaque()->GetParameters()[i], prospective_substitutions);
428+
if (generality == RelativeGenerality::kLessGeneral) {
429+
return RelativeGenerality::kLessGeneral;
430+
}
431+
if (generality == RelativeGenerality::kMoreGeneral) {
432+
max_generality = RelativeGenerality::kMoreGeneral;
433+
}
434+
}
435+
return max_generality;
436+
}
437+
438+
// Default not comparable. Since we ruled out polytypes, they should be
439+
// equivalent for the purposes of deciding the most general eligible
440+
// substitution.
441+
return RelativeGenerality::kEquivalent;
442+
}
443+
367444
bool TypeInferenceContext::OccursWithin(
368445
absl::string_view var_name, const Type& type,
369446
const SubstitutionMap& substitutions) const {

checker/internal/type_inference_context.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,15 @@ class TypeInferenceContext {
160160
absl::string_view name;
161161
};
162162

163+
// Relative generality between two types.
164+
enum class RelativeGenerality {
165+
kMoreGeneral,
166+
// Note: kLessGeneral does not imply it is definitely more specific, only
167+
// that we cannot determine if equivalent or more general.
168+
kLessGeneral,
169+
kEquivalent,
170+
};
171+
163172
absl::string_view NewTypeVar(absl::string_view name = "") {
164173
next_type_parameter_id_++;
165174
auto inserted = type_parameter_bindings_.insert(
@@ -190,6 +199,16 @@ class TypeInferenceContext {
190199
bool IsAssignableWithConstraints(const Type& from, const Type& to,
191200
SubstitutionMap& prospective_substitutions);
192201

202+
// Relative generality of `from` as compared to `to` with the current type
203+
// substitutions and any additional prospective substitutions.
204+
//
205+
// Generality is only defined as a partial ordering. Some types are
206+
// incomparable. However we only need to know if a type is definitely more
207+
// general or not.
208+
RelativeGenerality CompareGenerality(
209+
const Type& from, const Type& to,
210+
const SubstitutionMap& prospective_substitutions) const;
211+
193212
Type Substitute(const Type& type, const SubstitutionMap& substitutions) const;
194213

195214
bool OccursWithin(absl::string_view var_name, const Type& type,

checker/internal/type_inference_context_test.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,66 @@ TEST(TypeInferenceContextTest, AssignabilityContext) {
737737
IsTypeKind(TypeKind::kIntWrapper));
738738
}
739739

740+
TEST(TypeInferenceContextTest, AssignabilityContextAbstractType) {
741+
google::protobuf::Arena arena;
742+
TypeInferenceContext context(&arena);
743+
744+
Type list_of_a = ListType(&arena, TypeParamType("A"));
745+
746+
Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);
747+
748+
{
749+
auto assignability_context = context.CreateAssignabilityContext();
750+
EXPECT_TRUE(assignability_context.IsAssignable(
751+
OptionalType(&arena, IntType()),
752+
list_of_a_instance.AsList()->GetElement()));
753+
EXPECT_TRUE(assignability_context.IsAssignable(
754+
OptionalType(&arena, DynType()),
755+
list_of_a_instance.AsList()->GetElement()));
756+
757+
assignability_context.UpdateInferredTypeAssignments();
758+
}
759+
Type resolved_type = context.FinalizeType(list_of_a_instance);
760+
761+
ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
762+
ASSERT_THAT(resolved_type.AsList()->GetElement(),
763+
IsTypeKind(TypeKind::kOpaque));
764+
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(),
765+
"optional_type");
766+
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(),
767+
ElementsAre(IsTypeKind(TypeKind::kDyn)));
768+
}
769+
770+
TEST(TypeInferenceContextTest, AssignabilityContextAbstractTypeWrapper) {
771+
google::protobuf::Arena arena;
772+
TypeInferenceContext context(&arena);
773+
774+
Type list_of_a = ListType(&arena, TypeParamType("A"));
775+
776+
Type list_of_a_instance = context.InstantiateTypeParams(list_of_a);
777+
778+
{
779+
auto assignability_context = context.CreateAssignabilityContext();
780+
EXPECT_TRUE(assignability_context.IsAssignable(
781+
OptionalType(&arena, IntType()),
782+
list_of_a_instance.AsList()->GetElement()));
783+
EXPECT_TRUE(assignability_context.IsAssignable(
784+
OptionalType(&arena, IntWrapperType()),
785+
list_of_a_instance.AsList()->GetElement()));
786+
787+
assignability_context.UpdateInferredTypeAssignments();
788+
}
789+
Type resolved_type = context.FinalizeType(list_of_a_instance);
790+
791+
ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList));
792+
ASSERT_THAT(resolved_type.AsList()->GetElement(),
793+
IsTypeKind(TypeKind::kOpaque));
794+
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(),
795+
"optional_type");
796+
EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(),
797+
ElementsAre(IsTypeKind(TypeKind::kIntWrapper)));
798+
}
799+
740800
TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) {
741801
google::protobuf::Arena arena;
742802
TypeInferenceContext context(&arena);

checker/optional_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ INSTANTIATE_TEST_SUITE_P(
227227
new AstType(ast_internal::PrimitiveType::kString)))))},
228228
TestCase{"['v1', ?'v2']", _,
229229
"expected type 'optional_type<string>' but found 'string'"},
230+
TestCase{"[optional.of(dyn('1')), optional.of('2')][0]",
231+
IsOptionalType(AstType(ast_internal::DynamicType()))},
232+
TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]",
233+
IsOptionalType(AstType(ast_internal::DynamicType()))},
234+
TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]",
235+
IsOptionalType(AstType(ast_internal::DynamicType()))},
236+
TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]",
237+
IsOptionalType(AstType(ast_internal::DynamicType()))},
238+
TestCase{"[optional.of('1'), optional.of(2)][0]",
239+
Eq(AstType(ast_internal::DynamicType()))},
240+
TestCase{"['v1', ?'v2']", _,
241+
"expected type 'optional_type<string>' but found 'string'"},
230242
TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: "
231243
"optional.of(1)}",
232244
Eq(AstType(ast_internal::MessageType(

0 commit comments

Comments
 (0)