Skip to content

Commit b77b374

Browse files
committed
Improve match subject inference
1 parent e089abc commit b77b374

File tree

4 files changed

+261
-37
lines changed

4 files changed

+261
-37
lines changed

mypy/checker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5673,7 +5673,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
56735673
# capture variable may depend on multiple patterns (it
56745674
# will be a union of all capture types). This pass ignores
56755675
# guard expressions.
5676-
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]
5676+
pattern_types = [
5677+
self.pattern_checker.accept(p, subject_type, [unwrapped_subject])
5678+
for p in s.patterns
5679+
]
56775680
type_maps: list[TypeMap] = [t.captures for t in pattern_types]
56785681
inferred_types = self.infer_variable_types_from_type_maps(type_maps)
56795682

@@ -5683,7 +5686,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
56835686
current_subject_type = self.expr_checker.narrow_type_from_binder(
56845687
named_subject, subject_type
56855688
)
5686-
pattern_type = self.pattern_checker.accept(p, current_subject_type)
5689+
pattern_type = self.pattern_checker.accept(
5690+
p, current_subject_type, [unwrapped_subject]
5691+
)
56875692
with self.binder.frame_context(can_skip=True, fall_through=2):
56885693
if b.is_unreachable or isinstance(
56895694
get_proper_type(pattern_type.type), UninhabitedType

mypy/checkpattern.py

Lines changed: 116 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,25 @@
1010
from mypy.checkmember import analyze_member_access
1111
from mypy.expandtype import expand_type_by_instance
1212
from mypy.join import join_types
13-
from mypy.literals import literal_hash
13+
from mypy.literals import Key, literal_hash
1414
from mypy.maptype import map_instance_to_supertype
1515
from mypy.meet import narrow_declared_type
1616
from mypy.messages import MessageBuilder
17-
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TempNode, TypeAlias, Var
17+
from mypy.nodes import (
18+
ARG_POS,
19+
Context,
20+
Expression,
21+
IndexExpr,
22+
IntExpr,
23+
ListExpr,
24+
MemberExpr,
25+
NameExpr,
26+
TempNode,
27+
TupleExpr,
28+
TypeAlias,
29+
UnaryExpr,
30+
Var,
31+
)
1832
from mypy.options import Options
1933
from mypy.patterns import (
2034
AsPattern,
@@ -98,10 +112,8 @@ class PatternChecker(PatternVisitor[PatternType]):
98112
msg: MessageBuilder
99113
# Currently unused
100114
plugin: Plugin
101-
# The expression being matched against the pattern
102-
subject: Expression
103-
104-
subject_type: Type
115+
# The expressions being matched against the (sub)pattern
116+
subject_context: list[list[Expression]]
105117
# Type of the subject to check the (sub)pattern against
106118
type_context: list[Type]
107119
# Types that match against self instead of their __match_args__ if used as a class pattern
@@ -120,24 +132,28 @@ def __init__(
120132
self.msg = msg
121133
self.plugin = plugin
122134

135+
self.subject_context = []
123136
self.type_context = []
124137
self.self_match_types = self.generate_types_from_names(self_match_type_names)
125138
self.non_sequence_match_types = self.generate_types_from_names(
126139
non_sequence_match_type_names
127140
)
128141
self.options = options
129142

130-
def accept(self, o: Pattern, type_context: Type) -> PatternType:
143+
def accept(self, o: Pattern, type_context: Type, subject: list[Expression]) -> PatternType:
144+
self.subject_context.append(subject)
131145
self.type_context.append(type_context)
132146
result = o.accept(self)
147+
self.subject_context.pop()
133148
self.type_context.pop()
134149

135150
return result
136151

137152
def visit_as_pattern(self, o: AsPattern) -> PatternType:
153+
current_subject = self.subject_context[-1]
138154
current_type = self.type_context[-1]
139155
if o.pattern is not None:
140-
pattern_type = self.accept(o.pattern, current_type)
156+
pattern_type = self.accept(o.pattern, current_type, current_subject)
141157
typ, rest_type, type_map = pattern_type
142158
else:
143159
typ, rest_type, type_map = current_type, UninhabitedType(), {}
@@ -152,14 +168,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
152168
return PatternType(typ, rest_type, type_map)
153169

154170
def visit_or_pattern(self, o: OrPattern) -> PatternType:
171+
current_subject = self.subject_context[-1]
155172
current_type = self.type_context[-1]
156173

157174
#
158175
# Check all the subpatterns
159176
#
160-
pattern_types = []
177+
pattern_types: list[PatternType] = []
161178
for pattern in o.patterns:
162-
pattern_type = self.accept(pattern, current_type)
179+
pattern_type = self.accept(pattern, current_type, current_subject)
163180
pattern_types.append(pattern_type)
164181
if not is_uninhabited(pattern_type.type):
165182
current_type = pattern_type.rest_type
@@ -175,28 +192,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
175192
#
176193
# Check the capture types
177194
#
178-
capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
195+
capture_types: dict[Var, dict[Key | None, list[tuple[Expression, Type]]]] = defaultdict(
196+
lambda: defaultdict(list)
197+
)
198+
capture_expr_keys: set[Key | None] = set()
179199
# Collect captures from the first subpattern
180200
for expr, typ in pattern_types[0].captures.items():
181-
node = get_var(expr)
182-
capture_types[node].append((expr, typ))
201+
if (node := get_var(expr)) is None:
202+
continue
203+
key = literal_hash(expr)
204+
capture_types[node][key].append((expr, typ))
205+
if isinstance(expr, NameExpr):
206+
capture_expr_keys.add(key)
183207

184208
# Check if other subpatterns capture the same names
185209
for i, pattern_type in enumerate(pattern_types[1:]):
186-
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
187-
if capture_types.keys() != vars:
210+
vars = {
211+
literal_hash(expr) for expr in pattern_type.captures if isinstance(expr, NameExpr)
212+
}
213+
if capture_expr_keys != vars:
214+
# Only fail for directly captured names (with NameExpr)
188215
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
189216
for expr, typ in pattern_type.captures.items():
190-
node = get_var(expr)
191-
capture_types[node].append((expr, typ))
217+
if (node := get_var(expr)) is None:
218+
continue
219+
key = literal_hash(expr)
220+
capture_types[node][key].append((expr, typ))
192221

193222
captures: dict[Expression, Type] = {}
194-
for capture_list in capture_types.values():
195-
typ = UninhabitedType()
196-
for _, other in capture_list:
197-
typ = make_simplified_union([typ, other])
223+
for expressions in capture_types.values():
224+
for key, capture_list in expressions.items():
225+
if other_types := [entry[1] for entry in capture_list]:
226+
typ = make_simplified_union(other_types)
227+
else:
228+
typ = UninhabitedType()
198229

199-
captures[capture_list[0][0]] = typ
230+
captures[capture_list[0][0]] = typ
200231

201232
union_type = make_simplified_union(types)
202233
return PatternType(union_type, current_type, captures)
@@ -289,12 +320,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
289320
contracted_inner_types = self.contract_starred_pattern_types(
290321
inner_types, star_position, required_patterns
291322
)
292-
for p, t in zip(o.patterns, contracted_inner_types):
293-
pattern_type = self.accept(p, t)
323+
current_subjects: list[list[Expression]] = [[] for _ in range(len(contracted_inner_types))]
324+
end_pos = len(contracted_inner_types) if star_position is None else star_position
325+
for subject in self.subject_context[-1]:
326+
if isinstance(subject, (ListExpr, TupleExpr)):
327+
# For list and tuple expressions, lookup expression in items
328+
for i in range(end_pos):
329+
if i < len(subject.items):
330+
current_subjects[i].append(subject.items[i])
331+
if star_position is not None:
332+
for i in range(star_position + 1, len(contracted_inner_types)):
333+
offset = len(contracted_inner_types) - i
334+
if offset <= len(subject.items):
335+
current_subjects[i].append(subject.items[-offset])
336+
else:
337+
# Support x[0], x[1], ... lookup until wildcard
338+
for i in range(end_pos):
339+
current_subjects[i].append(IndexExpr(subject, IntExpr(i)))
340+
# For everything after wildcard use x[-2], x[-1]
341+
for i in range((star_position or -1) + 1, len(contracted_inner_types)):
342+
offset = len(contracted_inner_types) - i
343+
current_subjects[i].append(IndexExpr(subject, UnaryExpr("-", IntExpr(offset))))
344+
for p, t, s in zip(o.patterns, contracted_inner_types, current_subjects):
345+
pattern_type = self.accept(p, t, s)
294346
typ, rest, type_map = pattern_type
295347
contracted_new_inner_types.append(typ)
296348
contracted_rest_inner_types.append(rest)
297349
self.update_type_map(captures, type_map)
350+
if s:
351+
self.update_type_map(
352+
captures, {subject: typ for subject in s}, fail_multiple_assignments=False
353+
)
298354

299355
new_inner_types = self.expand_starred_pattern_types(
300356
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
@@ -478,11 +534,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
478534
if inner_type is None:
479535
can_match = False
480536
inner_type = self.chk.named_type("builtins.object")
481-
pattern_type = self.accept(value, inner_type)
537+
current_subjects: list[Expression] = [
538+
IndexExpr(s, key) for s in self.subject_context[-1]
539+
]
540+
pattern_type = self.accept(value, inner_type, current_subjects)
482541
if is_uninhabited(pattern_type.type):
483542
can_match = False
484543
else:
485544
self.update_type_map(captures, pattern_type.captures)
545+
if current_subjects:
546+
self.update_type_map(
547+
captures, {subject: pattern_type.type for subject in current_subjects}
548+
)
486549

487550
if o.rest is not None:
488551
mapping = self.chk.named_type("typing.Mapping")
@@ -590,7 +653,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
590653
if self.should_self_match(typ):
591654
if len(o.positionals) > 1:
592655
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
593-
pattern_type = self.accept(o.positionals[0], narrowed_type)
656+
pattern_type = self.accept(o.positionals[0], narrowed_type, [])
594657
if not is_uninhabited(pattern_type.type):
595658
return PatternType(
596659
pattern_type.type,
@@ -690,11 +753,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
690753
elif keyword is not None:
691754
new_type = self.chk.add_any_attribute_to_type(new_type, keyword)
692755

693-
inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
756+
current_subjects: list[Expression] = []
757+
if keyword is not None:
758+
current_subjects = [MemberExpr(s, keyword) for s in self.subject_context[-1]]
759+
inner_type, inner_rest_type, inner_captures = self.accept(
760+
pattern, key_type, current_subjects
761+
)
694762
if is_uninhabited(inner_type):
695763
can_match = False
696764
else:
697765
self.update_type_map(captures, inner_captures)
766+
if current_subjects:
767+
self.update_type_map(
768+
captures, {subject: inner_type for subject in current_subjects}
769+
)
698770
if not is_uninhabited(inner_rest_type):
699771
rest_type = current_type
700772

@@ -743,17 +815,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
743815
return types
744816

745817
def update_type_map(
746-
self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
818+
self,
819+
original_type_map: dict[Expression, Type],
820+
extra_type_map: dict[Expression, Type],
821+
fail_multiple_assignments: bool = True,
747822
) -> None:
748823
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
749824
# expressions, as suggested in the TODO above it's definition
750825
already_captured = {literal_hash(expr) for expr in original_type_map}
751826
for expr, typ in extra_type_map.items():
752827
if literal_hash(expr) in already_captured:
753-
node = get_var(expr)
754-
self.msg.fail(
755-
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
756-
)
828+
if (node := get_var(expr)) is None:
829+
continue
830+
if fail_multiple_assignments:
831+
self.msg.fail(
832+
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
833+
)
757834
else:
758835
original_type_map[expr] = typ
759836

@@ -805,12 +882,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
805882
return args
806883

807884

808-
def get_var(expr: Expression) -> Var:
885+
def get_var(expr: Expression) -> Var | None:
809886
"""
810887
Warning: this in only true for expressions captured by a match statement.
811888
Don't call it from anywhere else
812889
"""
813-
assert isinstance(expr, NameExpr), expr
890+
if isinstance(expr, MemberExpr):
891+
return get_var(expr.expr)
892+
if isinstance(expr, IndexExpr):
893+
return get_var(expr.base)
894+
if not isinstance(expr, NameExpr):
895+
return None
814896
node = expr.node
815897
assert isinstance(node, Var), node
816898
return node

mypy/literals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def visit_set_expr(self, e: SetExpr) -> Key | None:
228228
return self.seq_expr(e, "Set")
229229

230230
def visit_index_expr(self, e: IndexExpr) -> Key | None:
231-
if literal(e.index) == LITERAL_YES:
231+
if literal(e.index) != LITERAL_NO:
232232
return ("Index", literal_hash(e.base), literal_hash(e.index))
233233
return None
234234

0 commit comments

Comments
 (0)